Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Static Cache #30760

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Whisper Static Cache #30760

wants to merge 2 commits into from

Conversation

huseinzol05
Copy link
Contributor

@huseinzol05 huseinzol05 commented May 11, 2024

Static Cache for Whisper

This enable to use torch.compile for Whisper generation to enable faster generation, example https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88

Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s .

Still work in progress

  1. Current forked only work to use static cache, need to follow caching steps as Llama.
  2. There are so many conditions need to fulfill first.
  3. Only worked on Pytorch 2.4.0.dev20240508+cu121 version, not yet released as stable for custom function reduce-overhead torch compile.

@mobicham
Copy link
Contributor

Thank you very much @huseinzol05 for the work.
Here's a version with HQQ 4-bit using the torchao backend. As expected there's a good speed-up with the static cache and fullgraph compilation: https://gist.github.com/mobicham/ecfe09a48efb11e4014386901a5c6cce

GPU: 4090
orig - no compile : 48 it/sec
orig + compiled   : 227 it/sec

hqq - no compile  : 42 it/sec
hqq + compile     : 308 it/sec

@kadirnar
Copy link
Contributor

Will it be merged? @younesbelkada

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi

@huseinzol05
Copy link
Contributor Author

@kadirnar , this PR is not ready to merge, or you can continue to work on it to fulfill no 1, 2 and 3. But if you want to use it, you have to split the audio into 30s chunks with overlap and feed into encoder-decoder process, feel free to add temperature and top_k like https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L52

@kadirnar
Copy link
Contributor

kadirnar commented May 16, 2024

Will it work if I run these codes? Also, should I make any changes to the gpt-fast library?

https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88

@huseinzol05
Copy link
Contributor Author

yeah it should work, i use it in my prod, but dont forget to warmup the static cache multiple time first

@kadirnar
Copy link
Contributor

yeah it should work, i use it in my prod, but dont forget to warmup the static cache multiple time first

I ran the notebook file. It gives this error.

File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
    482 torch._dynamo.mark_static_address(e_key_cache)
    483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
    485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
    486 self.key_cache.append(new_layer_key_cache)File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
    482 torch._dynamo.mark_static_address(e_key_cache)
    483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
    485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
    486 self.key_cache.append(new_layer_key_cache)

IndexError: tuple index out of range

@huseinzol05
Copy link
Contributor Author

I just reran and no issue, super weird, which line is that you the error?

@@ -448,3 +448,109 @@ def reset(self):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

class WhisperStaticCache(Cache):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great first start @huseinzol05! With @gante, we were discussing how the design of the static k/v cache should look for encoder-decoder models, and we distilled the design options down to two possibilities:

  1. Hold a tuple of StaticCache caches, e.g. as proposed here
  2. Add a new Cache classes specific to encoder-decoder models, e.g. those with the attributes:
    • key_cache (same as decoder-only self-attn)
    • value_cache (same as decoder-only self-attn)
    • cross_key_cache (new for enc-dec cross-attn)
    • cross_value_cache (new for enc-dec cross-attn)

Option 1 doesn't require any new Cache classes, so should be easier to maintain! Thus, we were thinking this would be the best design option for Whisper (and other encoder-decoder models in the library, such as BART). Would be curious to hear you opinions here, having had a go at option 2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huseinzol05 this is great work!

I'm heavily biased towards option 1, especially now that we are seeing more cache types. For instance, we could easily plug in the quantized cache as the decoder cache with 0 code overhead, if we design Whisper to support a tuple of Cache objects through past_key_values 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im good with anything

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants