Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FSDP resume Initialization issue #34032

Merged
merged 6 commits into from
Oct 15, 2024

Conversation

Itssshikhar
Copy link
Contributor

Addresses the issue with Fully Sharded Data Parallel (FSDP) initialization when resuming training from a checkpoint. It implements a solution by adding a dummy forward pass during the initialization process.

Fixes #31892

Added tests in the test_trainer.py file to ensure proper FSDP initialization

@muellerzr @SunMarc I am creating a draft PR, let me know if there anymore changes that I can make

@SunMarc
Copy link
Member

SunMarc commented Oct 9, 2024

Thanks for the PR ! Could you explain a bit more why this PR fixes the issue that you linked ? Thanks

@Itssshikhar
Copy link
Contributor Author

Yeah, sure.

There is a similar issue in Pytorch (pytorch/pytorch#113496) which causes the same error. Reason being, initialization error in the forward pass, which causes FSDP to fail.

The Fix seems fairly simple, as we just have to run forward pass once using dummy values, before initializing FSDP.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fix makes sense to me, thanks for the explanation. Could you document and add the link to that issue on top of the _init_fsdp func so we can fully trace why this is needed?

Also please do pip install -e .[quality]; make fixup and this will fix the quality tests.

@Itssshikhar Itssshikhar marked this pull request as ready for review October 15, 2024 07:04
@Itssshikhar
Copy link
Contributor Author

@muellerzr @SunMarc
All the tests have passed, but one remains tests_non_models that requires to have CUDA.

It would be great if you guys can see to it once and if there's anything from my end that needs to be done?!

Thanks!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a comment to show what to fix in order to pass the CI !

@@ -4911,3 +4911,33 @@ def test_get_optimizer_group(self):
param = next(model.parameters())
group = trainer.get_optimizer_group(param)
self.assertIn(param, group["params"])


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pass a @require_cuda decorator for this test !

@SunMarc SunMarc merged commit 4de1bdb into huggingface:main Oct 15, 2024
24 of 25 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@ringohoffman ringohoffman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR broke the test in tests/trainer/test_trainer_fsdp.py, which actually tests initializing a trainer using FSDP.

Given that there are also some flaws in the logic of this PR, it might be worth reverting this so it can be properly relanded.

@SunMarc

dtype=torch.long,
device=device,
)
for name in model.forward.__code__.co_varnames
Copy link
Contributor

@ringohoffman ringohoffman Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the variable names inside of forward... not the parameters to forward. I think you probably meant to do something like inspect.signature.

Comment on lines +296 to +301
name: torch.ones(
(1, 512),
dtype=torch.long,
device=device,
)
for name in model.forward.__code__.co_varnames
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not every parameter to forward is a tensor, but you are sending in a tensor for every value.

@Qubitium
Copy link
Contributor

Regression as result of this merge for trl/sft + fsdp training on 2x gpu.

TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'

trl 0.11.4
accelerate 1.0.1
transformers 4.46.0.dev
File "/python/ai/train/sft_trainer.py", line 380, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/python/ai/train/sft_trainer.py", line 380, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 401, in __init__
    super().__init__(
  File "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 401, in __init__
    super().__init__(
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 639, in __init__
    self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 165, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 305, in _init_fsdp
    _ = model(**dummy_input)
        ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 639, in __init__
    self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 305, in _init_fsdp
    _ = model(**dummy_input)
        ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'
  File "/root/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: LlamaForCausalLM.forward() got an unexpected keyword argument 'args'

@SunMarc
Copy link
Member

SunMarc commented Oct 16, 2024

Thanks for the heads-up @Qubitium @ringohoffman ! I will revert this PR !

SunMarc added a commit that referenced this pull request Oct 16, 2024
@Itssshikhar
Copy link
Contributor Author

Thanks for info @Qubitium @ringohoffman on the PR. I'll try to resolve the errors.

muellerzr pushed a commit that referenced this pull request Oct 16, 2024
Revert "Fix FSDP resume Initialization issue (#34032)"

This reverts commit 4de1bdb.
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* Fix FSDP Initialization for resume training

* Added init_fsdp function to work with dummy values

* Fix FSDP initialization for resuming training

* Added CUDA decorator for tests

* Added torch_gpu decorator to FSDP tests

* Fixup for failing code quality tests
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
Revert "Fix FSDP resume Initialization issue (huggingface#34032)"

This reverts commit 4de1bdb.
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Fix FSDP Initialization for resume training

* Added init_fsdp function to work with dummy values

* Fix FSDP initialization for resuming training

* Added CUDA decorator for tests

* Added torch_gpu decorator to FSDP tests

* Fixup for failing code quality tests
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Revert "Fix FSDP resume Initialization issue (huggingface#34032)"

This reverts commit 4de1bdb.
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Fix FSDP Initialization for resume training

* Added init_fsdp function to work with dummy values

* Fix FSDP initialization for resuming training

* Added CUDA decorator for tests

* Added torch_gpu decorator to FSDP tests

* Fixup for failing code quality tests
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
Revert "Fix FSDP resume Initialization issue (huggingface#34032)"

This reverts commit 4de1bdb.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Load fsdp+lora checkpoint error
7 participants
  NODES
COMMUNITY 2
innovation 2
Project 5
USERS 1