-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper Static Cache #30760
base: main
Are you sure you want to change the base?
Whisper Static Cache #30760
Conversation
Thank you very much @huseinzol05 for the work.
|
Will it be merged? @younesbelkada |
@kadirnar , this PR is not ready to merge, or you can continue to work on it to fulfill no 1, 2 and 3. But if you want to use it, you have to split the audio into 30s chunks with overlap and feed into encoder-decoder process, feel free to add |
Will it work if I run these codes? Also, should I make any changes to the gpt-fast library? https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88 |
yeah it should work, i use it in my prod, but dont forget to warmup the static cache multiple time first |
I ran the notebook file. It gives this error. File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
482 torch._dynamo.mark_static_address(e_key_cache)
483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
486 self.key_cache.append(new_layer_key_cache)File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
482 torch._dynamo.mark_static_address(e_key_cache)
483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
486 self.key_cache.append(new_layer_key_cache)
IndexError: tuple index out of range |
I just reran and no issue, super weird, which line is that you the error? |
@@ -448,3 +448,109 @@ def reset(self): | |||
# In-place ops prevent breaking the static address | |||
self.key_cache[layer_idx].zero_() | |||
self.value_cache[layer_idx].zero_() | |||
|
|||
class WhisperStaticCache(Cache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this great first start @huseinzol05! With @gante, we were discussing how the design of the static k/v cache should look for encoder-decoder models, and we distilled the design options down to two possibilities:
- Hold a tuple of
StaticCache
caches, e.g. as proposed here - Add a new Cache classes specific to encoder-decoder models, e.g. those with the attributes:
key_cache
(same as decoder-only self-attn)value_cache
(same as decoder-only self-attn)cross_key_cache
(new for enc-dec cross-attn)cross_value_cache
(new for enc-dec cross-attn)
Option 1 doesn't require any new Cache classes, so should be easier to maintain! Thus, we were thinking this would be the best design option for Whisper (and other encoder-decoder models in the library, such as BART). Would be curious to hear you opinions here, having had a go at option 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@huseinzol05 this is great work!
I'm heavily biased towards option 1, especially now that we are seeing more cache types. For instance, we could easily plug in the quantized cache as the decoder cache with 0 code overhead, if we design Whisper to support a tuple of Cache
objects through past_key_values
🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im good with anything
Static Cache for Whisper
This enable to use
torch.compile
for Whisper generation to enable faster generation, example https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s .
Still work in progress
2.4.0.dev20240508+cu121
version, not yet released as stable for custom functionreduce-overhead
torch compile.