Skip to content

Commit 5912626

Browse files
committed
add speed-up ratio to main.py; fix KeyError in hparams.py
1 parent d805cd0 commit 5912626

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
parser.add_argument('--title', type=str, required=False, help='Title of output file')
2323
parser.add_argument('--num', type=int, default=1, help='Number of runs')
2424
parser.add_argument('--seed', type=int, help='Random seed of the inference')
25+
parser.add_argument('--speedup', type=int, default=0, help='PNDM speed-up ratio')
2526
args = parser.parse_args()
2627

2728
name = os.path.basename(args.proj).split('.')[0] if not args.title else args.title
@@ -46,6 +47,9 @@
4647
exp
4748
]
4849

50+
if args.speedup > 0:
51+
sys.argv += ['--hparams', f'pndm_speedup={args.speedup}']
52+
4953
if not isinstance(params, list):
5054
params = [params]
5155

utils/hparams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def load_config(config_fn): # deep first
8989
if args.hparams != "":
9090
for new_hparam in args.hparams.split(","):
9191
k, v = new_hparam.split("=")
92+
if k not in hparams_:
93+
hparams_[k] = eval(v)
9294
if v in ['True', 'False'] or type(hparams_[k]) == bool:
9395
hparams_[k] = eval(v)
9496
else:

0 commit comments

Comments
 (0)