Skip to main content
Version: Latest

CCO Model Requirements:

To be able to use a torch model as input for CCO it must answer the following requirement:


The input model must be traceable, meaning it does not contain any dynamic python operations like "if" statements. To check if a model is indeed traceable you can use the function torch.fx.symbolic_trace().

For more information, see PyTorch documentation

If you are using Hugging Face Transformers library you can use transformers.utils.fx.symbolic_trace() to check the traceability of your model.

For more information, see Transformers documentation


In this example we will refactor a simple non-traceable model into a traceable model:

import torch
from torch import nn

class NonTraceableModel(nn.Module):
def __init__(self):
super(NonTraceableModel, self).__init__()
self.linear = nn.Linear(10, 10)

def forward(self, x):
# use "if" statement which will NOT be traceable
if x.sum() > 0:
return self.linear(x)
return torch.zeros(10)

# will crash and therefore cannot be used for CCO
non_traceable_model = NonTraceableModel()

class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear = nn.Linear(10, 10)

def forward(self, x):
# use torch operations instead of "if" statement
condition = x.sum() > 0
linear_output = self.linear(x)
linear_output = linear_output * condition
return linear_output

# will run successfully and therefore can be used for CCO
traceable_model = TraceableModel()