Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plug py moudle : Serial merge pt weights to form pth, transform to hf format #476

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e2992e3
add qwen_1.8b config
xiaohangguo Dec 1, 2023
ed520ba
fix .gitignore
xiaohangguo Dec 1, 2023
6c011c7
Merge branch 'main' into qwen_1_8b_config
xiaohangguo Dec 1, 2023
d058d9e
deepseek config & deepseek templates
xiaohangguo Dec 1, 2023
9b08d54
add qwen_72b config
xiaohangguo Dec 1, 2023
4460f67
remove superfluous config
xiaohangguo Dec 1, 2023
71fb36b
revert qwen-1.8b & deepseekcoder config
xiaohangguo Dec 1, 2023
4d595c9
Merge branch 'main' into qwen_72b_config
xiaohangguo Dec 4, 2023
f1209ed
[Fix] qwen 72b bos & stops word
xiaohangguo Dec 11, 2023
0f15318
Merge branch 'qwen_72b_config' of https://github.com/xiaohangguo/xtun…
xiaohangguo Dec 11, 2023
d3f79fb
fix flake8 conflict
xiaohangguo Dec 11, 2023
3d1f473
fix qwen_1_8b_chat eos & stops word -> <|im_end|>
xiaohangguo Dec 11, 2023
569f5bf
Merge branch 'main' into qwen_72b_config
xiaohangguo Dec 12, 2023
796222a
update qwen_72b warmup
xiaohangguo Dec 21, 2023
893cbde
Merge branch 'main' into qwen_72b_config
xiaohangguo Dec 21, 2023
8f7e2de
Delete .gitignore
xiaohangguo Dec 22, 2023
bdf0231
Restore .gitignore to leave it as is
xiaohangguo Dec 22, 2023
db6271c
Merge branch 'main' of https://github.com/xiaohangguo/xtuner into ens…
xiaohangguo Dec 29, 2023
b3bb678
Packer update init
xiaohangguo Dec 29, 2023
75bf0f3
[Bug] Not as expected pack
xiaohangguo Jan 15, 2024
e34ee7a
检查其切割点是否位于input_ids部分
xiaohangguo Mar 5, 2024
0664990
串行合并weight 形成pth 然后转化成hf格式
xiaohangguo Mar 14, 2024
c2620e4
Merge branch 'main' into convert_pt
xiaohangguo Mar 14, 2024
c135233
update
xiaohangguo Mar 14, 2024
0286369
Merge branch 'main' into convert_pt
xiaohangguo Mar 14, 2024
91b3e08
Merge branch 'main' into convert_pt
xiaohangguo Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 73 additions & 0 deletions xtuner/tools/serial_merge_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import os
import subprocess

import torch
from tqdm import tqdm


def merge_weights(ckpt_dir, new_ckpt_path):

merged_weights = {}

pt_files = [f for f in os.listdir(ckpt_dir) if f.endswith('.pt')]

for filename in tqdm(pt_files, desc='Merging weights'):

file_path = os.path.join(ckpt_dir, filename)

weights = torch.load(file_path, map_location='cpu')

merged_weights.update(weights)

torch.save(merged_weights, new_ckpt_path)
return new_ckpt_path


def convert_to_hf(config_path, ckpt_path, output_dir):

command = [
'xtuner',
'convert',
'pth_to_hf',
config_path,
ckpt_path,
output_dir,
]

subprocess.run(command, check=True)


def process_weights(ckpt_dir, config_path, output_dir):
# Step 1: Merge weights
new_ckpt_filename = 'merged_model_states.pth'
new_ckpt_path = os.path.join(ckpt_dir, new_ckpt_filename)
merge_weights(ckpt_dir, new_ckpt_path)

# Step 2: Convert to Hugging Face format
convert_to_hf(config_path, new_ckpt_path, output_dir)


def main():
parser = argparse.ArgumentParser(
description='Serial Merge Weights and Convert to Hugging Face Format')
parser.add_argument(
'ckpt_dir',
type=str,
help='The directory where the weight file is located')
parser.add_argument(
'config_path',
type=str,
help='Configuration file path used for training, \
for example :work_dirs/**/epoch_3.pth ,\
The directory is all in the file bf16_zero_pp_rank_*.pt. \
which requires a xtuner convert merge first if it is qlora training.')
parser.add_argument(
'output_dir', type=str, help='Hugging Face model output directory')

args = parser.parse_args()
process_weights(args.ckpt_dir, args.config_path, args.output_dir)


if __name__ == '__main__':
main()