Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_eager_matches_sdpa_inference for XPU backend #34889

Merged
merged 3 commits into from
Dec 2, 2024

Conversation

dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Nov 22, 2024

Included fixes:

  • Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel
  • Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron
  • Reuse CUDA MATH thresholds in for XPU (as of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH)

Fixes: #34888
CC: @amyeroberts @ydshieh

@dvrogozh
Copy link
Contributor Author

Looks like Friday evening is not the best time to run ci. Pushed same code 3 times, seeing different errors on each run:). Not related to the change I think. Will continue on Monday :).

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Nov 25, 2024

Clarified XPU backend behavior for torch.backends.cuda.sdp_kernel. As of PyTorch 2.5 (and today PyTorch main), XPU backend supports only torch.nn.attention.SDPBackend.MATH implementation of which is device agnostic with respect to implementation of each individual aten operator. So, we can reuse CUDA or CPU MATH weights for XPU (for all cases in disrespect to SDP backend). That's the change I did in the last version. @faaany, @EikanWang : fyi.

Comment on lines +657 to +660
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

although xpu only support MATH, does it means the results from XPU will be the same as, say CUDA? Otherwise, I don't see the reason to reuse CUDA threshold.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, at least for the recent versions of PyTorch (2.5 and upcoming 2.6 ) MATH should give identical results on any hardware including different GPU devices and CPU because algorithm is implemented on torch level and is device agnostic (well, up to aten operators implementation which is device specific, but they still should give same results). The only exception here is MPS which has separate branch in the code though also implemented at torch level. Here are relevant places in the sources:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, per above I think that current reuse of CUDA thresholds is reasonable... That being said, question is whether this is sustainable in a longer term or we soon will need to adjust thresholds for XPU? Well, we might need to. It's likely that upstream pytorch XPU will get implementation for one or both of attention algorithms and here I am not sure that these will behave same as CUDA. Plus, there is also IPEX for XPU which might behave different. And all that might be version dependent....

Anyhow, whether we reuse CUDA thresholds or not is a call we need to make here. I would in any case start by copying CUDA thresholds to XPU specific location. Note that we might need few iterations to settle everything down.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sound reason able to reuse (for MATH) threshold now. Thank you for explaining!

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 26, 2024

Thanks @dvrogozh ! LGTM.

Have you run (some of) the relevant tests on a GPU (cuda) and XPU (cc @faaany for this part maybe) machine?

@dvrogozh
Copy link
Contributor Author

Have you run (some of) the relevant tests on a GPU (cuda) and XPU (cc @faaany for this part maybe) machine?

I ran python3 -m pytest --pspec -k test_eager_matches_sdpa_inference tests/models:

  • On Nvidia A10, CUDA, passing
  • On Intel PVC, upstream pytorch XPU (without IPEX), passing
  • On Intel PVC, with IPEX 2.3.110+xpu (pytorch 2.3.1), pasing
    I also think that Transformers ci executes these CUDA tests and they are passing.

Above being said, we synced with @Faany offline and she did try this PR on her side. I believe she saw tests passing for her w/o IPEX, but she mentioned that some tests are failing for her with IPEX. The later point is different from test results I have. I suspect that's due to the different IPEX versions we tried: I think @faaany tried IPEX with later version (she's offline now, I will check with her later on that).

Note that without this PR test_eager_matches_sdpa_inference tests fail for both upstream pytorch XPU and IPEX. The difference between upstream pytorch XPU and IPEX is which attention algorithms are supported. I know the current status for upstream pytorch XPU since I discussed that with relevant folks already. But I did not have such discussion for IPEX - that's what I will need to do. I know that IPEX might register additional operations and algorithms on top of upstream pytorch XPU, but whether it does so for SDP attention and in which version - I do not know (though I do have some understanding from the results I see so far).

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again!

@faaany Do you have any further comments?

@dvrogozh
Copy link
Contributor Author

For myself: need to follow up with #34941 (sdpa tests for beit) which is being reviewed in parallel.

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Nov 26, 2024

FAILED tests/models/xlm/test_modeling_xlm.py::XLMModelTest::test_batching_equivalence - AssertionError: tensor(False) is not true : Batched and Single row outputs are not equal in XLMForQuestionAnswering for key=end_top_index. Difference=1.

Failure on ci seems unrelated. I also see some flakiness on main ci results. I tried to rebase couple times, but this did not help. No changes from last review.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 26, 2024

yes we do have some flaky tests. (trying to fix them but in a slow peace)

…ds.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
…in nemotron

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@ydshieh
Copy link
Collaborator

ydshieh commented Nov 27, 2024

There is no need to have a green CI if you think some failure are irrelevant. Ping me for a double check then we could wait for a core maintainer's review.

@dvrogozh
Copy link
Contributor Author

There is no need to have a green CI if you think some failure are irrelevant. Ping me for a double check then we could wait for a core maintainer's review.

Yeah, I think that's the case here. I rebased today w/o any changes. Another test failed this time, I believe unrelated as well. Sorry, I am just paranoid about green ci - I consider that's my responsibility to achieve that on my PRs.

FAILED tests/trainer/test_trainer.py::TrainerIntegrationWithHubTester::test_push_to_hub_tags - KeyError: 'url'
FAILED tests/trainer/test_trainer.py::TrainerIntegrationWithHubTester::test_push_to_hub_with_saves_each_epoch - AssertionError: no logs of level WARNING or higher triggered on root

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge if it's alright with you @ydshieh

@@ -607,7 +607,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating!

@ydshieh ydshieh merged commit 3183047 into huggingface:main Dec 2, 2024
26 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dvrogozh dvrogozh mentioned this pull request Dec 2, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…e#34889)

* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…e#34889)

* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Tests Related to tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

xpu: test_eager_matches_sdpa_inference tests fail with pytorch XPU backend
7 participants
  NODES
COMMUNITY 2
innovation 1
Note 2
Project 5
todo 1
USERS 1