@@ -1290,7 +1290,8 @@ def mol_gen_sample(
12901290 context : Optional [TensorType ["batch_size" , "num_context_features" ]] = None ,
12911291 fix_noise : bool = False ,
12921292 generate_x_only : bool = False ,
1293- fix_self_conditioning_noise : bool = False
1293+ fix_self_conditioning_noise : bool = False ,
1294+ norm_with_original_timesteps : bool = False ,
12941295 ) -> Tuple [
12951296 Union [
12961297 TensorType ["batch_num_nodes" , "num_x_dims_plus_num_node_scalar_features" ],
@@ -1329,13 +1330,13 @@ def mol_gen_sample(
13291330
13301331 # iteratively sample p(z_s | z_t) for `t = 1, ..., T`, with `s = t - 1`.
13311332 self_cond = None
1332- s_array_self_cond = torch .full ((num_samples , 1 ), fill_value = 0 , device = device ) / num_timesteps
1333+ s_array_self_cond = torch .full ((num_samples , 1 ), fill_value = 0 , device = device ) / ( self . T if norm_with_original_timesteps else num_timesteps )
13331334 out = torch .zeros ((return_frames ,) + z .size (), device = device )
13341335 for s in reversed (range (0 , num_timesteps )):
13351336 s_array = torch .full ((num_samples , 1 ), fill_value = s , device = device )
13361337 t_array = s_array + 1
1337- s_array = s_array / num_timesteps
1338- t_array = t_array / num_timesteps
1338+ s_array = s_array / ( self . T if norm_with_original_timesteps else num_timesteps )
1339+ t_array = t_array / ( self . T if norm_with_original_timesteps else num_timesteps )
13391340
13401341 z = self .sample_p_zs_given_zt (
13411342 s = s_array ,
0 commit comments