Skip to content

Commit a17d91a

Browse files
committed
Update the RMSD plotting script
1 parent 60489ab commit a17d91a

1 file changed

Lines changed: 33 additions & 5 deletions

File tree

src/data/components/plot_dataset_rmsd.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def plot_dataset_rmsd(
105105
filtered_ids_to_keep_file: Optional[str] = None,
106106
filtered_ids_to_skip: Optional[Set[str]] = None,
107107
is_casp_dataset: bool = False,
108+
public_plots: bool = True,
108109
accurate_rmsd_threshold: float = 4.0,
109110
accurate_tm_score_threshold: float = 0.7,
110111
):
@@ -119,6 +120,7 @@ def plot_dataset_rmsd(
119120
:param filtered_ids_to_keep_file: File containing IDs of sequences to keep.
120121
:param filtered_ids_to_skip: Set of IDs of sequences to skip.
121122
:param is_casp_dataset: Whether the dataset is a CASP dataset.
123+
:param public_plots: Whether to save the public versions of the plots.
122124
:param accurate_rmsd_threshold: RMSD threshold for accurate predictions.
123125
:param accurate_tm_score_threshold: TM-score threshold for accurate predictions.
124126
"""
@@ -135,7 +137,12 @@ def plot_dataset_rmsd(
135137

136138
dataset_rows = []
137139

138-
for pred_pdb_file in tqdm(os.listdir(pred_pdb_dir), desc=f"Plotting RMSD for {dataset_name}"):
140+
dataset_suffix = " (Public)" if is_casp_dataset and public_plots else ""
141+
142+
for pred_pdb_file in tqdm(
143+
os.listdir(pred_pdb_dir),
144+
desc=f"Plotting RMSD for {dataset_name}{dataset_suffix}",
145+
):
139146
pdb_id = os.path.splitext(os.path.basename(pred_pdb_file))[0].split("_holo")[0]
140147

141148
if filtered_ids_to_keep is not None and pdb_id not in filtered_ids_to_keep:
@@ -193,10 +200,10 @@ def plot_dataset_rmsd(
193200
/ dataset_df.shape[0]
194201
)
195202
logging.info(
196-
f"For the {dataset_name} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}."
203+
f"For the {dataset_name}{dataset_suffix} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}."
197204
)
198205

199-
plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset else "plots")
206+
plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset and public_plots else "plots")
200207
plot_dir.mkdir(exist_ok=True)
201208

202209
plt.clf()
@@ -280,11 +287,32 @@ def main(cfg: DictConfig):
280287
),
281288
usalign_exec_path=cfg.usalign_exec_path,
282289
filtered_ids_to_skip={
283-
"T1170"
284-
}, # NOTE: We don't score this target due to CASP internal parsing issues
290+
"T1127v2",
291+
"T1146",
292+
"T1170",
293+
"T1181",
294+
"T1186",
295+
}, # NOTE: We don't score `T1170` due to CASP internal parsing issues
285296
is_casp_dataset=True,
297+
public_plots=True,
286298
)
287299

300+
# plot_dataset_rmsd(
301+
# "CASP15 Set",
302+
# os.path.join(cfg.data_dir, "casp15_set", "predicted_structures"),
303+
# os.path.join(cfg.data_dir, "casp15_set", "targets"),
304+
# os.path.join(
305+
# cfg.data_dir,
306+
# "casp15_set",
307+
# ),
308+
# usalign_exec_path=cfg.usalign_exec_path,
309+
# filtered_ids_to_skip={
310+
# "T1170",
311+
# }, # NOTE: We don't score `T1170` due to CASP internal parsing issues
312+
# is_casp_dataset=True,
313+
# public_plots=False,
314+
# )
315+
288316

289317
if __name__ == "__main__":
290318
main()

0 commit comments

Comments
 (0)