Skip to content

Commit c08f608

Browse files
authored
support ".pth" model
1 parent 460aee4 commit c08f608

1 file changed

Lines changed: 24 additions & 17 deletions

File tree

vocoders/hifigan.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,26 @@
1616

1717
def load_model(config_path, checkpoint_path):
1818
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19-
ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
20-
if '.yaml' in config_path:
21-
config = set_hparams(config_path, global_hparams=False)
22-
state = ckpt_dict["state_dict"]["model_gen"]
23-
elif '.json' in config_path:
24-
config = json.load(open(config_path, 'r', encoding='utf-8'))
25-
state = ckpt_dict["generator"]
26-
27-
model = HifiGanGenerator(config)
28-
model.load_state_dict(state, strict=True)
29-
model.remove_weight_norm()
19+
ext = os.path.splitext(file_path)[-1]
20+
if ext == '.pth':
21+
if '.yaml' in config_path:
22+
config = set_hparams(config_path, global_hparams=False)
23+
elif '.json' in config_path:
24+
config = json.load(open(config_path, 'r', encoding='utf-8'))
25+
model = torch.load(file_path, map_location="cpu")
26+
elif ext == '.ckpt':
27+
ckpt_dict = torch.load(file_path, map_location="cpu")
28+
if '.yaml' in config_path:
29+
config = set_hparams(config_path, global_hparams=False)
30+
state = ckpt_dict["state_dict"]["model_gen"]
31+
elif '.json' in config_path:
32+
config = json.load(open(config_path, 'r', encoding='utf-8'))
33+
state = ckpt_dict["generator"]
34+
model = HifiGanGenerator(config)
35+
model.load_state_dict(state, strict=True)
36+
model.remove_weight_norm()
3037
model = model.eval().to(device)
31-
print(f"| Loaded model parameters from {checkpoint_path}.")
38+
print(f"| Loaded model parameters from {file_path}.")
3239
print(f"| HifiGAN device: {device}.")
3340
return model, config, device
3441

@@ -42,15 +49,15 @@ def __init__(self):
4249
base_dir = hparams['vocoder_ckpt']
4350
config_path = f'{base_dir}/config.yaml'
4451
if os.path.exists(config_path):
45-
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
46-
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))[-1]
47-
print('| load HifiGAN: ', ckpt)
48-
self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
52+
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
53+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
54+
print('| load HifiGAN: ', file_path)
55+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
4956
else:
5057
config_path = f'{base_dir}/config.json'
5158
ckpt = f'{base_dir}/generator_v1'
5259
if os.path.exists(config_path):
53-
self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
60+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
5461

5562
def spec2wav(self, mel, **kwargs):
5663
device = self.device

0 commit comments

Comments
 (0)