@@ -166,17 +166,20 @@ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
166166 return model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
167167
168168 @torch .no_grad ()
169- def p_sample_plms (self , x , t , cond , clip_denoised = True , repeat_noise = False ):
169+ def p_sample_plms (self , x , t , interval , cond , clip_denoised = True , repeat_noise = False ):
170170 """
171171 Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
172172 """
173173
174- def get_x_pred (x , e_t , t ):
174+ def get_x_pred (x , noise_t , t ):
175175 a_t = extract (self .alphas_cumprod , t , x .shape )
176- a_prev = extract (self .alphas_cumprod_prev , t , x .shape )
176+ if t [0 ] < interval :
177+ a_prev = torch .ones_like (a_t )
178+ else :
179+ a_prev = extract (self .alphas_cumprod , t - interval , x .shape )
177180 a_t_sq , a_prev_sq = a_t .sqrt (), a_prev .sqrt ()
178181
179- x_delta = (a_prev - a_t ) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq ))) * x - 1 / (a_t_sq * (((1 - a_prev ) * a_t ).sqrt () + ((1 - a_t ) * a_prev ).sqrt ())) * e_t )
182+ x_delta = (a_prev - a_t ) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq ))) * x - 1 / (a_t_sq * (((1 - a_prev ) * a_t ).sqrt () + ((1 - a_t ) * a_prev ).sqrt ())) * noise_t )
180183 x_pred = x + x_delta
181184
182185 return x_pred
@@ -186,7 +189,7 @@ def get_x_pred(x, e_t, t):
186189
187190 if len (noise_list ) == 0 :
188191 x_pred = get_x_pred (x , noise_pred , t )
189- noise_pred_prev = self .denoise_fn (x_pred , t + 1 , cond = cond )
192+ noise_pred_prev = self .denoise_fn (x_pred , torch . max ( t - interval , torch . zeros_like ( t )) , cond = cond )
190193 noise_pred_prime = (noise_pred + noise_pred_prev ) / 2
191194 elif len (noise_list ) == 1 :
192195 noise_pred_prime = (3 * noise_pred - noise_list [- 1 ]) / 2
@@ -255,8 +258,9 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
255258 shape = (cond .shape [0 ], 1 , self .mel_bins , cond .shape [2 ])
256259 x = torch .randn (shape , device = device )
257260 self .noise_list = deque (maxlen = 4 )
258- for i in tqdm (reversed (range (0 , t )), desc = 'sample time step' , total = t ):
259- x = self .p_sample_plms (x , torch .full ((b ,), i , device = device , dtype = torch .long ), cond )
261+ iteration_interval = 5
262+ for i in tqdm (reversed (range (0 , t , iteration_interval )), desc = 'sample time step' , total = t ):
263+ x = self .p_sample_plms (x , torch .full ((b ,), i , device = device , dtype = torch .long ), iteration_interval , cond )
260264 x = x [:, 0 ].transpose (1 , 2 )
261265 if mel2ph is not None : # for singing
262266 ret ['mel_out' ] = self .denorm_spec (x ) * ((mel2ph > 0 ).float ()[:, :, None ])
0 commit comments