Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MLLM #529

Open
wants to merge 132 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
90222e0
refactor llava
hhaAndroid Mar 29, 2024
4dd223e
fix
hhaAndroid Mar 29, 2024
d2428af
fix
hhaAndroid Mar 29, 2024
39add6b
update
hhaAndroid Mar 29, 2024
f154bb9
update
hhaAndroid Mar 29, 2024
2feb0e3
fix ddp
hhaAndroid Apr 1, 2024
cd0a01b
add config
hhaAndroid Apr 1, 2024
36053f6
add config
hhaAndroid Apr 1, 2024
211c33a
add config
hhaAndroid Apr 1, 2024
b38f453
fix disp
hhaAndroid Apr 1, 2024
a868151
fix test
hhaAndroid Apr 1, 2024
7f70c56
add dataset
hhaAndroid Apr 1, 2024
04f0ac1
fix eval dataset
hhaAndroid Apr 1, 2024
43c2bab
Merge branch 'main' into refactor_llava
LZHgrla Apr 2, 2024
8b44a9e
update config
hhaAndroid Apr 3, 2024
3a994dd
Merge branch 'refactor_llava' of github.com:hhaAndroid/xtuner into re…
hhaAndroid Apr 3, 2024
05534c9
fix
hhaAndroid Apr 3, 2024
262b636
fix
hhaAndroid Apr 3, 2024
a22f033
update design
hhaAndroid Apr 7, 2024
94d8bc9
update cfg
hhaAndroid Apr 7, 2024
2efb456
add anyres
hhaAndroid Apr 8, 2024
9f8d2b3
add gqa
hhaAndroid Apr 8, 2024
350f030
fix
hhaAndroid Apr 8, 2024
286b653
remove
hhaAndroid Apr 8, 2024
6dbf0db
add cfg
hhaAndroid Apr 8, 2024
27ce878
add any res
hhaAndroid Apr 9, 2024
cdac294
update
hhaAndroid Apr 9, 2024
e25b70c
fix
hhaAndroid Apr 9, 2024
622250c
fix
hhaAndroid Apr 9, 2024
c450da7
add pretrain
hhaAndroid Apr 9, 2024
3043bb7
add pretrain
hhaAndroid Apr 9, 2024
14f4528
addcomment
hhaAndroid Apr 9, 2024
8bffc5a
fix path
hhaAndroid Apr 9, 2024
bf6e5e5
fix
hhaAndroid Apr 9, 2024
8f7a2f7
fix
hhaAndroid Apr 9, 2024
a151111
update
hhaAndroid Apr 9, 2024
6753fde
fix
hhaAndroid Apr 9, 2024
6f5a66d
fix
hhaAndroid Apr 9, 2024
5ec58f9
fix bug
hhaAndroid Apr 10, 2024
02974d6
token_merge_ratio (#2)
LZHgrla Apr 11, 2024
fa2948a
add config
hhaAndroid Apr 11, 2024
10e24fe
fix bug
hhaAndroid Apr 11, 2024
7d59d82
add mini-geminie
hhaAndroid Apr 12, 2024
6afaa55
fix bug
hhaAndroid Apr 12, 2024
cef1cb7
fix bug
hhaAndroid Apr 12, 2024
5c95d52
fix bug
hhaAndroid Apr 12, 2024
736eba7
fix bug
hhaAndroid Apr 12, 2024
b5ec232
fix bug
hhaAndroid Apr 12, 2024
2b4b353
add finetune
hhaAndroid Apr 12, 2024
f706c46
add mmstar 和 vqav2
hhaAndroid Apr 12, 2024
91b718e
update
hhaAndroid Apr 12, 2024
ca5e7fc
fix bug
hhaAndroid Apr 12, 2024
1131f37
support s2+siglip
hhaAndroid Apr 18, 2024
b9cb7b9
fix
hhaAndroid Apr 18, 2024
c7b0829
Merge branch 'main' of github.com:InternLM/xtuner into refactor_llava
hhaAndroid Apr 18, 2024
c9a3ffc
update config
hhaAndroid Apr 18, 2024
bd01320
add llama3
hhaAndroid Apr 19, 2024
d302134
fix temp
hhaAndroid Apr 19, 2024
b71b237
update temp
hhaAndroid Apr 19, 2024
9e2c2b2
add finetune config
hhaAndroid Apr 19, 2024
a5ac45b
add internvl config
hhaAndroid Apr 19, 2024
e8f26e4
merge
hhaAndroid Apr 23, 2024
9458ad2
updata
hhaAndroid Apr 23, 2024
e79b7aa
add chartqa
hhaAndroid Apr 23, 2024
dc15317
fix
hhaAndroid Apr 23, 2024
d9de303
update
hhaAndroid Apr 25, 2024
82e8901
merge
hhaAndroid Apr 25, 2024
c18a81d
add phi3 pretrain config
hhaAndroid Apr 25, 2024
2d0bad5
update
hhaAndroid Apr 25, 2024
2d43f20
update
hhaAndroid Apr 25, 2024
f6a688b
fix
hhaAndroid Apr 25, 2024
dd0365d
update
hhaAndroid Apr 25, 2024
643c6f6
fix mmmu results
hhaAndroid Apr 25, 2024
797af9a
fix mmmu results
hhaAndroid Apr 25, 2024
e24b6dc
fix mmmu results
hhaAndroid Apr 25, 2024
1ced889
fix mmmu results
hhaAndroid Apr 25, 2024
3594582
update
hhaAndroid Apr 25, 2024
2f4055c
update infovqa
hhaAndroid Apr 25, 2024
666beed
fix
hhaAndroid Apr 26, 2024
13b499c
fix
hhaAndroid Apr 26, 2024
2cbb29b
add any res
hhaAndroid Apr 26, 2024
42d08d7
fix
hhaAndroid Apr 26, 2024
0340daa
fix
hhaAndroid Apr 26, 2024
cb1f29c
fix
hhaAndroid Apr 26, 2024
a19d597
fix
hhaAndroid Apr 26, 2024
7db1352
update file
hhaAndroid Apr 26, 2024
6522375
update file
hhaAndroid Apr 26, 2024
256c0f5
fix
hhaAndroid Apr 26, 2024
1f56691
fix
hhaAndroid Apr 26, 2024
d4ef310
fix
hhaAndroid Apr 26, 2024
7f62009
add config
hhaAndroid Apr 26, 2024
1bd9be1
update
hhaAndroid Apr 26, 2024
f6abf85
add 70b finetune
hhaAndroid Apr 28, 2024
d732b58
add internvl 1.5 pretrain
hhaAndroid Apr 28, 2024
6f8d2fb
add internvl 1.5 finetune
hhaAndroid Apr 28, 2024
ab0b003
update
hhaAndroid Apr 28, 2024
f47d06d
update
hhaAndroid Apr 28, 2024
323dfbb
add layer-wise learning rate (LLDR)
hhaAndroid Apr 29, 2024
e605a73
update config
hhaAndroid Apr 29, 2024
ed1a836
fix
hhaAndroid Apr 29, 2024
f5a1922
update
hhaAndroid Apr 29, 2024
cfd8d4d
fix
hhaAndroid Apr 29, 2024
98e6ac9
update config
hhaAndroid May 6, 2024
8bf0f3e
Merge branch 'main' of github.com:InternLM/xtuner into refactor_llava
hhaAndroid May 6, 2024
1706628
update config
hhaAndroid May 7, 2024
38c8c27
add test
hhaAndroid May 7, 2024
55f01aa
fix
hhaAndroid May 7, 2024
1c5de9d
add allava
hhaAndroid May 10, 2024
c361adc
fix
hhaAndroid May 10, 2024
2409629
add finetune
hhaAndroid May 10, 2024
43b27d0
add finetune1
hhaAndroid May 10, 2024
28796c2
add config
hhaAndroid May 11, 2024
aae7b46
updata
hhaAndroid May 11, 2024
85a62f5
updata
hhaAndroid May 11, 2024
f998ae5
update
hhaAndroid May 13, 2024
6278913
add patch select
hhaAndroid May 16, 2024
98f1f58
fix
hhaAndroid May 16, 2024
bd4bf22
update
hhaAndroid May 17, 2024
32723b1
fix bug
hhaAndroid May 17, 2024
1eac3a0
update
hhaAndroid May 20, 2024
5137d54
add new config
hhaAndroid May 20, 2024
e0dbf4f
update config
hhaAndroid May 20, 2024
8939442
update config
hhaAndroid May 21, 2024
5c744e7
update 1.8
hhaAndroid May 30, 2024
a374004
update 1.8
hhaAndroid May 30, 2024
0e8febb
fix
hhaAndroid May 30, 2024
0c09d9c
fix
hhaAndroid May 30, 2024
7b46ab3
fix eval
hhaAndroid May 31, 2024
41da1b1
fix
hhaAndroid May 31, 2024
44b8d1c
fix
hhaAndroid May 31, 2024
673ce12
fix
hhaAndroid May 31, 2024
17ab71c
fix
hhaAndroid May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
SiglipImageProcessor, SiglipVisionModel, BitsAndBytesConfig)
from peft import LoraConfig
from xtuner.dataset import LLaVADataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
from xtuner.dataset.samplers import LengthGroupedSampler
from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
from xtuner.model import LLaVAModel
from xtuner.utils import PROMPT_TEMPLATE
from xtuner.dataset.evaluation import MMELLaVADataset, MultipleChoiceLLaVADataset
from xtuner.dataset import ConcatDataset
from xtuner.engine.runner import TrainLoop, ValLoop, TestLoop
from mmengine.dataset import DefaultSampler
import torch
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
llm_name_or_path = 'microsoft/phi-2'
visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
# Specify the pretrained pth
pretrained_pth = 'work_dirs/llava_phi2_2_7b_siglip_so400m_p14_384_e1_gpu8_pretrain/iter_2181.pth'

# Data
data_root = '/mnt/petrelfs/share_data/huanghaian/llava_data/'
data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
image_folder = data_root + 'llava_images'
prompt_template = PROMPT_TEMPLATE.vicuna
max_length = int(2048 - (384 // 14) ** 2)

# Scheduler & Optimizer
batch_size = 16 # per_device
accumulative_counts = 1
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
lr = 2e-5
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = ''
evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']

#######################################################################
# PART 2 Model & Tokenizer & Image Processor #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
padding_side='right')

image_processor = dict(
type=SiglipImageProcessor.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path,
trust_remote_code=True)

model = dict(
type=LLaVAModel,
freeze_llm=True,
freeze_visual_encoder=True,
pretrained_pth=pretrained_pth,
tokenizer=tokenizer,
template=prompt_template,
image_processor=image_processor,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
llm_lora=dict(
type=LoraConfig,
r=512,
lora_alpha=256,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM'),
visual_encoder=dict(
type=SiglipVisionModel.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
llava_dataset = dict(
type=LLaVADataset,
offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/phi2_2_7b_finetune',
data_path=data_path,
image_folder=image_folder,
tokenizer=tokenizer,
image_processor=image_processor,
dataset_map_fn=llava_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
pad_image_to_square=True)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
pin_memory=True,
dataset=llava_dataset,
sampler=dict(
type=LengthGroupedSampler,
length_property='modality_length',
per_device_batch_size=batch_size * accumulative_counts),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs, val_interval=save_steps)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
image_processor=image_processor,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
evaluation_images=evaluation_images,
system=SYSTEM,
prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
save_optimizer=False, # can save disk memory mmengine >=0.10.3
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)

# ==================== val and test cfg =======================
val_dataset = [
dict(
type=MMELLaVADataset,
data_file='/mnt/petrelfs/huanghaian/code/xtuner/LMUData/MME.tsv',
image_folder='/mnt/petrelfs/share_data/duanhaodong/data/mme/MME_Benchmark_release',
prompt_template=PROMPT_TEMPLATE.vicuna,
tokenizer=tokenizer,
image_processor=image_processor,
pad_image_to_square=True),
# dict(
# type=MultipleChoiceLLaVADataset,
# data_file='/mnt/petrelfs/huanghaian/code/xtuner/LMUData/MMBench_DEV_EN.tsv',
# prompt_template=PROMPT_TEMPLATE.vicuna,
# tokenizer=tokenizer,
# image_processor=image_processor,
# pad_image_to_square=True)
]

test_dataset = [
dict(
type=MMELLaVADataset,
data_file='/mnt/petrelfs/huanghaian/code/xtuner/LMUData/MME.tsv',
image_folder='/mnt/petrelfs/share_data/duanhaodong/data/mme/MME_Benchmark_release',
prompt_template=PROMPT_TEMPLATE.vicuna,
tokenizer=tokenizer,
image_processor=image_processor,
pad_image_to_square=True),
dict(
type=MultipleChoiceLLaVADataset,
data_file='/mnt/petrelfs/huanghaian/code/xtuner/LMUData/MMBench_DEV_EN.tsv',
prompt_template=PROMPT_TEMPLATE.vicuna,
tokenizer=tokenizer,
image_processor=image_processor,
pad_image_to_square=True)
]

# TODO: We are not currently using val_evaluator
# Don't support num_workers > 0
val_dataloader = dict(
batch_size=1,
num_workers=0,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(type=ConcatDataset, datasets=val_dataset))
val_evaluator = dict()
val_cfg = dict(type=ValLoop)

# TODO: We are not currently using test_evaluator
test_dataloader = dict(
batch_size=1,
num_workers=0,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(type=ConcatDataset, datasets=test_dataset))
test_evaluator = val_evaluator
test_cfg = dict(type=TestLoop, select_metric='first')