Skip to content

Commit 3e696e3

Browse files
authored
fix a device recognization bug
1 parent edd444d commit 3e696e3

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

modules/diffsinger_midi/fs2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,20 @@ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
145145
if f0 is not None:
146146
delta_l = nframes - f0.size(1)
147147
if delta_l > 0:
148-
f0 = torch.cat((f0,torch.FloatTensor([[x[-1]] * delta_l for x in f0]).to(f0.get_device())),1)
148+
f0 = torch.cat((f0,torch.FloatTensor([[x[-1]] * delta_l for x in f0]).to(f0.device)),1)
149149
f0 = f0[:,:nframes]
150150
if uv is not None:
151151
delta_l = nframes - uv.size(1)
152152
if delta_l > 0:
153-
uv = torch.cat((uv,torch.FloatTensor([[x[-1]] * delta_l for x in uv]).to(uv.get_device())),1)
153+
uv = torch.cat((uv,torch.FloatTensor([[x[-1]] * delta_l for x in uv]).to(uv.device)),1)
154154
uv = uv[:,:nframes]
155155
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
156156

157157
if hparams['use_energy_embed']:
158158
if energy is not None:
159159
delta_l = nframes - energy.size(1)
160160
if delta_l > 0:
161-
energy = torch.cat((energy,torch.FloatTensor([[x[-1]] * delta_l for x in energy]).to(energy.get_device())),1)
161+
energy = torch.cat((energy,torch.FloatTensor([[x[-1]] * delta_l for x in energy]).to(energy.device)),1)
162162
energy = energy[:,:nframes]
163163
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
164164

0 commit comments

Comments
 (0)