Skip to content

Commit 75a6828

Browse files
authored
vowel-alignment algorithm
1 parent 8ecbda4 commit 75a6828

1 file changed

Lines changed: 36 additions & 1 deletion

File tree

modules/diffsinger_midi/fs2.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from utils.hparams import hparams
77
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
88
from modules.fastspeech.fs2 import FastSpeech2
9+
from utils.text_encoder import TokenTextEncoder
10+
from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
11+
from torch.nn import functional as F
12+
import torch
913

1014

1115
class FastspeechMIDIEncoder(FastspeechEncoder):
@@ -47,10 +51,41 @@ class FastSpeech2MIDI(FastSpeech2):
4751
def __init__(self, dictionary, out_dims=None):
4852
super().__init__(dictionary, out_dims)
4953
del self.encoder
54+
5055
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
5156
self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
5257
self.midi_dur_layer = Linear(1, self.hidden_size)
5358
self.is_slur_embed = Embedding(2, self.hidden_size)
59+
yunmu = ['AP', 'SP'] + ALL_YUNMU
60+
yunmu.remove('ng')
61+
self.vowel_tokens = [dictionary.encode(ph)[0] for ph in yunmu]
62+
63+
def add_dur(self, dur_input, mel2ph, txt_tokens, ret, midi_dur = None):
64+
src_padding = txt_tokens == 0
65+
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
66+
if mel2ph is None:
67+
dur, xs = self.dur_predictor.inference(dur_input, src_padding)
68+
ret['dur'] = xs
69+
dur = xs.squeeze(-1).exp() - 1.0
70+
for i in range(len(dur)):
71+
for j in range(len(dur[i])):
72+
if txt_tokens[i,j] in self.vowel_tokens:
73+
if j < len(dur[i])-1 and txt_tokens[i,j+1] not in self.vowel_tokens:
74+
dur[i,j] = midi_dur[i,j] - dur[i,j+1]
75+
if dur[i,j] < 0:
76+
dur[i,j] = 0
77+
dur[i,j+1] = midi_dur[i,j]
78+
else:
79+
dur[i,j]=midi_dur[i,j]
80+
dur[:,0] = dur[:,0] + 0.5
81+
dur_acc = F.pad(torch.round(torch.cumsum(dur, axis=1)), (1,0))
82+
dur = torch.clamp(dur_acc[:,1:]-dur_acc[:,:-1], min=0).long()
83+
ret['dur_choice'] = dur
84+
mel2ph = self.length_regulator(dur, src_padding).detach()
85+
else:
86+
ret['dur'] = self.dur_predictor(dur_input, src_padding)
87+
ret['mel2ph'] = mel2ph
88+
return mel2ph
5489

5590
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
5691
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
@@ -92,7 +127,7 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
92127
# add dur
93128
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
94129

95-
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
130+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret, midi_dur=kwargs['midi_dur']*hparams['audio_sample_rate']/hparams['hop_size'])
96131

97132
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
98133

0 commit comments

Comments
 (0)