Skip to main content
Version: 25.4.0

Model requirements and limitations

To be able to use a torch.nn.Module as input for ACE, the model must satisfy a traceability requirement by either torch.dynamo or torch.fx.symbolic_trace. The CLIKA SDK will attempt to trace the entire model using torch.dynamo; in the event that this trace fails, the SDK will fall back to using torch.fx.symbolic_trace.

Traceability

The input model must be traceable, meaning it must pass either of the following:

  1. compiled_model = torch.compile(model, backend="inductor", fullgraph=True, dynamic=True); compiled_model(dummy_input).

  2. torch.fx.symbolic_trace(model). Or, for Hugging Face models: transformers.utils.fx.symbolic_trace(model)

The conditions above typically fail if the model contains data-dependent control flow operations.

In general, we believe torch.dynamo to be the future of tracing; therefore we recommend it over using torch.fx.symbolic_trace or torch.jit.script (to be deprecated: https://github.com/pytorch/pytorch/issues/103841)

note

In future releases, the CLIKA SDK will support torch.compile(model, backend="inductor", **fullgraph=False**, dynamic=True). This implies that even if there are control flow Python operations in the model, they will be supported without modification.

Data-dependent control flow operations

Control flow operators (if, else, while, ...) can be used but not when they are data-dependent. Data-dependent control flow refers to situations where the execution path of a program depends on the values of tensors. Representing and capturing the divergence of computational paths is a hard problem.

In general, there is no problem capturing control flow operators when they help guide the tracing of the model.

Example - data-dependent if statement

In this example, we'll refactor a simple non-traceable model into a traceable model:

import torch
from torch import nn


class NonTraceableModel(nn.Module):
def __init__(self):
super(NonTraceableModel, self).__init__()
self.linear = nn.Linear(10, 10)

def forward(self, x):
# use "if" statement which will NOT be traceable
if x.sum() > 0:
return self.linear(x)
else:
return torch.zeros(10)


# Both options will crash and therefore cannot be used by ACE

non_traceable_model = NonTraceableModel()

###### Option 1:
torch.fx.symbolic_trace(non_traceable_model)
# Throws:
# torch.fx.proxy.TraceError:
# symbolically traced variables cannot be used as inputs to control flow

###### Option 2:
model = torch.compile(non_traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))
# Throws:
# torch._dynamo.exc.UserError:
# Dynamic control flow is not supported at the moment.
# Please use functorch.experimental.control_flow.cond to explicitly capture the control flow.
# For more information about this error, see:
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

Now we will modify it to be Tracable:

class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear = nn.Linear(10, 10)

def forward(self, x):
# use torch operations instead of "if" statement
condition = x.sum() > 0
linear_output = self.linear(x)
linear_output = linear_output * condition
return linear_output

# both options will run successfully and therefore can be used by ACE

###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)

###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))

Example - non-data-dependent control flow operator

Note that model control flow can be handled without issue provided that the control flow is not data-dependent. Here the model contains conditional statements, but they are not dependent on any value of a tensor.

import torch
from torch import nn

class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.my_cond = True
self.num_loops = 5

def forward(self, x):
if self.my_cond:
for i in range(self.num_loops):
x = x + 3
return x
else:
return x

# both options will run successfully and therefore can be used by ACE

###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)

###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))