Skip to content

Commit

Permalink
BaseTune and BaseEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Mar 28, 2024
1 parent e571688 commit 654f1b1
Show file tree
Hide file tree
Showing 9 changed files with 606 additions and 141 deletions.
59 changes: 30 additions & 29 deletions xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,62 @@
"messages": [
{"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"},
{
"role": "user",
"content": "Please help me process and visualize this dataset.",
"role": "user",
"content": "Please help me process and visualize this dataset.",
"files": [{"path": "data.csv", "size": "10K"}]
},
},
{
"role": "assistant",
"content": "I have processed the data and visualized it for you.",
"role": "assistant",
"content": "I have processed the data and visualized it for you.",
"code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='<b>Date: %{text}</b><br>Wind Direction 9am: %{x}<br>Rainfall: %{y}<br>Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```"
},
},
{
"role": "code_interpreter",
"role": "code_interpreter",
"content": "![image](xxx.png)"
},
},
{
"role": "assistant",
"role": "assistant",
"content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background."
},
},
{
"role": "user",
"role": "user",
"content": "I want to know today's weather in Shanghai"
},
{
"role": "assistant",
"content": "Sure, I will search for the weather of Shanghai.",
"role": "assistant",
"content": "Sure, I will search for the weather of Shanghai.",
"function_call": {
"name": "get_current_weather",
"name": "get_current_weather",
"parameters": {"location": "Shanghai"}
}
},
},
{
"role": "function",
"name": "get_current_weather",
"role": "function",
"name": "get_current_weather",
"content": "{'temperature': 22}"
},
},
{
"role": "assistant",
"role": "assistant",
"content": "The weather in Shanghai is 22 celsius"
}
],
],

"functions": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"type": "object",
"properties": {
"location": {
"type": "string",
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
"unit": {"type": "string"}},
"unit": {"type": "string"}},
"required": ["location"]
}
}
}
],

"code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"}
],

"code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"
}
26 changes: 15 additions & 11 deletions xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,32 @@

from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages


chat_template = HybridChatTemplate(
system='<|im_start|>system\n{system}<|im_end|>\n',
user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
assistant='{assistant}<|im_end|>\n',
stop_words=['<|im_end|>'],
image_token='<image>',
files='<|im_start|>user name=file\n{files}<|im_end|>\n',
function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
files='<|im_start|>user name=file\n{files}<|im_end|>\n',
function_call=
'{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
function_result=
'<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n',
code_interpreter_call='{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
code_interpreter_result='<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
code_interpreter='<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n'

)
code_interpreter_call=
'{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
code_interpreter_result=
'<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
code_interpreter=
'<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n')

agent_data = json.load(open('agent.json'))

msg = TrainingHybridChatMessages.from_dict(agent_data)
print(msg.apply_chat_template(chat_template))

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True)
print(msg.tokenize(tokenizer, chat_template))

tokenizer = AutoTokenizer.from_pretrained(
'internlm/internlm2-chat-7b', trust_remote_code=True)
print(msg.tokenize(tokenizer, chat_template))
20 changes: 20 additions & 0 deletions xtuner/model/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmengine import Config

from xtuner.model.base import BaseTune
from xtuner.registry import BUILDER


class AutoModel():

@classmethod
def from_config(cls, config: str):
config = Config.fromfile(config)
model: BaseTune = BUILDER.build(config.model)
return model

@classmethod
def from_pretrained(cls, config: str, checkpoint: str):
config = Config.fromfile(config)
model: BaseTune = BUILDER.build(config.model)
model.load_checkpoint(checkpoint)
return model
71 changes: 71 additions & 0 deletions xtuner/model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractclassmethod, abstractmethod

from mmengine.model import BaseModel

from xtuner.types import HybridChatMessages, HybridChatTemplate


class BaseTune(BaseModel):

def __init__():
super().__init__()

def init_weights(self):
"""Parent class method.
To avoid overwriting the loaded weights, overload it to an empty
function.
"""
pass

def avoid_override_weights(self):
self._is_init = True

@abstractmethod
@property
def chat_template(self) -> HybridChatTemplate:
pass

@abstractmethod
@property
def llm(self):
pass

@abstractmethod
@property
def tokenizer(self):
pass

@abstractmethod
def gradient_checkpointing_enable(self):
pass

def forward(self, data, data_samples=None, mode='loss'):
"""Overload parent class method, only support training."""

if mode == 'loss':
return self.compute_loss(data)
else:
raise NotImplementedError(
f"{type(self)}'s forward is only supported for use during "
'training. If you want to get predictions or chat, please '
"directly use `llm`'s forward.")

@abstractmethod
def chat(self, messages: HybridChatMessages, sample_params, streamer):
pass

@abstractmethod
def save_checkpoint(self, *args, **kwargs):
pass

@abstractmethod
def load_checkpoint(self, *args, **kwargs) -> 'BaseTune':
pass

def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
2 changes: 2 additions & 0 deletions xtuner/model/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import EncoderWrapper
from .llava import LlavaEncoderWrapper
53 changes: 53 additions & 0 deletions xtuner/model/encoders/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractclassmethod, abstractmethod
from typing import List, Union

import torch
from PIL import Image
from torch import nn

_ImageType = Union[str, Image.Image]


class EncoderWrapper(nn.Module):

def __init__(self):
super().__init__()

@abstractmethod
@property
def encoder(self):
pass

@abstractmethod
@property
def projector(self):
pass

@abstractmethod
def post_init_proj(self, llm):
pass

@abstractmethod
def preprocess(self, image: _ImageType) -> torch.Tensor:
pass

@abstractmethod
def batch_infer(images: List[_ImageType]) -> List[torch.Tensor]:
pass

@abstractmethod
def gradient_checkpointing_enable(self):
pass

@abstractclassmethod
def save_checkpoint(self, *args, **kwargs):
pass

@abstractclassmethod
def load_checkpoint(self, *args, **kwargs) -> 'EncoderWrapper':
pass

@abstractclassmethod
def only_build_processor(self, *args, **kwargs):
pass

0 comments on commit 654f1b1

Please sign in to comment.