Skip to content

Commit c37c561

Browse files
committed
replace torch.chunk with torch.split
1 parent 75a6828 commit c37c561

5 files changed

Lines changed: 13 additions & 8 deletions

File tree

acoustic/dfs_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
1616
class GaussianDiffusionDenoise(GaussianDiffusion):
1717
def forward(self, x, t, cond):
1818
x = self.p_sample(x, t, cond)
19-
return [x, cond]
19+
return x

onnx_export_singer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, txt_tokens,
5050
mel2ph = None
5151
else:
5252
mel2ph = mel2ph[1].item()
53-
53+
5454
if (torch.numel(txt_tokens) == 0):
5555
txt_tokens = None
5656
if (torch.numel(spk_id) == 0):
@@ -166,7 +166,7 @@ def build_model(self):
166166
torch.from_numpy(np.array([0, 0]).astype(np.int64)).to(dev),
167167
),
168168
"singer_fs.onnx",
169-
# verbose=True,
169+
verbose=True,
170170
input_names=["txt_tokens", "spk_id",
171171
"pitch_midi", "midi_dur", "is_slur", "mel2ph"],
172172
dynamic_axes={
@@ -209,6 +209,7 @@ def build_model(self):
209209
torch.rand(1, 256, 967).to(dev),
210210
),
211211
"singer_denoise.onnx",
212+
verbose=True,
212213
input_names=[
213214
"x",
214215
"t",
@@ -224,7 +225,7 @@ def build_model(self):
224225
0: "batch_size",
225226
1: "what",
226227
2: "frames",
227-
}
228+
},
228229
},
229230
opset_version=11
230231
)

onnx_test_singer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def forward_model(self, inp):
142142
}
143143
)
144144
x = torch.from_numpy(res2[0])
145-
cond = torch.from_numpy(res2[1])
146145

147146
x = x[:, 0].transpose(1, 2)
148147

onnx_test_singer_gpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def forward_model(self, inp):
142142
}
143143
)
144144
x = torch.from_numpy(res2[0])
145-
cond = torch.from_numpy(res2[1])
146145

147146
x = x[:, 0].transpose(1, 2)
148147

usr/diff/net.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,17 @@ def forward(self, x, conditioner, diffusion_step):
7070

7171
y = self.dilated_conv(y) + conditioner
7272

73-
gate, filter = torch.chunk(y, 2, dim=1)
73+
# gate, filter = torch.chunk(y, 2, dim=1)
74+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
75+
gate, filter = torch.split(y, torch.div(y.shape[1], 2), dim=1)
76+
7477
y = torch.sigmoid(gate) * torch.tanh(filter)
7578

7679
y = self.output_projection(y)
77-
residual, skip = torch.chunk(y, 2, dim=1)
80+
# residual, skip = torch.chunk(y, 2, dim=1)
81+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
82+
residual, skip = torch.split(y, torch.div(y.shape[1], 2), dim=1)
83+
7884
return (x + residual) / sqrt(2.0), skip
7985

8086

0 commit comments

Comments
 (0)