Skip to content

Commit a977f60

Browse files
Merge pull request openvpi#69 from luping-liu/master
add the PLMS method
2 parents 303c52e + 33c9ec8 commit a977f60

1 file changed

Lines changed: 49 additions & 6 deletions

File tree

usr/diff/shallow_diffusion_tts.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import random
3+
from collections import deque
34
from functools import partial
45
from inspect import isfunction
56
from pathlib import Path
@@ -15,7 +16,6 @@
1516
from utils.hparams import hparams
1617

1718

18-
1919
def exists(x):
2020
return x is not None
2121

@@ -69,7 +69,8 @@ def cosine_beta_schedule(timesteps, s=0.008):
6969

7070
class GaussianDiffusion(nn.Module):
7171
def __init__(self, phone_encoder, out_dims, denoise_fn,
72-
timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
72+
timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None,
73+
spec_max=None):
7374
super().__init__()
7475
self.denoise_fn = denoise_fn
7576
if hparams.get('use_midi') is not None and hparams['use_midi']:
@@ -95,6 +96,8 @@ def __init__(self, phone_encoder, out_dims, denoise_fn,
9596
self.K_step = K_step
9697
self.loss_type = loss_type
9798

99+
self.noise_list = deque(maxlen=4)
100+
98101
to_torch = partial(torch.tensor, dtype=torch.float32)
99102

100103
self.register_buffer('betas', to_torch(betas))
@@ -162,6 +165,44 @@ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
162165
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
163166
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
164167

168+
@torch.no_grad()
169+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
170+
"""
171+
Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
172+
"""
173+
174+
def get_x_pred(x, noise_t, t):
175+
a_t = extract(self.alphas_cumprod, t, x.shape)
176+
if t[0] < interval:
177+
a_prev = torch.ones_like(a_t)
178+
else:
179+
a_prev = extract(self.alphas_cumprod, t - interval, x.shape)
180+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
181+
182+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
183+
x_pred = x + x_delta
184+
185+
return x_pred
186+
187+
noise_list = self.noise_list
188+
noise_pred = self.denoise_fn(x, t, cond=cond)
189+
190+
if len(noise_list) == 0:
191+
x_pred = get_x_pred(x, noise_pred, t)
192+
noise_pred_prev = self.denoise_fn(x_pred, torch.max(t-interval, torch.zeros_like(t)), cond=cond)
193+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
194+
elif len(noise_list) == 1:
195+
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
196+
elif len(noise_list) == 2:
197+
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
198+
elif len(noise_list) >= 3:
199+
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
200+
201+
x_prev = get_x_pred(x, noise_pred_prime, t)
202+
noise_list.append(noise_pred)
203+
204+
return x_prev
205+
165206
def q_sample(self, x_start, t, noise=None):
166207
noise = default(noise, lambda: torch.randn_like(x_start))
167208
return (
@@ -216,8 +257,10 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
216257
print('===> gaussion start.')
217258
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
218259
x = torch.randn(shape, device=device)
219-
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
220-
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
260+
self.noise_list = deque(maxlen=4)
261+
iteration_interval = 5
262+
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step', total=t):
263+
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval, cond)
221264
x = x[:, 0].transpose(1, 2)
222265
if mel2ph is not None: # for singing
223266
ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
@@ -233,7 +276,7 @@ def denorm_spec(self, x):
233276

234277
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
235278
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
236-
279+
237280
def out2mel(self, x):
238281
return x
239282

@@ -270,4 +313,4 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
270313
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
271314
x = x[:, 0].transpose(1, 2)
272315
ret['mel_out'] = self.denorm_spec(x)
273-
return ret
316+
return ret

0 commit comments

Comments
 (0)