@@ -47,15 +47,19 @@ def forward_model(self, inp):
4747 def build_vocoder (self ):
4848 base_dir = hparams ['vocoder_ckpt' ]
4949 config_path = f'{ base_dir } /config.yaml'
50- ckpt = sorted (glob .glob (f'{ base_dir } /model_ckpt_steps_*.ckpt' ), key =
51- lambda x : int (re .findall (f'{ base_dir } /model_ckpt_steps_(\d+).ckpt' , x .replace ('\\ ' ,'/' ))[0 ]))[- 1 ]
52- print ('| load HifiGAN: ' , ckpt )
53- ckpt_dict = torch .load (ckpt , map_location = "cpu" )
54- config = set_hparams (config_path , global_hparams = False )
55- state = ckpt_dict ["state_dict" ]["model_gen" ]
56- vocoder = HifiGanGenerator (config )
57- vocoder .load_state_dict (state , strict = True )
58- vocoder .remove_weight_norm ()
50+ file_path = sorted (glob .glob (f'{ base_dir } /model_ckpt_steps_*.*' ), key =
51+ lambda x : int (re .findall (f'{ base_dir } /model_ckpt_steps_(\d+).*' , x .replace ('\\ ' ,'/' ))[0 ]))[- 1 ]
52+ print ('| load HifiGAN: ' , file_path )
53+ ext = os .path .splitext (file_path )[- 1 ]
54+ if ext == '.pth' :
55+ vocoder = torch .load (file_path , map_location = "cpu" )
56+ elif ext == '.ckpt' :
57+ ckpt_dict = torch .load (file_path , map_location = "cpu" )
58+ config = set_hparams (config_path , global_hparams = False )
59+ state = ckpt_dict ["state_dict" ]["model_gen" ]
60+ vocoder = HifiGanGenerator (config )
61+ vocoder .load_state_dict (state , strict = True )
62+ vocoder .remove_weight_norm ()
5963 vocoder = vocoder .eval ().to (self .device )
6064 return vocoder
6165
0 commit comments