Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Issue #19

DevonPeroutky opened this issue May 6, 2024 · 1 comment

Training Issue #19

DevonPeroutky opened this issue May 6, 2024 · 1 comment


Copy link

DevonPeroutky commented May 6, 2024


  • Platform: Debian Linux
  • GPU: A100
  • Torch: '2.1.2+cu121'
  • Transfomers: '4.37.2'


I'm seeing random and sudden loss spikes during training, if there is a simpler way of debugging this, I'm open to a new approach. However, I attempted to reproduce the training loop in pytorch such that I could log out abnormal gradients during the training process to detect any erroneous examples in my training data.

However, I'm always getting AttributeError: 'NoneType' object has no attribute 'device' in the forward pass (Full stacktrace below).

I built the model exactly how its done in and my training loop looks like

# Define a threshold for outlier detection
gradient_threshold = 10.0

# Create a DataLoader for iterating through the dataset
train_dataloader =['train_dataset'], batch_size=1, shuffle=True)

for batch_idx, batch in enumerate(train_dataloader):
    input_ids = batch["input_ids"]             # torch.Size([1, 200])
    labels = batch["labels"]                         # torch.Size([1, 200])
    image_tensor = batch["image"].half() # torch.Size([1, 3, 336, 336])

    # Zero the gradient

    # Always errors out here
    output = model.forward(input_ids=input_ids, images=image_tensor)

The model.forward always fails with the below stacktrace. I've tried the forward pass with and without labels, similar results. After prepare_inputs_labels_for_multimodal call, the inputs look like the following:

Input IDs:  None
position_ids:  None
Attention Mask:  None
past_key_values:  None
labels: None
inputs_embeds:  torch.Size([1, 512, 4096])

Below is the full stacktrace and the model layers. What am I missing?


  (base_model): LoraModel(
    (model): LlavaLlamaForCausalLM(
      (model): LlavaLlamaModel(
        (embed_tokens): Embedding(128257, 4096, padding_idx=128256)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (k_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=1024, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (v_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=1024, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (o_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (rotary_emb): LlamaRotaryEmbedding()
            (mlp): LlamaMLP(
              (gate_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=14336, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (up_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=14336, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (down_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=14336, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=128, bias=False)
                (lora_B): ModuleDict(
                  (default): Linear(in_features=128, out_features=4096, bias=False)
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              (act_fn): SiLU()
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
        (norm): LlamaRMSNorm()
        (vision_tower): CLIPVisionTower(
          (vision_tower): CLIPVisionModel(
            (vision_model): CLIPVisionTransformer(
              (embeddings): CLIPVisionEmbeddings(
                (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
                (position_embedding): Embedding(577, 1024)
              (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (encoder): CLIPEncoder(
                (layers): ModuleList(
                  (0-23): 24 x CLIPEncoderLayer(
                    (self_attn): CLIPAttention(
                      (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                    (mlp): CLIPMLP(
                      (activation_fn): QuickGELUActivation()
                      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                    (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mm_projector): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=4096, out_features=4096, bias=True)
      (lm_head): Linear8bitLt(in_features=4096, out_features=128257, bias=False)

Full StackTrace

AttributeError                            Traceback (most recent call last)
Cell In[36], line 46
     28 # (_input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels_embeds) = model.prepare_inputs_labels_for_multimodal(input_ids=input_ids, position_ids=None, attention_mask=None, past_key_values=None, labels=labels, images=image_tensor)
     45 # 4
---> 46 output = model.forward(input_ids=input_ids, images=image_tensor, labels=labels)
     47 loss = compute_loss(output.logits, labels)
     48 print("LOSS: ", loss.item())

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/](, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1127     with self._enable_peft_forward_hooks(**kwargs):
   1128         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1129         return self.base_model(
   1130             input_ids=input_ids,
   1131             attention_mask=attention_mask,
   1132             inputs_embeds=inputs_embeds,
   1133             labels=labels,
   1134             output_attentions=output_attentions,
   1135             output_hidden_states=output_hidden_states,
   1136             return_dict=return_dict,
   1137             **kwargs,
   1138         )
   1140 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1141 if attention_mask is not None:
   1142     # concat prompt attention mask

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/](, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/](, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [~/LLaVA-pp/LLaVA/llava/model/language_model/](, in LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict, cache_position)
    101     print("inputs_embeds: ", inputs_embeds.shape)
    102 print("labels: ", labels)
--> 103 return super().forward(
    104     input_ids=input_ids,
    105     attention_mask=attention_mask,
    106     position_ids=position_ids,
    107     past_key_values=past_key_values,
    108     inputs_embeds=inputs_embeds,
    109     labels=labels,
    110     use_cache=use_cache,
    111     output_attentions=output_attentions,
    112     output_hidden_states=output_hidden_states,
    113     return_dict=return_dict
    114 )

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/](, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1180 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1182 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1183 outputs = self.model(
   1184     input_ids=input_ids,
   1185     attention_mask=attention_mask,
   1186     position_ids=position_ids,
   1187     past_key_values=past_key_values,
   1188     inputs_embeds=inputs_embeds,
   1189     use_cache=use_cache,
   1190     output_attentions=output_attentions,
   1191     output_hidden_states=output_hidden_states,
   1192     return_dict=return_dict,
   1193 )
   1195 hidden_states = outputs[0]
   1196 if self.config.pretraining_tp > 1:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/](, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/](, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
   1060     layer_outputs = self._gradient_checkpointing_func(
   1061         decoder_layer.__call__,
   1062         hidden_states,
   1067         use_cache,
   1068     )
   1069 else:
-> 1070     layer_outputs = decoder_layer(
   1071         hidden_states,
   1072         attention_mask=attention_mask,
   1073         position_ids=position_ids,
   1074         past_key_value=past_key_values,
   1075         output_attentions=output_attentions,
   1076         use_cache=use_cache,
   1077     )
   1079 hidden_states = layer_outputs[0]
   1081 if use_cache:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/](, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/](, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    795 hidden_states = self.input_layernorm(hidden_states)
    797 # Self Attention
--> 798 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    799     hidden_states=hidden_states,
    800     attention_mask=attention_mask,
    801     position_ids=position_ids,
    802     past_key_value=past_key_value,
    803     output_attentions=output_attentions,
    804     use_cache=use_cache,
    805     **kwargs,
    806 )
    807 hidden_states = residual + hidden_states
    809 # Fully Connected

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/](, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/transformers/models/llama/](, in LlamaFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    490 output_attentions = False
    492 bsz, q_len, _ = hidden_states.size()
--> 494 query_states = self.q_proj(hidden_states)
    495 key_states = self.k_proj(hidden_states)
    496 value_states = self.v_proj(hidden_states)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/peft/tuners/lora/](, in Linear8bitLt.forward(self, x, *args, **kwargs)
    215     result = self.base_layer(x, *args, **kwargs)
    216 else:
--> 217     result = self.base_layer(x, *args, **kwargs)
    218     for active_adapter in self.active_adapters:
    219         if active_adapter not in self.lora_A.keys():

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/nn/modules/](, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/accelerate/](, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/nn/](, in Linear8bitLt.forward(self, x)
    794 if self.bias is not None and self.bias.dtype != x.dtype:
    795 =
--> 797 out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
    799 if not self.state.has_fp16_weights:
    800     if self.state.CB is not None and self.state.CxB is not None:
    801         # we converted 8-bit row major to turing[/ampere]( format in the first inference pass
    802         # we no longer need the row-major weight

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/](, in matmul(A, B, out, state, threshold, bias)
    554 if threshold > 0.0:
    555     state.threshold = threshold
--> 556 return MatMul8bitLt.apply(A, B, out, bias, state)

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/torch/autograd/](, in Function.apply(cls, *args, **kwargs)
    536 if not torch._C._are_functorch_transforms_active():
    537     # See NOTE: [functorch vjp and autograd interaction]
    538     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539     return super().apply(*args, **kwargs)  # type: ignore[misc]
    541 if cls.setup_context == _SingleLevelFunction.setup_context:
    542     raise RuntimeError(
    543         "In order to use an autograd.Function with functorch transforms "
    544         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    545         "staticmethod. For more details, please see "
    546         ""
    547     )

File /opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/autograd/, in MatMul8bitLt.forward(ctx, A, B, out, bias, state)
    331     else:
    332         if state.CxB is None and using_igemmlt:
    333             # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
    334             # we also need to convert it to the turing[/ampere]( format
--> 335             state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
    336 else:
    337     if not state.has_fp16_weights and state.CxB is None and using_igemmlt:

File [/opt/conda/envs/llama3-llavapp/lib/python3.10/site-packages/bitsandbytes/](, in transform(A, to_order, from_order, out, transpose, state, ld)
   2596 def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
-> 2597     prev_device = pre_call(A.device)
   2598     if state is None:
   2599         state = (A.shape, from_order)

AttributeError: 'NoneType' object has no attribute 'device'
Copy link

mmaaz60 commented May 10, 2024

Hi @DevonPeroutky,

Thank you for your interest in our work. Did you try to upgrade the transformers to the latest version? Please note that LLaMA-3 based trainings are only supported with "transformers==4.41+" which you can install as follows,

pip install git+

Let me know if it helps. Good Luck!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

2 participants