-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcluster_seed_script.py
More file actions
55 lines (43 loc) · 1.82 KB
/
cluster_seed_script.py
File metadata and controls
55 lines (43 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import shutil
import numpy as np
import yaml
import train
import testBinary
import glob
import pandas as pd
CONFIG_PATH = "/home/ocaki13/ResNet50_Classifier/config.yml"
BATCHSAMPLE = "/home/ocaki13/ResNet50_Classifier/batch_sample_train.png"
def seedTrain(cfg, seedList, deleteImages=True, deleteNonBestModels=True):
resultsPath = cfg['dataset_config']['save_dir']
if not os.path.exists(resultsPath):
os.mkdir(resultsPath)
resultsDict = {}
for seed in seedList:
currentSaveDir = os.path.join(resultsPath, 'seed_'+str(seed))
cfg['dataset_config']['save_dir'] = currentSaveDir
cfg['train_config']['seed'] = seed
best_path = os.path.join(currentSaveDir, 'models/best.pt')
train.main(cfg)
currentResults = testBinary.main(cfg, best_path)
resultsDict[seed] = currentResults
if deleteImages:
shutil.rmtree(os.path.join(currentSaveDir ,"images"))
if deleteNonBestModels:
folder_path = os.path.join(currentSaveDir ,"models")
files_to_delete = glob.glob(os.path.join(folder_path, "*epoch*"))
for file_path in files_to_delete:
try:
os.remove(file_path) # Delete the file
except Exception as e:
print(f"Error deleting {file_path}: {e}")
# Convert the dictionary of dictionaries into a DataFrame
results_df = pd.DataFrame.from_dict(resultsDict, orient='index')
# Save the DataFrame to a CSV file
results_df.to_csv(os.path.join(resultsPath,"results.csv"), index_label="Seed")
if __name__ == "__main__":
#absolute paths suggested especially for wsl.
with open(CONFIG_PATH, "r") as ymlfile:
cfg = yaml.safe_load(ymlfile)
seedList = [35, 1063, 306]
seedTrain(cfg, seedList)