|
6 | 6 | from utils.hparams import hparams |
7 | 7 | from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0 |
8 | 8 | from modules.fastspeech.fs2 import FastSpeech2 |
| 9 | +from utils.text_encoder import TokenTextEncoder |
| 10 | +from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU |
| 11 | +from torch.nn import functional as F |
| 12 | +import torch |
9 | 13 |
|
10 | 14 |
|
11 | 15 | class FastspeechMIDIEncoder(FastspeechEncoder): |
@@ -47,10 +51,41 @@ class FastSpeech2MIDI(FastSpeech2): |
47 | 51 | def __init__(self, dictionary, out_dims=None): |
48 | 52 | super().__init__(dictionary, out_dims) |
49 | 53 | del self.encoder |
| 54 | + |
50 | 55 | self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary) |
51 | 56 | self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx) |
52 | 57 | self.midi_dur_layer = Linear(1, self.hidden_size) |
53 | 58 | self.is_slur_embed = Embedding(2, self.hidden_size) |
| 59 | + yunmu = ['AP', 'SP'] + ALL_YUNMU |
| 60 | + yunmu.remove('ng') |
| 61 | + self.vowel_tokens = [dictionary.encode(ph)[0] for ph in yunmu] |
| 62 | + |
| 63 | + def add_dur(self, dur_input, mel2ph, txt_tokens, ret, midi_dur = None): |
| 64 | + src_padding = txt_tokens == 0 |
| 65 | + dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach()) |
| 66 | + if mel2ph is None: |
| 67 | + dur, xs = self.dur_predictor.inference(dur_input, src_padding) |
| 68 | + ret['dur'] = xs |
| 69 | + dur = xs.squeeze(-1).exp() - 1.0 |
| 70 | + for i in range(len(dur)): |
| 71 | + for j in range(len(dur[i])): |
| 72 | + if txt_tokens[i,j] in self.vowel_tokens: |
| 73 | + if j < len(dur[i])-1 and txt_tokens[i,j+1] not in self.vowel_tokens: |
| 74 | + dur[i,j] = midi_dur[i,j] - dur[i,j+1] |
| 75 | + if dur[i,j] < 0: |
| 76 | + dur[i,j] = 0 |
| 77 | + dur[i,j+1] = midi_dur[i,j] |
| 78 | + else: |
| 79 | + dur[i,j]=midi_dur[i,j] |
| 80 | + dur[:,0] = dur[:,0] + 0.5 |
| 81 | + dur_acc = F.pad(torch.round(torch.cumsum(dur, axis=1)), (1,0)) |
| 82 | + dur = torch.clamp(dur_acc[:,1:]-dur_acc[:,:-1], min=0).long() |
| 83 | + ret['dur_choice'] = dur |
| 84 | + mel2ph = self.length_regulator(dur, src_padding).detach() |
| 85 | + else: |
| 86 | + ret['dur'] = self.dur_predictor(dur_input, src_padding) |
| 87 | + ret['mel2ph'] = mel2ph |
| 88 | + return mel2ph |
54 | 89 |
|
55 | 90 | def forward(self, txt_tokens, mel2ph=None, spk_embed=None, |
56 | 91 | ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False, |
@@ -92,7 +127,7 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None, |
92 | 127 | # add dur |
93 | 128 | dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding |
94 | 129 |
|
95 | | - mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) |
| 130 | + mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret, midi_dur=kwargs['midi_dur']*hparams['audio_sample_rate']/hparams['hop_size']) |
96 | 131 |
|
97 | 132 | decoder_inp = F.pad(encoder_out, [0, 0, 1, 0]) |
98 | 133 |
|
|
0 commit comments