Skip to content

Commit 7d52da7

Browse files
committed
support pitch control in main.py
1 parent 3e696e3 commit 7d52da7

1 file changed

Lines changed: 14 additions & 10 deletions

File tree

main.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
from crossfade import cross_fade
11+
from inference.svs.ds_cascade import DiffSingerCascadeInfer
1112
from inference.svs.ds_e2e import DiffSingerE2EInfer
1213
from utils.audio import save_wav
1314
from utils.hparams import set_hparams, hparams
@@ -20,36 +21,39 @@
2021
parser.add_argument('--exp', type=str, required=False, help='Selection of model')
2122
parser.add_argument('--out', type=str, required=False, help='Path of the output folder')
2223
parser.add_argument('--title', type=str, required=False, help='Title of output file')
23-
parser.add_argument('--num', type=int, default=1, help='Number of runs')
24-
parser.add_argument('--seed', type=int, help='Random seed of the inference')
25-
parser.add_argument('--speedup', type=int, default=0, help='PNDM speed-up ratio')
24+
parser.add_argument('--num', type=int, required=False, default=1, help='Number of runs')
25+
parser.add_argument('--seed', type=int, required=False, help='Random seed of the inference')
26+
parser.add_argument('--speedup', type=int, required=False, default=0, help='PNDM speed-up ratio')
27+
parser.add_argument('--pitch', action='store_true', required=False, default=False, help='Enable manual pitch mode')
2628
args = parser.parse_args()
2729

2830
name = os.path.basename(args.proj).split('.')[0] if not args.title else args.title
2931
exp = args.exp
3032
if not exp:
31-
if os.path.exists(os.path.join(root_dir, 'checkpoints/0814_opencpop_ds_rhythm_fix')):
33+
if args.pitch:
34+
exp = '0909_opencpop_ds100_pitchcontrol'
35+
elif os.path.exists(os.path.join(root_dir, 'checkpoints/0814_opencpop_ds_rhythm_fix')):
3236
exp = '0814_opencpop_ds_rhythm_fix'
3337
else:
3438
exp = '0814_opencpop_500k(修复无参音素)'
3539
out = args.out
3640
if not out:
3741
out = os.path.dirname(os.path.abspath(args.proj))
3842

39-
with open(args.proj, 'r', encoding='utf-8') as f:
40-
params = json.load(f)
41-
4243
sys.argv = [
43-
f'{root_dir}/inference/svs/ds_e2e.py',
44+
f'{root_dir}/inference/svs/ds_e2e.py' if not args.pitch else f'{root_dir}/inference/svs/ds_cascade.py',
4445
'--config',
45-
f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml',
46+
f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml' if not args.pitch else f'{root_dir}/usr/configs/midi/cascade/opencs/ds100_rel.yaml',
4647
'--exp_name',
4748
exp
4849
]
4950

5051
if args.speedup > 0:
5152
sys.argv += ['--hparams', f'pndm_speedup={args.speedup}']
5253

54+
with open(args.proj, 'r', encoding='utf-8') as f:
55+
params = json.load(f)
56+
5357
if not isinstance(params, list):
5458
params = [params]
5559

@@ -58,7 +62,7 @@
5862

5963
infer_ins = None
6064
if len(params) > 0:
61-
infer_ins = DiffSingerE2EInfer(hparams)
65+
infer_ins = DiffSingerE2EInfer(hparams) if not args.pitch else DiffSingerCascadeInfer(hparams)
6266

6367

6468
def infer_once(path: str):

0 commit comments

Comments
 (0)