-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mistral-related models for QnA #34045
Conversation
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.__init__ with Llama->Mistral,transformer->model | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.model = MistralModel(config) | ||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.get_input_embeddings with transformer->model | ||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.set_input_embeddings with transformer->model | ||
def set_input_embeddings(self, value): | ||
self.model.embed_tokens = value | ||
|
||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) | ||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.forward with Llama->Mistral, transformer->model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's more of a stylistic choice: individual copies vs. include (unnecessary) base model prefix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If #34061 gets merged, we can top-level copy from llama without any problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use # Ignore copy
on the single place where the copy does not match!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, perfect I'll change it later and ping you when ready ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general! would be nice to have a single # Copied from at the top of the class (either # Ignore copy or just don't copy from llama for one of them!)
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.__init__ with Llama->Mistral,transformer->model | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.model = MistralModel(config) | ||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.get_input_embeddings with transformer->model | ||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.set_input_embeddings with transformer->model | ||
def set_input_embeddings(self, value): | ||
self.model.embed_tokens = value | ||
|
||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) | ||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering.forward with Llama->Mistral, transformer->model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use # Ignore copy
on the single place where the copy does not match!
@ArthurZucker Changed it to top-level copied from now. Lmk if I should change something else. |
) | ||
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model | ||
class MistralForQuestionAnswering(MistralPreTrainedModel): | ||
base_model_prefix = "model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_model_prefix = "model"
is due to the llama stuff I mentioned, otherwise the classes have different structures and copied from will fail in an error.
That's it! Merging 🤗 |
* mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from
* mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from
* mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from
What does this PR do?
Adds question answering to mistral, mixtral, qwen2, qwen2moe. Either we take every model due to the copy statements or we need to ignore it in the copied checks. Based on #29168 but using copied from instead.
Motivation: We have a benchmark paper at https://github.com/LSX-UniWue/SuperGLEBer which uses the transformers QnA models for simplicity but due to it not being available in main, it's manually patched in. Would be great to see it getting into main!
Fixes #28908
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @ArthurZucker