11import math
22import random
3+ from collections import deque
34from functools import partial
45from inspect import isfunction
56from pathlib import Path
1516from utils .hparams import hparams
1617
1718
18-
1919def exists (x ):
2020 return x is not None
2121
@@ -69,7 +69,8 @@ def cosine_beta_schedule(timesteps, s=0.008):
6969
7070class GaussianDiffusion (nn .Module ):
7171 def __init__ (self , phone_encoder , out_dims , denoise_fn ,
72- timesteps = 1000 , K_step = 1000 , loss_type = hparams .get ('diff_loss_type' , 'l1' ), betas = None , spec_min = None , spec_max = None ):
72+ timesteps = 1000 , K_step = 1000 , loss_type = hparams .get ('diff_loss_type' , 'l1' ), betas = None , spec_min = None ,
73+ spec_max = None ):
7374 super ().__init__ ()
7475 self .denoise_fn = denoise_fn
7576 if hparams .get ('use_midi' ) is not None and hparams ['use_midi' ]:
@@ -95,6 +96,8 @@ def __init__(self, phone_encoder, out_dims, denoise_fn,
9596 self .K_step = K_step
9697 self .loss_type = loss_type
9798
99+ self .noise_list = deque (maxlen = 4 )
100+
98101 to_torch = partial (torch .tensor , dtype = torch .float32 )
99102
100103 self .register_buffer ('betas' , to_torch (betas ))
@@ -162,6 +165,44 @@ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
162165 nonzero_mask = (1 - (t == 0 ).float ()).reshape (b , * ((1 ,) * (len (x .shape ) - 1 )))
163166 return model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
164167
168+ @torch .no_grad ()
169+ def p_sample_plms (self , x , t , interval , cond , clip_denoised = True , repeat_noise = False ):
170+ """
171+ Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
172+ """
173+
174+ def get_x_pred (x , noise_t , t ):
175+ a_t = extract (self .alphas_cumprod , t , x .shape )
176+ if t [0 ] < interval :
177+ a_prev = torch .ones_like (a_t )
178+ else :
179+ a_prev = extract (self .alphas_cumprod , t - interval , x .shape )
180+ a_t_sq , a_prev_sq = a_t .sqrt (), a_prev .sqrt ()
181+
182+ x_delta = (a_prev - a_t ) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq ))) * x - 1 / (a_t_sq * (((1 - a_prev ) * a_t ).sqrt () + ((1 - a_t ) * a_prev ).sqrt ())) * noise_t )
183+ x_pred = x + x_delta
184+
185+ return x_pred
186+
187+ noise_list = self .noise_list
188+ noise_pred = self .denoise_fn (x , t , cond = cond )
189+
190+ if len (noise_list ) == 0 :
191+ x_pred = get_x_pred (x , noise_pred , t )
192+ noise_pred_prev = self .denoise_fn (x_pred , torch .max (t - interval , torch .zeros_like (t )), cond = cond )
193+ noise_pred_prime = (noise_pred + noise_pred_prev ) / 2
194+ elif len (noise_list ) == 1 :
195+ noise_pred_prime = (3 * noise_pred - noise_list [- 1 ]) / 2
196+ elif len (noise_list ) == 2 :
197+ noise_pred_prime = (23 * noise_pred - 16 * noise_list [- 1 ] + 5 * noise_list [- 2 ]) / 12
198+ elif len (noise_list ) >= 3 :
199+ noise_pred_prime = (55 * noise_pred - 59 * noise_list [- 1 ] + 37 * noise_list [- 2 ] - 9 * noise_list [- 3 ]) / 24
200+
201+ x_prev = get_x_pred (x , noise_pred_prime , t )
202+ noise_list .append (noise_pred )
203+
204+ return x_prev
205+
165206 def q_sample (self , x_start , t , noise = None ):
166207 noise = default (noise , lambda : torch .randn_like (x_start ))
167208 return (
@@ -216,8 +257,10 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
216257 print ('===> gaussion start.' )
217258 shape = (cond .shape [0 ], 1 , self .mel_bins , cond .shape [2 ])
218259 x = torch .randn (shape , device = device )
219- for i in tqdm (reversed (range (0 , t )), desc = 'sample time step' , total = t ):
220- x = self .p_sample (x , torch .full ((b ,), i , device = device , dtype = torch .long ), cond )
260+ self .noise_list = deque (maxlen = 4 )
261+ iteration_interval = 5
262+ for i in tqdm (reversed (range (0 , t , iteration_interval )), desc = 'sample time step' , total = t ):
263+ x = self .p_sample_plms (x , torch .full ((b ,), i , device = device , dtype = torch .long ), iteration_interval , cond )
221264 x = x [:, 0 ].transpose (1 , 2 )
222265 if mel2ph is not None : # for singing
223266 ret ['mel_out' ] = self .denorm_spec (x ) * ((mel2ph > 0 ).float ()[:, :, None ])
@@ -233,7 +276,7 @@ def denorm_spec(self, x):
233276
234277 def cwt2f0_norm (self , cwt_spec , mean , std , mel2ph ):
235278 return self .fs2 .cwt2f0_norm (cwt_spec , mean , std , mel2ph )
236-
279+
237280 def out2mel (self , x ):
238281 return x
239282
@@ -270,4 +313,4 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
270313 x = self .p_sample (x , torch .full ((b ,), i , device = device , dtype = torch .long ), cond )
271314 x = x [:, 0 ].transpose (1 , 2 )
272315 ret ['mel_out' ] = self .denorm_spec (x )
273- return ret
316+ return ret
0 commit comments