Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto compile when static cache #34247

Merged
merged 12 commits into from
Nov 22, 2024
Merged

Auto compile when static cache #34247

merged 12 commits into from
Nov 22, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 18, 2024

What does this PR do?

Add automatic compile for static cache.
This can be tested with:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
device = "cuda"
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

sequence = "Hey what's the plan"

inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0

t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500)
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}', out)

t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}', out)

t0 = time.time()
out = model.generate(inputs, do_sample=False, max_new_tokens=500, cache_implementation="static")
out = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
dt = time.time() - t0
print(f'dt: {dt}', out)

which give 15sec for dynamic, 30 for the first generate, then 4seconds for the next one

@ArthurZucker ArthurZucker requested a review from gante October 18, 2024 13:01
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker marked this pull request as ready for review November 21, 2024 15:52
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On principle this LGTM, however in practice I am not experiencing any speedup (but aggravated performance) until the number of new tokens is quite high, of the order of ~2500/3000 (quick test with Llama 3.1 8B)
Note sure if it comes only from compilation time and warmup, or some graph breaks somewhere
Did you try to compare performances a bit?

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
ArthurZucker and others added 2 commits November 22, 2024 13:07
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
@ArthurZucker ArthurZucker changed the title Draft compile decoding Auto compile when static cache Nov 22, 2024
@ArthurZucker ArthurZucker merged commit 597efd2 into main Nov 22, 2024
21 of 25 checks passed
@ArthurZucker ArthurZucker deleted the generate-compile branch November 22, 2024 14:33
@@ -3222,6 +3223,16 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker This PR breaks some tests on PEFT :(

I checked why that is exactly, and I found that it has nothing to do with the compilation. The sole reason is that on 2nd iteration, we use this function, which effectively calls:

self.forward(**model_inputs, return_dict=True)

whereas on the first iteration (and before the PR), we would call:

self(**model_inputs, return_dict=True)

Is there any specific reason why forward is used? Using __call__ looks correct to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be solved with #34923 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the quick reply.

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* generate with compile

* nits

* simple

* generate with compile

* nits

* simple

* safe

* style

* Update src/transformers/generation/utils.py

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* remove TOKENIZER forked warning

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* generate with compile

* nits

* simple

* generate with compile

* nits

* simple

* safe

* style

* Update src/transformers/generation/utils.py

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* remove TOKENIZER forked warning

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
  NODES
COMMUNITY 2
innovation 1
Note 1
Project 5
USERS 1