Skip to content

Commit a2c630e

Browse files
authored
Update mol_gen_eval_optimization_qm9.py
1 parent b804331 commit a2c630e

1 file changed

Lines changed: 59 additions & 25 deletions

File tree

src/mol_gen_eval_optimization_qm9.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch.nn as nn
1111

12-
from omegaconf import DictConfig
12+
from omegaconf import DictConfig, open_dict
1313
from pathlib import Path
1414
from pytorch_lightning import LightningDataModule, LightningModule
1515
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -24,7 +24,7 @@
2424
from src.datamodules.components.edm import check_molecular_stability, get_bond_length_arrays
2525
from src.datamodules.components.edm.datasets_config import QM9_WITH_H, QM9_WITHOUT_H
2626
from src.models import PropertiesDistribution, compute_mean_mad
27-
from src.models.components import load_molecule_xyz
27+
from src.models.components import load_molecule_xyz, save_xyz_file
2828
from src.utils.pylogger import get_pylogger
2929

3030
from src import LR_SCHEDULER_MANUAL_INTERPOLATION_HELPER_CONFIG_ITEMS, LR_SCHEDULER_MANUAL_INTERPOLATION_PRIMARY_CONFIG_ITEMS, get_classifier, test_with_property_classifier, utils
@@ -34,6 +34,9 @@
3434

3535
patch_typeguard() # use before @typechecked
3636

37+
import lovely_tensors as lt
38+
lt.monkey_patch()
39+
3740

3841
QM9_OPTIMIZATION_NUM_NODES = 19
3942

@@ -81,7 +84,9 @@ def __init__(
8184
iterations: int = 200,
8285
num_optimization_timesteps: int = 10,
8386
return_frames: int = 1,
84-
unknown_labels: bool = False
87+
experiment_name: str = "conditional_diffusion",
88+
unknown_labels: bool = False,
89+
save_molecules: bool = False,
8590
):
8691
assert iterations > 0, \
8792
"Optimization requires at least two iterations, " \
@@ -95,9 +100,11 @@ def __init__(
95100
self.device = device
96101
self.dataset_info = dataset_info
97102
self.iterations = iterations
103+
self.experiment_name = experiment_name
98104
self.num_optimization_timesteps = num_optimization_timesteps
99105
self.return_frames = return_frames
100106
self.unknown_labels = unknown_labels
107+
self.save_molecules = save_molecules
101108

102109
self.samples = self.load_pregenerated_samples(sampling_output_dir)
103110
self.context = self.props_distr.sample_batch(num_nodes).to(device) # fix context so we can compare between iterations
@@ -126,7 +133,8 @@ def load_pregenerated_samples(
126133
def optimize_pregenerated_samples(
127134
self,
128135
score_initial_samples: bool = False,
129-
analyze_initial_samples_stability: bool = True
136+
analyze_initial_samples_stability: bool = True,
137+
id_from: int = 0
130138
) -> Dict[str, Any]:
131139
if score_initial_samples:
132140
# evaluate stability of initial molecules
@@ -150,17 +158,24 @@ def optimize_pregenerated_samples(
150158
)
151159
else:
152160
# optimize initial molecules
153-
x, one_hot, _, batch_index = self.model.optimize(
161+
x, one_hot, charges, batch_index = self.model.optimize(
154162
samples=self.samples,
155163
num_nodes=self.num_nodes,
156164
context=self.context,
157165
num_timesteps=self.num_optimization_timesteps,
158166
sampling_output_dir=self.sampling_output_dir,
159167
optim_property=self.optim_property,
160168
iteration_index=self.i - 1,
161-
return_frames=self.return_frames
169+
return_frames=self.return_frames,
170+
norm_with_original_timesteps=False, # NOTE: this is important to ensure the samples are "fully optimized" each iteration
162171
)
163172

173+
# iteratively update samples with optimized molecule features
174+
self.samples = [
175+
(x[(batch_index == sample_idx)].cpu(), one_hot[(batch_index == sample_idx)].cpu())
176+
for sample_idx in range(self.num_samples)
177+
]
178+
164179
# evaluate stability of optimized molecules
165180
num_mols_stable = 0
166181
for sample_idx in range(self.num_samples):
@@ -177,6 +192,19 @@ def optimize_pregenerated_samples(
177192
mols_stable_pct = (num_mols_stable / self.num_samples) * 100
178193
log.info(f"Percentage of optimized samples that are stable molecules: {mols_stable_pct}%")
179194

195+
# record optimized molecules as `.xyz` files
196+
if self.save_molecules:
197+
save_xyz_file(
198+
path=f"outputs/{self.experiment_name}/analysis/run{self.i}/",
199+
positions=x,
200+
one_hot=one_hot,
201+
charges=charges,
202+
dataset_info=self.dataset_info,
203+
id_from=id_from,
204+
name="conditional",
205+
batch_index=batch_index
206+
)
207+
180208
# build dense node mask, coordinates, and one-hot types
181209
max_num_nodes = self.num_nodes.max().item()
182210
node_mask_range_tensor = torch.arange(max_num_nodes, device=self.device).unsqueeze(0)
@@ -243,9 +271,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
243271
assert (
244272
os.path.exists(cfg.unconditional_generator_model_filepath) and
245273
os.path.exists(cfg.conditional_generator_model_filepath) and
246-
(os.path.exists(cfg.classifier_model_dir) or cfg.sweep_property_values) and
274+
os.path.exists(cfg.classifier_model_dir) and
247275
cfg.property in cfg.conditional_generator_model_filepath and
248-
(cfg.property in cfg.classifier_model_dir or cfg.sweep_property_values)
276+
cfg.property in cfg.classifier_model_dir
249277
)
250278

251279
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
@@ -303,18 +331,21 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
303331
model.sample_and_save(
304332
num_samples=cfg.num_samples,
305333
num_nodes=sampling_num_nodes,
306-
sampling_output_dir=Path(cfg.sampling_output_dir)
334+
sampling_output_dir=Path(cfg.sampling_output_dir),
335+
num_timesteps=cfg.num_timesteps,
336+
norm_with_original_timesteps=False, # NOTE: this is important to ensure the initial samples are "unoptimized" yet realistic when `num_timesteps << T`
307337
)
308338

309339
if cfg.generate_molecules_only:
310340
log.info(f"Done generating {cfg.num_samples} 3D molecules unconditionally! Exiting...")
311341
exit(0)
312342

313343
log.info("Installing conditional model configuration values!")
314-
cfg.model.module_cfg.conditioning = [cfg.property]
315-
cfg.model.diffusion_cfg.norm_values = [1.0, 8.0, 1.0]
316-
cfg.datamodule.dataloader_cfg.include_charges = False
317-
cfg.datamodule.dataloader_cfg.dataset = "QM9_second_half"
344+
with open_dict(cfg):
345+
cfg.model.module_cfg.conditioning = [cfg.property]
346+
cfg.model.diffusion_cfg.norm_values = [1.0, 8.0, 1.0]
347+
cfg.datamodule.dataloader_cfg.include_charges = False
348+
cfg.datamodule.dataloader_cfg.dataset = "QM9_second_half"
318349

319350
log.info(f"Instantiating conditional generator model <{cfg.model._target_}>")
320351
model: LightningModule = hydra.utils.instantiate(
@@ -394,24 +425,27 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
394425
dataset_info=dataset_info,
395426
iterations=cfg.iterations,
396427
num_optimization_timesteps=cfg.num_optimization_timesteps,
397-
return_frames=cfg.return_frames
428+
return_frames=cfg.return_frames,
429+
experiment_name=cfg.experiment_name,
430+
save_molecules=cfg.save_molecules
398431
)
399432

400433
log.info("Loading classifier model!")
401434
classifier = get_classifier(cfg.classifier_model_dir).to(device)
402435

403436
log.info("Evaluating classifier on conditional generator's optimized samples!")
404-
loss = test_with_property_classifier(
405-
model=classifier,
406-
epoch=0,
407-
dataloader=optimization_diffusion_dataloader,
408-
mean=mean,
409-
mad=mad,
410-
property=cfg.property,
411-
device=device,
412-
log_interval=1,
413-
debug_break=cfg.debug_break
414-
)
437+
with torch.no_grad():
438+
loss = test_with_property_classifier(
439+
model=classifier,
440+
epoch=0,
441+
dataloader=optimization_diffusion_dataloader,
442+
mean=mean,
443+
mad=mad,
444+
property=cfg.property,
445+
device=device,
446+
log_interval=1,
447+
debug_break=cfg.debug_break
448+
)
415449
log.info("Classifier loss (MAE) on conditional generator's optimized samples: %.4f" % loss)
416450

417451
metric_dict = {}

0 commit comments

Comments
 (0)