|
8 | 8 | import torch |
9 | 9 |
|
10 | 10 | from crossfade import cross_fade |
| 11 | +from inference.svs.ds_cascade import DiffSingerCascadeInfer |
11 | 12 | from inference.svs.ds_e2e import DiffSingerE2EInfer |
12 | 13 | from utils.audio import save_wav |
13 | 14 | from utils.hparams import set_hparams, hparams |
|
20 | 21 | parser.add_argument('--exp', type=str, required=False, help='Selection of model') |
21 | 22 | parser.add_argument('--out', type=str, required=False, help='Path of the output folder') |
22 | 23 | parser.add_argument('--title', type=str, required=False, help='Title of output file') |
23 | | -parser.add_argument('--num', type=int, default=1, help='Number of runs') |
24 | | -parser.add_argument('--seed', type=int, help='Random seed of the inference') |
25 | | -parser.add_argument('--speedup', type=int, default=0, help='PNDM speed-up ratio') |
| 24 | +parser.add_argument('--num', type=int, required=False, default=1, help='Number of runs') |
| 25 | +parser.add_argument('--seed', type=int, required=False, help='Random seed of the inference') |
| 26 | +parser.add_argument('--speedup', type=int, required=False, default=0, help='PNDM speed-up ratio') |
| 27 | +parser.add_argument('--pitch', action='store_true', required=False, default=False, help='Enable manual pitch mode') |
26 | 28 | args = parser.parse_args() |
27 | 29 |
|
28 | 30 | name = os.path.basename(args.proj).split('.')[0] if not args.title else args.title |
29 | 31 | exp = args.exp |
30 | 32 | if not exp: |
31 | | - if os.path.exists(os.path.join(root_dir, 'checkpoints/0814_opencpop_ds_rhythm_fix')): |
| 33 | + if args.pitch: |
| 34 | + exp = '0909_opencpop_ds100_pitchcontrol' |
| 35 | + elif os.path.exists(os.path.join(root_dir, 'checkpoints/0814_opencpop_ds_rhythm_fix')): |
32 | 36 | exp = '0814_opencpop_ds_rhythm_fix' |
33 | 37 | else: |
34 | 38 | exp = '0814_opencpop_500k(修复无参音素)' |
35 | 39 | out = args.out |
36 | 40 | if not out: |
37 | 41 | out = os.path.dirname(os.path.abspath(args.proj)) |
38 | 42 |
|
39 | | -with open(args.proj, 'r', encoding='utf-8') as f: |
40 | | - params = json.load(f) |
41 | | - |
42 | 43 | sys.argv = [ |
43 | | - f'{root_dir}/inference/svs/ds_e2e.py', |
| 44 | + f'{root_dir}/inference/svs/ds_e2e.py' if not args.pitch else f'{root_dir}/inference/svs/ds_cascade.py', |
44 | 45 | '--config', |
45 | | - f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml', |
| 46 | + f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml' if not args.pitch else f'{root_dir}/usr/configs/midi/cascade/opencs/ds100_rel.yaml', |
46 | 47 | '--exp_name', |
47 | 48 | exp |
48 | 49 | ] |
49 | 50 |
|
50 | 51 | if args.speedup > 0: |
51 | 52 | sys.argv += ['--hparams', f'pndm_speedup={args.speedup}'] |
52 | 53 |
|
| 54 | +with open(args.proj, 'r', encoding='utf-8') as f: |
| 55 | + params = json.load(f) |
| 56 | + |
53 | 57 | if not isinstance(params, list): |
54 | 58 | params = [params] |
55 | 59 |
|
|
58 | 62 |
|
59 | 63 | infer_ins = None |
60 | 64 | if len(params) > 0: |
61 | | - infer_ins = DiffSingerE2EInfer(hparams) |
| 65 | + infer_ins = DiffSingerE2EInfer(hparams) if not args.pitch else DiffSingerCascadeInfer(hparams) |
62 | 66 |
|
63 | 67 |
|
64 | 68 | def infer_once(path: str): |
|
0 commit comments