Skip to content

Commit

Permalink
Merge branch 'main' into refactor-llm
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Mar 29, 2024
2 parents 0f31481 + 0b5708c commit 3724180
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 26 deletions.
5 changes: 4 additions & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ tiktoken
torch<=2.1.2
torchvision<=0.16.2
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
transformers>=4.36.0
# Registering a causal mask in `LlamaModel` is not friendly for very large
# `max_position_embeddings`. Refer to
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2
transformers_stream_generator
4 changes: 2 additions & 2 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def dispatch_modules(model, use_varlen_attn=False):
dispatch_internlm2_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
dispatch_internlm2_rmsnorm_forward(model)
# replace_internlm2_rote(model)
replace_internlm2_rote(model)
elif 'internlm' in model_name:
dispatch_internlm_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
dispatch_internlm_rmsnorm_forward(model)
# replace_internlm_rote(model)
replace_internlm_rote(model)
elif 'llama' in model_name:
dispatch_llama_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
Expand Down
25 changes: 10 additions & 15 deletions xtuner/model/modules/dispatch/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,26 +234,18 @@ def llama_attn_forward_legacy(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
# Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37'
' Please make sure use `attention_mask` instead.`')

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')

output_attentions = False
'Passing `padding_mask` is deprecated and will be removed in '
'v4.37. Please make sure use `attention_mask` instead.`')

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
Expand All @@ -263,6 +255,13 @@ def llama_attn_forward_legacy(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
'The cache structure has changed since version v4.36. '
f'If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, '
'please make sure to initialize the attention class '
'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
assert position_ids is not None
Expand All @@ -282,10 +281,6 @@ def llama_attn_forward_legacy(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# repeat kv for sequence parallel
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)

assert SUPPORT_FLASH2
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
Expand Down
9 changes: 1 addition & 8 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,6 @@ def _build_from_cfg_or_module(self, cfg_or_mod):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
if SUPPORT_FLASH2:
cfg_or_mod.torch_dtype = torch.bfloat16 \
if torch.cuda.is_bf16_supported() else torch.float16
cfg_or_mod.attn_implementation = 'flash_attention_2'
if max_position_embeddings is not None:
cfg_or_mod = self._prepare_for_long_context_training(
cfg_or_mod, max_position_embeddings)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
Expand Down Expand Up @@ -265,4 +258,4 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
return getattr(self.llm, name)

0 comments on commit 3724180

Please sign in to comment.