Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecurrentGemma not compatible with autocast / AMP training #30830

Open
4 tasks done
xplip opened this issue May 15, 2024 · 0 comments · May be fixed by #30832
Open
4 tasks done

RecurrentGemma not compatible with autocast / AMP training #30830

xplip opened this issue May 15, 2024 · 0 comments · May be fixed by #30832

Comments

@xplip
Copy link

xplip commented May 15, 2024

System Info

  • transformers version: 4.40.2
  • Platform: Linux-5.15.0-92-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: False

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM


def main():
    V = 288
    B = 16
    T = 300
    device = "cuda"

    config = RecurrentGemmaConfig(
        vocab_size=V,
        num_hidden_layers=12,
        hidden_size=1024,
        num_attention_heads=8,
        intermediate_size=6144,
        attention_window_size=T,
    )
    model = RecurrentGemmaForCausalLM._from_config(config, torch_dtype=torch.float32).to(device)

    autocast_settings = [
        {"dtype": torch.float16, "enabled": True},
        {"dtype": torch.bfloat16, "enabled": True},
        {"enabled": False},
    ]

    for autocast_setting in autocast_settings:
        print(f"\nRunning with autocast setting: {autocast_setting}:")
        try:
            with torch.cuda.amp.autocast(**autocast_setting):
                outputs = model(input_ids=torch.randint(0, V, (B, T), device=device))
                print(outputs.logits.shape)
        except RuntimeError as e:
            print(e)
    

if __name__ == "__main__":
    main()

Expected behavior

The script should run without errors with autocast enabled, as otherwise training with AMP is not available.

Output of the script above:

Running with autocast setting: {'dtype': torch.float16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Running with autocast setting: {'dtype': torch.bfloat16, 'enabled': True}
Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.

Running with autocast setting: {'enabled': False}
torch.Size([16, 300, 288])

Expected output:

Running with autocast setting: {'dtype': torch.float16, 'enabled': True}
torch.Size([16, 300, 288])

Running with autocast setting: {'dtype': torch.bfloat16, 'enabled': True}
torch.Size([16, 300, 288])

Running with autocast setting: {'enabled': False}
torch.Size([16, 300, 288])
@xplip xplip linked a pull request May 15, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant