We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 161177e commit 6c4f0c9Copy full SHA for 6c4f0c9
1 file changed
src/models/geom_mol_gen_ddpm.py
@@ -737,6 +737,7 @@ def sample_and_save(
737
num_samples: int,
738
node_mask: Optional[TensorType["batch_num_nodes"]] = None,
739
context: Optional[TensorType["batch_size", "num_context_features"]] = None,
740
+ num_timesteps: Optional[int] = None,
741
id_from: int = 0,
742
name: str = "molecule"
743
):
@@ -758,7 +759,8 @@ def sample_and_save(
758
759
num_nodes=num_nodes,
760
node_mask=node_mask,
761
context=context,
- device=self.device
762
+ device=self.device,
763
+ num_timesteps=num_timesteps
764
)
765
766
x = xh[:, :self.num_x_dims]
0 commit comments