@@ -135,13 +135,31 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
135135 decoder_inp_origin = decoder_inp = torch .gather (decoder_inp , 1 , mel2ph_ ) # [B, T, H]
136136
137137 tgt_nonpadding = (mel2ph > 0 ).float ()[:, :, None ]
138-
138+
139139 # add pitch and energy embed
140140 pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0 ) * tgt_nonpadding
141+ nframes = mel2ph .size (1 )
142+
141143 if hparams ['use_pitch_embed' ]:
142144 pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0 ) * src_nonpadding
145+ if f0 is not None :
146+ delta_l = nframes - f0 .size (1 )
147+ if delta_l > 0 :
148+ f0 = torch .cat ((f0 ,torch .FloatTensor ([[x [- 1 ]] * delta_l for x in f0 ]).to (f0 .get_device ())),1 )
149+ f0 = f0 [:,:nframes ]
150+ if uv is not None :
151+ delta_l = nframes - uv .size (1 )
152+ if delta_l > 0 :
153+ uv = torch .cat ((uv ,torch .FloatTensor ([[x [- 1 ]] * delta_l for x in uv ]).to (uv .get_device ())),1 )
154+ uv = uv [:,:nframes ]
143155 decoder_inp = decoder_inp + self .add_pitch (pitch_inp , f0 , uv , mel2ph , ret , encoder_out = pitch_inp_ph )
156+
144157 if hparams ['use_energy_embed' ]:
158+ if energy is not None :
159+ delta_l = nframes - energy .size (1 )
160+ if delta_l > 0 :
161+ energy = torch .cat ((energy ,torch .FloatTensor ([[x [- 1 ]] * delta_l for x in energy ]).to (energy .get_device ())),1 )
162+ energy = energy [:,:nframes ]
145163 decoder_inp = decoder_inp + self .add_energy (pitch_inp , energy , ret )
146164
147165 ret ['decoder_inp' ] = decoder_inp = (decoder_inp + spk_embed ) * tgt_nonpadding
0 commit comments