Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Integrated Training and Inference -- Part 1 #532

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

pppppM
Copy link
Collaborator

@pppppM pppppM commented Mar 29, 2024

加载模型 & Chat 用例:xtuner/model/auto.py

训练 alpaca

# 把 alpaca 数据集转换为 openai 格式 json
python xtuner/tools/convert_dataset.py tatsu-lab/alpaca alpaca --save-dir converted_alpaca
xtuner train xtuner/configs/internlm/internlm2_chat_1_8b/example.py

HIT-cwh and others added 4 commits March 29, 2024 18:10
* support sequence

* add configs

* add sp example to custom dataset

* WIP

* add dispatch utils

* delete useless codes

* move xtuner/engine/sequence_parallel to xtuner/parallel/sequence

* fix lint

* fix lint

* add init_dist to xtuner and add trust_remote_code=True to AutoConfig

* add internlm2 custom_dataset sp4 config

* Sequence Parallel doc V1

* Sequence Parallel doc V1

* Sequence Parallel doc V1

* fix bugs in llama_varlen_attn_forward

* rename indexes to position_ids

* add attn_implementation to config

* delete useless codes

* fix lint

* refine default_collate_fn

* refine doc

* refine doc

* refine doc

* delete replace_internlm2_rote

* add repeat_kv_bshd

* fix apply_rotary_pos_emb bug

* add enable_sequence_parallel flag

* refine doc

* assert {'input_ids', 'labels'}.issubset(dataset.column_names)

* refine doc
xtuner/model/base.py Show resolved Hide resolved
xtuner/model/base.py Outdated Show resolved Hide resolved
xtuner/model/base.py Outdated Show resolved Hide resolved
xtuner/model/base.py Outdated Show resolved Hide resolved
xtuner/model/base.py Show resolved Hide resolved
xtuner/model/text/finetune.py Show resolved Hide resolved
xtuner/dataset/hybrid/collate.py Outdated Show resolved Hide resolved
attn_kwargs = cls._flash_attn_kwargs(config)
kwargs.update(attn_kwargs)

if torch.cuda.is_bf16_supported():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样写的话,用户是不是没法通过配置或者输入参数修改模型类型?

return model

@staticmethod
def _flash_attn_kwargs(config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果用户自己新曾了一个 llm,这个字段应该如何修改?或者说用户如何知道要修改?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个操作主要是为了保证 attn_mask shape 的正确性 (flash_attn, sdpa 和普通 attn 的attn_mask可能不同)。
感觉之后可以把 _built_in_flash_attn_1 _built_in_flash_attn_2 放到一个别的什么地方,之后出一个文档讲一下新增模型需要考虑的东西

xtuner/model/text/finetune.py Show resolved Hide resolved
xtuner/dataset/hybrid/dataset.py Outdated Show resolved Hide resolved
from pydantic import BaseModel


class SampleParams(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个对象可以在配置里面修改吗?同时要考虑在评测时候不同数据这个参数不一样。需要在评测时候实时传给 model

xtuner/model/auto.py Outdated Show resolved Hide resolved
checkpoint: str,
config: Optional[str] = None,
from_hub: bool = False):
config = Config.fromfile(config)
Copy link
Collaborator

@LZHgrla LZHgrla Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方,是不是得配合着下面的if-else分支针对 config 是否为 None 做个判断?

xtuner/model/auto.py Outdated Show resolved Hide resolved
@HIT-cwh
Copy link
Collaborator

HIT-cwh commented Apr 8, 2024

加载模型 & Chat 用例:xtuner/model/auto.py

训练 alpaca

# 把 alpaca 数据集转换为 openai 格式 json
python xtuner/tools/convert_dataset.py tatsu-lab/alpaca alpaca --save-dir converted_alpaca
xtuner train xtuner/configs/internlm/internlm2_chat_1_8b/example.py

如果我要一起训练Alpaca和Alpaca-zh,我是先分别convert之后再用ConcatDataset还是一起convert

position_ids.append(torch.arange(chunk_tokens))
position_ids = torch.cat(position_ids, dim=0).unsqueeze(0)

from mmengine import MessageHub
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码位置

xtuner/dataset/hybrid/_pack.py Outdated Show resolved Hide resolved
def main():
args = parse_args()

dataset = load_dataset(path=args.path)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load 方式有很多

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持老用户输入 config,将数据转换成新的

else:
raise RuntimeError

model: BaseAlgorithm = BUILDER.build(config.model)
Copy link
Collaborator

@LZHgrla LZHgrla Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一步会自动下载未finetune的模型,应该得想办法避免

assert eos_token_ids is not None, \
'Please set eos_token for Qwen tokenizer!'
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
eos_token_ids = tokenizer.eos_token_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 if 和下面的 else 有啥区别吗?


shard = converted.select(range(begin, end)).to_list()
with open(save_path, 'w') as f:
json.dump(shard, f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
json.dump(shard, f)
json.dump(shard, f, indent=2)

chat_template: Union[Dict, ChatTemplate],
sample_ratio: Union[float, List[float]] = 1.0,
max_length: int = 2048,
pack_to_max_length: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加 shuffle_before_pack 参数?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在默认就是 shuffle before pack,会有场景需要 pack 前不 shuffle 么?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 pretrain 场景,同一上下文的数据往往是相连的,会有人想要它们相邻。

Comment on lines +190 to +195
if isinstance(sample_ratio, (list, tuple)):
if len(sample_ratio) != len(data_files):
raise ValueError('The length of `sample_ratio`'
f'({len(sample_ratio)}) should be the same '
'as the length of `data_files`'
f'({len(data_files)})')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_files 为 None,使用 data_dir 来传数据的时候,这个地方会报错。考虑在此之前就把 data_dir 转换成 data_files

return dataset

def filter_non_labels_data(self, dataset: List[dict]) -> List[dict]:
"""Filter the data which all labels are ignore.
Copy link
Collaborator

@LZHgrla LZHgrla Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Filter the data which all labels are ignore.
"""Filter out data that do not contain valid labels.

Comment on lines +447 to +448
f'Filtered {ori_samples - new_samples} samples '
'(all labels are ignore)',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f'Filtered {ori_samples - new_samples} samples '
'(all labels are ignore)',
f'Filtered {ori_samples - new_samples} samples '
'that do not contain valid labels.',

Comment on lines +224 to +227
if torch.cuda.is_bf16_supported():
kwargs.update(torch_dtype=torch.bfloat16)
else:
kwargs.update(torch_dtype=torch.float16)
Copy link
Collaborator

@LZHgrla LZHgrla Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不使用 DeepSpeed,直接使用普通的 amp optimizer,会报错。

# 设为bf16
RuntimeError: "_amp_foreach_non_finite_check_and_unscale_cuda" not implemented for 'BFloat16'
# 设为fp16
ValueError: Attempting to unscale FP16 gradients.

runner.logger.info(f'(ChatHook {position}){answer}')

def before_train(self, runner: Union[Runner, FlexibleRunner]):
runner.logger.info('before_train in EvaluateChatHook.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
runner.logger.info('before_train in EvaluateChatHook.')
runner.logger.info('before_train in ChatHook.')

@LZHgrla
Copy link
Collaborator

LZHgrla commented Apr 16, 2024

无法使用work_dirs保存的config进行训练,目前我是卡在了 TypeError: collate_fn should be a dict or callable object, but got xtuner.model.TextFinetune.dataloader_collate_fn
是否有必要修一下?(好像有点难修,因为这个逻辑是 dataloader 的,应该是封装在 mmengine 内了)

super().__init__()

self.llm = llm
self.llm.cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是 quant 模型,直接 cuda 会有问题?

# PART 2 Model & Tokenizer #
#######################################################################
model = dict(
type=TextFinetune,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入 use_varlen_attn

@HIT-cwh
Copy link
Collaborator

HIT-cwh commented Apr 16, 2024

pr 567 的修改需要同步

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants