@@ -489,10 +489,12 @@ def main():
489489 logger .info ('Loading InternVLChatModel...' )
490490 config = InternVLChatConfig .from_pretrained (model_args .model_name_or_path )
491491 config .vision_config .drop_path_rate = model_args .drop_path_rate
492- if 'internlm' in model_args . model_name_or_path . lower () :
492+ if config . llm_config . model_type == 'internlm2' :
493493 config .llm_config .attn_implementation = 'flash_attention_2' # for InternLM
494+ logger .info ('Using flash_attention_2 for InternLM' )
494495 else :
495496 config .llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
497+ logger .info ('Using flash_attention_2 for LLaMA' )
496498 config .template = data_args .conv_style
497499 config .select_layer = model_args .vision_select_layer
498500 config .dynamic_image_size = data_args .dynamic_image_size
@@ -510,10 +512,12 @@ def main():
510512 model_args .vision_path , torch_dtype = torch .bfloat16 , config = vision_config )
511513 logger .info ('Loading LLaMA...' )
512514 llm_config = AutoConfig .from_pretrained (model_args .llm_path , trust_remote_code = True )
513- if 'internlm' in model_args . llm_path . lower () :
515+ if llm_config . model_type == 'internlm2' :
514516 llm_config .attn_implementation = 'flash_attention_2' # for InternLM
517+ logger .info ('Using flash_attention_2 for InternLM' )
515518 else :
516519 llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
520+ logger .info ('Using flash_attention_2 for LLaMA' )
517521 llm = AutoModelForCausalLM .from_pretrained (
518522 model_args .llm_path , torch_dtype = torch .bfloat16 ,
519523 config = llm_config , trust_remote_code = True )
0 commit comments