1414from vocoders .vocoder_utils import denoise
1515
1616
17- def load_model (config_path , checkpoint_path ):
17+ def load_model (config_path , file_path ):
1818 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
19- ckpt_dict = torch .load (checkpoint_path , map_location = "cpu" )
20- if '.yaml' in config_path :
21- config = set_hparams (config_path , global_hparams = False )
22- state = ckpt_dict ["state_dict" ]["model_gen" ]
23- elif '.json' in config_path :
24- config = json .load (open (config_path , 'r' , encoding = 'utf-8' ))
25- state = ckpt_dict ["generator" ]
26-
27- model = HifiGanGenerator (config )
28- model .load_state_dict (state , strict = True )
29- model .remove_weight_norm ()
19+ ext = os .path .splitext (file_path )[- 1 ]
20+ if ext == '.pth' :
21+ if '.yaml' in config_path :
22+ config = set_hparams (config_path , global_hparams = False )
23+ elif '.json' in config_path :
24+ config = json .load (open (config_path , 'r' , encoding = 'utf-8' ))
25+ model = torch .load (file_path , map_location = "cpu" )
26+ elif ext == '.ckpt' :
27+ ckpt_dict = torch .load (file_path , map_location = "cpu" )
28+ if '.yaml' in config_path :
29+ config = set_hparams (config_path , global_hparams = False )
30+ state = ckpt_dict ["state_dict" ]["model_gen" ]
31+ elif '.json' in config_path :
32+ config = json .load (open (config_path , 'r' , encoding = 'utf-8' ))
33+ state = ckpt_dict ["generator" ]
34+ model = HifiGanGenerator (config )
35+ model .load_state_dict (state , strict = True )
36+ model .remove_weight_norm ()
3037 model = model .eval ().to (device )
31- print (f"| Loaded model parameters from { checkpoint_path } ." )
38+ print (f"| Loaded model parameters from { file_path } ." )
3239 print (f"| HifiGAN device: { device } ." )
3340 return model , config , device
3441
@@ -42,15 +49,15 @@ def __init__(self):
4249 base_dir = hparams ['vocoder_ckpt' ]
4350 config_path = f'{ base_dir } /config.yaml'
4451 if os .path .exists (config_path ):
45- ckpt = sorted (glob .glob (f'{ base_dir } /model_ckpt_steps_*.ckpt ' ), key =
46- lambda x : int (re .findall (f'{ base_dir } /model_ckpt_steps_(\d+).ckpt ' , x .replace ('\\ ' ,'/' ))[0 ]))[- 1 ]
47- print ('| load HifiGAN: ' , ckpt )
48- self .model , self .config , self .device = load_model (config_path = config_path , checkpoint_path = ckpt )
52+ file_path = sorted (glob .glob (f'{ base_dir } /model_ckpt_steps_*.* ' ), key =
53+ lambda x : int (re .findall (f'{ base_dir } /model_ckpt_steps_(\d+).* ' , x .replace ('\\ ' ,'/' ))[0 ]))[- 1 ]
54+ print ('| load HifiGAN: ' , file_path )
55+ self .model , self .config , self .device = load_model (config_path = config_path , file_path = file_path )
4956 else :
5057 config_path = f'{ base_dir } /config.json'
5158 ckpt = f'{ base_dir } /generator_v1'
5259 if os .path .exists (config_path ):
53- self .model , self .config , self .device = load_model (config_path = config_path , checkpoint_path = ckpt )
60+ self .model , self .config , self .device = load_model (config_path = config_path , file_path = file_path )
5461
5562 def spec2wav (self , mel , ** kwargs ):
5663 device = self .device
0 commit comments