Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient checkpointing in Qwen2VL ViT #34724

Merged
merged 3 commits into from
Nov 19, 2024

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Nov 14, 2024

What does this PR do?

Support gradient checkpointing for Qwen2VL ViT part. The current implementation in main branch only supports gradient checkpointing in language part. This PR further supports checkpointing vision encoder.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @ArthurZucker, @amyeroberts, @qubvel

@li-plus li-plus force-pushed the qwen2vl-grad-ckpt branch 2 times, most recently from 9163b7f to cad72d1 Compare November 14, 2024 01:58
@qubvel
Copy link
Member

qubvel commented Nov 18, 2024

Hi @li-plus! Thanks for your contribution!

Can you please enable gradient checkpointing tests for this model to make sure it works properly? I see these tests are skipped on main, however, I was able to run them without issues locally.

@li-plus
Copy link
Contributor Author

li-plus commented Nov 18, 2024

Hi @li-plus! Thanks for your contribution!

Can you please enable gradient checkpointing tests for this model to make sure it works properly? I see these tests are skipped on main, however, I was able to run them without issues locally.

@qubvel Thanks for advice. I've re-enable these gradient checkpointing tests for Qwen2VL in the latest commit. They run just fine on my machine.

@qubvel
Copy link
Member

qubvel commented Nov 18, 2024

Thanks! Can you please also push an empty commit with the message [run-slow] qwen2_vl to trigger all model tests? We should be fine here, cause test_training_gradient_checkpointing is not a slow test, however just to double check everything else is fine 🙂

@li-plus
Copy link
Contributor Author

li-plus commented Nov 18, 2024

Thanks. Just pushed!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@li-plus
Copy link
Contributor Author

li-plus commented Nov 18, 2024

@qubvel It seems those failures are not related to this PR. Any idea?

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I checked, the same tests fail on main. Thanks for triggering slow tests!

@qubvel qubvel requested a review from ArthurZucker November 18, 2024 21:02
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks indeed there is supports_gradient_checkpointingset to True good catch

@ArthurZucker ArthurZucker merged commit 0db91c3 into huggingface:main Nov 19, 2024
16 of 18 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Support gradient checkpointing in Qwen2VL ViT

* Enable gradient checkpoint tests for Qwen2VL

* [run-slow] qwen2_vl
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Support gradient checkpointing in Qwen2VL ViT

* Enable gradient checkpoint tests for Qwen2VL

* [run-slow] qwen2_vl
@ShuaibinQi
Copy link

@li-plus

Thanks your commit!
when I use gradient_checkpointing for Qwen2VisionTransformerPretrainedModel, I meet this Error:

AttributeError: 'Qwen2VisionTransformerPretrainedModel' object has no attribute '_gradient_checkpointing_func'. Did you mean: 'gradient_checkpointing'?

Is it necessary to implement this '_gradient_checkpointing_func' in 'Qwen2VisionTransformerPretrainedModel' ?

@ShuaibinQi
Copy link

@li-plus

Thanks your commit! when I use gradient_checkpointing for Qwen2VisionTransformerPretrainedModel, I meet this Error:

AttributeError: 'Qwen2VisionTransformerPretrainedModel' object has no attribute '_gradient_checkpointing_func'. Did you mean: 'gradient_checkpointing'?

Is it necessary to implement this '_gradient_checkpointing_func' in 'Qwen2VisionTransformerPretrainedModel' ?

My transfromers version is latest, and it is:
transformers ==4.47.0

@li-plus li-plus deleted the qwen2vl-grad-ckpt branch December 16, 2024 12:36
@li-plus
Copy link
Contributor Author

li-plus commented Dec 16, 2024

@ShuaibinQi Hi, I did not reproduce this error using transformers 4.47.0 using this demo code:

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="cuda",
)
model.gradient_checkpointing_enable()
model.train()

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

print(f'before forward {torch.cuda.memory_allocated()/1e9=:.3f} GB, {torch.cuda.memory_reserved()/1e9=:.3f} GB')
output = model(**inputs, use_cache=False)
print(f'after forward {torch.cuda.memory_allocated()/1e9=:.3f} GB, {torch.cuda.memory_reserved()/1e9=:.3f} GB')
output.logits.sum().backward()
print(f'after backward {torch.cuda.memory_allocated()/1e9=:.3f} GB, {torch.cuda.memory_reserved()/1e9=:.3f} GB')

Did you use gradient_checkpointing_enable to enable gradient checkpointing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
  NODES
chat 1
COMMUNITY 3
Idea 1
idea 1
innovation 1
Project 5
USERS 1