**CCO** input requirements

## Loss function

When writing a loss function or a callable class for **CCO**, it must satisfy the following rules:

- Takes two input arguments (order matters):
- Returns one of the following:

### Examples

`from torch import Tensor`

from typing import Union, Tuple, List, Dict

def loss_fn(model_output: Tensor, target: Tensor) -> Union[Tensor, Tuple[Tensor, ...],List[Tensor, ...], Dict[str, Tensor]]:

# calculate loss...

loss: Tensor = some_scalar_tensor

# returns a loss as a tensor

return loss

# or a tuple of scalar tensors

return loss_1, loss_2, loss_3

# or a dictionary with scalar tensors as dictionary values

return {"cls_loss": loss_1, "box_loss": loss_2, "obj_loss": loss_3}

## Dataset / Dataloader

### Dataset / Dataloader output format

When writing a dataset class for **CCO**, it must return a tuple of length two (order matters)

- First item (the input data to the model) can be one of the following:
- single
`Tensor`

`tuple`

containing`Tensor`

objects`list`

containing`Tensor`

objects`dict`

with matching forward function argument names as`string`

keys and`Tensor`

values, for example, when a forward function has the following signature:`def forward(self, x: Tensor, y: Tensor)`

the first item would be the dictionary`{"x": tensor_1, "y": tensor_2}`

- single
- Second item (the labels) can be any type

### Dataloader input for multi-GPU distributed compression

When using multi-GPU, if you require finer control over the `Dataset`

in the context of distributed compression, you may
set the `torch.utils.data.DataLoader`

initializer function so that it takes the following arguments
by either keyword or as positional arguments:

(See example #3)

### Example #1

In this example we return a tuple containing two `Tensor`

objects:

`from torch import Tensor`

from typing import Tuple

from torch.utils.data import Dataset

class CustomDataset(Dataset):

...

def __getitem__(self, x) -> Tuple[Tensor, Tensor]:

...

return (input_tensor, label_tensor)

### Example #2

If you require auxiliary information to be included, you may set the second item in the returned tuple to be a dictionary or a tuple (or any other type) containing any auxiliary information necessary as shown below:

`from typing import Dict, Tuple`

from torch import Tensor

from torch.utils.data import Dataset

class CustomDataset(Dataset):

...

def __getitem__(self, x) -> Tuple[Tensor, Tuple[Tensor,Dict]]:

...

return input_tensor, (labels, auxilary_info)

...

### Example #3

In this example, the dataloader initializer function can take Distributed Compression
related arguments for finer control during **CCO**:

`from typing import Optional`

from torch.utils.data import DataLoader ,Dataset

from clika_compression.settings import Settings, DistributedTrainingSettings

def get_train_loader(world_size: Optional[int], global_rank: Optional[int],

local_rank: Optional[int], train_dataset:Dataset):

# use the args passed

return DataLoader(train_dataset,batch_size=64, shuffle=True)

settings = Settings()

# set distributed learning settings

settings.distributed_training_settings = DistributedTrainingSettings(multi_gpu=True,use_sharding=True)

## Metric function / class

When writing a metric function or class for **CCO**,
it needs to be one of the following objects

(all metrics will be shown in the CCO output log):

- A function that takes two inputs -
`(model_output:`

and returns one of the Metric Return Types (see example #1)`Tensor`

target:`Tensor`

) - A class that implements (or inherits from) a
`TorchMetrics`

-like API with the following methods (see examples #2 and #3)- The
`update`

method takes two inputs,`(model_output:`

and updates the state of the object. The`Tensor`

, target:`Tensor`

)`update`

method may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary). - The
`compute`

method takes no inputs, calculates the final metric, and returns it as one of the Metric Return Types. The return values will be shown on the output log epoch summary. - The
`reset`

method should be implemented when not inheriting from a`TorchMetrics`

object. The`reset`

method takes no inputs, resets the metric object's state attributes to their default values, and is called at the end of each epoch.

- The
- A
**stateless**class inheriting from the`torch.nn.Module`

that implements a`forward`

method which takes two inputs,`(model_output:`

and returns one of the Metric Return Types (see examples #4 and #5)`Tensor`

, target:`Tensor`

) - A
**stateful**class inheriting from the`torch.nn.Module`

that implements the`update`

,`compute`

and`reset`

methods in a`TorchMetrics`

-like API fashion (see example #6)- The
`update`

method takes two inputs,`(model_output:`

and updates the state of the object. It may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary).`Tensor`

, target:`Tensor`

) - The
`compute`

method takes no inputs, calculate the final metric, and returns it as one of the Metric Return Types (return values will be shown on the output log epoch summary). - The
`reset`

method resets the metric object's state attributes to their default values. It is called at the end of each epoch and takes no inputs.

- The

### Metric return types

These are types that are acceptable for the output of a metric method.

- a single
`int`

/`float`

or a scalar`Tensor`

`tuple`

or`list`

containing`int`

/`float`

or a scalar`Tensor`

`dict`

with`str`

as keys (metric name) and`int`

/`float`

or a scalar`Tensor`

as values

### Example #1

A lambda function example that takes two arguments and returns a tensor scalar

`import torch`

metric_fn = lambda model_output, target: torch.mean(

torch.argmax(model_output, 1) == target)

### Example #2

Inheriting from `torchmetrics.classification.MulticlassAccuracy`

`import torchmetrics`

from torch import Tensor

class Accuracy(torchmetrics.classification.MulticlassAccuracy):

...

def update(self, model_output: Tensor, target: Tensor) -> None:

super().update(model_output, target)

def compute(self):

result = super().compute()

return result * 100

metric_fn = Accuracy(num_classes=num_classes)

### Example #3

A custom class implementing TorchMetrics-like API

`class Accuracy(object):`

def __init__(self):

self.avgs = []

def update(self, model_output: Tensor, target: Tensor) -> None:

batch_mean = torch.mean((torch.argmax(model_output, 1) == target).float())

self.avgs.append(batch_mean.item())

def compute(self):

return sum(self.avgs) / len(self.avgs)

def reset(self):

# resting the object state attributes to default values

self.avgs = []

metric_fn = Accuracy()

### Example #4

Using a torch.nn.Module class directly as metric (already satisfies the return inputs and outputs constraints) in a stateless fashion

`import torch`

metric_fn = torch.nn.cNLLLoss()

### Example #5

A stateless class inheriting from torch.nn.Module

`from typing import Dict`

import torch

from torch import Tensor

class Accuracy(torch.nn.Module):

def forward(self, model_output: Tensor, target: Tensor)-> Dict[str, Tensor]:

# calculate metric values

num_correct_predictions = torch.sum(torch.argmax(model_output, 1) == target)

accuracy = torch.mean((torch.argmax(model_output, 1) == target).float())

# return a dict with strings as keys and scalar tensors as values

res = {"num_correct_predictions": num_correct_predictions, "accuracy": accuracy}

return res

metric_fn = Accuracy()

### Example #6

A stateful class inheriting from `torch.nn.Module`

`from typing import Union, Tuple, List, Dict`

import torch

class Accuracy(torch.nn.Module):

def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

self.register_buffer("correct", torch.Tensor([0]))

self.register_buffer("total", torch.Tensor([0]))

def compute(self) -> Union[torch.Tensor, Tuple, List, Dict]:

return self.correct / self.total

def reset(self) -> None:

self.correct.copy_(0)

self.total.copy_(0)

def update(self, preds, targets) -> Union[torch.Tensor, Tuple, List, Dict, None]:

# assuming preds is (N, C) and targets is (N,) and categories is (,C)

correct = (preds.argmax(-1) == targets).sum().item()

total = preds.shape[0]

self.correct += correct

self.total += total

return {"batch_acc": correct / total}

metric_fn = Accuracy()