Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add HFCheckpointHook to auto save hf model after the whole training phase #621

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dataset_info_hook import DatasetInfoHook
from .evaluate_chat_hook import EvaluateChatHook
from .hf_checkpoint_hook import HFCheckpointHook
from .throughput_hook import ThroughputHook
from .varlen_attn_args_to_messagehub_hook import VarlenAttnArgsToMessageHubHook

__all__ = [
'EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook',
'VarlenAttnArgsToMessageHubHook'
'VarlenAttnArgsToMessageHubHook', 'HFCheckpointHook'
]
53 changes: 53 additions & 0 deletions xtuner/engine/hooks/hf_checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
from typing import Optional, Union

import torch.distributed as dist
from mmengine._strategy import DeepSpeedStrategy
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import FlexibleRunner

DATA_BATCH = Optional[Union[dict, tuple, list]]


class HFCheckpointHook(Hook):

priority = 95 # lower than CheckpointHook in MMEngine

def __init__(self, out_dir: Optional[Union[str, Path]] = None) -> None:
self.out_dir = out_dir

def after_run(self, runner) -> None:
assert isinstance(runner,
FlexibleRunner), 'Runner should be `FlexibleRunner`'
assert isinstance(
runner.strategy,
DeepSpeedStrategy), 'Strategy should be `DeepSpeedStrategy`'

if self.out_dir is None:
self.out_dir = osp.join(runner.work_dir, 'hf_model')

wrapped_model = runner.strategy.model
if wrapped_model.zero_optimization_partition_weights():
assert wrapped_model.zero_gather_16bit_weights_on_model_save(), \
('Please set `gather_16bit_weights_on_model_save=True` '
'in your DeepSpeed config.')
state_dict = wrapped_model._zero3_consolidated_16bit_state_dict()
else:
state_dict = wrapped_model.module_state_dict(
exclude_frozen_parameters=runner.strategy.
exclude_frozen_parameters)

model = runner.model
if is_model_wrapper(model):
model = model.module
llm = model.llm
if (not dist.is_initialized()) or dist.get_rank() == 0:
# keys in state_dict are prefixed with 'llm.'
keys = list(state_dict.keys())
for k in keys:
val = state_dict.pop(k)
state_dict[k[4:]] = val
llm.save_pretrained(self.out_dir, state_dict=state_dict)