99
1010import torch .nn as nn
1111
12- from omegaconf import DictConfig
12+ from omegaconf import DictConfig , open_dict
1313from pathlib import Path
1414from pytorch_lightning import LightningDataModule , LightningModule
1515from typing import Any , Dict , List , Optional , Tuple , Union
2424from src .datamodules .components .edm import check_molecular_stability , get_bond_length_arrays
2525from src .datamodules .components .edm .datasets_config import QM9_WITH_H , QM9_WITHOUT_H
2626from src .models import PropertiesDistribution , compute_mean_mad
27- from src .models .components import load_molecule_xyz
27+ from src .models .components import load_molecule_xyz , save_xyz_file
2828from src .utils .pylogger import get_pylogger
2929
3030from src import LR_SCHEDULER_MANUAL_INTERPOLATION_HELPER_CONFIG_ITEMS , LR_SCHEDULER_MANUAL_INTERPOLATION_PRIMARY_CONFIG_ITEMS , get_classifier , test_with_property_classifier , utils
3434
3535patch_typeguard () # use before @typechecked
3636
37+ import lovely_tensors as lt
38+ lt .monkey_patch ()
39+
3740
3841QM9_OPTIMIZATION_NUM_NODES = 19
3942
@@ -81,7 +84,9 @@ def __init__(
8184 iterations : int = 200 ,
8285 num_optimization_timesteps : int = 10 ,
8386 return_frames : int = 1 ,
84- unknown_labels : bool = False
87+ experiment_name : str = "conditional_diffusion" ,
88+ unknown_labels : bool = False ,
89+ save_molecules : bool = False ,
8590 ):
8691 assert iterations > 0 , \
8792 "Optimization requires at least two iterations, " \
@@ -95,9 +100,11 @@ def __init__(
95100 self .device = device
96101 self .dataset_info = dataset_info
97102 self .iterations = iterations
103+ self .experiment_name = experiment_name
98104 self .num_optimization_timesteps = num_optimization_timesteps
99105 self .return_frames = return_frames
100106 self .unknown_labels = unknown_labels
107+ self .save_molecules = save_molecules
101108
102109 self .samples = self .load_pregenerated_samples (sampling_output_dir )
103110 self .context = self .props_distr .sample_batch (num_nodes ).to (device ) # fix context so we can compare between iterations
@@ -126,7 +133,8 @@ def load_pregenerated_samples(
126133 def optimize_pregenerated_samples (
127134 self ,
128135 score_initial_samples : bool = False ,
129- analyze_initial_samples_stability : bool = True
136+ analyze_initial_samples_stability : bool = True ,
137+ id_from : int = 0
130138 ) -> Dict [str , Any ]:
131139 if score_initial_samples :
132140 # evaluate stability of initial molecules
@@ -150,17 +158,24 @@ def optimize_pregenerated_samples(
150158 )
151159 else :
152160 # optimize initial molecules
153- x , one_hot , _ , batch_index = self .model .optimize (
161+ x , one_hot , charges , batch_index = self .model .optimize (
154162 samples = self .samples ,
155163 num_nodes = self .num_nodes ,
156164 context = self .context ,
157165 num_timesteps = self .num_optimization_timesteps ,
158166 sampling_output_dir = self .sampling_output_dir ,
159167 optim_property = self .optim_property ,
160168 iteration_index = self .i - 1 ,
161- return_frames = self .return_frames
169+ return_frames = self .return_frames ,
170+ norm_with_original_timesteps = False , # NOTE: this is important to ensure the samples are "fully optimized" each iteration
162171 )
163172
173+ # iteratively update samples with optimized molecule features
174+ self .samples = [
175+ (x [(batch_index == sample_idx )].cpu (), one_hot [(batch_index == sample_idx )].cpu ())
176+ for sample_idx in range (self .num_samples )
177+ ]
178+
164179 # evaluate stability of optimized molecules
165180 num_mols_stable = 0
166181 for sample_idx in range (self .num_samples ):
@@ -177,6 +192,19 @@ def optimize_pregenerated_samples(
177192 mols_stable_pct = (num_mols_stable / self .num_samples ) * 100
178193 log .info (f"Percentage of optimized samples that are stable molecules: { mols_stable_pct } %" )
179194
195+ # record optimized molecules as `.xyz` files
196+ if self .save_molecules :
197+ save_xyz_file (
198+ path = f"outputs/{ self .experiment_name } /analysis/run{ self .i } /" ,
199+ positions = x ,
200+ one_hot = one_hot ,
201+ charges = charges ,
202+ dataset_info = self .dataset_info ,
203+ id_from = id_from ,
204+ name = "conditional" ,
205+ batch_index = batch_index
206+ )
207+
180208 # build dense node mask, coordinates, and one-hot types
181209 max_num_nodes = self .num_nodes .max ().item ()
182210 node_mask_range_tensor = torch .arange (max_num_nodes , device = self .device ).unsqueeze (0 )
@@ -243,9 +271,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
243271 assert (
244272 os .path .exists (cfg .unconditional_generator_model_filepath ) and
245273 os .path .exists (cfg .conditional_generator_model_filepath ) and
246- ( os .path .exists (cfg .classifier_model_dir ) or cfg . sweep_property_values ) and
274+ os .path .exists (cfg .classifier_model_dir ) and
247275 cfg .property in cfg .conditional_generator_model_filepath and
248- ( cfg .property in cfg .classifier_model_dir or cfg . sweep_property_values )
276+ cfg .property in cfg .classifier_model_dir
249277 )
250278
251279 log .info (f"Instantiating datamodule <{ cfg .datamodule ._target_ } >" )
@@ -303,18 +331,21 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
303331 model .sample_and_save (
304332 num_samples = cfg .num_samples ,
305333 num_nodes = sampling_num_nodes ,
306- sampling_output_dir = Path (cfg .sampling_output_dir )
334+ sampling_output_dir = Path (cfg .sampling_output_dir ),
335+ num_timesteps = cfg .num_timesteps ,
336+ norm_with_original_timesteps = False , # NOTE: this is important to ensure the initial samples are "unoptimized" yet realistic when `num_timesteps << T`
307337 )
308338
309339 if cfg .generate_molecules_only :
310340 log .info (f"Done generating { cfg .num_samples } 3D molecules unconditionally! Exiting..." )
311341 exit (0 )
312342
313343 log .info ("Installing conditional model configuration values!" )
314- cfg .model .module_cfg .conditioning = [cfg .property ]
315- cfg .model .diffusion_cfg .norm_values = [1.0 , 8.0 , 1.0 ]
316- cfg .datamodule .dataloader_cfg .include_charges = False
317- cfg .datamodule .dataloader_cfg .dataset = "QM9_second_half"
344+ with open_dict (cfg ):
345+ cfg .model .module_cfg .conditioning = [cfg .property ]
346+ cfg .model .diffusion_cfg .norm_values = [1.0 , 8.0 , 1.0 ]
347+ cfg .datamodule .dataloader_cfg .include_charges = False
348+ cfg .datamodule .dataloader_cfg .dataset = "QM9_second_half"
318349
319350 log .info (f"Instantiating conditional generator model <{ cfg .model ._target_ } >" )
320351 model : LightningModule = hydra .utils .instantiate (
@@ -394,24 +425,27 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
394425 dataset_info = dataset_info ,
395426 iterations = cfg .iterations ,
396427 num_optimization_timesteps = cfg .num_optimization_timesteps ,
397- return_frames = cfg .return_frames
428+ return_frames = cfg .return_frames ,
429+ experiment_name = cfg .experiment_name ,
430+ save_molecules = cfg .save_molecules
398431 )
399432
400433 log .info ("Loading classifier model!" )
401434 classifier = get_classifier (cfg .classifier_model_dir ).to (device )
402435
403436 log .info ("Evaluating classifier on conditional generator's optimized samples!" )
404- loss = test_with_property_classifier (
405- model = classifier ,
406- epoch = 0 ,
407- dataloader = optimization_diffusion_dataloader ,
408- mean = mean ,
409- mad = mad ,
410- property = cfg .property ,
411- device = device ,
412- log_interval = 1 ,
413- debug_break = cfg .debug_break
414- )
437+ with torch .no_grad ():
438+ loss = test_with_property_classifier (
439+ model = classifier ,
440+ epoch = 0 ,
441+ dataloader = optimization_diffusion_dataloader ,
442+ mean = mean ,
443+ mad = mad ,
444+ property = cfg .property ,
445+ device = device ,
446+ log_interval = 1 ,
447+ debug_break = cfg .debug_break
448+ )
415449 log .info ("Classifier loss (MAE) on conditional generator's optimized samples: %.4f" % loss )
416450
417451 metric_dict = {}
0 commit comments