Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Dispatch info log once #552

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 65 additions & 33 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,26 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
llama_varlen_attn_forward_legacy)

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
print_flag = True
for module in model.modules():
# Do not need to dispatch if
# type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is
# required when using sequence parallel
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'):
if use_varlen_attn:
print_log('dispatch llama varlen attn forward', 'current')
if IS_LOW_VERSION_TRANSFORMERS:
module.forward = types.MethodType(
llama_varlen_attn_forward_legacy, module)
else:
module.forward = types.MethodType(
llama_varlen_attn_forward, module)
if print_flag:
print_log('dispatch llama varlen attn forward', 'current')
forward_func = (
llama_varlen_attn_forward_legacy if
IS_LOW_VERSION_TRANSFORMERS else llama_varlen_attn_forward)
else:
print_log('dispatch llama attn forward', 'current')
if IS_LOW_VERSION_TRANSFORMERS:
module.forward = types.MethodType(
llama_attn_forward_legacy, module)
else:
module.forward = types.MethodType(llama_attn_forward,
module)
if print_flag:
print_log('dispatch llama attn forward', 'current')
forward_func = (
llama_attn_forward_legacy
if IS_LOW_VERSION_TRANSFORMERS else llama_attn_forward)
module.forward = types.MethodType(forward_func, module)
print_flag = False


def dispatch_llama_rmsnorm_forward(model):
Expand All @@ -85,10 +84,13 @@ def dispatch_llama_rmsnorm_forward(model):

from .triton_kernels import rms_norm_forward

print_flag = True
for module in model.modules():
if type(module).__name__ == 'LlamaRMSNorm':
print_log('dispatch llama rmsnorm forward', 'current')
if print_flag:
print_log('dispatch llama rmsnorm forward', 'current')
module.forward = types.MethodType(rms_norm_forward, module)
print_flag = False


def dispatch_internlm_attn_forward(model, use_varlen_attn):
Expand All @@ -101,16 +103,21 @@ def dispatch_internlm_attn_forward(model, use_varlen_attn):
from .internlm import internlm_attn_forward, internlm_varlen_attn_forward

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
print_flag = True
for module in model.modules():
if type(module).__name__ == 'InternLMAttention':
if use_varlen_attn:
print_log('dispatch internlm varlen attn forward', 'current')
if print_flag:
print_log('dispatch internlm varlen attn forward',
'current')
module.forward = types.MethodType(internlm_varlen_attn_forward,
module)
else:
print_log('dispatch internlm attn forward', 'current')
if print_flag:
print_log('dispatch internlm attn forward', 'current')
module.forward = types.MethodType(internlm_attn_forward,
module)
print_flag = False


def dispatch_internlm2_attn_forward(model, use_varlen_attn):
Expand All @@ -124,17 +131,22 @@ def dispatch_internlm2_attn_forward(model, use_varlen_attn):
internlm2_varlen_attn_forward)

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
print_flag = True
for module in model.modules():
if type(module).__name__ in ('InternLM2Attention',
'InternLM2FlashAttention2'):
if use_varlen_attn:
print_log('dispatch internlm2 varlen attn forward', 'current')
if print_flag:
print_log('dispatch internlm2 varlen attn forward',
'current')
module.forward = types.MethodType(
internlm2_varlen_attn_forward, module)
else:
print_log('dispatch internlm2 attn forward', 'current')
if print_flag:
print_log('dispatch internlm2 attn forward', 'current')
module.forward = types.MethodType(internlm2_attn_forward,
module)
print_flag = False


def dispatch_internlm_rmsnorm_forward(model):
Expand All @@ -143,9 +155,12 @@ def dispatch_internlm_rmsnorm_forward(model):

from .triton_kernels import rms_norm_forward

print_flag = True
for module in model.modules():
if type(module).__name__ == 'InternLMRMSNorm':
print_log('dispatch internlm rmsnorm forward', 'current')
if print_flag:
print_log('dispatch internlm rmsnorm forward', 'current')
print_flag = False
module.forward = types.MethodType(rms_norm_forward, module)


Expand All @@ -155,55 +170,64 @@ def dispatch_internlm2_rmsnorm_forward(model):

from .triton_kernels import rms_norm_forward

print_flag = True
for module in model.modules():
if type(module).__name__ == 'InternLM2RMSNorm':
print_log('dispatch internlm2 rmsnorm forward', 'current')
if print_flag:
print_log('dispatch internlm2 rmsnorm forward', 'current')
print_flag = False
module.forward = types.MethodType(rms_norm_forward, module)


def replace_internlm_rote(model):
from .internlm import InternLMRotaryEmbedding

def traverse(module):
def traverse(module, print_flag):
for name, child in module.named_children():
if type(child).__name__ in (
'InternLMRotaryEmbedding',
'InternLMDynamicNTKScalingRotaryEmbedding'):
print_log('replace internlm rope', 'current')
if print_flag[0]:
print_log('replace internlm rope', 'current')
print_flag[0] = False

dim_model = child.inv_freq.shape[0] * 2
child_new = InternLMRotaryEmbedding(
dim_model, child.max_seq_len_cached).to(
device=child.inv_freq.device,
dtype=child.inv_freq.dtype)
setattr(module, name, child_new)

else:
traverse(child)

traverse(model)
traverse(model, [True])


def replace_internlm2_rote(model):
from .internlm2 import InternLM2RotaryEmbedding

rotary_base = model.config.rope_theta

def traverse(module):
def traverse(module, print_flag):
for name, child in module.named_children():
if type(child).__name__ in (
'InternLM2RotaryEmbedding',
'InternLM2LinearScalingRotaryEmbedding',
'InternLM2DynamicNTKScalingRotaryEmbedding'):
print_log('replace internlm2 rope', 'current')
if print_flag[0]:
print_log('replace internlm2 rope', 'current')
print_flag[0] = False
dim_model = child.inv_freq.shape[0] * 2
child_new = InternLM2RotaryEmbedding(
dim_model, child.max_position_embeddings, rotary_base).to(
device=child.inv_freq.device,
dtype=child.inv_freq.dtype)
setattr(module, name, child_new)
else:
traverse(child)
traverse(child, print_flag)

traverse(model)
traverse(model, [True])


def dispath_baichuan2_norm_head_forward(model):
Expand Down Expand Up @@ -258,10 +282,13 @@ def dispatch_mistral_attn_forward(model, use_varlen_attn):
from .mistral import mistral_varlen_attn_forward

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
print_flag = True
for module in model.modules():
if type(module).__name__ in ('MistralAttention',
'MistralFlashAttention2'):
print_log('dispatch mistral varlen attn forward', 'current')
if print_flag:
print_log('dispatch mistral varlen attn forward', 'current')
print_flag = False
module.forward = types.MethodType(mistral_varlen_attn_forward,
module)

Expand All @@ -272,9 +299,12 @@ def dispatch_mistral_rmsnorm_forward(model):

from .triton_kernels import rms_norm_forward

print_flag = True
for module in model.modules():
if type(module).__name__ == 'MistralRMSNorm':
print_log('dispatch mistral rmsnorm forward', 'current')
if print_flag:
print_log('dispatch mistral rmsnorm forward', 'current')
print_flag = False
module.forward = types.MethodType(rms_norm_forward, module)


Expand All @@ -283,10 +313,12 @@ def replace_mistral_rote(model):

rotary_base = model.config.rope_theta

def traverse(module):
def traverse(module, print_flag):
for name, child in module.named_children():
if type(child).__name__ == 'MistralRotaryEmbedding':
print_log('replace mistral rope', 'current')
if print_flag[0]:
print_log('replace mistral rope', 'current')
print_flag[0] = False
dim_model = child.inv_freq.shape[0] * 2
child_new = MistralRotaryEmbedding(
dim_model, child.max_seq_len_cached, rotary_base).to(
Expand All @@ -296,7 +328,7 @@ def traverse(module):
else:
traverse(child)

traverse(model)
traverse(model, [True])


def dispatch_modules(model, use_varlen_attn=False):
Expand Down