99from utils .audio import save_wav
1010from utils .hparams import set_hparams , hparams
1111
12+ import acoustic .dfs_models as adm
13+
1214import torch
15+ import numpy as np
1316
1417from utils .text_encoder import TokenTextEncoder
1518from usr .diffsinger_task import DIFF_DECODERS
2629]
2730
2831
29- class GaussianDiffusionWrap (GaussianDiffusion ):
30- def forward (self , txt_tokens , mel2ph ,
32+ class GaussianDiffusionWrap (adm . GaussianDiffusionFS ):
33+ def forward (self , txt_tokens ,
3134 # Wrapped Arguments
3235 spk_id ,
3336 pitch_midi ,
3437 midi_dur ,
3538 is_slur ,
39+ mel2ph ,
3640 ):
3741
42+ print (f"txt_tokens: { txt_tokens } " )
43+ print (f"spk_id: { spk_id } " )
44+ print (f"pitch_midi: { pitch_midi } " )
45+ print (f"midi_dur: { midi_dur } " )
46+ print (f"is_slur: { is_slur } " )
47+ print (f"mel2ph: { mel2ph } " )
48+
49+ if (mel2ph [0 ].item () == 0 ):
50+ mel2ph = None
51+ else :
52+ mel2ph = mel2ph [1 ].item ()
53+
3854 if (torch .numel (txt_tokens ) == 0 ):
3955 txt_tokens = None
40- if (torch .numel (mel2ph ) == 0 ):
41- mel2ph = None
4256 if (torch .numel (spk_id ) == 0 ):
4357 spk_id = None
4458 if (torch .numel (pitch_midi ) == 0 ):
@@ -57,7 +71,8 @@ class DFSInferWrapped(e2e.DiffSingerE2EInfer):
5771 def build_model (self ):
5872 model = GaussianDiffusionWrap (
5973 phone_encoder = self .ph_encoder ,
60- out_dims = hparams ['audio_num_mel_bins' ], denoise_fn = DIFF_DECODERS [hparams ['diff_decoder_type' ]](hparams ),
74+ out_dims = hparams ['audio_num_mel_bins' ], denoise_fn = DIFF_DECODERS [hparams ['diff_decoder_type' ]](
75+ hparams ),
6176 timesteps = hparams ['timesteps' ],
6277 K_step = hparams ['K_step' ],
6378 loss_type = hparams ['diff_loss_type' ],
@@ -71,9 +86,33 @@ def build_model(self):
7186 self .pe = PitchExtractor ().to (self .device )
7287 load_ckpt (self .pe , hparams ['pe_ckpt' ], 'model' , strict = True )
7388 self .pe .eval ()
74-
89+
7590 return model
7691
92+
93+ class DFSInferWrapped2 (e2e .DiffSingerE2EInfer ):
94+ def build_model (self ):
95+ model = adm .GaussianDiffusionDenoise (
96+ phone_encoder = self .ph_encoder ,
97+ out_dims = hparams ['audio_num_mel_bins' ], denoise_fn = DIFF_DECODERS [hparams ['diff_decoder_type' ]](
98+ hparams ),
99+ timesteps = hparams ['timesteps' ],
100+ K_step = hparams ['K_step' ],
101+ loss_type = hparams ['diff_loss_type' ],
102+ spec_min = hparams ['spec_min' ], spec_max = hparams ['spec_max' ],
103+ )
104+
105+ model .eval ()
106+ load_ckpt (model , hparams ['work_dir' ], 'model' )
107+
108+ if hparams .get ('pe_enable' ) is not None and hparams ['pe_enable' ]:
109+ self .pe = PitchExtractor ().to (self .device )
110+ load_ckpt (self .pe , hparams ['pe_ckpt' ], 'model' , strict = True )
111+ self .pe .eval ()
112+
113+ return model
114+
115+
77116if __name__ == '__main__' :
78117
79118 inp = {
@@ -90,25 +129,43 @@ def build_model(self):
90129 infer_ins = DFSInferWrapped (hparams )
91130 infer_ins .model .to (dev )
92131
132+ infer_ins2 = DFSInferWrapped2 (hparams )
133+ infer_ins2 .model .to (dev )
134+
135+ adm .device = dev
136+
93137 with torch .no_grad ():
94- inp = infer_ins .preprocess_input (inp , input_type = inp ['input_type' ] if inp .get ('input_type' ) else 'word' )
138+ inp = infer_ins .preprocess_input (
139+ inp , input_type = inp ['input_type' ] if inp .get ('input_type' ) else 'word' )
95140 sample = infer_ins .input_to_batch (inp )
96141 txt_tokens = sample ['txt_tokens' ] # [B, T_t]
97142 spk_id = sample .get ('spk_ids' )
98143
144+ print (txt_tokens )
145+ print (spk_id )
146+ print (sample ['pitch_midi' ])
147+ print (sample ['midi_dur' ])
148+ print (sample ['is_slur' ])
149+ print (sample ['mel2ph' ])
150+
99151 torch .onnx .export (
100152 infer_ins .model ,
101153 (
102154 txt_tokens .to (dev ),
103- {
104- 'spk_id' : spk_id .to (dev ),
105- 'pitch_midi' : sample ['pitch_midi' ].to (dev ),
106- 'midi_dur' : sample ['midi_dur' ].to (dev ),
107- 'is_slur' : spk_id .to (dev ),
108- 'mel2ph' : spk_id .to (dev )
109- }
155+ # {
156+ # 'spk_id': spk_id.to(dev),
157+ # 'pitch_midi': sample['pitch_midi'].to(dev),
158+ # 'midi_dur': sample['midi_dur'].to(dev),
159+ # 'is_slur': spk_id.to(dev),
160+ # 'mel2ph': spk_id.to(dev)
161+ # }
162+ spk_id .to (dev ),
163+ sample ['pitch_midi' ].to (dev ),
164+ sample ['midi_dur' ].to (dev ),
165+ sample ['is_slur' ].to (dev ),
166+ torch .from_numpy (np .array ([0 , 0 ]).astype (np .int64 )).to (dev ),
110167 ),
111- "singer .onnx" ,
168+ "singer_fs .onnx" ,
112169 # verbose=True,
113170 input_names = ["txt_tokens" , "spk_id" ,
114171 "pitch_midi" , "midi_dur" , "is_slur" , "mel2ph" ],
@@ -132,10 +189,41 @@ def build_model(self):
132189 "is_slur" : {
133190 0 : "a" ,
134191 1 : "b" ,
192+ }
193+ },
194+ opset_version = 11
195+ )
196+
197+ # fs_res = infer_ins.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
198+ # pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
199+ # is_slur=sample['is_slur'], mel2ph=sample['mel2ph'])
200+ # cond = fs_res.transpose(1, 2)
201+ # shape = (cond.shape[0], 1, infer_ins.model.mel_bins, cond.shape[2])
202+ # x = torch.randn(shape, device=dev)
203+
204+ torch .onnx .export (
205+ infer_ins2 .model ,
206+ (
207+ torch .rand (1 , 1 , 80 , 967 ).to (dev ),
208+ torch .full ((1 ,), 1 , dtype = torch .long ).to (dev ),
209+ torch .rand (1 , 256 , 967 ).to (dev ),
210+ ),
211+ "singer_denoise.onnx" ,
212+ input_names = [
213+ "x" ,
214+ "t" ,
215+ "cond" ,
216+ ],
217+ dynamic_axes = {
218+ "x" : {
219+ 0 : "batch_size" ,
220+ 2 : "num_mel_bin" ,
221+ 3 : "frames" ,
135222 },
136- "mel2ph" : {
137- 0 : "a" ,
138- 1 : "b" ,
223+ "cond" : {
224+ 0 : "batch_size" ,
225+ 1 : "what" ,
226+ 2 : "frames" ,
139227 }
140228 },
141229 opset_version = 11
0 commit comments