-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
First of all, thank you for developing GradCache and making it available for the community. It's been incredibly useful for my work.
Currently, GradCache supports loss functions that do not require label information, such as SimCLR. However, I would like to use GradCache with label-dependent loss functions like the Supervised Contrastive (SupCon) loss.
The current implementation of contrastive_loss
in the README only supports inputs without labels. Here is a sample code snippet from the README for reference:
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from grad_cache.functional import cached, cat_input_tensor
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Could you provide guidance on how to incorporate label information in the contrastive_loss
function with GradCache? Specifically, how can we adapt the current GradCache framework to support supervised loss functions like the SupCon loss?
Metadata
Metadata
Assignees
Labels
No labels