Skip to main content
Version: 0.3

CCO input requirements

Loss function

When writing a loss function or a callable class for CCO, it must satisfy the following rules:

Examples

from torch import Tensor
from typing import Union, Tuple, List, Dict

def loss_fn(model_output: Tensor, target: Tensor) -> Union[Tensor, Tuple[Tensor, ...],List[Tensor, ...], Dict[str, Tensor]]:

# calculate loss...
loss: Tensor = some_scalar_tensor
# returns a loss as a tensor
return loss

# or a tuple of scalar tensors
return loss_1, loss_2, loss_3

# or a dictionary with scalar tensors as dictionary values
return {"cls_loss": loss_1, "box_loss": loss_2, "obj_loss": loss_3}

Dataset / Dataloader

Dataset / Dataloader output format

When writing a dataset class for CCO, it must return a tuple of length two (order matters)

  • First item (the input data to the model) can be one of the following:
    • single Tensor
    • tuple containing Tensor objects
    • list containing Tensor objects
    • dict with matching forward function argument names as string keys and Tensor values, for example, when a forward function has the following signature: def forward(self, x: Tensor, y: Tensor) the first item would be the dictionary {"x": tensor_1, "y": tensor_2}
  • Second item (the labels) can be any type

(See examples #1 and #2)

Dataloader input for multi-GPU distributed compression

When using multi-GPU, if you require finer control over the Dataset in the context of distributed compression, you may set the torch.utils.data.DataLoader initializer function so that it takes the following arguments by either keyword or as positional arguments:

(See example #3)

Example #1

In this example we return a tuple containing two Tensor objects:

from torch import Tensor
from typing import Tuple
from torch.utils.data import Dataset

class CustomDataset(Dataset):
...

def __getitem__(self, x) -> Tuple[Tensor, Tensor]:
...
return (input_tensor, label_tensor)

Example #2

If you require auxiliary information to be included, you may set the second item in the returned tuple to be a dictionary or a tuple (or any other type) containing any auxiliary information necessary as shown below:

from typing import Dict, Tuple
from torch import Tensor
from torch.utils.data import Dataset

class CustomDataset(Dataset):
...

def __getitem__(self, x) -> Tuple[Tensor, Tuple[Tensor,Dict]]:
...
return input_tensor, (labels, auxilary_info)

...

Example #3

In this example, the dataloader initializer function can take Distributed Compression related arguments for finer control during CCO:

from typing import Optional
from torch.utils.data import DataLoader ,Dataset
from clika_compression.settings import Settings, DistributedTrainingSettings

def get_train_loader(world_size: Optional[int], global_rank: Optional[int],
local_rank: Optional[int], train_dataset:Dataset):
# use the args passed
return DataLoader(train_dataset,batch_size=64, shuffle=True)

settings = Settings()

# set distributed learning settings
settings.distributed_training_settings = DistributedTrainingSettings(multi_gpu=True,use_sharding=True)

Metric function / class

When writing a metric function or class for CCO, it needs to be one of the following objects
(all metrics will be shown in the CCO output log):

  • A function that takes two inputs - (model_output: Tensor target: Tensor) and returns one of the Metric Return Types (see example #1)
  • A class that implements (or inherits from) a TorchMetrics-like API with the following methods (see examples #2 and #3)
    • The update method takes two inputs, (model_output: Tensor, target: Tensor) and updates the state of the object. The update method may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary).
    • The compute method takes no inputs, calculates the final metric, and returns it as one of the Metric Return Types. The return values will be shown on the output log epoch summary.
    • The reset method should be implemented when not inheriting from a TorchMetrics object. The reset method takes no inputs, resets the metric object's state attributes to their default values, and is called at the end of each epoch.
  • A stateless class inheriting from the torch.nn.Module that implements a forward method which takes two inputs, (model_output: Tensor, target: Tensor) and returns one of the Metric Return Types (see examples #4 and #5)
  • A stateful class inheriting from the torch.nn.Module that implements the update, compute and reset methods in a TorchMetrics-like API fashion (see example #6)
    • The update method takes two inputs, (model_output: Tensor, target: Tensor) and updates the state of the object. It may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary).
    • The compute method takes no inputs, calculate the final metric, and returns it as one of the Metric Return Types (return values will be shown on the output log epoch summary).
    • The reset method resets the metric object's state attributes to their default values. It is called at the end of each epoch and takes no inputs.

Metric return types

These are types that are acceptable for the output of a metric method.

Example #1

A lambda function example that takes two arguments and returns a tensor scalar

import torch

metric_fn = lambda model_output, target: torch.mean(
torch.argmax(model_output, 1) == target)

Example #2

Inheriting from torchmetrics.classification.MulticlassAccuracy

import torchmetrics
from torch import Tensor

class Accuracy(torchmetrics.classification.MulticlassAccuracy):
...

def update(self, model_output: Tensor, target: Tensor) -> None:
super().update(model_output, target)

def compute(self):
result = super().compute()
return result * 100

metric_fn = Accuracy(num_classes=num_classes)

Example #3

A custom class implementing TorchMetrics-like API

class Accuracy(object):

def __init__(self):
self.avgs = []

def update(self, model_output: Tensor, target: Tensor) -> None:
batch_mean = torch.mean((torch.argmax(model_output, 1) == target).float())
self.avgs.append(batch_mean.item())

def compute(self):
return sum(self.avgs) / len(self.avgs)

def reset(self):
# resting the object state attributes to default values
self.avgs = []

metric_fn = Accuracy()

Example #4

Using a torch.nn.Module class directly as metric (already satisfies the return inputs and outputs constraints) in a stateless fashion

import torch
metric_fn = torch.nn.cNLLLoss()

Example #5

A stateless class inheriting from torch.nn.Module

from typing import Dict
import torch
from torch import Tensor

class Accuracy(torch.nn.Module):

def forward(self, model_output: Tensor, target: Tensor)-> Dict[str, Tensor]:
# calculate metric values
num_correct_predictions = torch.sum(torch.argmax(model_output, 1) == target)
accuracy = torch.mean((torch.argmax(model_output, 1) == target).float())
# return a dict with strings as keys and scalar tensors as values
res = {"num_correct_predictions": num_correct_predictions, "accuracy": accuracy}
return res

metric_fn = Accuracy()

Example #6

A stateful class inheriting from torch.nn.Module

from typing import Union, Tuple, List, Dict
import torch

class Accuracy(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("correct", torch.Tensor([0]))
self.register_buffer("total", torch.Tensor([0]))

def compute(self) -> Union[torch.Tensor, Tuple, List, Dict]:
return self.correct / self.total

def reset(self) -> None:
self.correct.copy_(0)
self.total.copy_(0)

def update(self, preds, targets) -> Union[torch.Tensor, Tuple, List, Dict, None]:
# assuming preds is (N, C) and targets is (N,) and categories is (,C)
correct = (preds.argmax(-1) == targets).sum().item()
total = preds.shape[0]
self.correct += correct
self.total += total
return {"batch_acc": correct / total}
metric_fn = Accuracy()