Skip to content

Commit 460aee4

Browse files
authored
support ".pth" model
1 parent 50bd624 commit 460aee4

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

inference/svs/base_svs_infer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,19 @@ def forward_model(self, inp):
4747
def build_vocoder(self):
4848
base_dir = hparams['vocoder_ckpt']
4949
config_path = f'{base_dir}/config.yaml'
50-
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
51-
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))[-1]
52-
print('| load HifiGAN: ', ckpt)
53-
ckpt_dict = torch.load(ckpt, map_location="cpu")
54-
config = set_hparams(config_path, global_hparams=False)
55-
state = ckpt_dict["state_dict"]["model_gen"]
56-
vocoder = HifiGanGenerator(config)
57-
vocoder.load_state_dict(state, strict=True)
58-
vocoder.remove_weight_norm()
50+
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
51+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
52+
print('| load HifiGAN: ', file_path)
53+
ext = os.path.splitext(file_path)[-1]
54+
if ext == '.pth':
55+
vocoder = torch.load(file_path, map_location="cpu")
56+
elif ext == '.ckpt':
57+
ckpt_dict = torch.load(file_path, map_location="cpu")
58+
config = set_hparams(config_path, global_hparams=False)
59+
state = ckpt_dict["state_dict"]["model_gen"]
60+
vocoder = HifiGanGenerator(config)
61+
vocoder.load_state_dict(state, strict=True)
62+
vocoder.remove_weight_norm()
5963
vocoder = vocoder.eval().to(self.device)
6064
return vocoder
6165

0 commit comments

Comments
 (0)