Skip to content

Commit 33c9ec8

Browse files
committed
add different interval support
1 parent ebedd45 commit 33c9ec8

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

usr/diff/shallow_diffusion_tts.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,20 @@ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
166166
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
167167

168168
@torch.no_grad()
169-
def p_sample_plms(self, x, t, cond, clip_denoised=True, repeat_noise=False):
169+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
170170
"""
171171
Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
172172
"""
173173

174-
def get_x_pred(x, e_t, t):
174+
def get_x_pred(x, noise_t, t):
175175
a_t = extract(self.alphas_cumprod, t, x.shape)
176-
a_prev = extract(self.alphas_cumprod_prev, t, x.shape)
176+
if t[0] < interval:
177+
a_prev = torch.ones_like(a_t)
178+
else:
179+
a_prev = extract(self.alphas_cumprod, t - interval, x.shape)
177180
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
178181

179-
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * e_t)
182+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
180183
x_pred = x + x_delta
181184

182185
return x_pred
@@ -186,7 +189,7 @@ def get_x_pred(x, e_t, t):
186189

187190
if len(noise_list) == 0:
188191
x_pred = get_x_pred(x, noise_pred, t)
189-
noise_pred_prev = self.denoise_fn(x_pred, t + 1, cond=cond)
192+
noise_pred_prev = self.denoise_fn(x_pred, torch.max(t-interval, torch.zeros_like(t)), cond=cond)
190193
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
191194
elif len(noise_list) == 1:
192195
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
@@ -255,8 +258,9 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
255258
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
256259
x = torch.randn(shape, device=device)
257260
self.noise_list = deque(maxlen=4)
258-
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
259-
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
261+
iteration_interval = 5
262+
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step', total=t):
263+
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval, cond)
260264
x = x[:, 0].transpose(1, 2)
261265
if mel2ph is not None: # for singing
262266
ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])

0 commit comments

Comments
 (0)