Skip to content

Conversation

zkpranav
Copy link

@zkpranav zkpranav commented Jun 19, 2025

Related to: unslothai/unsloth-zoo#172

Avoids materializing the entire logit matrix for ref, old, and new policy’s log probability calculation using CCE with no reductions.
selective_log_softmax(e @ c.T, index) == -cce(e, c, index, reduction="none”)

The default invocation of linear_cross_entropy applies gradient filtering, which can be turned off by setting filter_eps to -inf.

num_generations = 8
num_iterations = 4
batch_size = 8
unsloth_num_chunks = 4
max_prompt_length = 512
max_completion_length = 1024
vocab_size = 128256

gpu_mem loss

Reduces VRAM usage by around 15% - 20%, though the memory usage should be lower still with CCE. Moreover, for larger values of batch_size, max_completion_length, and vocab_size, the difference will be much more profound.

Other changes -

  1. Modifies _get_per_token_logps to accept a batch_size (https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L853). Removes calc_logprob_flag.
  2. Computes logps in compute_loss (before calling into UnslothEfficientGRPO), ensuring a consistent interface with HF.
  3. Removes explicit computation of ref logps since HF does that now (https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L1292).

@danielhanchen
Copy link
Contributor

Wait I thought we didn't materialize logits but folded it in a torch.compile kernel @Datta0 @pluesclues

hidden_states = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1).logits
# Add dummy input_id at the end. Last logp is exluded.
input_ids_batch = torch.cat((input_ids_batch[:, -logits_to_keep:], torch.zeros((batch_size, 1), dtype=input_ids_batch.dtype, device=input_ids_batch.device)), dim=-1)
logps = -1 * linear_cross_entropy(hidden_states.to(dtype=lm_head.dtype), lm_head, input_ids_batch, reduction="none", impl="cce")
Copy link
Collaborator

@Datta0 Datta0 Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, why do we need cross entropy in get_per_token_logps?

Copy link
Contributor

@pluesclues pluesclues Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, these will return logprobs and are equivalent to selective softmax? But I am not sure if we want to return logprobs in this matrix because like @danielhanchen said we folded it into a torch.compile kernel. I am questioning if the memory saved here is actually from the cut cross entropy loss rather than the chunked concatenation of the hidden states. I am currently at work but we can check later if chunking the hidden states conserves similar amounts of memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I see is that according to to this person's post, there also seems to be some speed up as well, what we can do instead of materializing the logits outside of here is also put the linear_cut_cross_entropy in place of the code in selective_softmax so we get speed up and memory and do not materialize logits outside of the kernel.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pluesclues That would work too. The only reason I did it this way is to ensure consistency with HF. That being said, we may, at some point, need to write a custom kernel anyway to run fused operations on the logit matrix chunk. Currently, the implementation in HF scales the logits with temperature before computing logps (https://github.com/huggingface/trl/blob/4c92de00001379ceedaf073512ce4df5da304d08/trl/trainer/grpo_trainer.py#L871).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I just tested this method inside of the kernel, its as I suspected, we cannot use linear_cross_entropy which is a torch.compile kernel in itself inside of a torch.compile kernel, I confirmed this by running ref = -1 * linear_cross_entropy(ref_hidden_states_j[:, :-1, :].to(dtype=lm_head.dtype), lm_head, input_ids_j, reduction="none", impl="cce") right before accumulate_chunk outside the kernel and also called this line inside the kernel, outside the kernel it works just fine, inside it seems to break. I still haven't tested the speed up on my machine yet, but so far it looks like we can either merge this or just change the way we calculate logprobs to exactly how CCE does it in their kernel.

Copy link
Author

@zkpranav zkpranav Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the memory-saving speed-up I reported, I believe the manner in which I profiled it does not provide an accurate account. I am only logging the peak memory allocated throughout a training step, clearing it at the beginning. This approach fails to account for the memory allocated for the old and ref policies as they are computed and cached outside the new policy update loop, i.e., every _step % num_interations == 0. I expected much higher memory savings. I would appreciate some help with this.

Moreover, I would like to confirm that UNSLOTH_USE_NEW_MODEL being set to 0 must be interpreted as the pathway to UnslothEfficientGRPO as is the case in the current implementation.
Also, UNSLOTH_RETURN_HIDDEN_STATES is set to 1 before executing the forward pass in _get_per_token_logps but never reset to its original value, creating an unintended side-effect. This is done in a couple of places. Would it not be better to reset it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have the wandb of memory usage over time (as tracked by trl/wandb itself) of the run?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_mem

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much smaller run with double the batch size. The CCE version completes its 4 training steps in 7 mins, whereas the current implementation OOMs on my machine after 12 mins.
In this case, the amount of memory saved is roughly 25%.

batch_size = 16
unsloth_num_chunks = 4

gpu_mem_oom

@zkpranav
Copy link
Author

These are the loss and grad_norm graphs for the current version and CCE with no gradient filtering.

loss

@zkpranav
Copy link
Author

I believe this is relevant to -
#2752
#2736
#2702

@danielhanchen
Copy link
Contributor

@zkpranav I'm actually quite surprised selective_log_softmax(e @ c.T, index) == -cce(e, c, index, reduction="none”) is equivalent :) Is it possible to write a simply Python test to make this into a small test - that would be cool - also good to check torch.allclose or torch.dist()

Also the loss plots look mostly the same - I'm assuming it's the generation dynamics - can you set temperature to a small number for GRPO say 0.001, top_k = 1 and seed = 3407 or something - the main issue is the losses arent exactly matching

@zkpranav
Copy link
Author

zkpranav commented Jun 23, 2025

@danielhanchen Sure, I'll write the tests and profile with a range of temperature and top_k values. I believe the differences in the loss value are a direct result of the differences in reward and KL divergence. I have seeded the runs, but perhaps I missed something? I could also try with beta=0.0 to try and isolate the issue.

I am also confused as to how the reward values diverge, albeit slightly. The changes in this PR only affect what comes after the reward and advantage calculation.

Screenshot 2025-06-23 at 5 46 41 PM

@danielhanchen
Copy link
Contributor

Oh for GRPO temperature = 1.0, so sampling is done - seeding won't work since vLLM and other systems are not fully determinstic - it's best to change temperature = 0.001 and top_k = 1 to at least provide the illusion of mostly non sampling

@pluesclues
Copy link
Contributor

@danielhanchen I also did a check on the selective soft max function vs CCE and the logits generally matched up with slight differences which I am assuming is part of the reason that the losses do not exactly match up along with other non deterministic parts of the system.

@danielhanchen
Copy link
Contributor

Oh interesting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants