Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding: Test the _target distribution (to prevent issues like #32867) #34553

Merged

Conversation

keyboardAnt
Copy link
Contributor

@keyboardAnt keyboardAnt commented Nov 1, 2024

What does this PR do?

This PR introduces a test for speculative decoding to ensure the _target distribution is preserved, addressing potential issues similar to #32867. The added test (test_speculative_sampling__target_distribution) validates that tokens are generated according to their intended likelihood, as defined in the logits, ensuring that the speculative decoding process adheres to expected distributions. Additionally, this is a foundational step toward supporting advanced speculative decoding algorithms, such as token-tree-based rejection sampling, which will enhance flexibility and performance in future implementations.

Motivation and Context

The speculative decoding process has previously encountered issues where the _target distribution was not preserved (e.g., in issues #32867 and #33534). This PR implements a test to safeguard against such inconsistencies by verifying that:

  • The most likely tokens are chosen more frequently than less probable ones.
  • Tokens are selected in alignment with the predefined candidate and new logits.

This enhancement not only improves the reliability of speculative sampling by enforcing distributional accuracy but also prepares the ground for implementing more advanced speculative decoding techniques, like token-trees-based sampling.

This PR is an initial step toward advancements in Universal Assisted Generation. In collaboration with @orenpereg, @danielkorat, @mosheber, @jmamou, and @MosheWasserb, we're preparing for a new speculative decoding function that this test will verify for losslessness in _target distribution preservation.

Dependencies

No additional dependencies are required.

Linked Issues

#32867, #33534

Before Submitting Checklist

  • I have read the contributor guidelines.
  • Documentation updates are not needed as this is a test enhancement.
  • New test coverage has been added to verify the speculative sampling behavior.

Who can review?

@gante

@keyboardAnt keyboardAnt force-pushed the test-speculative-sampling-distribution branch from 1be059c to 5522333 Compare November 1, 2024 00:14
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on improving our tests 💛

A question: is this test somewhat fast to run (<5s)? If yes, amazing! If no, let's either a) reduce the number in range or b) tag the test as @slow [note: tests with @slow are usually run daily, so bad commits may squeeze in]

Comment on lines 2474 to 2486
[
-inf,
2.0,
-inf,
1.0,
-inf,
-inf,
-inf,
-0.01,
2.0,
-inf,
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value
Copy link
Member

@gante gante Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's make it in one line, so we can quickly compare indexes with other tensors.

(you'll have to remove the comma after the last -inf, otherwise the make fixup command will make it revert back to this format)

Copy link
Contributor Author

@keyboardAnt keyboardAnt Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the formatting as requested, but ruff's formatting check then failed the CI. (make fixup still reformats it into a column, even after removing the last comma you mentioned)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use # fmt: off and # fmt: on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ArthurZucker. I changed all these inline comments to block comments, and it solved the issue while keeping the ruff checks on. 👍

@gante gante requested a review from ArthurZucker November 4, 2024 11:11
@keyboardAnt
Copy link
Contributor Author

Thank you for working on improving our tests 💛

A question: is this test somewhat fast to run (<5s)? If yes, amazing! If no, let's either a) reduce the number in range or b) tag the test as @slow [note: tests with @slow are usually run daily, so bad commits may squeeze in]

The test itself takes 1.89 s, and when you run it with pytest (pytest tests/generation/test_utils.py::UtilsFunctionsTest::test_speculative_sampling__target_distribution), it's not more than 2.57 s. Although it was only tested locally on my laptop, I believe it’s safe to keep it with the rest of the <5s tests.

@keyboardAnt keyboardAnt force-pushed the test-speculative-sampling-distribution branch from 3a05527 to dc3be00 Compare November 5, 2024 02:19
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Comment on lines 2474 to 2486
[
-inf,
2.0,
-inf,
1.0,
-inf,
-inf,
-inf,
-0.01,
2.0,
-inf,
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use # fmt: off and # fmt: on

@keyboardAnt
Copy link
Contributor Author

All checks have successfully passed (screenshot below). Are there any additional workflows to run before merging?

image

@keyboardAnt
Copy link
Contributor Author

@gante @ArthurZucker
Would appreciate your help with finalizing the review

@ArthurZucker
Copy link
Collaborator

Yep sorry for the delay, merging!

@ArthurZucker ArthurZucker merged commit 42b36d7 into huggingface:main Nov 22, 2024
22 checks passed
@keyboardAnt keyboardAnt deleted the test-speculative-sampling-distribution branch November 22, 2024 15:36
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
… like huggingface#32867) (huggingface#34553)

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* formatting

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* format

* comments at standard positions
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
… like huggingface#32867) (huggingface#34553)

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* formatting

* Update test_utils.py

* formatting

* Update test_utils.py

* formatting

* format

* comments at standard positions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
  NODES
COMMUNITY 2
innovation 1
Note 2
Project 5
USERS 1
Verify 3