-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseed_script.py
More file actions
87 lines (65 loc) · 2.25 KB
/
seed_script.py
File metadata and controls
87 lines (65 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import shutil
import numpy as np
import yaml
import train
import testBinary
#absolute paths suggested especially for wsl.
CONFIG_PATH = "/home/rauzen/Projects/eye_project/classifier/config.yml"
OUTPUT_PATH = "/mnt/c/Users/mbomt/Downloads/deneme"
RESULT_PATH = "/home/rauzen/Projects/eye_project/classifier/results"
SEED_LIST = [35, 1063, 306, 629, 1940, 288, 399, 1215, 187, 1636]
DELETE_IMAGES = True
DELETE_NON_BEST_MODELS = True
def sort_filenames(l):
temp = l[:]
last_flag = False
if "last_epoch.pt" in temp:
temp.remove("last_epoch.pt")
last_flag = True
#extraction
for i in range(len(temp)):
temp[i] = int(temp[i][5:-3])
temp.sort()
#packing
for i in range(len(temp)):
temp[i] = "epoch" + str(temp[i]) + ".pt"
if last_flag:
temp.append("last_epoch.pt")
return temp
def train_one_seed(cfg, seed):
global RESULT_PATH
cfg['train_config']['seed'] = seed
train.main(cfg)
best_path = os.path.join(RESULT_PATH + "/models", sort_filenames(os.listdir(RESULT_PATH + "/models"))[-2])
testBinary.main(cfg, best_path)
def save_results(cfg, seed):
global OUTPUT_PATH
global RESULT_PATH
global SEED_LIST
global DELETE_IMAGES
global DELETE_NON_BEST_MODELS
cfg['train_config']['seed'] = seed
if DELETE_IMAGES:
shutil.rmtree(RESULT_PATH + "/images")
if DELETE_NON_BEST_MODELS:
os.remove(RESULT_PATH + "/models/last_epoch.pt")
dirpaths = os.listdir(RESULT_PATH + "/models")
to_be_removed = sort_filenames(dirpaths)[:-1]
for i in to_be_removed:
os.remove(RESULT_PATH + "/models/" + i)
seeddir = OUTPUT_PATH + "/seed" + str(seed)
os.mkdir(seeddir)
shutil.copytree(RESULT_PATH, seeddir + "/results")
shutil.copy2(RESULT_PATH + "/../batch_sample_train.png", seeddir)
os.remove(RESULT_PATH + "/../batch_sample_train.png")
shutil.rmtree(RESULT_PATH)
f = open(os.path.join(seeddir, "used_config.yml"), "w")
yaml.dump(cfg, f)
if __name__ == "__main__":
# config_path = 'config.yml'
with open(CONFIG_PATH, "r") as ymlfile:
cfg = yaml.safe_load(ymlfile)
for seed in SEED_LIST:
train_one_seed(cfg, seed)
save_results(cfg, seed)