Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Integrated Training and Inference -- Part 1 #532

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
360391a
[Feature] Support Sequence parallel (#456)
HIT-cwh Mar 22, 2024
5fffd8c
integrated norm chat finetune and inference
pppppM Mar 29, 2024
0f31481
Merge branch 'main' into refactor-llm
pppppM Mar 29, 2024
f1111a9
Merge branch 'main' into refactor-llm
pppppM Mar 29, 2024
7edb76d
remove encoder
pppppM Mar 29, 2024
c5f38d2
add open-source dataset convert tool
pppppM Mar 29, 2024
e0ab003
fix shard count
pppppM Mar 29, 2024
b0e71b1
fix dataset bugs
pppppM Mar 29, 2024
5da532c
add alpaca example
pppppM Mar 29, 2024
5cfd71f
refactored the inheritance hierarchy
pppppM Apr 1, 2024
2e1d238
adjust dir structure
pppppM Apr 1, 2024
046e943
add BaseAlogrithm docstrings
pppppM Apr 1, 2024
68afab4
add dataset docstring
pppppM Apr 2, 2024
cf8e8af
add pack dataset docstrings
pppppM Apr 7, 2024
9dc1142
remove old collate fn
pppppM Apr 7, 2024
662cebb
Merge branch 'main' into refactor-llm
pppppM Apr 7, 2024
7ac1e0f
add new chat hook
pppppM Apr 7, 2024
0ad84f2
add gradient disable interface
pppppM Apr 7, 2024
e85d176
add llava dataset example
pppppM Apr 7, 2024
9f39627
batch_infer is no longer an abstract method
pppppM Apr 7, 2024
c0655a1
support auto model
pppppM Apr 9, 2024
fd9ecca
rename
pppppM Apr 9, 2024
15860f9
update auto model
pppppM Apr 11, 2024
7236d40
refactor dataset
pppppM Apr 15, 2024
4db2955
enhance dataset convert
pppppM Apr 15, 2024
aad9ee3
remove useless code
pppppM Apr 15, 2024
8daafcb
diff files support diff sample ratios
pppppM Apr 15, 2024
df60d91
unified naming
pppppM Apr 15, 2024
6185c9b
Merge branch 'main' of github.com:InternLM/xtuner into refactor-llm
HIT-cwh Apr 16, 2024
ca272bf
support sp in TextFinetune
HIT-cwh Apr 16, 2024
b658d76
Merge pull request #2 from HIT-cwh/refactor-llm
pppppM Apr 19, 2024
d2f1002
Merge branch 'main' into refactor-llm
pppppM May 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added xtuner/chat/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions xtuner/chat/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .huggingface import HFBackend
from .lmdeploy import LMDeployBackend

__all__ = ['HFBackend', 'LMDeployBackend']
30 changes: 30 additions & 0 deletions xtuner/chat/backend/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import abstractmethod
from typing import List, Optional

from xtuner.chat.streamer import SteamerType
from xtuner.types import (ChatBackendProtocol, ChatMessages, ChatTemplate,
SampleParams)


class BaseBackend(ChatBackendProtocol):

@property
def chat_template(self) -> ChatTemplate:
pass

@abstractmethod
def create_streamer(self, iterable: bool = False) -> SteamerType:
pass

@abstractmethod
def chat(self,
messages: ChatMessages,
sample_params: Optional[SampleParams] = None,
streamer: Optional[SteamerType] = None):
pass

@abstractmethod
def batch_infer(self,
messages: List[ChatMessages],
sample_params: Optional[SampleParams] = None):
pass
153 changes: 153 additions & 0 deletions xtuner/chat/backend/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import List, Optional

import torch
from peft import PeftModel
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
from transformers import GenerationConfig as HFGenerationConfig
from transformers import PreTrainedModel, PreTrainedTokenizer

from xtuner.chat.streamer import HFTextIteratorStreamer, HFTextStreamer
from xtuner.model.utils import LoadWoInit
from xtuner.tools.utils import get_stop_criteria
from xtuner.types import ChatMessages, ChatTemplate, SampleParams
from .base import BaseBackend


class HFBackend(BaseBackend):

def __init__(
self,
chat_template: ChatTemplate,
llm: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()

self.llm = llm
self.llm.cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是 quant 模型,直接 cuda 会有问题?

self.tokenizer = tokenizer

self._chat_template = chat_template

@property
def chat_template(self) -> ChatTemplate:
return self._chat_template

@property
def eos_token_id(self):
if self.tokenizer.pad_token_id:
return self.tokenizer.eos_token_id
else:
return self.tokenizer.eos_token_id

@property
def pad_token_id(self):
return self.tokenizer.pad_token_id

def build_llm_and_tokenizer(self,
model_name_or_path,
adapter=None,
bits=None):

if bits is None:
quantization_config = None
load_in_8bit = False
elif bits == 4:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')
load_in_8bit = False
elif bits == 8:
quantization_config = None
load_in_8bit = True

tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)

with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map='auto',
load_in_8bit=load_in_8bit,
quantization_config=quantization_config,
trust_remote_code=True,
torch_dtype=torch.float16)

if adapter is not None:
model = PeftModel.from_pretrained(model, adapter)

model.eval()
return model, tokenizer

def create_streamer(self, iterable=False):
if iterable:
return HFTextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
chat_template=self.chat_template)
else:
return HFTextStreamer(
self.tokenizer,
skip_prompt=True,
chat_template=self.chat_template)

def parse_sample_params(self, params: SampleParams) -> HFGenerationConfig:

if params is None:
params = SampleParams()

hf_gen_config = HFGenerationConfig(
max_new_tokens=params.max_new_tokens,
do_sample=params.temperature > 0,
temperature=params.temperature,
top_k=params.top_k,
top_p=params.top_p,
repetition_penalty=params.repetition_penalty,
seed=params.seed,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id)

stop_words = params.stop_words
stop_words.extend(self.chat_template.stop_words)

return hf_gen_config, stop_words

def chat(self,
messages: ChatMessages,
streamer=None,
sample_params: Optional[SampleParams] = None):

prompt = messages.get_prompt(self.chat_template)
ids = self.tokenizer.encode(prompt, return_tensors='pt')

hf_gen_config, stop_words = self.parse_sample_params(sample_params)

stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)

generate_output = self.llm.generate(
inputs=ids.cuda(),
streamer=streamer,
generation_config=hf_gen_config,
stopping_criteria=stop_criteria)

output = self.tokenizer.decode(
generate_output[0][len(ids[0]):], skip_special_tokens=True)

for word in stop_words:
output = output.rstrip(word)

return output

def batch_infer(self,
messages: List[ChatMessages],
sample_params: SampleParams | None = None):
raise NotImplementedError
3 changes: 3 additions & 0 deletions xtuner/chat/backend/lmdeploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .backend import LMDeployBackend

__all__ = ['LMDeployBackend']
27 changes: 27 additions & 0 deletions xtuner/chat/backend/lmdeploy/_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.serve.async_engine import AsyncEngine

from xtuner.types import ChatMessages, ChatTemplate


class _AsyncEngine(AsyncEngine):
"""Async inference engine."""

def __init__(self, chat_template: ChatTemplate, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert self.model_name == 'base'
self.chat_template = chat_template

async def _get_prompt_input(self, prompt: ChatMessages,
do_preprocess: bool, sequence_start: bool):
"""get input_ids, embeddings and offsets."""

decorated = prompt.get_prompt(self.chat_template)

results = {}

input_ids = self.tokenizer.encode(decorated, add_bos=sequence_start)

results['input_ids'] = input_ids
results['prompt'] = decorated
return results
94 changes: 94 additions & 0 deletions xtuner/chat/backend/lmdeploy/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import os
from typing import List, Optional, Union

from lmdeploy.utils import get_logger

from xtuner.types import ChatMessages, ChatTemplate, SampleParams
from ...streamer import LMDeployTextIteratorStreamer, LMDeployTextStreamer
from ..base import BaseBackend
from ._engine import _AsyncEngine

os.environ['TM_LOG_LEVEL'] = 'ERROR'
logger = get_logger('lmdeploy')
logger.setLevel('ERROR')

_StreamerType = Union[LMDeployTextStreamer, LMDeployTextIteratorStreamer]


class LMDeployBackend(BaseBackend):

def __init__(self, chat_template, llm_name_or_path) -> None:
super().__init__()

self._engine = _AsyncEngine(
chat_template, model_path=llm_name_or_path, model_name='base')

self._chat_template = chat_template

@property
def chat_template(self) -> ChatTemplate:
return self._chat_template

def create_streamer(self, iterable=False):

if iterable:
return LMDeployTextIteratorStreamer()
else:
return LMDeployTextStreamer()

def parse_sample_params(self, params: SampleParams):

if params is None:
params = SampleParams()

stop_words = params.stop_words
stop_words.extend(self.chat_template.stop_words)

from lmdeploy.messages import GenerationConfig as LMDGenerationConfig
lmd_gen_config = LMDGenerationConfig(
max_new_tokens=params.max_new_tokens,
temperature=params.temperature,
top_k=params.top_k,
top_p=params.top_p,
repetition_penalty=params.repetition_penalty,
random_seed=params.seed,
stop_words=stop_words)

return lmd_gen_config

def chat(self,
messages: ChatMessages,
streamer: Optional[_StreamerType] = None,
sample_params: Optional[SampleParams] = None):

lmd_gen_config = self.parse_sample_params(sample_params)
self.session_id += 1
import random

generator = self._engine.generate(
messages, random.randint(1, 100000), gen_config=lmd_gen_config)

async def get_response():
out = ''
async for res in generator:
out += res.response
if streamer:
streamer.put(res.response)
if streamer:
streamer.end()
return out

loop = asyncio.new_event_loop()
response = loop.run_until_complete(get_response())
return response

def batch_infer(self,
messages: List[ChatMessages],
sample_params: Optional[SampleParams] = None):

lmd_gen_config = self.parse_sample_params(sample_params)

results = self._engine.batch_infer(messages, gen_config=lmd_gen_config)

return [r.text for r in results]
12 changes: 12 additions & 0 deletions xtuner/chat/streamer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Union

from .huggingface import HFTextIteratorStreamer, HFTextStreamer
from .lmdeploy import LMDeployTextIteratorStreamer, LMDeployTextStreamer

SteamerType = Union[HFTextIteratorStreamer, HFTextStreamer,
LMDeployTextIteratorStreamer, LMDeployTextStreamer]

__all__ = [
'HFTextIteratorStreamer', 'HFTextStreamer', 'LMDeployTextIteratorStreamer',
'LMDeployTextStreamer'
]
37 changes: 37 additions & 0 deletions xtuner/chat/streamer/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformers import TextIteratorStreamer, TextStreamer
from transformers.models.auto import AutoTokenizer


class HFTextIteratorStreamer(TextIteratorStreamer):

def __init__(self,
tokenizer: AutoTokenizer,
skip_prompt: bool = False,
timeout=None,
chat_template=None,
**decode_kwargs):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.chat_template = chat_template

def on_finalized_text(self, text: str, stream_end: bool = False):

for word in self.chat_template.stop_words:
text = text.rstrip(word)
super().on_finalized_text(text, stream_end)


class HFTextStreamer(TextStreamer):

def __init__(self,
tokenizer: AutoTokenizer,
skip_prompt: bool = False,
chat_template=None,
**decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.chat_template = chat_template

def on_finalized_text(self, text: str, stream_end: bool = False):

for word in self.chat_template.stop_words:
text = text.rstrip(word)
super().on_finalized_text(text, stream_end)