Skip to content

Commit c912c76

Browse files
authored
Update variational_diffusion.py
1 parent 6c4f0c9 commit c912c76

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/models/components/variational_diffusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,8 @@ def mol_gen_sample(
12901290
context: Optional[TensorType["batch_size", "num_context_features"]] = None,
12911291
fix_noise: bool = False,
12921292
generate_x_only: bool = False,
1293-
fix_self_conditioning_noise: bool = False
1293+
fix_self_conditioning_noise: bool = False,
1294+
norm_with_original_timesteps: bool = False,
12941295
) -> Tuple[
12951296
Union[
12961297
TensorType["batch_num_nodes", "num_x_dims_plus_num_node_scalar_features"],
@@ -1329,13 +1330,13 @@ def mol_gen_sample(
13291330

13301331
# iteratively sample p(z_s | z_t) for `t = 1, ..., T`, with `s = t - 1`.
13311332
self_cond = None
1332-
s_array_self_cond = torch.full((num_samples, 1), fill_value=0, device=device) / num_timesteps
1333+
s_array_self_cond = torch.full((num_samples, 1), fill_value=0, device=device) / (self.T if norm_with_original_timesteps else num_timesteps)
13331334
out = torch.zeros((return_frames,) + z.size(), device=device)
13341335
for s in reversed(range(0, num_timesteps)):
13351336
s_array = torch.full((num_samples, 1), fill_value=s, device=device)
13361337
t_array = s_array + 1
1337-
s_array = s_array / num_timesteps
1338-
t_array = t_array / num_timesteps
1338+
s_array = s_array / (self.T if norm_with_original_timesteps else num_timesteps)
1339+
t_array = t_array / (self.T if norm_with_original_timesteps else num_timesteps)
13391340

13401341
z = self.sample_p_zs_given_zt(
13411342
s=s_array,

0 commit comments

Comments
 (0)