Skip to content

Commit 8b92432

Browse files
committed
Fix internlm bug
1 parent 1d0eba0 commit 8b92432

3 files changed

Lines changed: 18 additions & 8 deletions

File tree

internvl_chat/internvl/train/internvl_chat_finetune.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,12 @@ def main():
489489
logger.info('Loading InternVLChatModel...')
490490
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
491491
config.vision_config.drop_path_rate = model_args.drop_path_rate
492-
if 'internlm' in model_args.model_name_or_path.lower():
492+
if config.llm_config.model_type == 'internlm2':
493493
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
494+
logger.info('Using flash_attention_2 for InternLM')
494495
else:
495496
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
497+
logger.info('Using flash_attention_2 for LLaMA')
496498
config.template = data_args.conv_style
497499
config.select_layer = model_args.vision_select_layer
498500
config.dynamic_image_size = data_args.dynamic_image_size
@@ -510,10 +512,12 @@ def main():
510512
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
511513
logger.info('Loading LLaMA...')
512514
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
513-
if 'internlm' in model_args.llm_path.lower():
515+
if llm_config.model_type == 'internlm2':
514516
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
517+
logger.info('Using flash_attention_2 for InternLM')
515518
else:
516519
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
520+
logger.info('Using flash_attention_2 for LLaMA')
517521
llm = AutoModelForCausalLM.from_pretrained(
518522
model_args.llm_path, torch_dtype=torch.bfloat16,
519523
config=llm_config, trust_remote_code=True)

internvl_chat/internvl/train/internvl_chat_pretrain.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,12 @@ def main():
509509
logger.info('Loading InternVLChatModel...')
510510
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
511511
config.vision_config.drop_path_rate = model_args.drop_path_rate
512-
if 'internlm' in model_args.model_name_or_path.lower():
512+
if config.llm_config.model_type == 'internlm2':
513513
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
514+
logger.info('Using flash_attention_2 for InternLM')
514515
else:
515516
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
517+
logger.info('Using flash_attention_2 for LLaMA')
516518
config.template = data_args.conv_style
517519
config.select_layer = model_args.vision_select_layer
518520
config.dynamic_image_size = data_args.dynamic_image_size
@@ -530,10 +532,12 @@ def main():
530532
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
531533
logger.info('Loading LLaMA...')
532534
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
533-
if 'internlm' in model_args.llm_path.lower():
535+
if llm_config.model_type == 'internlm2':
534536
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
537+
logger.info('Using flash_attention_2 for InternLM')
535538
else:
536539
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
540+
logger.info('Using flash_attention_2 for LLaMA')
537541
llm = AutoModelForCausalLM.from_pretrained(
538542
model_args.llm_path, torch_dtype=torch.bfloat16,
539543
config=llm_config, trust_remote_code=True)

internvl_chat/tools/replace_llm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from internvl.model.internvl_chat import InternVLChatModel
5-
from transformers import LlamaForCausalLM, LlamaTokenizer
5+
from transformers import AutoModel, AutoTokenizer
66

77
argparse = argparse.ArgumentParser()
88
argparse.add_argument('model_path', type=str, default='')
@@ -13,10 +13,12 @@
1313
if args.model_path[-1] == '/':
1414
args.model_path = args.model_path[:-1]
1515

16-
model = InternVLChatModel.from_pretrained(args.model_path)
16+
model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
1717

18-
llm = LlamaForCausalLM.from_pretrained(args.llm_path)
19-
tokenizer = LlamaTokenizer.from_pretrained(args.llm_path)
18+
llm = AutoModel.from_pretrained(
19+
args.llm_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
20+
tokenizer = AutoTokenizer.from_pretrained(
21+
args.llm_path, trust_remote_code=True)
2022
model.language_model = llm
2123
model.config.llm_config = llm.config
2224
model.to(torch.bfloat16)

0 commit comments

Comments
 (0)