Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GPTQ Kernels dont work with PEFT #633

Open
achew010 opened this issue Apr 5, 2024 · 0 comments
Open

[BUG] GPTQ Kernels dont work with PEFT #633

achew010 opened this issue Apr 5, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@achew010
Copy link

achew010 commented Apr 5, 2024

Expected Behavior

Converging loss curves for all AutoGPTQ Quantized Linear Kernels

Description

CUDA, EXLLAMA, MARLIN do not work

The loss curves for Cuda/Exllama kernels do not converge and the Marlin loss goes to zero mid-training.

I set the self.kernel_switch_threshold attribute to False in the CUDA Qlinear to force the use of the kernel, I see the iteration time increasing dramatically and the loss not converging at all

My input dimensions (BS * seqlen) are however always larger than the default self.kernel_switch_threshold value and will never enter the condition x.shape[0] < self.kernel_switch_threshold to use the kernel in normal settings, it will always use standard torch.matmul and the resulting loss convergence look decent without the kernel

see this link for reference:
https://github.com/AutoGPTQ/AutoGPTQ/blob/866b4c8c2cbb893f1156cb6c114625bba2e4d7c5/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py#L348C1-L350C42

TritonV2 Works

The TritonV2 Qlinear kernel does show an acceptable loss convergence

I suspect this behaviour is with the Autograd backward function (or the lack of it in Cuda, Exllama and Marlin). In TritonV2, we see the backward function applied to the Autograd function in the QuantLinearFunction class (Refer below). However, the backward function isn't apparent in the other 3 kernels.

class QuantLinearFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
        output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
        ctx.save_for_backward(qweight, scales, qzeros, g_idx)
        ctx.bits, ctx.maxq = bits, maxq
        return output

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        qweight, scales, qzeros, g_idx = ctx.saved_tensors
        bits, maxq = ctx.bits, ctx.maxq
        grad_input = None

        if ctx.needs_input_grad[0]:
            grad_input = quant_matmul_248(
                grad_output, qweight, scales, qzeros, g_idx, bits, maxq, transpose=True
            )
        return grad_input, None, None, None, None, None, None

This problem has also been briefly talked about in issue #530 as well as Unsloth

Hardware details

  • AMD EPYC 7763 64-Core Processor
  • CPU RAM 256GB
  • NVIDIA A100 80GB

Software Version

Python Version = 3.10.8

requirements.txt

auto-gptq==0.8.0.dev0+cu121
accelerate==0.28.0
torch==2.2.0
trl==0.8.1
transformers==4.39.3
optimum==1.18.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12== 8.9.2.26
nvidia-cufft-cu12== 11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12== 12.4.99
nvidia-nvtx-cu12==12.1.105

Reproduce

from peft import (
    get_peft_model,
    prepare_model_for_kbit_training, 
    LoraConfig, TaskType, 
)
import torch
from transformers import GPTQConfig, AutoTokenizer, AutoModelForCausalLM        

args = {
    'batch_size': 4,
    'context_length': 512,
    'gradient_accumulation_steps': 1,
    'num_epochs': 1,
    'dataset': 'alpaca',
    'use_gradient_checkpointing': True,
    'precision': torch.float16,
    'model_name': 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ', 
    'output_dir': 'output',
    'lora_rank': 16,
    'lora_alpha': 16,
    'lora_dropout': 0,
    'lora_target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    'lr': 0.0002,
    'grad_norm': 0.3,
    'optimizer': 'adamw_torch',
    'lr_scheduler': 'linear',
    'seed': 42,
    'checkpoint_steps': 500,
    'logging_steps': 5,
    'load_in_4bit': True,
    'weight_decay': 0.01,
    'max_steps': 1000,
    'warmup_steps': 10,
    'use_unsloth': False,
    'use_gptq': True,
    'gptq_kernel': None,
    'log_to': 'stdout'
}

def load_optimum_gptq_model(args, num_bits=4):
    tokenizer = AutoTokenizer.from_pretrained(args['model_name'])
    tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = 'right'
   
    quantization_config = GPTQConfig(
        bits=num_bits, 
        max_input_length = args['context_length'],
        use_exllama = True if args['gptq_kernel'] == 'exllama' else False,
        exllama_config = {'version': 2} if args['gptq_kernel'] == 'exllama' else None,
    )

    model = AutoModelForCausalLM.from_pretrained(
        args['model_name'], 
        device_map="auto", 
        torch_dtype = args['precision'],
        quantization_config = quantization_config,
    )

    config = model.config.update({"pad_token_id" : tokenizer.unk_token_id})

    # prepare for quantized training
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing = args['use_gradient_checkpointing'],
    )

    # install adapters
    lora_config = LoraConfig(
        r              = args['lora_rank'],
        lora_alpha     = args['lora_alpha'],
        target_modules = args['lora_target_modules'],
        lora_dropout   = args['lora_dropout'],
        bias           = "none",
        task_type      = TaskType.CAUSAL_LM,
    )
    model = get_peft_model(model, lora_config)

    return model, tokenizer

exllama_model, exllama_tokenizer = load_optimum_gptq_model(args)

from trl import SFTTrainer
from datasets import load_dataset
from transformers import TrainingArguments

training_args = TrainingArguments(
        per_device_train_batch_size = args['batch_size'],
        gradient_accumulation_steps = args['gradient_accumulation_steps'],
        gradient_checkpointing=args['use_gradient_checkpointing'],
        warmup_steps = args['warmup_steps'],
        max_steps = args['max_steps'],
        learning_rate = args['lr'],
        logging_strategy = 'steps',
        logging_steps = args['logging_steps'],
        output_dir = args['output_dir'],
        optim = args['optimizer'],
        weight_decay = args['weight_decay'],
        lr_scheduler_type = args['lr_scheduler'],
        seed = args['seed'],
        include_tokens_per_second = True,
        max_grad_norm = args['grad_norm'],
        remove_unused_columns=False,
        adam_epsilon=1e-4,
    )

RESPONSE_TEMPLATE = "### Response:"
INSTRUCTION_TEMPLATE = "### Instruction:"

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n\n"
    ),
}

dataset = load_dataset("yahma/alpaca-cleaned", split="train")

def formatting_prompts_func(example):
    output_texts = []
    if example.get("input", "") == "":
        prompt = PROMPT_DICT["prompt_no_input"].format_map(example)
    else:
        prompt = PROMPT_DICT["prompt_input"].format_map(example)
    new_example = prompt + example["output"]
    return new_example

trainer = SFTTrainer(
    model = exllama_model,
    tokenizer = exllama_tokenizer,
    train_dataset = dataset,
    max_seq_length = args['context_length'],
    args = training_args,
    formatting_func=formatting_prompts_func,
    packing=True,
)

trainer_stats = trainer.train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants
@achew010 and others