Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using PEFT causes model to not predict EOS #1672

Open
2 of 4 tasks
Km3888 opened this issue Apr 23, 2024 · 12 comments
Open
2 of 4 tasks

Using PEFT causes model to not predict EOS #1672

Km3888 opened this issue Apr 23, 2024 · 12 comments

Comments

@Km3888
Copy link

Km3888 commented Apr 23, 2024

System Info

peft version: 0.9.0
accelerate version: 0.27.2
transformers version: 4.37.0
trl version: 0.7.12.dev0
base model: openai-community/gpt2
hardware: 2xA100

I'm doing a LORA peft of GPT2 through trl and have noticed that my trained model assigns very low probability to the EOS token which causes it to alway generate the maximum number of tokens.

After trying a few different fixes I ran the code without the PEFT option and just used the base model. The problem resolved immediately.

To make the comparison clear I created a toy case with a dataset that contains the same datapoint ("Hello <|endoftext|>") repeatedly. I then overfit on this dataset with a small batch size for a few dozen iterations. To see the effect on the probability of generating the eos_token I inserted the following code fragment in my compute_metrics method:

logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)

The basic full finetuning results in the EOS token probability converging to 1 almost immediately as the model memorizes the location of the EOS tokens. However if I just use TRL's code for a LORA PEFT the printed values remain close to zero and don't increase at all.

I've seen some references online suggesting that this could be caused by LORA not updating the model's embedding matrix. So I added the following change to the peft_config: peft_config.modules_to_save = ["wte"]. This doesn't have any effect on the results. I'm also doubtful this is the cause as when I run the supervised finetuning I don't see any change in the embedding matrix but get the desired results anyway.

Any help would be appreciated as I would like to avoid a full finetuning but right now have no way of getting a functional model with a PEFT.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Use the following model_config (note the PEFT parameters) and training arguments:

ModelConfig(model_name_or_path='openai-community/gpt2', model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=True, lora_r=64, lora_alpha=16, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)

TrainingArguments(
n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=1,
eval_delay=0,
eval_steps=10,
evaluation_strategy=steps,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=1.41e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=/scratch/km3888/gcode_peft/
/runs/Apr23_14-18-54_gh004.hpc.nyu.edu,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1.0,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=linear,
max_grad_norm=1.0,
max_steps=20000,
metric_for_best_model=None,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=1.0,
optim=adamw_torch,
optim_args=None,
output_dir=/scratch/km3888/gcode_peft/,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=4,
per_device_train_batch_size=4,
prediction_loss_only=False,
push_to_hub=True,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['wandb'],
resume_from_checkpoint=None,
run_name=/scratch/km3888/gcode_peft/
,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
skip_memory_metrics=True,
split_batches=False,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.0,
)

Create dataset:

import copy
dummy_data = [{"text":"Hello <|endoftext|>"} for _ in range(1000)]
with open("dummy_data.json","w") as f:
json.dump(dummy_data,f)
full_dataset = load_dataset('json', data_files="dummy_data.json",split='train')
full_dataset = full_dataset.map(lambda x: {'text':add_eos(x['text'])})
split_data = full_dataset.train_test_split(test_size=0.05)
train_dataset = split_data['train'].shuffle()
eval_dataset = copy.deepocpy(train_dataset)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, token=access_token,use_fast=True, add_eos=True)

Set up custom evaluation function:

def compute_metrics(eval_preds):
metric = evaluate.load("accuracy",training_args.output_dir.split('/')[-1])
logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)
predictions = np.argmax(logits,axis=-1)
predictions = np.reshape(predictions.astype(np.int32),-1)
labels = np.reshape(labels.astype(np.int32),-1)
return metric.compute(predictions=predictions, references=labels)

Instantiate and run SFTTrainer

trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=False,
peft_config=get_peft_config(model_config),
compute_metrics=compute_metrics,
dataset_num_proc=20)

trainer.train()

The eos_probs printed in compute_metrics will be near-zero

Expected behavior

I would expect the above code to result in eos_probs values being nearly 1 after a few training iterations.

@BenjaminBossan
Copy link
Member

Good idea training on that dummy dataset to debug what's happening. Do you see any model progress at all or is the output basically static? Any logs you can share could be helpful.

Regarding the SFTTrainer, did you per chance try to train this model with vanilla PyTorch? The number of arguments is so big that it's hard for me to understand if any of them could be causing the issue.

@Km3888
Copy link
Author

Km3888 commented Apr 25, 2024

Yes I should have specified that the training and validation loss both go down when training on my regular dataset (as well as the dummy dataset).

Unfortunately I haven't tried training the model directly without SFTTrainer. I'm also a bit new to this library and don't fully understand their implementation.

@BenjaminBossan
Copy link
Member

I see. Could you please paste your complete python code so that we can try to replicate the issue?

@younesbelkada Can you see anything wrong with the shown code?

@geronimi73
Copy link

lora_target_modules=None

This might be the problem. Try "all-linear"

@BenjaminBossan
Copy link
Member

This might be the problem. Try "all-linear"

Note that this is only an issue if the model is not one of the pre-configured models. If there is no matching layer at all, we raise an error, so it shouldn't go unnoticed.

@derekelewis
Copy link

Running into this issue myself with 'all-linear' and SFTTrainer using peft main, transformers 4.40.2, and trl 0.8.6. No PEFT and EOS is predicted fine. w/ PEFT and EOS is not predicted correctly. The prediction just continues until max_tokens is reached.

@BenjaminBossan
Copy link
Member

@derekelewis If you have a minimal reproducer to share, that would be great.

@derekelewis
Copy link

derekelewis commented May 10, 2024

@BenjaminBossan see below. Tried to simplify as much as possible. Also uploaded fine-tuned models to hub w/ PEFT enabled & disabled. TRL seems to be having some issues w/ chat_templates & EOS in general (huggingface/trl#1412, huggingface/trl#1623, huggingface/trl#1578), but I think it is separate from what is going on here.

PEFT enabled: https://huggingface.co/delewis/gemma-2b-peft-eos-issue
PEFT disabled: https://huggingface.co/delewis/gemma-2b-no-peft-eos-issue

Training script:

from typing import Optional
from dataclasses import dataclass, field
from datasets import load_dataset, Features, Value
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, setup_chat_format
import torch
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
)
from peft import LoraConfig
import torch.distributed as dist

import os

model_id = "google/gemma-2b"
dataset_id = "databricks/databricks-dolly-15k"

system_message = "You are a helpful assistant that answers the questions to the best of your ability."


def create_conversation(sample):
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": sample["instruction"]},
            {"role": "assistant", "content": sample["response"]},
        ]
    }


def training_function(training_args):

    dataset = load_dataset(dataset_id)

    dataset = dataset["train"].train_test_split(test_size=0.1)

    train_dataset = dataset["train"]
    val_dataset = dataset["test"]

    train_dataset = train_dataset.map(
        create_conversation,
        batched=False,
        remove_columns=train_dataset.column_names,
    )

    val_dataset = val_dataset.map(
        create_conversation,
        batched=False,
        remove_columns=val_dataset.column_names,
    )

    # print example from train_dataset
    print("Example from train_dataset:")
    print(train_dataset[0])

    # print example from val_dataset
    print("Example from val_dataset:")
    print(val_dataset[0])

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_cache=False,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        device_map=None if training_args.fsdp else "auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    model, tokenizer = setup_chat_format(model, tokenizer)

    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",
        modules_to_save=["lm_head", "embed_tokens"],
    )

    max_seq_length = 8192

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        args=training_args,
    )

    trainer.train()

    trainer.save_model(training_args.output_dir)
    tokenizer.save_pretrained(training_args.output_dir)


def main():
    parser = HfArgumentParser((TrainingArguments,))
    training_args = parser.parse_args_into_dataclasses()[0]

    set_seed(training_args.seed)

    training_function(training_args)


if __name__ == "__main__":
    main()

Test script:

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
from peft import AutoPeftModelForCausalLM

model_id = "./output"

# model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model = AutoPeftModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompt = [{"role": "user", "content": "What is the capital of France?"}]

prompt = tokenizer.apply_chat_template(
    prompt, tokenize=False, add_generation_prompt=True
)

print(f"Prompt: {prompt}")

text_pipeline = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, return_full_text=False
)

print(
    text_pipeline(
        prompt,
        max_length=2048,
        do_sample=False,
        num_return_sequences=1,
    )
)

Output of test script w/ PEFT enabled:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.                                                                           
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.                                                                           
Prompt: <|im_start|>user                                                                                                                                                                        
What is the capital of France?<|im_end|>                                                                                                                                                        <|im_start|>assistant         
                                                                                                                                                                                                The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].             
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
[{'generated_text': 'Paris is the capital of France.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to tra
vel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions t
o the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quicke
st way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to 
travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast 
Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quic
kest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers 
the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by pla
ne.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is t
he best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the
 Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassi
stant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistan
t that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\n
user\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by
 driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to 
New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San F
rancisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of 
your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to trav
el by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from Sa
n Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsyst
em\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to tr
avel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions 
to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quick
est way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to
 travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast
 Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe qui
ckest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers
 the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by pl
ane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is 
the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on th
e Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nass
istant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assista
nt that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to Ne
w York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\
nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is b
y driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to
 New York?\nassistant\nThe quickest way to travel from San Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a 
helpful assistant that answers the questions to the best of your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to travel from San 
Francisco to New York is by plane.  The quickest way to travel by car is by driving on the Pacific Coast Highway.\nsystem\nYou are a helpful assistant that answers the questions to the best of
 your ability.\nuser\nWhat is the best way to travel from San Francisco to New York?\nassistant\nThe quickest way to'}]

Output w/ PEFT disabled:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Prompt: <|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
[{'generated_text': 'The capital of France is Paris'}]

@BenjaminBossan
Copy link
Member

@derekelewis Thanks for the script. Unfortunately I could not run it due to memory constraints, but it's still helpful. I can spot 3 potential issues:

  1. You're using a rank of 64 with LoRA alpha of 16, which is quite low, it is mostly recommended to have alpha = 2 * rank.
  2. Check the number of trainable parameters (trainer.model.print_trainable_parameters()), it is quite huge. This is because you want to fully fine-tune the input embeddings and output layer, which is gigantic for gemma (especially for the 2b model). You could consider applying LoRA fine-tuning to that layer instead (target_modules=["q_proj", "v_proj", "down_proj", "lm_head", "embed_tokens"], modules_to_save=[]).
  3. You're using bfloat16 to load the model. This results in LoRA weights also being in bfloat16, which could lead to unstable training. When loading a model with bnb (QLoRA), trl will automatically cast the LoRA weights to float32, but not here. What you could try is to add this snippet before calling trainer.train():
for p in trainer.model.parameters():
    if p.requires_grad:
        p.data = p.data.float()

Also ensure that you enable bf16=True for the SFTTrainer to use AMP. This will overall cost more memory than before, but might stabilize training. Could you check if that helps?

@AIR-hl
Copy link

AIR-hl commented May 17, 2024

System Info

peft version: 0.9.0
accelerate version: 0.27.2
transformers version: 4.37.0
trl version: 0.7.12.dev0
base model: openai-community/gpt2
hardware: 2xA100

I'm doing a LORA peft of GPT2 through trl and have noticed that my trained model assigns very low probability to the EOS token which causes it to alway generate the maximum number of tokens.

After trying a few different fixes I ran the code without the PEFT option and just used the base model. The problem resolved immediately.

To make the comparison clear I created a toy case with a dataset that contains the same datapoint ("Hello <|endoftext|>") repeatedly. I then overfit on this dataset with a small batch size for a few dozen iterations. To see the effect on the probability of generating the eos_token I inserted the following code fragment in my compute_metrics method:

logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)

The basic full finetuning results in the EOS token probability converging to 1 almost immediately as the model memorizes the location of the EOS tokens. However if I just use TRL's code for a LORA PEFT the printed values remain close to zero and don't increase at all.

I've seen some references online suggesting that this could be caused by LORA not updating the model's embedding matrix. So I added the following change to the peft_config: peft_config.modules_to_save = ["wte"]. This doesn't have any effect on the results. I'm also doubtful this is the cause as when I run the supervised finetuning I don't see any change in the embedding matrix but get the desired results anyway.

Any help would be appreciated as I would like to avoid a full finetuning but right now have no way of getting a functional model with a PEFT.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Use the following model_config (note the PEFT parameters) and training arguments:

ModelConfig(model_name_or_path='openai-community/gpt2', model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=True, lora_r=64, lora_alpha=16, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)

TrainingArguments(
n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=1,
eval_delay=0,
eval_steps=10,
evaluation_strategy=steps,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=1.41e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=/scratch/km3888/gcode_peft/
/runs/Apr23_14-18-54_gh004.hpc.nyu.edu,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=1.0,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=linear,
max_grad_norm=1.0,
max_steps=20000,
metric_for_best_model=None,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=1.0,
optim=adamw_torch,
optim_args=None,
output_dir=/scratch/km3888/gcode_peft/,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=4,
per_device_train_batch_size=4,
prediction_loss_only=False,
push_to_hub=True,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=['wandb'],
resume_from_checkpoint=None,
run_name=/scratch/km3888/gcode_peft/
,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
skip_memory_metrics=True,
split_batches=False,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.0,
)

Create dataset:

import copy
dummy_data = [{"text":"Hello <|endoftext|>"} for _ in range(1000)]
with open("dummy_data.json","w") as f:
json.dump(dummy_data,f)
full_dataset = load_dataset('json', data_files="dummy_data.json",split='train')
full_dataset = full_dataset.map(lambda x: {'text':add_eos(x['text'])})
split_data = full_dataset.train_test_split(test_size=0.05)
train_dataset = split_data['train'].shuffle()
eval_dataset = copy.deepocpy(train_dataset)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, token=access_token,use_fast=True, add_eos=True)

Set up custom evaluation function:

def compute_metrics(eval_preds):
metric = evaluate.load("accuracy",training_args.output_dir.split('/')[-1])
logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)
predictions = np.argmax(logits,axis=-1)
predictions = np.reshape(predictions.astype(np.int32),-1)
labels = np.reshape(labels.astype(np.int32),-1)
return metric.compute(predictions=predictions, references=labels)

Instantiate and run SFTTrainer

trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=False,
peft_config=get_peft_config(model_config),
compute_metrics=compute_metrics,
dataset_num_proc=20)

trainer.train()

The eos_probs printed in compute_metrics will be near-zero

Expected behavior

I would expect the above code to result in eos_probs values being nearly 1 after a few training iterations.

Hello! Do you solve this problem? I have met the same problem when sft phi-1.5 !

@Vermeille
Copy link

same issue here it seems.
It seems to be due to an update. The same code that was working a few weeks ago (hard to be certain of course) now has EOS prediction issues apparently.

@BenjaminBossan
Copy link
Member

@Vermeille If you could share a minimal reproducer, we could take a look, otherwise it's going to be hard for us to help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants