-
-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Avoid materializing the entire logit matrix for logp calculations. #2772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…avoid recomputation of ref logpss
Wait I thought we didn't materialize logits but folded it in a torch.compile kernel @Datta0 @pluesclues |
unsloth/models/rl_replacements.py
Outdated
hidden_states = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1).logits | ||
# Add dummy input_id at the end. Last logp is exluded. | ||
input_ids_batch = torch.cat((input_ids_batch[:, -logits_to_keep:], torch.zeros((batch_size, 1), dtype=input_ids_batch.dtype, device=input_ids_batch.device)), dim=-1) | ||
logps = -1 * linear_cross_entropy(hidden_states.to(dtype=lm_head.dtype), lm_head, input_ids_batch, reduction="none", impl="cce") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um, why do we need cross entropy in get_per_token_logps
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently, these will return logprobs
and are equivalent to selective softmax
? But I am not sure if we want to return logprobs in this matrix because like @danielhanchen said we folded it into a torch.compile kernel. I am questioning if the memory saved here is actually from the cut cross entropy loss rather than the chunked concatenation of the hidden states. I am currently at work but we can check later if chunking the hidden states conserves similar amounts of memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing I see is that according to to this person's post, there also seems to be some speed up as well, what we can do instead of materializing the logits outside of here is also put the linear_cut_cross_entropy
in place of the code in selective_softmax
so we get speed up and memory and do not materialize logits outside of the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pluesclues That would work too. The only reason I did it this way is to ensure consistency with HF. That being said, we may, at some point, need to write a custom kernel anyway to run fused operations on the logit matrix chunk. Currently, the implementation in HF scales the logits with temperature before computing logps (https://github.com/huggingface/trl/blob/4c92de00001379ceedaf073512ce4df5da304d08/trl/trainer/grpo_trainer.py#L871).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I just tested this method inside of the kernel, its as I suspected, we cannot use linear_cross_entropy
which is a torch.compile
kernel in itself inside of a torch.compile kernel
, I confirmed this by running ref = -1 * linear_cross_entropy(ref_hidden_states_j[:, :-1, :].to(dtype=lm_head.dtype), lm_head, input_ids_j, reduction="none", impl="cce")
right before accumulate_chunk
outside the kernel and also called this line inside the kernel, outside the kernel it works just fine, inside it seems to break. I still haven't tested the speed up on my machine yet, but so far it looks like we can either merge this or just change the way we calculate logprobs to exactly how CCE does it in their kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the memory-saving speed-up I reported, I believe the manner in which I profiled it does not provide an accurate account. I am only logging the peak memory allocated throughout a training step, clearing it at the beginning. This approach fails to account for the memory allocated for the old and ref policies as they are computed and cached outside the new policy update loop, i.e., every _step % num_interations == 0
. I expected much higher memory savings. I would appreciate some help with this.
Moreover, I would like to confirm that UNSLOTH_USE_NEW_MODEL
being set to 0
must be interpreted as the pathway to UnslothEfficientGRPO
as is the case in the current implementation.
Also, UNSLOTH_RETURN_HIDDEN_STATES
is set to 1
before executing the forward pass in _get_per_token_logps
but never reset to its original value, creating an unintended side-effect. This is done in a couple of places. Would it not be better to reset it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have the wandb of memory usage over time (as tracked by trl/wandb itself) of the run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zkpranav I'm actually quite surprised Also the loss plots look mostly the same - I'm assuming it's the generation dynamics - can you set temperature to a small number for GRPO say 0.001, top_k = 1 and seed = 3407 or something - the main issue is the losses arent exactly matching |
@danielhanchen Sure, I'll write the tests and profile with a range of temperature and top_k values. I believe the differences in the loss value are a direct result of the differences in reward and KL divergence. I have seeded the runs, but perhaps I missed something? I could also try with I am also confused as to how the reward values diverge, albeit slightly. The changes in this PR only affect what comes after the reward and advantage calculation. ![]() |
Oh for GRPO temperature = 1.0, so sampling is done - seeding won't work since vLLM and other systems are not fully determinstic - it's best to change temperature = 0.001 and top_k = 1 to at least provide the illusion of mostly non sampling |
@danielhanchen I also did a check on the selective soft max function vs CCE and the logits generally matched up with slight differences which I am assuming is part of the reason that the losses do not exactly match up along with other non deterministic parts of the system. |
Oh interesting |
Related to: unslothai/unsloth-zoo#172
Avoids materializing the entire logit matrix for ref, old, and new policy’s log probability calculation using CCE with no reductions.
selective_log_softmax(e @ c.T, index) == -cce(e, c, index, reduction="none”)
The default invocation of
linear_cross_entropy
applies gradient filtering, which can be turned off by settingfilter_eps
to-inf
.num_generations = 8
num_iterations = 4
batch_size = 8
unsloth_num_chunks = 4
max_prompt_length = 512
max_completion_length = 1024
vocab_size = 128256
Reduces VRAM usage by around 15% - 20%, though the memory usage should be lower still with CCE. Moreover, for larger values of batch_size, max_completion_length, and vocab_size, the difference will be much more profound.
Other changes -
_get_per_token_logps
to accept a batch_size (https://github.com/huggingface/trl/blob/5206c927f6bb161e45114531b0bca8286acfeada/trl/trainer/grpo_trainer.py#L853). Removes calc_logprob_flag.compute_loss
(before calling intoUnslothEfficientGRPO
), ensuring a consistent interface with HF.