Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Feature] DPO #434

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import DPODataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import ultra_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import DPO
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'internlm/internlm2-chat-1_8b'
use_varlen_attn = False

# Data
ultra_path = 'HuggingFaceH4/ultrachat_200k'
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3
optim_type = AdamW
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip
warmup_ratio = 0.03

# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = SYSTEM_TEMPLATE.alpaca
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=DPO, # TODO
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'),
beta=0.1)

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
ultra = dict(
type=DPODataset, # TODO
data_path=ultra_path,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=ultra_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length,
use_varlen_attn=use_varlen_attn)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=ultra,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type=LinearLR,
start_factor=1e-5,
by_epoch=True,
begin=0,
end=warmup_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=0.0,
by_epoch=True,
begin=warmup_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template)
]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

# set log processor
log_processor = dict(by_epoch=False)
36 changes: 36 additions & 0 deletions xtuner/dataset/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os

import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset

from xtuner.registry import BUILDER
from .huggingface import process_hf_dataset
from .utils import expand2square


class DPODataset(Dataset):

def __init__(self,
data_path,
tokenizer,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048):
super().__init__()
# TODO
pass

def __len__(self):
# TODO
pass

def __getitem__(self, index):
# TODO
pass
136 changes: 136 additions & 0 deletions xtuner/model/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict

from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from peft import get_peft_model, prepare_model_for_kbit_training
from torch import nn

from xtuner.registry import BUILDER
from .modules import dispatch_modules
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, make_inputs_require_grad,
traverse_dict)


class DPO(BaseModel):

def __init__(self,
llm,
ref_llm=None,
lora=None,
peft_model=None,
use_activation_checkpointing=True,
use_varlen_attn=False):
super().__init__()
with LoadWoInit():
self.llm = self._build_from_cfg_or_module(llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)

if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)

# enable gradient checkpointing for memory efficiency
self.gradient_checkpointing_enable()

if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
lora, ConfigDict):
self.lora = BUILDER.build(lora)
else:
self.lora = lora
self.peft_model = peft_model
self.use_lora = lora is not None
if self.use_lora:
self._prepare_for_lora(peft_model, use_activation_checkpointing)

self._is_init = True
# Determines whether to calculate attention based on the
# seq_len dimension (use_varlen_attn = False) or the actual length of
# the sequence.
self.use_varlen_attn = use_varlen_attn

# TODO: Add ref model and ref model config
self.ref_llm = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref_llm, 也支持 api model


def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()

def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()

def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()

def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()

def _prepare_for_lora(self,
peft_model=None,
use_activation_checkpointing=True):
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if self.lora.target_modules is None:
modules = find_all_linear_names(self.llm)
self.lora.target_modules = modules

self.llm = get_peft_model(self.llm, self.lora)
if peft_model is not None:
_ = load_checkpoint(self, peft_model)

def init_weights(self):
pass

def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError

def forward(self, data, data_samples=None, mode='loss'):

if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError

def _forward(self, data, data_samples=None):

outputs = self.llm(**data)

return outputs

def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict

def compute_loss(self, data, data_samples=None):
# TODO
pass

def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if not self.use_lora:
return state_dict
to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
return OrderedDict(to_return)

def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)