|
| 1 | +# coding=utf8 |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import inference.svs.ds_e2e as e2e |
| 6 | +from modules.fastspeech.pe import PitchExtractor |
| 7 | +from usr.diff.shallow_diffusion_tts import GaussianDiffusion |
| 8 | +from utils import load_ckpt |
| 9 | +from utils.audio import save_wav |
| 10 | +from utils.hparams import set_hparams, hparams |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from utils.text_encoder import TokenTextEncoder |
| 15 | +from usr.diffsinger_task import DIFF_DECODERS |
| 16 | + |
| 17 | +root_dir = os.path.dirname(os.path.abspath(__file__)) |
| 18 | +os.environ['PYTHONPATH'] = f'"{root_dir}"' |
| 19 | + |
| 20 | +sys.argv = [ |
| 21 | + f'{root_dir}/inference/svs/ds_e2e.py', |
| 22 | + '--config', |
| 23 | + f'{root_dir}/usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml', |
| 24 | + '--exp_name', |
| 25 | + '0228_opencpop_ds100_rel' |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +class GaussianDiffusionWrap(GaussianDiffusion): |
| 30 | + def forward(self, txt_tokens, mel2ph, |
| 31 | + # Wrapped Arguments |
| 32 | + spk_id, |
| 33 | + pitch_midi, |
| 34 | + midi_dur, |
| 35 | + is_slur, |
| 36 | + ): |
| 37 | + |
| 38 | + if (torch.numel(txt_tokens) == 0): |
| 39 | + txt_tokens = None |
| 40 | + if (torch.numel(mel2ph) == 0): |
| 41 | + mel2ph = None |
| 42 | + if (torch.numel(spk_id) == 0): |
| 43 | + spk_id = None |
| 44 | + if (torch.numel(pitch_midi) == 0): |
| 45 | + pitch_midi = None |
| 46 | + if (torch.numel(midi_dur) == 0): |
| 47 | + midi_dur = None |
| 48 | + if (torch.numel(is_slur) == 0): |
| 49 | + is_slur = None |
| 50 | + |
| 51 | + return super().forward(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True, |
| 52 | + pitch_midi=pitch_midi, midi_dur=midi_dur, |
| 53 | + is_slur=is_slur, mel2ph=mel2ph) |
| 54 | + |
| 55 | + |
| 56 | +class DFSInferWrapped(e2e.DiffSingerE2EInfer): |
| 57 | + def build_model(self): |
| 58 | + model = GaussianDiffusionWrap( |
| 59 | + phone_encoder=self.ph_encoder, |
| 60 | + out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), |
| 61 | + timesteps=hparams['timesteps'], |
| 62 | + K_step=hparams['K_step'], |
| 63 | + loss_type=hparams['diff_loss_type'], |
| 64 | + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], |
| 65 | + ) |
| 66 | + |
| 67 | + model.eval() |
| 68 | + load_ckpt(model, hparams['work_dir'], 'model') |
| 69 | + |
| 70 | + if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
| 71 | + self.pe = PitchExtractor().to(self.device) |
| 72 | + load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) |
| 73 | + self.pe.eval() |
| 74 | + |
| 75 | + return model |
| 76 | + |
| 77 | +if __name__ == '__main__': |
| 78 | + |
| 79 | + inp = { |
| 80 | + 'text': '小酒窝长睫毛AP是你最美的记号', |
| 81 | + 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4', |
| 82 | + 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340', |
| 83 | + 'input_type': 'word' |
| 84 | + } # user input: Chinese characters |
| 85 | + |
| 86 | + set_hparams(print_hparams=False) |
| 87 | + |
| 88 | + dev = 'cuda' |
| 89 | + |
| 90 | + infer_ins = DFSInferWrapped(hparams) |
| 91 | + infer_ins.model.to(dev) |
| 92 | + |
| 93 | + with torch.no_grad(): |
| 94 | + inp = infer_ins.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word') |
| 95 | + sample = infer_ins.input_to_batch(inp) |
| 96 | + txt_tokens = sample['txt_tokens'] # [B, T_t] |
| 97 | + spk_id = sample.get('spk_ids') |
| 98 | + |
| 99 | + torch.onnx.export( |
| 100 | + infer_ins.model, |
| 101 | + ( |
| 102 | + txt_tokens.to(dev), |
| 103 | + { |
| 104 | + 'spk_id': spk_id.to(dev), |
| 105 | + 'pitch_midi': sample['pitch_midi'].to(dev), |
| 106 | + 'midi_dur': sample['midi_dur'].to(dev), |
| 107 | + 'is_slur': spk_id.to(dev), |
| 108 | + 'mel2ph': spk_id.to(dev) |
| 109 | + } |
| 110 | + ), |
| 111 | + "singer.onnx", |
| 112 | + # verbose=True, |
| 113 | + input_names=["txt_tokens", "spk_id", |
| 114 | + "pitch_midi", "midi_dur", "is_slur", "mel2ph"], |
| 115 | + dynamic_axes={ |
| 116 | + "txt_tokens": { |
| 117 | + 0: "a", |
| 118 | + 1: "b", |
| 119 | + }, |
| 120 | + "spk_id": { |
| 121 | + 0: "a", |
| 122 | + 1: "b", |
| 123 | + }, |
| 124 | + "pitch_midi": { |
| 125 | + 0: "a", |
| 126 | + 1: "b", |
| 127 | + }, |
| 128 | + "midi_dur": { |
| 129 | + 0: "a", |
| 130 | + 1: "b", |
| 131 | + }, |
| 132 | + "is_slur": { |
| 133 | + 0: "a", |
| 134 | + 1: "b", |
| 135 | + }, |
| 136 | + "mel2ph": { |
| 137 | + 0: "a", |
| 138 | + 1: "b", |
| 139 | + } |
| 140 | + }, |
| 141 | + opset_version=11 |
| 142 | + ) |
| 143 | + |
| 144 | + print("OK") |
0 commit comments