Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 115 additions & 92 deletions src/transformers/generation/utils.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/transformers/models/csm/generation_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def _sample(
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
"""
Expand Down
35 changes: 17 additions & 18 deletions src/transformers/models/dia/generation_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,20 @@ def _main_generate_loop(
):
# ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation

generation_mode_kwargs = self._extract_generation_mode_kwargs(
custom_generate,
kwargs,
synced_gpus,
assistant_model,
streamer,
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
generation_mode = generation_config.get_generation_mode(assistant_model)

self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
self._validate_generation_mode(generation_mode, generation_mode_kwargs)

# 2. Set generation parameters if not already defined
if synced_gpus is None:
Expand Down Expand Up @@ -308,7 +314,7 @@ def _main_generate_loop(
)

if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))

if streamer is not None:
streamer.put(input_ids.cpu())
Expand Down Expand Up @@ -347,18 +353,10 @@ def _main_generate_loop(
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length
generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
)

# 8. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)

if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

# 9. prepare logits processors and stopping criteria
# 8. prepare logits processors and stopping criteria
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
Expand All @@ -371,7 +369,9 @@ def _main_generate_loop(
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
generation_config=generation_config,
stopping_criteria=stopping_criteria,
tokenizer=generation_mode_kwargs.get("tokenizer"),
)

# Set model_kwargs `use_cache` so we can use it later in forward runs
Expand All @@ -393,8 +393,7 @@ def _main_generate_loop(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**generation_mode_kwargs,
**model_kwargs,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ def _prepare_model_inputs(
self.codec_model._prepare_cache_for_generation(
generation_config=self.codec_model.generation_config,
model_kwargs=temporary_model_kwargs,
assistant_model=None,
generation_mode=None,
batch_size=batch_size,
max_cache_length=self.config.codec_config.sliding_window,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _prepare_model_inputs(
self.codec_model._prepare_cache_for_generation(
generation_config=self.codec_model.generation_config,
model_kwargs=temporary_model_kwargs,
assistant_model=None,
generation_mode=None,
batch_size=batch_size,
max_cache_length=self.config.codec_config.sliding_window,
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def generate(
self._prepare_cache_for_generation(
generation_config,
model_kwargs,
assistant_model=None,
generation_mode=None,
batch_size=batch_size,
max_cache_length=max_cache_length,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,7 @@ def generate(
self._prepare_cache_for_generation(
generation_config,
model_kwargs,
assistant_model=None,
generation_mode=None,
batch_size=batch_size,
max_cache_length=max_cache_length,
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ def extend_enc_output(tensor, num_beams=None):
self._prepare_cache_for_generation(
generation_config,
model_kwargs,
assistant_model=None,
generation_mode=None,
batch_size=input_ids.shape[0],
max_cache_length=generation_config.max_length - 1,
)
Expand Down