Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔴 🚨 Resizing tokens embeddings: initialize from old embeddings' normal distribution. #33325

Merged
merged 32 commits into from
Oct 4, 2024

Conversation

abuelnasr0
Copy link
Contributor

@abuelnasr0 abuelnasr0 commented Sep 5, 2024

What does this PR do?

This PR initializes new tokens embeddings from a normal distribution with the old embeddings' mean and covariance. as described in this article https://nlp.stanford.edu/~johnhew/vocab-expansion.html. Thanks to this article, you can now add new tokens to your model without affecting its generation accuracy.

Fixes #32948

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

cc: @ArthurZucker @LysandreJik

@abuelnasr0
Copy link
Contributor Author

This gist shows the before and after GPT2 model generation after adding new tokens: https://colab.research.google.com/gist/abuelnasr0/7cc8decff72ccdbcf5073ad8259ee360/embedding_resize.ipynb

@LysandreJik
Copy link
Member

Nice! Thanks @abuelnasr0. cc @Rocketknight1 in case you have some bandwidth to do a first review?

@Rocketknight1
Copy link
Member

Yes, I'm happy to take it, it seems like a great PR!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My overall impression is that this is a great addition that we should definitely merge, because the existing behaviour is highly undesirable, as mentioned in the article. My issues are basically nits:

  • We have a mild preference for avoiding "math variable" naming, though I don't want to bloat the code either. Maybe replace mu and cov with mean_embedding and covariance or covariance_matrix?
  • It'd be great if we could add a small test for this. There is a test_resize_tokens_embedding in tests/test_modeling_common.py, and the test could either be appended to that, or as a new test just below it. The test could check that new tokens are relatively close to the mean of the old tokens, though please set the tolerances quite loose - it'll be a real pain if the test becomes flaky because an outlier value is sampled!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@abuelnasr0
Copy link
Contributor Author

abuelnasr0 commented Sep 5, 2024

Thanks @LysandreJik
And thanks @Rocketknight1 for the review and the comments. I have added a test case to check if the new embeddings' mean is close to the original embeddings' mean, I think that will reduce the effect of any outlier values.
Feel free to let me know if there is anything else I should add.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! One final extra-minor nit, but I'm happy with it now. cc @LysandreJik for final review

tests/test_modeling_common.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we need a test for test_resize_tokens_embeddings with deepspeed and multi GPU as we had a lot of breaking changes in the past year around it! #26102 #32192 #21065 etc etc

@@ -2162,8 +2162,24 @@ def _get_resized_embeddings(
dtype=old_embeddings.weight.dtype,
)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# initialize new embeddings (in particular added tokens) if `new_num_tokens` is larger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing we need to make sure of is that this does not have issue with deepspeed and multinodes! (as for example the mean would require an all gather I think. This computation has to be done on gpu 0 is what my intuition tells me)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is absolutely right, thanks for mentioning that. I have added support for Deepspeed and tested it manually. I will add test cases for it tomorrow!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I have added tests for deepspeed but they fail because the test env doesn't have deepspeed installed. what could be a solution for this? should I add the tests to transformers/tests/deepspeed/test_deepspeed.py for example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abuelnasr0 when tests depend on a library like deepspeed, tag them with the @require_deepspeed decorator so the test runner knows how to handle them! You can search the codebase for other examples where it's used and just copy the imports/decorator from there.

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 Thank you for your response. It turns out that I was importing Deepspeed without checking if it was available. I fixed it.
@Rocketknight1 @ArthurZucker The code is ready for review now. Could you check it and see if the deepspeed tests are all right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change so can you add 🔴 🔴 🔴 🔴

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I have added 🔴 to the PR title. Is that what you meant? or should I add a logger.warning() to the code describing the new change in initializing embedding weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sure we warn users about it!

I didn't see this comment the first time I opened the PR! I will add a warning to the user describing new changes.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, it's abreaking change so let's make sure we warn users about it!

@@ -2162,8 +2162,24 @@ def _get_resized_embeddings(
dtype=old_embeddings.weight.dtype,
)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# initialize new embeddings (in particular added tokens) if `new_num_tokens` is larger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay!

@@ -2162,8 +2162,24 @@ def _get_resized_embeddings(
dtype=old_embeddings.weight.dtype,
)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# initialize new embeddings (in particular added tokens) if `new_num_tokens` is larger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change so can you add 🔴 🔴 🔴 🔴

@abuelnasr0 abuelnasr0 changed the title Resizing tokens embeddings: initialize from old embeddings' normal distribution. 🔴 🔴 Resizing tokens embeddings: initialize from old embeddings' normal distribution. Sep 13, 2024
@abuelnasr0 abuelnasr0 changed the title 🔴 🔴 Resizing tokens embeddings: initialize from old embeddings' normal distribution. 🔴 🚨 Resizing tokens embeddings: initialize from old embeddings' normal distribution. Sep 13, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay looks great! Given how big of a change this is, let's add a flag to _get_resized_embeddings, something like multivariate_resizing WDYT?

@abuelnasr0
Copy link
Contributor Author

@ArthurZucker Sorry for the delay. last week was very busy for me.
I have added the flag to resize_token_embeddings, _get_resized_embeddings, and _get_resized_lm_head.
_get_resized_lm_head is very important for models that have untied weights. This should have been added from the beginning but I don't know how I didn't notice that.

@abuelnasr0
Copy link
Contributor Author

abuelnasr0 commented Sep 27, 2024

Bias for linear lm_heads is not included in the article. but I have came to the conclusion that initializing with zero is the best solution.
here is the proof given the article:

$p_{\theta'}(w_i \mid w_{1:i-1}) = \frac{\exp(h_{i-1}^{\top} e_{w_i} + b_{w_i})}{Z + \exp(h_{i-1}^\top e_{n+1} + b_{n+1} ) }$
$p_{\theta'}(w_i \mid w_{1:i-1}) = \frac{\exp(h_{i-1}^{\top} e_{w_i} + b_{w_i})}{Z + \exp(h_{i-1}^\top e_{n+1} ) exp( b_{n+1})}$

If the new bias equals zero, then exp(0) = 1, and then the proof in the article will remain the same.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Sep 27, 2024

@abuelnasr0 I haven't dived into the math, but that seems counter-intuitive to me. Most models do not have a bias on the output head, but if they do, I would guess that the bias is usually large and negative for most tokens (I can't test this, though, because I can't find any examples of models with bias on the output head!)

If the average bias is large and negative, then initializing new tokens with the mean embedding but a zero bias will mean that logits for the new token will be very large, right?

UPDATE: I finally found an example of an old model with a bias on the lm_head (Salesforce/ctrl). The mean bias was -0.05, which is a much smaller magnitude than I expected. This probably isn't a big problem - a zero embedding is fine!

@abuelnasr0
Copy link
Contributor Author

abuelnasr0 commented Sep 27, 2024

@Rocketknight1 That's right.
I didn't consider that the new value of Z has also bias in it.

After considering the new value of Z, the new bias should be initialized with the old bias mean. I have tried to be more focused this time. I hope I am not wrong again 😅

this link contains a proof that should be replaced with "Averaging bounds the KL-divergence" part in the article: https://imgur.com/a/ZZQ3Mwk

@abuelnasr0
Copy link
Contributor Author

I have initialized the new bias like this:

new_lm_head.bias.data.normal_(mean=bias_mean, std=bias_std * 1e-5)

I have multiplied the std by 1e-5 just because the article initializes the embeddings by multiplying the covariance by 1e-5. I don't really know why, maybe to make the new embeddings alot closer to the mean.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, getting in a good shape! Let's help our users a little bit more, and make the code more readable !

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
"As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html."
)
added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this can be re-used no? As a "self.init_tensor" which checks if deepspeed is available, computes the covariance if not given, uses None otherwise

Copy link
Contributor Author

@abuelnasr0 abuelnasr0 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have introduced three functions:

  • self._init_added_embeddings_weights_with_mean()
  • self._init_added_lm_head_weights_with_mean() and it uses self._init_added_embeddings_weights_with_mean()
  • self._init_added_lm_head_bias_with_mean()
    This will improve code usability for our case. what do you think? I am open to any other change.

Also, I think that mean_resizing is more user-friendly and explains the whole point of the new resizing technique.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner! Thanks a lot for this update

@ArthurZucker
Copy link
Collaborator

The only thing left is to fix the failing test which are related (example torch for example) 🤗

@abuelnasr0
Copy link
Contributor Author

@ArthurZucker Thanks for your review.
Regarding tests, all tests have passed after rebasing.

@ArthurZucker
Copy link
Collaborator

All good for me! 🤗

@ArthurZucker ArthurZucker merged commit 78ef583 into huggingface:main Oct 4, 2024
24 checks passed
@Rocketknight1
Copy link
Member

Congrats on the PR @abuelnasr0, and thank you!

@abuelnasr0
Copy link
Contributor Author

@Rocketknight1 Thank you for the reviews and the help.

NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
…l distribution. (huggingface#33325)

* intilize new embeddings from normal distrib

* Fix typo in comments

* Fix typo in comments

* Fix style

* Fix variables naming

* Add tests

* Fix style

* code consistency nit

* Add deepspeed support

* Add deepspeed support

* Conver embeddings weights to float32 before computations

* Add deepspeed tests

* Cover when vocab_size is smaller than embedding_size

* Style fix

* Add tests for vocab_size smaller than hiddin_size

* Style fix

* Nits in tests

* Nits in tests

* Check for deepspeed before importing it

* Increase vocab_size for positive definite covariance matrix test

* Add warning

* Add multivariate_resizing flag and implement resizing for lm_heads

* Fix typo

* Fix wrong bias indexing

* Fix bias is zero check

* remove multivariate_resizing flag from tests

* Intialize bias from old bias normal distribution

* Fixup

* Code usability

* Use mean_resizing instead of multivariate_resizing

* Fix up

* Fix comments and docs
github-merge-queue bot pushed a commit to microsoft/DeepSpeed that referenced this pull request Nov 4, 2024
This commit causes breaking changes we need to fix, for now we will pin
the version but we will fix shortly

huggingface/transformers#33325
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…l distribution. (huggingface#33325)

* intilize new embeddings from normal distrib

* Fix typo in comments

* Fix typo in comments

* Fix style

* Fix variables naming

* Add tests

* Fix style

* code consistency nit

* Add deepspeed support

* Add deepspeed support

* Conver embeddings weights to float32 before computations

* Add deepspeed tests

* Cover when vocab_size is smaller than embedding_size

* Style fix

* Add tests for vocab_size smaller than hiddin_size

* Style fix

* Nits in tests

* Nits in tests

* Check for deepspeed before importing it

* Increase vocab_size for positive definite covariance matrix test

* Add warning

* Add multivariate_resizing flag and implement resizing for lm_heads

* Fix typo

* Fix wrong bias indexing

* Fix bias is zero check

* remove multivariate_resizing flag from tests

* Intialize bias from old bias normal distribution

* Fixup

* Code usability

* Use mean_resizing instead of multivariate_resizing

* Fix up

* Fix comments and docs
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…l distribution. (huggingface#33325)

* intilize new embeddings from normal distrib

* Fix typo in comments

* Fix typo in comments

* Fix style

* Fix variables naming

* Add tests

* Fix style

* code consistency nit

* Add deepspeed support

* Add deepspeed support

* Conver embeddings weights to float32 before computations

* Add deepspeed tests

* Cover when vocab_size is smaller than embedding_size

* Style fix

* Add tests for vocab_size smaller than hiddin_size

* Style fix

* Nits in tests

* Nits in tests

* Check for deepspeed before importing it

* Increase vocab_size for positive definite covariance matrix test

* Add warning

* Add multivariate_resizing flag and implement resizing for lm_heads

* Fix typo

* Fix wrong bias indexing

* Fix bias is zero check

* remove multivariate_resizing flag from tests

* Intialize bias from old bias normal distribution

* Fixup

* Code usability

* Use mean_resizing instead of multivariate_resizing

* Fix up

* Fix comments and docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

resize_token_embeddings in NLLB leading to empty outputs
5 participants
  NODES
chat 2
COMMUNITY 2
innovation 1
Project 5
USERS 4