Skip to content

Support for Label-Dependent Loss Functions (e.g., Supervised Contrastive Loss) #32

@penguinwang96825

Description

@penguinwang96825

First of all, thank you for developing GradCache and making it available for the community. It's been incredibly useful for my work.

Currently, GradCache supports loss functions that do not require label information, such as SimCLR. However, I would like to use GradCache with label-dependent loss functions like the Supervised Contrastive (SupCon) loss.

The current implementation of contrastive_loss in the README only supports inputs without labels. Here is a sample code snippet from the README for reference:

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from grad_cache.functional import cached, cat_input_tensor

@cached
@autocast()
def call_model(model, input):
    return model(**input).pooler_output

@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
    target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
    scores = torch.matmul(x, y.transpose(0, 1))
    return F.cross_entropy(scores, target=target)

Could you provide guidance on how to incorporate label information in the contrastive_loss function with GradCache? Specifically, how can we adapt the current GradCache framework to support supervised loss functions like the SupCon loss?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions