Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model #34629

Merged
merged 28 commits into from
Nov 26, 2024

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR commented Nov 6, 2024

What does this PR do?

This PR is to allow the new version of the GLM-4 model to use different rotary_pos_emb.
I am still researching how to modify modular_glm.py so that model_glm.py can automatically generate an additional parameter called apply_rotary_pos_emb.

Who can review?

This PR may Cyrilvallez to help.

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title change apply_rotary_pos_emb change apply_rotary_pos_emb of Glmmodel(Draft) Nov 6, 2024
@zRzRzRzRzRzRzR
Copy link
Contributor Author

zRzRzRzRzRzRzR commented Nov 20, 2024

Now this implementation is compatible with both GLM-Edge and GLM-4 models, @Cyrilvallez , I would like to know how to modify modular_glm.py to achieve automatic updates, because some parts of the implementation in modeling_glm.py need to add new parameters to work properly.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry for the delay, as I've said we were all in Martinique for our offsite the past week!
You can check my comments, but unless I'm very much mistaken or missing something, most of the changes you propose are no-ops, and you only need to change q, q_pass/k, k_pass in apply_rptary_pos_emb
BTW, you tagged the wrong Cyril in the PR 🤣

Comment on lines 79 to 93
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, rotary_percent=0.5, device=None):
super().__init__()

self.dim = dim
self.rotary_percent = rotary_percent
self.dim = dim * rotary_percent
self.max_position_embeddings = max_position_embeddings
self.base = base

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
self.register_buffer("inv_freq", inv_freq)

def forward(self, x, position_ids=None):
batch_size, seq_len, head_dim = x.shape
device = x.device
dtype = x.dtype

seq_idx = torch.arange(0, self.max_position_embeddings, device=device).float()
idx_theta = torch.outer(seq_idx, self.inv_freq)

if position_ids is not None:
idx_theta = idx_theta[position_ids[0]]
else:
idx_theta = idx_theta[:seq_len]
if self.rotary_percent == 0.5:
idx_theta = torch.cat([idx_theta, idx_theta], dim=-1) # for glm-4-9b

device_type = device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = torch.cos(idx_theta).to(dtype=dtype)
sin = torch.sin(idx_theta).to(dtype=dtype)

cos = cos[None, :, :].expand(batch_size, seq_len, -1)
sin = sin[None, :, :].expand(batch_size, seq_len, -1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you modified the RotaryEmbedding class here.

Comment on lines 169 to 145
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
cos = cos[..., : int(cos.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1)
sin = sin[..., : int(sin.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it here either, this is exactly the same as before without modifying the RotaryEmbedding class, but will only work with rotary_percent=0.5 or rotary_percent=1, and is much more confusing IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are two different models, 0.5 and 1. In the config, the glm-edge series model needs to be set to 1.
https://huggingface.co/ZP2HF/glm-edge-4b-chat/blob/6a5e92d0092bba5f94abd471720238b6dda8f9de/config.json#L11
Here I have made annotations.

Comment on lines 147 to 150
# Keep rotary_percent(half or not) for later concatenation
rotary_dim = int(q.shape[-1] * rotary_percent)
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed this is for me the only part that should need to be modified. The rest should not need any modification

src/transformers/models/glm/modeling_glm.py Outdated Show resolved Hide resolved
@zRzRzRzRzRzRzR
Copy link
Contributor Author

I referred to your plan and modified it to look like this. I believe the mathematical logic of this implementation is equivalent.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's getting better! Still some issue in the rotary though I think
Then, once we agree on the changes, you'll need to apply the changes in modular instead of modeling 🤗

src/transformers/models/glm/modeling_glm.py Outdated Show resolved Hide resolved
Comment on lines 176 to 142
# Apply rotary embeddings on the first half
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
# Apply rotary embeddings to the rotary portion
q = (q * cos[..., :rotary_dim]) + (rotate_half(q) * sin[..., :rotary_dim])
k = (k * cos[..., :rotary_dim]) + (rotate_half(k) * sin[..., :rotary_dim])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you don't need to modify anything, you are basically slicing up to the full length which is useless

src/transformers/models/glm/modeling_glm.py Outdated Show resolved Hide resolved
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title change apply_rotary_pos_emb of Glmmodel(Draft) change apply_rotary_pos_emb of Glmmodel for GLM-Edge-Model Nov 21, 2024
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title change apply_rotary_pos_emb of Glmmodel for GLM-Edge-Model change apply_rotary_pos_emb of Glmmodel for GLM-Edge Series model Nov 21, 2024
@zRzRzRzRzRzRzR
Copy link
Contributor Author

I would like to know if there are any improvements needed for this version, and also, I would like to know if @Cyrilvallez could guide me on how to modify modular_glm.py to make good changes.

query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
        )

Let it automatically generate to modeling_glm.py.

@ArthurZucker
Copy link
Collaborator

cc @Cyrilvallez but I think it would help to remove unrelated changes! 🤗

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! That's it! 🤗
Final comment is that we don't even need to change the signature of apply_rotary_pos_embed as we can retrieve the rotary_dim from cos and sin, I forgot before sorry! That way, we don't even have to modify the attention implementation, which is a big win for the modular!

Also, please remove the unrelated notebook change you added (I assume as a mistake) 🤗

Comment on lines 173 to 174
# Keep half or full tensor for later concatenation
rotary_dim = int(q.shape[-1] * partial_rotary_factor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Keep half or full tensor for later concatenation
rotary_dim = int(q.shape[-1] * partial_rotary_factor)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]

We actually don't need to pass the rotary_factor as an argument to the function!

@@ -142,7 +142,7 @@ def rotate_half(x):
return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

We actually don't need to pass the rotary_factor as an argument to the function! See next comment, that way we don't even have to modify the modular file for the Attentions!

Comment on lines 246 to 248
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to pass the extra arg! See above

Comment on lines 328 to 330
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines 442 to 444
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

@@ -68,7 +68,7 @@ def rotate_half(x):
return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, partial_rotary_factor=0.5):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to pass the extra arg! See above

@@ -85,6 +85,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor by which the rotary embedding.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, can be removed

Comment on lines 99 to 100
# Keep half or full tensor for later concatenation
rotary_dim = int(q.shape[-1] * partial_rotary_factor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Keep half or full tensor for later concatenation
rotary_dim = int(q.shape[-1] * partial_rotary_factor)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]

Same as above

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot a comment, you need to ensure the dim in an integer as well.

To automatically generate the modeling file from the modular, you can run

python utils/modular_model_converter.py --files_to_parse src/transformers/models/glm/modular_glm.py

from the root of the transformers repo 🤗

self.rotary_emb = GlmRotaryEmbedding(
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
dim=config.head_dim * config.partial_rotary_factor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dim=config.head_dim * config.partial_rotary_factor,
dim=int(config.head_dim * config.partial_rotary_factor),

You need int here as well

@zRzRzRzRzRzRzR
Copy link
Contributor Author

This modification should meet the requirements, and I have tried to remove all unnecessary code. The remaining code is all the code that will be used.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good, thanks for iterating!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, small convention for q_rot, q_pass, and a nit! We can merge afterwards!

Comment on lines 100 to 101
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

we usually use these notations!

@@ -151,8 +152,11 @@ def __init__(self, config: GlmConfig):
[GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.partial_rotary_factor = config.partial_rotary_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.partial_rotary_factor = config.partial_rotary_factor

I don't think this is used no?

@Cyrilvallez
Copy link
Member

All good, merging!

@Cyrilvallez Cyrilvallez merged commit 5a45617 into huggingface:main Nov 26, 2024
7 checks passed
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR deleted the glm-4-1108 branch November 26, 2024 14:10
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…ggingface#34629)

* change apply_rotary_pos_emb

* upload for glm-edge

* remove useless part

* follow the suggestion

* fix

* format

* format

* test

* format again

* format again

* remove modular change

* remove modular change

* this apply_rotary_pos_emb need modify?

* fix with this

* format

* format

* ruff check

* modify modular_glm failed

* remove partial_rotary_factor of function  partial_rotary_factor

* fix wrong change of examples/research_projects

* revert

* remove line 118

* use q_rot
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…ggingface#34629)

* change apply_rotary_pos_emb

* upload for glm-edge

* remove useless part

* follow the suggestion

* fix

* format

* format

* test

* format again

* format again

* remove modular change

* remove modular change

* this apply_rotary_pos_emb need modify?

* fix with this

* format

* format

* ruff check

* modify modular_glm failed

* remove partial_rotary_factor of function  partial_rotary_factor

* fix wrong change of examples/research_projects

* revert

* remove line 118

* use q_rot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
  NODES
chat 1
COMMUNITY 2
innovation 1
Note 1
Project 8
USERS 1