Skip to main content
Version: 0.1

CCO Input Requirements

Loss Function

When writing a loss function or a callable class for CCO, it must satisfy the following rules:

Examples

from torch import Tensor
from typing import Union, Tuple, List, Dict

def loss_fn(model_output: Tensor, target: Tensor) -> Union[Tensor, Tuple[Tensor, ...],List[Tensor, ...], Dict[str, Tensor]]:

# calculate loss...
loss: Tensor = some_scalar_tensor
# returns a loss as a tensor
return loss

# or a tuple of scalar tensors
return loss_1, loss_2, loss_3

# or a dictionary with scalar tensors as dictionary values
return {"cls_loss": loss_1, "box_loss": loss_2, "obj_loss": loss_3}

Dataset / Dataloader

When writing a dataset class for CCO it must return a tuple of length two (order does matter)

  • First item (the input data to the model) can be one of the following:
  • Second item (the labels) can be any type

Examples

In this example we return a tuple containing two Tensor objects

import torch
from torch import Tensor
from typing import Tuple

class CustomDataset(torch.utils.data.Dataset):
...

def __getitem__(self, x) -> Tuple[Tensor, Tensor]:
...
return (input_tensor, label_tensor)

If you require auxiliary information to be included, you may set the second item in the returned tuple to be a dictionary or a tuple (or any other type) containing any auxiliary information necessary as shown below:

from typing import Tuple, Dict, Any 
import torch
from torch import Tensor


class CustomDataset(torch.utils.data.Dataset):
...

def __getitem__(self, x) -> Tuple[Tensor, Tuple[Any]]:
...
return input_tensor, (label_tensor, img_size, img_info)

# ============== OR ============== #

def __getitem__(self, x) -> Tuple[Tensor, Dict[str, Any]]:
...
return input_tensor, {"label": label_tensor, "size": img_size, "info": img_info}

Metric Function / Class

When writing a metric function or class for CCO it needs to be one of the following objects:
(all metrics will be shown in the CCO output log)


See example #1 below


  • A class implementing (or inheriting) TorchMetrics-like API with the following methods:

    • update takes two inputs, (model_output: Tensor, target: Tensor) and updates the state of the object.
      It may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary)

    • compute takes no inputs, calculate the final metric, and returns it as one of the Metric Return Types (return values will be shown on the output log epoch summary)

    • reset - when not inheriting from a TorchMetrics object, this method should be implemented. The method resets the metric object's state attributes to their default values. It is called at the end of each epoch and takes no inputs.

See examples #2 and #3 below


See examples #4 and #5 below


  • A stateful class inheriting from torch.nn.Module that implements update, compute and reset methods in a TorchMetrics-like API fashion
    • update takes two inputs, (model_output: Tensor, target: Tensor) and updates the state of the object.
      It may return one of the Metric Return Types (return values will be shown on the output log training steps and epoch summary)
    • compute takes no inputs, calculate the final metric, and returns it as one of the Metric Return Types (return values will be shown on the output log epoch summary)
    • reset resets the metric object's state attributes to their default values. It is called at the end of each epoch and takes no inputs.

See example #6 below


Metric Return Types

These are types that are the standard for metric method output.

Examples

import torch
import torchmetrics
from torch import Tensor
from typing import Dict

# Example 1 - lambda function example that takes two arguments and returns a tensor scalar

metric_fn = lambda model_output, target: torch.mean((torch.argmax(model_output, 1) == target).float()) * 100

# ============================ #

# Example 2 - inheriting from `torchmetrics.classification.MulticlassAccuracy`

class Accuracy(torchmetrics.classification.MulticlassAccuracy):
...

def update(self, model_output: Tensor, target: Tensor) -> None:
super().update(model_output, target)

def compute(self):
result = super().compute()
return result * 100

metric_fn = Accuracy(num_classes=num_classes)


# ============================ #

# Example 3 - a custom class implementing TorchMetrics-like API

class Accuracy(object):

def __init__(self):
self.avgs = []

def update(self, model_output: Tensor, target: Tensor) -> None:
batch_mean = torch.mean((torch.argmax(model_output, 1) == target).float())
self.avgs.append(batch_mean.item())

def compute(self):
return sum(self.avgs) / len(self.avgs)

def reset(self):
# resting the object state attributes to default values
self.avgs = []

metric_fn = Accuracy()
# ============================ #

# Example 4 - using a torch.nn.Module class directly as metric (already satisfies the return inputs and outputs constraints) in a stateless fashion

metric_fn = torch.nn.cNLLLoss()

# ============================ #

# Example 5 - a stateless class inheriting from torch.nn.Module

class Accuracy(torch.nn.Module):

def forward(self, model_output: Tensor, target: Tensor)-> Dict[str, Tensor]:
# calculate metric values
num_correct_predictions = torch.sum(torch.argmax(model_output, 1) == target)
accuracy = torch.mean((torch.argmax(model_output, 1) == target).float())
# return a dict with strings as keys and scalar tensors as values
res = {"num_correct_predictions": num_correct_predictions, "accuracy": accuracy}
return res

metric_fn = Accuracy()

# ============================ #

# Example 6 - a stateful class inheriting from torch.nn.Module

from typing import Union, Tuple, List, Dict

class Accuracy(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("correct", torch.Tensor([0]))
self.register_buffer("total", torch.Tensor([0]))

def compute(self) -> Union[Tensor, Tuple, List, Dict]:
return self.correct / self.total

def reset(self) -> None:
self.correct.copy_(0)
self.total.copy_(0)

def update(self, preds, targets) -> Union[Tensor, Tuple, List, Dict, None]:
# assuming preds is (N, C) and targets is (N,) and categories is (,C)
correct = (preds.argmax(-1) == targets).sum().item()
total = preds.shape[0]
self.correct += correct
self.total += total
return {"batch_acc": correct / total}
metric_fn = Accuracy()