-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto compile when static cache #34247
Conversation
a44e659
to
83a2c00
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…s into generate-compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On principle this LGTM, however in practice I am not experiencing any speedup (but aggravated performance) until the number of new tokens is quite high, of the order of ~2500/3000 (quick test with Llama 3.1 8B)
Note sure if it comes only from compilation time and warmup, or some graph breaks somewhere
Did you try to compare performances a bit?
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
@@ -3222,6 +3223,16 @@ def _sample( | |||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) | |||
|
|||
def model_forward(model, *args, **kwargs): | |||
return model.forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker This PR breaks some tests on PEFT :(
I checked why that is exactly, and I found that it has nothing to do with the compilation. The sole reason is that on 2nd iteration, we use this function, which effectively calls:
self.forward(**model_inputs, return_dict=True)
whereas on the first iteration (and before the PR), we would call:
self(**model_inputs, return_dict=True)
Is there any specific reason why forward
is used? Using __call__
looks correct to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be solved with #34923 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for the quick reply.
* generate with compile * nits * simple * generate with compile * nits * simple * safe * style * Update src/transformers/generation/utils.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * remove TOKENIZER forked warning --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
* generate with compile * nits * simple * generate with compile * nits * simple * safe * style * Update src/transformers/generation/utils.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * remove TOKENIZER forked warning --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
What does this PR do?
Add automatic compile for
static
cache.This can be tested with:
which give 15sec for dynamic, 30 for the first generate, then 4seconds for the next one