Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLAP Fine-tuning has run into a problem #30795

Open
2 of 4 tasks
ScottishFold007 opened this issue May 14, 2024 · 5 comments
Open
2 of 4 tasks

CLAP Fine-tuning has run into a problem #30795

ScottishFold007 opened this issue May 14, 2024 · 5 comments
Assignees
Labels

Comments

@ScottishFold007
Copy link
Contributor

ScottishFold007 commented May 14, 2024

System Info

  • transformers version: 4.39.3
  • Platform: Linux-4.19.91-014-kangaroo.2.10.13.5c249cdaf.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sanchit-gandhi @ylacombe @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to fine-tune the clap, but I'm having some problems with it, and I've previously referenced a solution in #26864
Here is my code:

load data

import re
import glob
import numpy as np
from datasets import load_dataset, load_metric

files= glob.glob("genshin-voice-v3.5-mandarin/data/*.parquet")


dataset = load_dataset(
    "parquet",
    data_files= files
     )["train"].select(range(500))


split_dataset= dataset.train_test_split(test_size=0.01) 
DatasetDict({
    train: Dataset({
        features: ['input_features', 'is_longer', 'input_ids'],
        num_rows: 495
    })
    test: Dataset({
        features: ['input_features', 'is_longer', 'input_ids'],
        num_rows: 5
    })
})

load model

import torch
from datasets import Audio
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments
from transformers import ClapProcessor
from transformers import ClapModel
from transformers import ClapConfig


model_path= "./laion_clap-htsat-fused"
config= ClapConfig.from_pretrained(model_path)

config.audio_config.enable_fusion= False

#config.audio_config.enable_patch_fusion = False
#config.audio_config.enable_patch_layer_norm = False


processor = ClapProcessor.from_pretrained(model_path)

model = ClapModel.from_pretrained(
                                model_path, 
                                #config= config
                                #torch_dtype =torch.bfloat16,
                                   
            )

process data:

def remove_special_characters(example):
    #try:
    chars_to_remove_regex = r'[\「\」\,\?\.\!\-\;\:\"\“\%\‘\”\�\…\。\、\?\!\,\']'
    example["text"] = re.sub(chars_to_remove_regex, '', example["text"]).lower()
    return example
    #except:
    #    print(batch)
    #    print("-----"*6)
        

def text_length(batch):
    text = str(batch["text"])
    batch["text_length"] = len(text)
    return batch

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute input length
    batch["input_length"] = len(batch["audio"])

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["is_longer"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors= "pt").is_longer[0][0]
    # encode target text to label ids 
    text_inputs = processor.tokenizer(batch["text"])
    batch["input_ids"] = text_inputs.input_ids
    #batch["attention_mask"] = text_inputs.attention_mask

    # compute labels length
    batch["labels_length"] = len(processor.tokenizer(batch["text"], add_special_tokens=False).input_ids)
    return batch


def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length


def filter_labels(labels_length):
    """Filter label sequences longer than max length (448)"""
    return labels_length < max_label_length

split_dataset= dataset.train_test_split(test_size=0.01) 

max_label_length= 448
MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 48000

split_dataset = split_dataset.filter(lambda x: x["text"] != None, num_proc= 32)
split_dataset = split_dataset.map(remove_special_characters, num_proc= 32, batched= False)
split_dataset = split_dataset.cast_column("audio", Audio(sampling_rate=48000))
split_dataset = split_dataset.map(prepare_dataset, remove_columns= split_dataset["train"].column_names, num_proc= 32)

split_dataset = split_dataset.filter(filter_inputs, input_columns=["input_length"])
split_dataset = split_dataset.filter(filter_labels, input_columns=["labels_length"])                                     
split_dataset = split_dataset.remove_columns(['input_length', 'labels_length'])

    
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
       
        is_longer_features = [feature["is_longer"] for feature in features] 
        #is_longer_batch = self.processor.tokenizer.pad(is_longer_features, return_tensors="pt")
    
        # get the tokenized label sequences
        label_features = [{"input_ids": feature["input_ids"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
       
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["input_ids"] = labels
        batch["is_longer"] = is_longer_features
        batch["return_loss"]= True
        return batch

 
from transformers import Trainer
from transformers import TrainingArguments


training_args = TrainingArguments(
    output_dir= "./openai_whisper-large-v3_ft", # change to a repo name of your choice
    do_train= True,
    do_eval= True,
    evaluation_strategy= "steps",
    per_device_train_batch_size= 16,
    per_device_eval_batch_size= 16,
    gradient_accumulation_steps= 2, # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    #max_steps=4000,
    num_train_epochs= 3,
    #gradient_checkpointing=True,
    bf16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    #report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)


trainer = Trainer(
    args=training_args,
    model=model,
    train_dataset= split_dataset["train"],
    eval_dataset= split_dataset["test"],
    data_collator= data_collator,
    #data_collator=collate_fn,
    #tokenizer= processor.feature_extractor,
)


trainer.train()

Then the following error occurs:
image

__AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 2094, in forward
audio_outputs = self.audio_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1742, in forward
return self.audio_encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self.call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 913, in forward
is_longer_list = is_longer.to(input_features.device)
AttributeError: 'list' object has no attribute 'to'

I'm having a lot of problems with the mode of enable_fusion=True, and I don't seem to have a good grasp of the handling of the input is_longer, so I hope I can get your pointers on this piece, thanks!

Expected behavior

The model should train normally.

@ylacombe
Copy link
Collaborator

Hey @ScottishFold007, in your data collator, is_longer_features is a list, it should be a torch tensor instead. Note that you've commented the following line:

        #is_longer_batch = self.processor.tokenizer.pad(is_longer_features, return_tensors="pt")

This should probably not be commented!
I hope it helps!

@ScottishFold007
Copy link
Contributor Author

When this line of code is not commented out, this is the error that appears:

#is_longer_batch = self.processor.tokenizer.pad(is_longer_features, return_tensors="pt")

/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
52 else:
53 data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

/tmp/ipykernel_1338324/128232537.py in call(self, features)
15
16 is_longer_features = [feature["is_longer"] for feature in features]
---> 17 is_longer_batch = self.processor.tokenizer.pad(is_longer_features, return_tensors="pt")
18
19 # get the tokenized label sequences

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in pad(self, encoded_inputs, padding, max_length, pad_to_multiple_of, return_attention_mask, return_tensors, verbose)
3288 raise ValueError(
3289 "You should supply an encoding or a list of encodings to this method "
-> 3290 f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
3291 )
3292

AttributeError: 'list' object has no attribute 'keys'

@ylacombe
Copy link
Collaborator

According to CLAP docs, is_longer should be a tensor of shape (batch_size, 1).

What you can do is probably something like this:

        is_longer_features = [feature["is_longer"] for feature in features]
        is_longer_features = torch.tensor(is_longer_features)[...,None] 

which basically create a tensor and adds an extra-dimension!

@ScottishFold007
Copy link
Contributor Author

ScottishFold007 commented May 17, 2024

According to CLAP docs, is_longer should be a tensor of shape (batch_size, 1).

What you can do is probably something like this:

        is_longer_features = [feature["is_longer"] for feature in features]
        is_longer_features = torch.tensor(is_longer_features)[...,None] 

which basically create a tensor and adds an extra-dimension!
When I made the changes here:

image The following error was reported again: ```

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 2102, in forward
text_outputs = self.text_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1892, in forward
encoder_outputs = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1602, in forward
layer_outputs = layer_module(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1491, in forward
self_attention_outputs = self.attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1418, in forward
self_outputs = self.self(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clap/modeling_clap.py", line 1356, in forward
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

@ylacombe
Copy link
Collaborator

Hey @ScottishFold007, could you use CUDA_LAUNCH_BLOCKING=1 in order for us to have more details on why your training didn't work ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants