Skip to content

Commit

Permalink
remove old collate fn
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Apr 7, 2024
1 parent cf8e8af commit 9dc1142
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 85 deletions.
3 changes: 1 addition & 2 deletions xtuner/dataset/hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .collate import text_collate_fn
# Copyright (c) OpenMMLab. All rights reserved.
from .dataset import TextDataset
from .mappings import map_protocol, map_sequential, openai_to_raw_training

__all__ = [
'text_collate_fn',
'TextDataset',
'map_protocol',
'map_sequential',
Expand Down
54 changes: 0 additions & 54 deletions xtuner/dataset/hybrid/collate.py

This file was deleted.

36 changes: 8 additions & 28 deletions xtuner/dataset/hybrid/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import json
import os
Expand Down Expand Up @@ -270,8 +271,8 @@ def load_dataset(
"""
if self.is_cached(cache_dir):
print_log(
f'{cache_dir} is cached dataset that will be loaded '
'directly; `data_files` and `data_dir` will become'
f'{cache_dir} is a cached dataset that will be loaded '
'directly; `data_files` and `data_dir` will become '
'invalid.',
logger='current')

Expand Down Expand Up @@ -359,7 +360,7 @@ def tokenize_dataset(self, dataset: List[dict]) -> List[dict]:
`labels` and `num_tokens`.
`input_ids` and `labels` are lists of int, and they should
have equal lengths.
`num_tokens` is an integerthe length of `input_ids`.
`num_tokens` is an integer, the length of `input_ids`.
"""

def openai_to_raw_training(item: dict) -> Dict:
Expand Down Expand Up @@ -574,48 +575,27 @@ def __getitem__(self, item: int) -> Dict[str, List]:
stop_words=['<|im_end|>'],
)

from xtuner.dataset.hybrid.mappings import openai_to_raw_training

data_dir = './llava_data/LLaVA-Instruct-150K/'
image_dir = './llava_data/llava_images/'
data_files = 'llava_v1_5_mix665k.json'

dataset = TextDataset(
'internlm/internlm2-chat-1_8b',
chat_template,
sample_ratio=1,
max_length=32 * 1024,
data_dir=data_dir,
data_files=data_files,
data_dir='converted_alpaca',
cache_dir='cached_alpaca',
pack_to_max_length=True,
mappings=[openai_to_raw_training],
num_proc=4)

print(dataset[0])

dataset.cache('cached_llava')
dataset = TextDataset(
'internlm/internlm2-chat-1_8b',
chat_template,
sample_ratio=1,
max_length=32 * 1024,
cache_dir='cached_llava',
pack_to_max_length=True,
mappings=[
openai_to_raw_training,
],
num_proc=4)
print(dataset[0])

from mmengine.dataset import DefaultSampler
from torch.utils.data import DataLoader

from xtuner.dataset.hybrid.collate import text_collate_fn
from xtuner.model import TextFinetune
loader = DataLoader(
dataset,
4,
num_workers=0,
collate_fn=text_collate_fn,
collate_fn=TextFinetune.dataloader_collate_fn,
sampler=DefaultSampler(dataset, shuffle=True))

for data in tqdm(loader):
Expand Down
1 change: 1 addition & 0 deletions xtuner/model/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .finetune import TextFinetune

__all__ = ['TextFinetune']
1 change: 0 additions & 1 deletion xtuner/model/text/finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.

from collections import OrderedDict
from typing import Dict, List, Optional, Union

Expand Down

0 comments on commit 9dc1142

Please sign in to comment.