Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative sampling does not maintain probability distribution of main model #32867

Closed
4 tasks
dmelcer9 opened this issue Aug 17, 2024 · 4 comments · Fixed by #33534
Closed
4 tasks

Speculative sampling does not maintain probability distribution of main model #32867

dmelcer9 opened this issue Aug 17, 2024 · 4 comments · Fixed by #33534
Labels

Comments

@dmelcer9
Copy link

System Info

  • transformers version: 4.44.0
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.10.13
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

In the speculative sampling procedure:

probability_ratio = p_i / q_i

The probability ratio is calculated as compared to the output probability of the assistant model.

However, the speculative model is always used greedily:

self.generation_config.do_sample = False

This is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).

As a more concrete example, if the assistant model outputs [0.51, 0.49], as long as the main model outputs [x >= 0.51, y <= 0.49], this will lead to the first token always being sampled by the procedure.

This is evident when you use a model as its own assistant, at least for the first 5 tokens from the speculative model (there is still some randomness from the extra token generated by the main model but not the assistant).

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "openai-community/gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("public int", return_tensors="pt")

# Greedy
# Always outputs `public int get_current_time()` (and then some)
tokenizer.decode(model.generate(**inputs, do_sample=False, max_new_tokens=25)[0])

# Sampling
# Gives different method names each time
tokenizer.decode(model.generate(**inputs, do_sample=True, max_new_tokens=25)[0])

# Should theoretically be sampling but is not
# Always outputs `public int get_current_time()`
tokenizer.decode(model.generate(**inputs, assistant_model=model, do_sample=True, max_new_tokens=25)[0])

Expected behavior

Assisted decoding should use a correct sampling method.

@dmelcer9 dmelcer9 added the bug label Aug 17, 2024
@llllvvuu
Copy link

llllvvuu commented Aug 18, 2024

It looks like that change (#30778) was made to save time, since if the token were sampled with correct probability then it would also have to re-sample some of the time. I think you are right, that would have to be reverted to restore correctness. Unless, "assisted decoding" is meant to behave differently from speculative sampling (#26565 (comment)) and not have the correctness property (Appendix A.1). @gante

Also, I suspect (but haven't done the math) that correctness would technically only hold if the probabilities here were adjusted by temp/min_p/etc (they seem to just be temp 1 probabilities right now):

q = candidate_logits.softmax(dim=-1)
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
probability_ratio = p_i / q_i

@huggingface huggingface deleted a comment from github-actions bot Sep 17, 2024
@gante
Copy link
Member

gante commented Sep 17, 2024

Hi @dmelcer9 @llllvvuu 👋

I agree that forcing greedy decoding on the assistant has to be reverted -- empirically, we can check that sampling with an assistant model has much smaller entropy than without an assistant model (and they should be the same). It was a rushed decision on my end before, I will open a PR to revert it.


However, I disagree with some of your statements on why it must be done 🤗 You wrote

This [forcing greedy decoding] is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).

We have to distinguish two phases of text generation: producing the distribution for the next token and selecting the next token given the distribution. On sampling and greedy decoding, the distribution for the next token is the same. Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling. The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding. However, speculative decoding assumes we are doing sampling from the next token distribution in the assistant model, which is simply not happening at the moment and results in quasi-deterministic outputs from speculative decoding.

@dmelcer9
Copy link
Author

@gante Thanks for opening the PR. I'm not quite sure what you mean in the second part

Greedy decoding does not set the temperature to 0, it simply takes the argmax of the distribution instead of sampling.

I was definitely a bit mathematically imprecise earlier- while of course greedy decoding doesn't involve dividing the logits by 0 before the softmax, the output of the softmax function in the limit as $t \rightarrow 0$ is a one-hot distribution; i.e. the deterministic result of greedy decoding.

This was meant in the context of that, if greedy decoding is used in the assistant model, that q_i "should be" the one-hot vector even though the following line uses the sampling probability as q_i:

probability_ratio = p_i / q_i

The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.

Not always: in the two-token vocabulary case, if the speculative model outputs [0.8, 0.2] and the main model outputs [0.6, 0.4], the output will still be probabilistic but with the wrong distribution. Token 0 will always be chosen from the draft model, and probability ratio will be calculated as 0.6 / 0.8 = 0.75. In the remaining 25% of the time, the residuals will be calculated as norm([clamp(0.6-0.8), clamp(0.4 - 0.2)]) = [0, 1], so token 1 will be sampled. The overall probability distribution has changed to [0.75, 0.25] instead of being [0.6, 0.4] (ideal), or being deterministic.

Also note that if q_i = [1, 0], we end up with the correct output distribution (though keeping q_i = [0.8, 0.2] is fine if the speculative model is changed to use sampling instead of greedy).

@gante
Copy link
Member

gante commented Sep 17, 2024

@dmelcer9 agreed with what you wrote 🤗

The probability properties of speculative decoding at a token level do hold even with the probability distributions for the next tokens with greedy decoding.

I was imprecise here indeed. It is missing ",if we were sampling from that distribution to get the next token from the assistant" :)

ArthurZucker pushed a commit that referenced this issue Nov 22, 2024
… like #32867) (#34553)

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* formatting

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* format

* comments at standard positions
BernardZach pushed a commit to BernardZach/transformers that referenced this issue Dec 5, 2024
… like huggingface#32867) (huggingface#34553)

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* formatting

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* format

* comments at standard positions
BernardZach pushed a commit to innovationcore/transformers that referenced this issue Dec 6, 2024
… like huggingface#32867) (huggingface#34553)

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* formatting

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* format

* comments at standard positions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
  NODES
COMMUNITY 3
Idea 1
idea 1
innovation 1
Note 1
Project 5
USERS 1