@@ -105,6 +105,7 @@ def plot_dataset_rmsd(
105105 filtered_ids_to_keep_file : Optional [str ] = None ,
106106 filtered_ids_to_skip : Optional [Set [str ]] = None ,
107107 is_casp_dataset : bool = False ,
108+ public_plots : bool = True ,
108109 accurate_rmsd_threshold : float = 4.0 ,
109110 accurate_tm_score_threshold : float = 0.7 ,
110111):
@@ -119,6 +120,7 @@ def plot_dataset_rmsd(
119120 :param filtered_ids_to_keep_file: File containing IDs of sequences to keep.
120121 :param filtered_ids_to_skip: Set of IDs of sequences to skip.
121122 :param is_casp_dataset: Whether the dataset is a CASP dataset.
123+ :param public_plots: Whether to save the public versions of the plots.
122124 :param accurate_rmsd_threshold: RMSD threshold for accurate predictions.
123125 :param accurate_tm_score_threshold: TM-score threshold for accurate predictions.
124126 """
@@ -135,7 +137,12 @@ def plot_dataset_rmsd(
135137
136138 dataset_rows = []
137139
138- for pred_pdb_file in tqdm (os .listdir (pred_pdb_dir ), desc = f"Plotting RMSD for { dataset_name } " ):
140+ dataset_suffix = " (Public)" if is_casp_dataset and public_plots else ""
141+
142+ for pred_pdb_file in tqdm (
143+ os .listdir (pred_pdb_dir ),
144+ desc = f"Plotting RMSD for { dataset_name } { dataset_suffix } " ,
145+ ):
139146 pdb_id = os .path .splitext (os .path .basename (pred_pdb_file ))[0 ].split ("_holo" )[0 ]
140147
141148 if filtered_ids_to_keep is not None and pdb_id not in filtered_ids_to_keep :
@@ -193,10 +200,10 @@ def plot_dataset_rmsd(
193200 / dataset_df .shape [0 ]
194201 )
195202 logging .info (
196- f"For the { dataset_name } dataset, { accurate_predictions_percent * 100 :.2f} % of the predictions have RMSD < { accurate_rmsd_threshold } and TM-score > { accurate_tm_score_threshold } ."
203+ f"For the { dataset_name } { dataset_suffix } dataset, { accurate_predictions_percent * 100 :.2f} % of the predictions have RMSD < { accurate_rmsd_threshold } and TM-score > { accurate_tm_score_threshold } ."
197204 )
198205
199- plot_dir = Path (output_dir ) / ("public_plots" if is_casp_dataset else "plots" )
206+ plot_dir = Path (output_dir ) / ("public_plots" if is_casp_dataset and public_plots else "plots" )
200207 plot_dir .mkdir (exist_ok = True )
201208
202209 plt .clf ()
@@ -280,11 +287,32 @@ def main(cfg: DictConfig):
280287 ),
281288 usalign_exec_path = cfg .usalign_exec_path ,
282289 filtered_ids_to_skip = {
283- "T1170"
284- }, # NOTE: We don't score this target due to CASP internal parsing issues
290+ "T1127v2" ,
291+ "T1146" ,
292+ "T1170" ,
293+ "T1181" ,
294+ "T1186" ,
295+ }, # NOTE: We don't score `T1170` due to CASP internal parsing issues
285296 is_casp_dataset = True ,
297+ public_plots = True ,
286298 )
287299
300+ # plot_dataset_rmsd(
301+ # "CASP15 Set",
302+ # os.path.join(cfg.data_dir, "casp15_set", "predicted_structures"),
303+ # os.path.join(cfg.data_dir, "casp15_set", "targets"),
304+ # os.path.join(
305+ # cfg.data_dir,
306+ # "casp15_set",
307+ # ),
308+ # usalign_exec_path=cfg.usalign_exec_path,
309+ # filtered_ids_to_skip={
310+ # "T1170",
311+ # }, # NOTE: We don't score `T1170` due to CASP internal parsing issues
312+ # is_casp_dataset=True,
313+ # public_plots=False,
314+ # )
315+
288316
289317if __name__ == "__main__" :
290318 main ()
0 commit comments