Skip to content

Commit ee4ffef

Browse files
Merge branch 'openvpi:master' into master
2 parents e6be5bf + 6090ecb commit ee4ffef

5 files changed

Lines changed: 41 additions & 28 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ __pycache__/
44
*.sh
55
local_tools/
66
*.ckpt
7+
*.pth
78
*.wav
89
infer_out/
910
config.yaml

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
33
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
44
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5+
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kfmZ6Y018c5trSwQAbhdQtZ7Il8W_4BU)
56
| [Interactive🤗 TTS](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
67
| [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
78

inference/svs/base_svs_infer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,19 @@ def forward_model(self, inp):
4747
def build_vocoder(self):
4848
base_dir = hparams['vocoder_ckpt']
4949
config_path = f'{base_dir}/config.yaml'
50-
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
51-
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))[-1]
52-
print('| load HifiGAN: ', ckpt)
53-
ckpt_dict = torch.load(ckpt, map_location="cpu")
54-
config = set_hparams(config_path, global_hparams=False)
55-
state = ckpt_dict["state_dict"]["model_gen"]
56-
vocoder = HifiGanGenerator(config)
57-
vocoder.load_state_dict(state, strict=True)
58-
vocoder.remove_weight_norm()
50+
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
51+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
52+
print('| load HifiGAN: ', file_path)
53+
ext = os.path.splitext(file_path)[-1]
54+
if ext == '.pth':
55+
vocoder = torch.load(file_path, map_location="cpu")
56+
elif ext == '.ckpt':
57+
ckpt_dict = torch.load(file_path, map_location="cpu")
58+
config = set_hparams(config_path, global_hparams=False)
59+
state = ckpt_dict["state_dict"]["model_gen"]
60+
vocoder = HifiGanGenerator(config)
61+
vocoder.load_state_dict(state, strict=True)
62+
vocoder.remove_weight_norm()
5963
vocoder = vocoder.eval().to(self.device)
6064
return vocoder
6165

tasks/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def start(cls):
251251
t = datetime.now().strftime('%Y%m%d%H%M%S')
252252
code_dir = f'{work_dir}/codes/{t}'
253253
# TODO: test filesystem calls
254-
os.mkdir(code_dir)
254+
os.makedirs(code_dir, exist_ok=True)
255255
# subprocess.check_call(f'mkdir "{code_dir}"', shell=True)
256256
for c in hparams['save_codes']:
257257
shutil.copytree(c, code_dir, dirs_exist_ok=True)

vocoders/hifigan.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,28 @@
1414
from vocoders.vocoder_utils import denoise
1515

1616

17-
def load_model(config_path, checkpoint_path):
17+
def load_model(config_path, file_path):
1818
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19-
ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
20-
if '.yaml' in config_path:
21-
config = set_hparams(config_path, global_hparams=False)
22-
state = ckpt_dict["state_dict"]["model_gen"]
23-
elif '.json' in config_path:
24-
config = json.load(open(config_path, 'r', encoding='utf-8'))
25-
state = ckpt_dict["generator"]
26-
27-
model = HifiGanGenerator(config)
28-
model.load_state_dict(state, strict=True)
29-
model.remove_weight_norm()
19+
ext = os.path.splitext(file_path)[-1]
20+
if ext == '.pth':
21+
if '.yaml' in config_path:
22+
config = set_hparams(config_path, global_hparams=False)
23+
elif '.json' in config_path:
24+
config = json.load(open(config_path, 'r', encoding='utf-8'))
25+
model = torch.load(file_path, map_location="cpu")
26+
elif ext == '.ckpt':
27+
ckpt_dict = torch.load(file_path, map_location="cpu")
28+
if '.yaml' in config_path:
29+
config = set_hparams(config_path, global_hparams=False)
30+
state = ckpt_dict["state_dict"]["model_gen"]
31+
elif '.json' in config_path:
32+
config = json.load(open(config_path, 'r', encoding='utf-8'))
33+
state = ckpt_dict["generator"]
34+
model = HifiGanGenerator(config)
35+
model.load_state_dict(state, strict=True)
36+
model.remove_weight_norm()
3037
model = model.eval().to(device)
31-
print(f"| Loaded model parameters from {checkpoint_path}.")
38+
print(f"| Loaded model parameters from {file_path}.")
3239
print(f"| HifiGAN device: {device}.")
3340
return model, config, device
3441

@@ -42,15 +49,15 @@ def __init__(self):
4249
base_dir = hparams['vocoder_ckpt']
4350
config_path = f'{base_dir}/config.yaml'
4451
if os.path.exists(config_path):
45-
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
46-
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))[-1]
47-
print('| load HifiGAN: ', ckpt)
48-
self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
52+
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
53+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
54+
print('| load HifiGAN: ', file_path)
55+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
4956
else:
5057
config_path = f'{base_dir}/config.json'
5158
ckpt = f'{base_dir}/generator_v1'
5259
if os.path.exists(config_path):
53-
self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
60+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
5461

5562
def spec2wav(self, mel, **kwargs):
5663
device = self.device

0 commit comments

Comments
 (0)