Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [WORKING] dbrx (mod) support #625

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

Conversation

Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Mar 29, 2024

Attemping to hack 4bit quant using modfied model code from https://huggingface.co/databricks/dbrx-instruct/discussions/10 written by https://huggingface.co/fahadh4ilyas

As the title implies: this is pure HACK/mod using a converted model with different layout of weights:

This PR (now working!) requires a converted v2 model below:

dbrx-base (databricks original)

model: https://huggingface.co/databrickks/dbrx-base

#status
inference:  OK
quantize: FAIL (important layers skipped)
training bfloat16: FAIL (oom)

dbrx-base-converted v2

model: https://huggingface.co/LnL-AI/dbrx-base-converted-v2

# status
inference:  OK
quantize: SUCCESS! (Marlin validated, GPTQ pending)
training bfloat16: OK ~767GB vram required

converted-v2 4bit quants:

  1. 4bit gptq/marlin: https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin-v2
  2. 4bit gptq/gptq: https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-gptq-v2

Quant Script:

import os
import math


max_threads = str(1)
os.environ['OMP_NUM_THREADS'] = max_threads
os.environ['OPENBLAS_NUM_THREADS'] = max_threads
os.environ['MKL_NUM_THREADS'] = max_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = max_threads
os.environ['NUMEXPR_NUM_THREADS'] = max_threads
os.environ['NUMEXPR_MAX_THREADS'] = max_threads

import numpy as np
import torch
import os
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging
from datasets import load_dataset
from transformers import AutoTokenizer

logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

max_seq_len = 32768
num_samples = 1024

model_id_or_path = "/monster/data/model/dbrx-base-converted-v2/"
quantized_model_dir = os.path.join(model_id_or_path, "quant")
quantized_model_dir = os.path.join(quantized_model_dir, "qubitium-v11-2048-32k-marlin")

os.makedirs(quantized_model_dir, exist_ok=True)

print("pretrained_model_dir", model_id_or_path)
print("quantized_model_dir", quantized_model_dir)

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:10%]")
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
ds = dataset.shuffle().select(range(num_samples))

examples = [
    tokenizer(
        example["text"], padding=False, max_length=max_seq_len, truncation=True,
    ) for example in ds
]

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # desc_act and group size only works on triton
    damp_percent=0.005,
    quant_method="gptq",
    checkpoint_format="marlin",
)

# load un-quantized model, the model will always be force loaded into cpu
model = AutoGPTQForCausalLM.from_pretrained(model_id_or_path, quantize_config, trust_remote_code=True, device_map="auto")

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(examples)

# save quantized model using safetensors
model.save_quantized(quantized_model_dir)

print(quantized_model_dir)

@LaaZa
Copy link
Contributor

LaaZa commented Mar 29, 2024

I don't think norm_1 and 2 or router helps here. The MoE achitechture used is very unusual. https://huggingface.co/databricks/dbrx-instruct/discussions/10

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 29, 2024

@LaaZa still trying to get quant to complete. Wqkv has massive loss values so not sure if this layer is even good for quant or current method not optimized for this Wqkv layer.

@Qubitium Qubitium changed the title [WIP] Add dbrx support [WIP] Hack dbrx support Mar 29, 2024
@Qubitium
Copy link
Contributor Author

status update: the quantize() stage is currently skipping over the following 3 important layers:

        ["ffn.experts.mlp.w1"],
        ["ffn.experts.mlp.w2"],
        ["ffn.experts.mlp.v1"],

and I am almost out of time today to work on this.

@LaaZa
Copy link
Contributor

LaaZa commented Mar 29, 2024

You need to have index for the experts if you use the modified model. For now, you need to duplicate the mlp lines for each from 0-15. Also the correct order is [w1, v1], [w2]

ffn.experts.mlp.0.w1

@Qubitium
Copy link
Contributor Author

@LaaZa Thanks for the fix. It's now slowly going down the layers. Hopefully this does the trick until the next bug.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 29, 2024

@LaaZa Btw, why is quantizing norm1, norm2, and router layers not helpful? I have little experience in the model layer code to infer the reason. Thanks. If we exclude the norm1, norm2 layers, are their weights retained in full float in the quantized weights? If this is the cause, I might want to test weeding out Wqkv layer too due to the massive 300+ loss values on almost every layer.

@LaaZa
Copy link
Contributor

LaaZa commented Mar 29, 2024

Well in general I would skip anything with wrong shape. Normalization modules are usually skipped and in this model they have the shape [6144] so it's one dimensional and we need both infeatures and outfeatures to be divisible by 32. router.layer has the shape [16, 6144] so the outfeatures are too small.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 29, 2024

Well in general I would skip anything with wrong shape. Normalization modules are usually skipped and in this model they have the shape [6144] so it's one dimensional and we need both infeatures and outfeatures to be divisible by 32. router.layer has the shape [16, 6144] so the outfeatures are too small.

Thanks. That make perfect sense.

I hacked the exllama (v1) code which is used in quant to pad to the correct outfeatures. autogptq exllama v2 had padding (but likely broken due to non-assignment #626) so maybe it will work? We shall see.

Anyways testing 4 quant tasks with various layers disabled at the same time to see if we hit the pot at the end of the rainbow.

@Qubitium Qubitium changed the title [WIP] Hack dbrx support [WIP] [Hackfudgery] dbrx support Mar 29, 2024
@fahadh4ilyas
Copy link

@LaaZa Btw, why is quantizing norm1, norm2, and router layers not helpful? I have little experience in the model layer code to infer the reason. Thanks. If we exclude the norm1, norm2 layers, are their weights retained in full float in the quantized weights? If this is the cause, I might want to test weeding out Wqkv layer too due to the massive 300+ loss values on almost every layer.

I wonder if we split Wqkv layer would help resolve the problem with massive loss

@Qubitium
Copy link
Contributor Author

So far quant is stuck on packing. It doesn't error but just take 1 core 100% cpu in an apparent loop on something. Going to stop testing until there is more updates on the model.

@Qubitium
Copy link
Contributor Author

@fahadh4ilyas Will restart and test quantization based on your new split Wqkv layers.

@Qubitium
Copy link
Contributor Author

I wonder if we split Wqkv layer would help resolve the problem with massive loss

Confirmed v2 converted has sane/normal losses for the split q,k,v layers! Waiting it for to finish now to check if inf is ok.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 30, 2024

Update: We have a problem. Packing of the v1, w1, w2 layers are extremely slow:

Packing transformer.blocks.2.ffn.experts.mlp.9.v1...:   7%|████████▌ | 139/2120 [3:14:41<39:18:27, 71.43s/it]

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 31, 2024

I have started a second quant with norm_1/2 + router removed. Will test inference on both to validate.

first one with all layers is 70% complete.

@Qubitium
Copy link
Contributor Author

Test quant for all layers (including norm_1/2 + router) finished quantize stage but again stuck on super slow packing:

Packing transformer.blocks.0.ffn.experts.mlp.10.v1...:   2%|██    | 36/2120 [56:34<72:06:44, 124.57s/it]

I have never quantized such a large model and unsure if the massive slow down in pack is something normal. The problem is not quantizing, but packing at the moment.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 31, 2024

@LaaZa @fahadh4ilyas @fxmarty @qwopqwop200 The current quant progress of dbrx-base-converted-v2 is hitting a roadblock off ungodly slow packing stage. I am quanting with 1x 4090 + zen3 2x7334 4ghz cpu + 2tb of ram + zero swap enabled so hardware is not an issue here. A little bummed that I can't even get a single quant to complete, let alone test. Why is packing so slow? it's lke 100x slower than the quantizing stage. We are so close!

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # desc_act and group size only works on triton
    damp_percent=0.005,
    quant_method="gptq",
    checkpoint_format="gptq",
)

@LaaZa
Copy link
Contributor

LaaZa commented Mar 31, 2024

Take a look at #439 if that helps. It's going to be slow though as it happens on the cpu.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 31, 2024

I have 4 * A800-80GB 500GB ram what can i help?

@Xu-Chen Please help by testing quantizing using dbrx-base-converted-v2. You only need 1 GPU for quantize stage but make sure cpu is not is used by others since the packing stage is pure cpu.

my current test script:

import os
max_cpu_threads = "24" # change this to 1/2 of number of cores show in os
os.environ["OMP_NUM_THREADS"] = max_cpu_threads
os.environ["OPENBLAS_NUM_THREADS"] = max_cpu_threads
os.environ["MKL_NUM_THREADS"] = max_cpu_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_NUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_MAX_THREADS"] = max_cpu_threads

import numpy as np
import torch
import os
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

pretrained_model_dir = "/monster/data/model/dbrx-base-converted-v2/"
quantized_model_dir = os.path.join(pretrained_model_dir, "quant")
quantized_model_dir = os.path.join(quantized_model_dir, "4bit-v10")

os.makedirs(quantized_model_dir, exist_ok=True)

print("pretrained_model_dir", pretrained_model_dir)
print("quantized_model_dir", quantized_model_dir)

def get_wikitext2(nsamples, seed, seqlen, model):
    from datasets import load_dataset

    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

    import random

    random.seed(seed)
    np.random.seed(0)
    torch.random.manual_seed(0)

    traindataset = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        traindataset.append({"input_ids": inp, "attention_mask": attention_mask})


    tokenizer.save_pretrained(quantized_model_dir)
    return traindataset, testenc


traindataset, testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # desc_act and group size only works on triton
    damp_percent=0.005,
    quant_method="gptq",
    checkpoint_format="gptq",
)

# load un-quantized model, the model will always be force loaded into cpu
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, trust_remote_code=True, device_map="auto")

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)

# save quantized model using safetensors
model.save_quantized(quantized_model_dir)

print(quantized_model_dir)

@Xu-Chen
Copy link

Xu-Chen commented Mar 31, 2024

I have 4 * A800-80GB 500GB ram what can i help?

@Xu-Chen Please help by testing quantizing using dbrx-base-converted-v2. You only need 1 GPU for quantize stage but make sure cpu is not is used by others since the packing stage is pure cpu.

my current test script:

import os
max_cpu_threads = "24" # change this to 1/2 of number of cores show in os
os.environ["OMP_NUM_THREADS"] = max_cpu_threads
os.environ["OPENBLAS_NUM_THREADS"] = max_cpu_threads
os.environ["MKL_NUM_THREADS"] = max_cpu_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_NUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_MAX_THREADS"] = max_cpu_threads

import numpy as np
import torch
import os
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

pretrained_model_dir = "/monster/data/model/dbrx-base-converted-v2/"
quantized_model_dir = os.path.join(pretrained_model_dir, "quant")
quantized_model_dir = os.path.join(quantized_model_dir, "4bit-v10")

os.makedirs(quantized_model_dir, exist_ok=True)

print("pretrained_model_dir", pretrained_model_dir)
print("quantized_model_dir", quantized_model_dir)

def get_wikitext2(nsamples, seed, seqlen, model):
    from datasets import load_dataset

    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

    import random

    random.seed(seed)
    np.random.seed(0)
    torch.random.manual_seed(0)

    traindataset = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        traindataset.append({"input_ids": inp, "attention_mask": attention_mask})


    tokenizer.save_pretrained(quantized_model_dir)
    return traindataset, testenc


traindataset, testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
    desc_act=False,  # desc_act and group size only works on triton
    damp_percent=0.005,
    quant_method="gptq",
    checkpoint_format="gptq",
)

# load un-quantized model, the model will always be force loaded into cpu
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, trust_remote_code=True, device_map="auto")

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset, use_triton=False)

# save quantized model using safetensors
model.save_quantized(quantized_model_dir)

print(quantized_model_dir)

Thank you , i will try. The following is the log of dbrx-instruct-converted-v2. Downloading dbrx-base-converted-v2.
image

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 31, 2024

Success! I will push quantized model for you guys to test. It will need about 68GB of vram to run so 1xA100 80GB will do.

EDIT: now the issue is that HF has limit of 50GB file each. The quant tensor is 67GB

UPDATE: uploading...into 2 split file.. with script/shell commands to recombine them

https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin

@Xu-Chen
Copy link

Xu-Chen commented Mar 31, 2024

Success! I will push quantized model for you guys to test. It will need about 68GB of vram to run so 1xA100 80GB will do.

EDIT: now the issue is that HF has limit of 50GB file each. The quant tensor is 67GB

UPDATE: uploading...into 2 split file.. with script/shell commands to recombine them

Thank you. I use convert_v2.py to convert dbrx-instruct as dbrx-instruct-converted-v2. But the avg loss is small at the beginning , gradually increases, as shown below. Is there any problem?

2024-03-31 21:58:38 INFO [auto_gptq.quantization.gptq] duration: 3.117948532104492 2024-03-31 21:58:38 INFO [auto_gptq.quantization.gptq] avg loss: 14.791854858398438 INFO - Quantizing ffn.experts.mlp.9.w2 in layer 40/40... 2024-03-31 21:58:41 INFO [auto_gptq.quantization.gptq] duration: 3.198195219039917 2024-03-31 21:58:41 INFO [auto_gptq.quantization.gptq] avg loss: 37.51245880126953 INFO - Quantizing ffn.experts.mlp.10.w2 in layer 40/40... 2024-03-31 21:58:44 INFO [auto_gptq.quantization.gptq] duration: 3.11279034614563 2024-03-31 21:58:44 INFO [auto_gptq.quantization.gptq] avg loss: 28.414779663085938 INFO - Quantizing ffn.experts.mlp.11.w2 in layer 40/40... 2024-03-31 21:58:47 INFO [auto_gptq.quantization.gptq] duration: 3.0984387397766113 2024-03-31 21:58:48 INFO [auto_gptq.quantization.gptq] avg loss: 47.26870346069336 INFO - Quantizing ffn.experts.mlp.12.w2 in layer 40/40... 2024-03-31 21:58:51 INFO [auto_gptq.quantization.gptq] duration: 3.0781986713409424 2024-03-31 21:58:51 INFO [auto_gptq.quantization.gptq] avg loss: 21.394535064697266 INFO - Quantizing ffn.experts.mlp.13.w2 in layer 40/40... 2024-03-31 21:58:54 INFO [auto_gptq.quantization.gptq] duration: 3.0461459159851074 2024-03-31 21:58:54 INFO [auto_gptq.quantization.gptq] avg loss: 76.7130126953125 INFO - Quantizing ffn.experts.mlp.14.w2 in layer 40/40... 2024-03-31 21:58:57 INFO [auto_gptq.quantization.gptq] duration: 3.075993299484253 2024-03-31 21:58:57 INFO [auto_gptq.quantization.gptq] avg loss: 26.178302764892578 INFO - Quantizing ffn.experts.mlp.15.w2 in layer 40/40... 2024-03-31 21:59:00 INFO [auto_gptq.quantization.gptq] duration: 3.0906994342803955 2024-03-31 21:59:00 INFO [auto_gptq.quantization.gptq] avg loss: 19.860462188720703

@Qubitium
Copy link
Contributor Author

Success! I will push quantized model for you guys to test. It will need about 68GB of vram to run so 1xA100 80GB will do.
EDIT: now the issue is that HF has limit of 50GB file each. The quant tensor is 67GB
UPDATE: uploading...into 2 split file.. with script/shell commands to recombine them

Thank you. I use convert_v2.py to convert dbrx-instruct as dbrx-instruct-converted-v2. But the avg loss is small at the beginning , gradually increases, as shown below. Is there any problem?

2024-03-31 21:58:38 INFO [auto_gptq.quantization.gptq] duration: 3.117948532104492 2024-03-31 21:58:38 INFO [auto_gptq.quantization.gptq] avg loss: 14.791854858398438 INFO - Quantizing ffn.experts.mlp.9.w2 in layer 40/40... 2024-03-31 21:58:41 INFO [auto_gptq.quantization.gptq] duration: 3.198195219039917 2024-03-31 21:58:41 INFO [auto_gptq.quantization.gptq] avg loss: 37.51245880126953 INFO - Quantizing ffn.experts.mlp.10.w2 in layer 40/40... 2024-03-31 21:58:44 INFO [auto_gptq.quantization.gptq] duration: 3.11279034614563 2024-03-31 21:58:44 INFO [auto_gptq.quantization.gptq] avg loss: 28.414779663085938 INFO - Quantizing ffn.experts.mlp.11.w2 in layer 40/40... 2024-03-31 21:58:47 INFO [auto_gptq.quantization.gptq] duration: 3.0984387397766113 2024-03-31 21:58:48 INFO [auto_gptq.quantization.gptq] avg loss: 47.26870346069336 INFO - Quantizing ffn.experts.mlp.12.w2 in layer 40/40... 2024-03-31 21:58:51 INFO [auto_gptq.quantization.gptq] duration: 3.0781986713409424 2024-03-31 21:58:51 INFO [auto_gptq.quantization.gptq] avg loss: 21.394535064697266 INFO - Quantizing ffn.experts.mlp.13.w2 in layer 40/40... 2024-03-31 21:58:54 INFO [auto_gptq.quantization.gptq] duration: 3.0461459159851074 2024-03-31 21:58:54 INFO [auto_gptq.quantization.gptq] avg loss: 76.7130126953125 INFO - Quantizing ffn.experts.mlp.14.w2 in layer 40/40... 2024-03-31 21:58:57 INFO [auto_gptq.quantization.gptq] duration: 3.075993299484253 2024-03-31 21:58:57 INFO [auto_gptq.quantization.gptq] avg loss: 26.178302764892578 INFO - Quantizing ffn.experts.mlp.15.w2 in layer 40/40... 2024-03-31 21:59:00 INFO [auto_gptq.quantization.gptq] duration: 3.0906994342803955 2024-03-31 21:59:00 INFO [auto_gptq.quantization.gptq] avg loss: 19.860462188720703

Yes. The increase of loss in as layers increase is also observed in my end. For now, we don't have much to go on. Calibration may need to be tweaked to optimize for dbrx. Right now generation is different than bfloat16 in my limited testing but still coherent. However, there may be EOS problem where it is not stopping. These are minor issues as more users start to quant and find the quirks.

@Qubitium
Copy link
Contributor Author

We need to add a feature to save_quantized into multiple files with max_split_size and load multiple files with a json map just like normal from_pretrained api.

@Qubitium Qubitium changed the title [WIP] [Hackfudgery] dbrx support [WIP] [WORKING] dbrx support Mar 31, 2024
@Qubitium Qubitium changed the title [WIP] [WORKING] dbrx support [WIP] [WORKING] dbrx (mod) support Mar 31, 2024
@Qubitium
Copy link
Contributor Author

@LaaZa @fahadh4ilyas @Xu-Chen If you have the vram, please test the 4bit marlin inference with https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin and let me if you are getting coherent responses. Note the loading time is quite long.

@Qubitium
Copy link
Contributor Author

I think the calibration code is broken for MOE style dbrx. Based on the massive escalation of errors of error loss on as we progress through the layers 1-40.

INFO - Quantizing ffn.experts.mlp.0.v1 in layer 11/40...
2024-03-31 15:11:09 INFO [auto_gptq.quantization.gptq] duration: 1.6267049312591553
2024-03-31 15:11:09 INFO [auto_gptq.quantization.gptq] avg loss: 28.512256622314453
INFO - Quantizing ffn.experts.mlp.1.w1 in layer 11/40...
2024-03-31 15:11:10 INFO [auto_gptq.quantization.gptq] duration: 1.6261475086212158
2024-03-31 15:11:10 INFO [auto_gptq.quantization.gptq] avg loss: 17.44123077392578
INFO - Quantizing ffn.experts.mlp.1.v1 in layer 11/40...
2024-03-31 15:11:12 INFO [auto_gptq.quantization.gptq] duration: 1.628706932067871
2024-03-31 15:11:12 INFO [auto_gptq.quantization.gptq] avg loss: 15.289453506469727
INFO - Quantizing ffn.experts.mlp.2.w1 in layer 11/40...
2024-03-31 15:11:13 INFO [auto_gptq.quantization.gptq] duration: 1.6164577007293701
2024-03-31 15:11:13 INFO [auto_gptq.quantization.gptq] avg loss: 18.87800407409668

INFO - Quantizing ffn.experts.mlp.14.w1 in layer 23/40...
2024-03-31 15:42:21 INFO [auto_gptq.quantization.gptq] duration: 1.6099011898040771
2024-03-31 15:42:21 INFO [auto_gptq.quantization.gptq] avg loss: 52.00520324707031
INFO - Quantizing ffn.experts.mlp.14.v1 in layer 23/40...
2024-03-31 15:42:22 INFO [auto_gptq.quantization.gptq] duration: 1.6029260158538818
2024-03-31 15:42:22 INFO [auto_gptq.quantization.gptq] avg loss: 45.09771728515625
INFO - Quantizing ffn.experts.mlp.15.w1 in layer 23/40...
2024-03-31 15:42:24 INFO [auto_gptq.quantization.gptq] duration: 1.611437320709228
2024-03-31 16:23:09 INFO [auto_gptq.quantization.gptq] duration: 2.92622447013855
2024-03-31 16:23:09 INFO [auto_gptq.quantization.gptq] avg loss: 159.55865478515625
INFO - Quantizing ffn.experts.mlp.2.w2 in layer 39/40...
2024-03-31 16:23:12 INFO [auto_gptq.quantization.gptq] duration: 2.8120198249816895
2024-03-31 16:23:12 INFO [auto_gptq.quantization.gptq] avg loss: 204.56991577148438
INFO - Quantizing ffn.experts.mlp.3.w2 in layer 39/40...
2024-03-31 16:23:14 INFO [auto_gptq.quantization.gptq] duration: 2.823086977005005
2024-03-31 16:23:14 INFO [auto_gptq.quantization.gptq] avg loss: 162.71250915527344
INFO - Quantizing ffn.experts.mlp.4.w2 in layer 39/40...
2024-03-31 16:23:17 INFO [auto_gptq.quantization.gptq] duration: 2.816202163696289
2024-03-31 16:23:17 INFO [auto_gptq.quantization.gptq] avg loss: 100.16374206542969
INFO - Quantizing ffn.experts.mlp.5.w2 in layer 39/40...
2024-03-31 16:23:20 INFO [auto_gptq.quantization.gptq] duration: 2.8179235458374023
2024-03-31 16:23:20 INFO [auto_gptq.quantization.gptq] avg loss: 126.32321166992188

@Xu-Chen
Copy link

Xu-Chen commented Apr 1, 2024

@Qubitium
If we use more samples for quantization, the final loss will decrease.

My current test script:

import os
max_cpu_threads = "24"
os.environ["OMP_NUM_THREADS"] = max_cpu_threads
os.environ["OPENBLAS_NUM_THREADS"] = max_cpu_threads
os.environ["MKL_NUM_THREADS"] = max_cpu_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_NUM_THREADS"] = max_cpu_threads
os.environ["NUMEXPR_MAX_THREADS"] = max_cpu_threads

import logging

logging.basicConfig(
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

import argparse, gc, shutil
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str) # input model path
parser.add_argument("--gptq-save-dir", type=str)
parser.add_argument("--channelwise", action="store_true") 
parser.add_argument("--num-samples", type=int, default=1024) # The more samples, the slower the quantization speed
parser.add_argument("--max-seq-len", type=int, default=32768) 

def preprocess(example):
        return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}

if __name__ == "__main__":
    args = parser.parse_args()

    dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:5%]")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
    ds = dataset.shuffle().select(range(args.num_samples))
    ds = ds.map(preprocess)

    examples = [
        tokenizer(
            example["text"], padding=False, max_length=args.max_seq_len, truncation=True,
        ) for example in ds
    ]

    if args.channelwise:
        group_size = -1
    else:
        group_size = 128

    quantize_config = BaseQuantizeConfig(
        bits=4,                         # Only support 4 bit
        group_size=group_size,          # Set to g=128 or -1 (for channelwise)
        desc_act=False,                 # Marlin does not suport act_order=True
        model_file_base_name="model",    # Name of the model.safetensors when we call save_pretrained
        quant_method="gptq",
        checkpoint_format="gptq"
    )

    model = AutoGPTQForCausalLM.from_pretrained(
        args.model_id,
        quantize_config,
        device_map="auto",
        trust_remote_code=True)
    model.quantize(examples)

    print(f"Saving gptq model to {args.gptq_save_dir}")
    model.save_pretrained(args.gptq_save_dir )
    # model.save_pretrained(args.gptq_save_dir, max_split_size="5GB")
    tokenizer.save_pretrained(args.gptq_save_dir)

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 1, 2024

@Qubitium If we use more samples for quantization, the final loss will decrease.

Nice! Can you post some of the layers losses from beginning/middle/end? To confirm, you made two changes:

  1. max_sample from 128 -> 1024
  2. max seq len from 2048 -> 32768
  3. wikitext2 to ultrachat

Are you using the above two values?

also use cores / 3 for max threads for max pack performance. 1/2 may still lead to slow down

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 1, 2024

We need to add a feature to save_quantized into multiple files with max_split_size and load multiple files with a json map just like normal from_pretrained api.

@LaaZa Already made this shard loading code. I just need to shared on save.

@LaaZa
Copy link
Contributor

LaaZa commented Apr 1, 2024

We need to add a feature to save_quantized into multiple files with max_split_size and load multiple files with a json map just like normal from_pretrained api.

@LaaZa Already made this shard loading code. I just need to shared on save.

We need to reimplement #364 saving works but it is obviously outdated and there was a weird issue where sharded saving with it breaks fused.

@Xu-Chen
Copy link

Xu-Chen commented Apr 1, 2024

@Qubitium If we use more samples for quantization, the final loss will decrease.

Nice! Can you post some of the layers losses from beginning/middle/end? To confirm, you made two changes:

  1. max_sample from 128 -> 1024
  2. max seq len from 2048 -> 32768

Are you using the above two values?

Yes,and I used samples from HuggingFaceH4/ultrachat_200k and applied apply_chat_template,while the model is dbrx-instruct-converted-v2 (use convert_v2.py to convert dbrx-instruct as dbrx-instruct-converted-v2).

The log has been overwritten by new output, but the final avg loss is less than 2.

Use 4096 samples, here is some log, but too slow and oom.

image

image

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 1, 2024

@Xu-Chen To anyone trying to quant, do not set more than 8 threads. I find negative returns. Warn/FIX in PR #628

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 1, 2024

@Xu-Chen
Copy link

Xu-Chen commented Apr 1, 2024

Two new quants have started based on calibration fix from @Xu-Chen . Marlin + Non-Marlin

  1. https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin-v2
  2. https://huggingface.co/LnL-AI/dbrx-base-converted-v2-4bit-gptq-gptq-v2 (PENDING)

The next step should be to deploy the model using vllm

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 1, 2024

The next step should be to deploy the model using vllm

Our vllm PR for dbrx-converted-v2 should be ready soon.

@Qubitium
Copy link
Contributor Author

Qubitium commented Apr 5, 2024

Our vllm PR for dbrx-converted-v2 should be ready soon.

Unfortunately dbrx finetuning has failed my internal quality metrics so I will not be spending time to port dbrx-converted to vllm. But if you plan to do so there are two paths: 1) reverse the convert_v2.py splits so existing works or 2) use the new inf code but load the weights 1/tp_size split and keep the tp_rank slice.

@fxmarty
Copy link
Collaborator

fxmarty commented Apr 12, 2024

Sorry for the delay, I have been busy.

Why is packing so slow? it's lke 100x slower than the quantizing stage. We are so close!

I never had a look at the implementation, it is probably suboptimal. Either I need to implement a CUDA kernel for packing, or there's a better implementation possible in PyTorch. Let me have a look next week.

@Qubitium
Copy link
Contributor Author

Dbrx proper support is finally merged to transformers:

huggingface/transformers#29921

The whole quant code path needs to be visited again. Maybe they fixed the two issues that plagued the pre-transformer merge: 1) training is not possible due to oom of fused layers 2) quant is not possible/feasible due the fused layers.

I no longer have the time to work on this. If you want to take over this WIP PR, you can fork it or let me know and I will add your to the branch permissions.

@LaaZa
Copy link
Contributor

LaaZa commented Apr 19, 2024

Dbrx proper support is finally merged to transformers:

huggingface/transformers#29921

The whole quant code path needs to be visited again. Maybe they fixed the two issues that plagued the pre-transformer merge: 1) training is not possible due to oom of fused layers 2) quant is not possible/feasible due the fused layers.

I no longer have the time to work on this. If you want to take over this WIP PR, you can fork it or let me know and I will add your to the branch permissions.

Well, #623 is meant to support the unmodified model if possible. I was waiting for resolution on the transformers implementation to see if it will be possible then. Do you think the modified version of the model is still needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants