Skip to content

Commit 161177e

Browse files
authored
Update qm9_mol_gen_ddpm.py
1 parent ce235be commit 161177e

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/models/qm9_mol_gen_ddpm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,9 +863,11 @@ def sample_and_save(
863863
num_nodes: Optional[TensorType["batch_size"]] = None,
864864
node_mask: Optional[TensorType["batch_num_nodes"]] = None,
865865
context: Optional[TensorType["batch_size", "num_context_features"]] = None,
866+
num_timesteps: Optional[int] = None,
866867
id_from: int = 0,
867868
name: str = "molecule",
868-
sampling_output_dir: Optional[Path] = None
869+
sampling_output_dir: Optional[Path] = None,
870+
norm_with_original_timesteps: bool = False,
869871
):
870872
# node count-conditioning
871873
if num_nodes is None:
@@ -890,7 +892,9 @@ def sample_and_save(
890892
num_nodes=num_nodes,
891893
node_mask=node_mask,
892894
context=context,
893-
device=self.device
895+
device=self.device,
896+
num_timesteps=num_timesteps,
897+
norm_with_original_timesteps=norm_with_original_timesteps,
894898
)
895899

896900
x = xh[:, :self.num_x_dims]

0 commit comments

Comments
 (0)