-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Allow custom args in custom_generate
Callables and unify generation args structure
#40586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Allow custom args in custom_generate
Callables and unify generation args structure
#40586
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
model_kwargs: dict, | ||
assistant_model: Optional["PreTrainedModel"] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are actually optional, so better to have them as keyword args
src/transformers/generation/utils.py
Outdated
generation_mode = generation_config.get_generation_mode(assistant_model) | ||
|
||
if streamer is not None and (generation_config.num_beams > 1): | ||
if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we compute generation mode early, these checks should be moved into a validation method (make validate_assistant more general?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that makes sense! _validate_generation_mode(**generation_mode_kwargs)
?
We could also make the streamer
part of the generation mode kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and synced gpus?
custom_generate
Callables and unify generation args structure
custom_generate
Callables and unify generation args structurecustom_generate
Callables and unify generation args structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Partial review]
commit beb2b5f Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 16:03:25 2025 +0200 also standardize _get_stopping_criteria commit 15c2566 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 15:48:38 2025 +0200 watch super.generate() usages commit 67dd845 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 14:44:32 2025 +0200 ops commit 4655dfa Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 14:41:36 2025 +0200 wrong merge commit 4647814 Merge: a72c2c4 8564e21 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 14:36:15 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into fix-custom-gen-from-function2 commit a72c2c4 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 14:04:59 2025 +0200 ops5 commit e72f914 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 12:06:19 2025 +0200 ops4 commit 12ca97b Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:58:59 2025 +0200 ops3 commit 8cac6c6 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:43:03 2025 +0200 ops2 commit 4681a7d Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:40:51 2025 +0200 ops commit 0d72aa6 Merge: e0d47e9 5bb6186 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:37:28 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 5bb6186 Merge: 44973da b0db5a0 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:36:30 2025 +0200 Merge branch 'main' into remove-constrained-bs commit 44973da Merge: 1ddab4b 893d89e Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 11:29:48 2025 +0200 Merge commit '893d89e5e6fac7279fe4292bfa3b027172287162' into remove-constrained-bs commit e0d47e9 Merge: 88128e4 1ddab4b Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 10:52:50 2025 +0200 Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2 commit 88128e4 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Mon Sep 1 10:44:38 2025 +0200 fix custom generate args, refactor gen mode args commit 1ddab4b Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Sun Aug 31 21:03:53 2025 +0200 fix commit 6095fdd Merge: 4a8b6d2 04addbc Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 17:49:16 2025 +0200 Merge branch 'remove-constrained-bs' of github.com:manueldeprada/transformers into remove-constrained-bs commit 4a8b6d2 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 17:48:25 2025 +0200 restore and deprecate beam obkects commit 04addbc Merge: e800c78 becab2c Author: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Thu Aug 28 14:38:29 2025 +0200 Merge branch 'main' into remove-constrained-bs commit e800c78 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 14:38:10 2025 +0200 tests gone after green commit 33971d2 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 14:07:11 2025 +0200 tests green, changed handling of deprecated methods commit ab30383 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 12:58:01 2025 +0200 tests fix commit ec74274 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 12:32:05 2025 +0200 ops commit 0fb1900 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 11:45:16 2025 +0200 whoops commit c946bea Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 11:35:36 2025 +0200 testing... commit 924c0de Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 11:22:46 2025 +0200 sweeep ready for tests commit b05aa77 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Thu Aug 28 11:13:01 2025 +0200 restore and deprecate constraints commit 9c7962d Merge: fceeb38 c17bf30 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 20:44:21 2025 +0200 Merge branch 'remove-group-bs' into remove-constrained-bs commit c17bf30 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 17:00:50 2025 +0200 fix test commit d579aee Merge: 822efd8 ed5dd29 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 16:04:31 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 822efd8 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 15:59:51 2025 +0200 aaand remove tests after all green!! commit 62cb274 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 11:48:19 2025 +0200 fix commit c89c892 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Wed Aug 27 11:45:20 2025 +0200 testing that hub works the same commit fceeb38 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Tue Aug 26 20:06:59 2025 +0200 draft commit 6a9b384 Merge: 8af3af1 58cebc8 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Tue Aug 26 15:00:05 2025 +0200 Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs commit 8af3af1 Author: Manuel de Prada Corral <manueldeprada@gmail.com> Date: Tue Aug 26 11:55:45 2025 +0200 Squashed commit remove-constrastive-search
beb2b5f
to
331a87a
Compare
src/transformers/generation/utils.py
Outdated
prepared_logits_processor = self._get_logits_processor( | ||
generation_config=generation_config, | ||
input_ids_seq_length=input_ids_length, | ||
encoder_input_ids=inputs_tensor, | ||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | ||
logits_processor=logits_processor, | ||
device=inputs_tensor.device, | ||
model_kwargs=model_kwargs, | ||
negative_prompt_ids=negative_prompt_ids, | ||
negative_prompt_attention_mask=negative_prompt_attention_mask, | ||
**logits_processor_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the case of logit_processors, I think it adds clarity to encapsulate into logits_processor_kwargs, no change to the signature. Lmk if you prefer to completely remove logits_processor_kwargs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
A few more upgrades and I think we're golden 👌
src/transformers/generation/utils.py
Outdated
@@ -2317,15 +2354,14 @@ def generate( | |||
stopping_criteria, | |||
prefix_allowed_tokens_fn, | |||
synced_gpus, | |||
assistant_model, | |||
gen_mode_kwargs.pop("assistant_model"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: in this function call, replace all positional arguments by keyword arguments. Then, this line is no longer needed, **gen_mode_kwargs
works :)
Also, **logits_processors_kwargs
can be used here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually logits_processors_kwargs is not doing anything right now (until we clean up the generate
signature), packing them is just creating an extra abstraction for
logits_processors_kwargs = {
"prefix_allowed_tokens_fn": prefix_allowed_tokens_fn,
"negative_prompt_ids": negative_prompt_ids,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
}
so I think better remove it right? less abstraction to have the args themselves
src/transformers/generation/utils.py
Outdated
generation_mode = generation_config.get_generation_mode(assistant_model) | ||
|
||
if streamer is not None and (generation_config.num_beams > 1): | ||
if streamer is not None and generation_mode == GenerationMode.BEAM_SEARCH: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that makes sense! _validate_generation_mode(**generation_mode_kwargs)
?
We could also make the streamer
part of the generation mode kwargs
src/transformers/generation/utils.py
Outdated
target_tokenizer=gen_mode_kwargs.get("tokenizer"), | ||
assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), | ||
model_kwargs=model_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_tokenizer=gen_mode_kwargs.get("tokenizer"), | |
assistant_tokenizer=gen_mode_kwargs.get("assistant_tokenizer"), | |
model_kwargs=model_kwargs, | |
model_kwargs=model_kwargs, | |
**gen_mode_kwargs |
If gen_mode_kwargs
contains keys that are not used, it will an exception -> this is wanted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wont work if we include streamer and synced_gpus, which are generation_mode_args but not for the candidate_generator.
My suggestion is that we leave as is in this PR and then I will open another one which will put _get_candidate_generator
inside _assisted_decoding
, unifying the interface for decoding loops, and making generation_mode_args the unique data pack for decoding methods, unlike _get_candidate_generator
, which is another thing and should not be called in generate.
We only need to solve the issue of passing inputs_tensors
(which is useful that is generally available for decoding loops) -> future work anyway
[For maintainers] Suggested jobs to run (before merge) run-slow: csm, dia, kyutai_speech_to_text, musicgen, musicgen_melody, rag |
This PR depends on #40518
The intent of this PR is to allow custom CALLABLE generation modes (that reuse generate's preparation step) to define their own parameters.
To do this, we need to inspect the signature and remove those parameters before
validate_model_kwargs
.So to do this, we need some notion of
gen_mode_kwargs
. I introduce that andlogit_processor_kwargs
, which also allows as to move to a more functional stage-by-stage functional approach, as discussed before @gante. An added benefit is that now there are no weird parametrizations likeget_generation_mode(assistant_model)
but rather a much more naturalget_generation_mode(**gen_mode_kwargs)