Skip to content

Commit 19355a9

Browse files
committed
add export singer model onnx
1 parent c8835da commit 19355a9

3 files changed

Lines changed: 319 additions & 18 deletions

File tree

acoustic/dfs_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from usr.diff.shallow_diffusion_tts import GaussianDiffusion
2+
3+
import torch
4+
5+
device = 'cpu'
6+
7+
8+
class GaussianDiffusionFS(GaussianDiffusion):
9+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
10+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
11+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
12+
skip_decoder=True, infer=infer, **kwargs)
13+
return ret['decoder_inp']
14+
15+
16+
class GaussianDiffusionDenoise(GaussianDiffusion):
17+
def forward(self, x, t, cond):
18+
x = self.p_sample(x, t, cond)
19+
return [x, cond]

onnx_export_singer.py

Lines changed: 108 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from utils.audio import save_wav
1010
from utils.hparams import set_hparams, hparams
1111

12+
import acoustic.dfs_models as adm
13+
1214
import torch
15+
import numpy as np
1316

1417
from utils.text_encoder import TokenTextEncoder
1518
from usr.diffsinger_task import DIFF_DECODERS
@@ -26,19 +29,30 @@
2629
]
2730

2831

29-
class GaussianDiffusionWrap(GaussianDiffusion):
30-
def forward(self, txt_tokens, mel2ph,
32+
class GaussianDiffusionWrap(adm.GaussianDiffusionFS):
33+
def forward(self, txt_tokens,
3134
# Wrapped Arguments
3235
spk_id,
3336
pitch_midi,
3437
midi_dur,
3538
is_slur,
39+
mel2ph,
3640
):
3741

42+
print(f"txt_tokens: {txt_tokens}")
43+
print(f"spk_id: {spk_id}")
44+
print(f"pitch_midi: {pitch_midi}")
45+
print(f"midi_dur: {midi_dur}")
46+
print(f"is_slur: {is_slur}")
47+
print(f"mel2ph: {mel2ph}")
48+
49+
if (mel2ph[0].item() == 0):
50+
mel2ph = None
51+
else:
52+
mel2ph = mel2ph[1].item()
53+
3854
if (torch.numel(txt_tokens) == 0):
3955
txt_tokens = None
40-
if (torch.numel(mel2ph) == 0):
41-
mel2ph = None
4256
if (torch.numel(spk_id) == 0):
4357
spk_id = None
4458
if (torch.numel(pitch_midi) == 0):
@@ -57,7 +71,8 @@ class DFSInferWrapped(e2e.DiffSingerE2EInfer):
5771
def build_model(self):
5872
model = GaussianDiffusionWrap(
5973
phone_encoder=self.ph_encoder,
60-
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
74+
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](
75+
hparams),
6176
timesteps=hparams['timesteps'],
6277
K_step=hparams['K_step'],
6378
loss_type=hparams['diff_loss_type'],
@@ -71,9 +86,33 @@ def build_model(self):
7186
self.pe = PitchExtractor().to(self.device)
7287
load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
7388
self.pe.eval()
74-
89+
7590
return model
7691

92+
93+
class DFSInferWrapped2(e2e.DiffSingerE2EInfer):
94+
def build_model(self):
95+
model = adm.GaussianDiffusionDenoise(
96+
phone_encoder=self.ph_encoder,
97+
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](
98+
hparams),
99+
timesteps=hparams['timesteps'],
100+
K_step=hparams['K_step'],
101+
loss_type=hparams['diff_loss_type'],
102+
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
103+
)
104+
105+
model.eval()
106+
load_ckpt(model, hparams['work_dir'], 'model')
107+
108+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
109+
self.pe = PitchExtractor().to(self.device)
110+
load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
111+
self.pe.eval()
112+
113+
return model
114+
115+
77116
if __name__ == '__main__':
78117

79118
inp = {
@@ -84,31 +123,51 @@ def build_model(self):
84123
} # user input: Chinese characters
85124

86125
set_hparams(print_hparams=False)
126+
spec_min= torch.FloatTensor(hparams['spec_min'])[None, None, :hparams['keep_bins']]
127+
spec_max= torch.FloatTensor(hparams['spec_max'])[None, None, :hparams['keep_bins']]
87128

88129
dev = 'cuda'
89130

90131
infer_ins = DFSInferWrapped(hparams)
91132
infer_ins.model.to(dev)
92133

134+
infer_ins2 = DFSInferWrapped2(hparams)
135+
infer_ins2.model.to(dev)
136+
137+
adm.device = dev
138+
93139
with torch.no_grad():
94-
inp = infer_ins.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
140+
inp = infer_ins.preprocess_input(
141+
inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
95142
sample = infer_ins.input_to_batch(inp)
96143
txt_tokens = sample['txt_tokens'] # [B, T_t]
97144
spk_id = sample.get('spk_ids')
98145

146+
print(txt_tokens)
147+
print(spk_id)
148+
print(sample['pitch_midi'])
149+
print(sample['midi_dur'])
150+
print(sample['is_slur'])
151+
print(sample['mel2ph'])
152+
99153
torch.onnx.export(
100154
infer_ins.model,
101155
(
102156
txt_tokens.to(dev),
103-
{
104-
'spk_id': spk_id.to(dev),
105-
'pitch_midi': sample['pitch_midi'].to(dev),
106-
'midi_dur': sample['midi_dur'].to(dev),
107-
'is_slur': spk_id.to(dev),
108-
'mel2ph': spk_id.to(dev)
109-
}
157+
# {
158+
# 'spk_id': spk_id.to(dev),
159+
# 'pitch_midi': sample['pitch_midi'].to(dev),
160+
# 'midi_dur': sample['midi_dur'].to(dev),
161+
# 'is_slur': spk_id.to(dev),
162+
# 'mel2ph': spk_id.to(dev)
163+
# }
164+
spk_id.to(dev),
165+
sample['pitch_midi'].to(dev),
166+
sample['midi_dur'].to(dev),
167+
sample['is_slur'].to(dev),
168+
torch.from_numpy(np.array([0, 0]).astype(np.int64)).to(dev),
110169
),
111-
"singer.onnx",
170+
"singer_fs.onnx",
112171
# verbose=True,
113172
input_names=["txt_tokens", "spk_id",
114173
"pitch_midi", "midi_dur", "is_slur", "mel2ph"],
@@ -132,10 +191,41 @@ def build_model(self):
132191
"is_slur": {
133192
0: "a",
134193
1: "b",
194+
}
195+
},
196+
opset_version=11
197+
)
198+
199+
# fs_res = infer_ins.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
200+
# pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
201+
# is_slur=sample['is_slur'], mel2ph=sample['mel2ph'])
202+
# cond = fs_res.transpose(1, 2)
203+
# shape = (cond.shape[0], 1, infer_ins.model.mel_bins, cond.shape[2])
204+
# x = torch.randn(shape, device=dev)
205+
206+
torch.onnx.export(
207+
infer_ins2.model,
208+
(
209+
torch.rand(1, 1, 80, 967).to(dev),
210+
torch.full((1,), 1, dtype=torch.long).to(dev),
211+
torch.rand(1, 256, 967).to(dev),
212+
),
213+
"singer_denoise.onnx",
214+
input_names=[
215+
"x",
216+
"t",
217+
"cond",
218+
],
219+
dynamic_axes={
220+
"x": {
221+
0: "batch_size",
222+
2: "num_mel_bin",
223+
3: "frames",
135224
},
136-
"mel2ph": {
137-
0: "a",
138-
1: "b",
225+
"cond": {
226+
0: "batch_size",
227+
1: "what",
228+
2: "frames",
139229
}
140230
},
141231
opset_version=11

0 commit comments

Comments
 (0)