Model requirements and limitations
To be able to use a torch.nn.Module as input for clika-ace, the model must be traceable by either torch.dynamo or torch.fx.symbolic_trace.
The CLIKA SDK will attempt to trace the entire model using torch.dynamo; in the event that this trace fails, the SDK will fall back to using torch.fx.symbolic_trace.
Traceability
For more information about torch.dynamo, see:
The input model must be traceable, meaning it must pass either of the following:
-
torch.fx.symbolic_trace(model). Or, for Hugging Face models:transformers.utils.fx.symbolic_trace(model)
The conditions above typically fail if the model contains data-dependent control flow operations.
We recommend torch.dynamo over torch.fx.symbolic_trace or torch.jit.script (which is being deprecated), as we believe it is the future of PyTorch model tracing.
A future release of clika-ace will support partial compilation, even if full-graph tracing is not possible. This means that models with control flow operations will be supported without modification.
Data-dependent control flow operations
Control flow operators (if, else, while, etc.) cannot be data-dependent. Data-dependent control flow is when the execution path of a program depends on tensor values. Capturing this divergence of computational paths is a difficult problem for graph tracing.
However, control flow that is not data-dependent (i.e., static) can be captured without issue.
Example - data-dependent if statement
In this example, we'll refactor a simple non-traceable model into a traceable model:
import torch
from torch import nn
class NonTraceableModel(nn.Module):
def __init__(self):
super(NonTraceableModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use "if" statement which will NOT be traceable
if x.sum() > 0:
return self.linear(x)
else:
return torch.zeros(10)
# Both options will crash and therefore cannot be used by ACE
non_traceable_model = NonTraceableModel()
###### Option 1:
torch.fx.symbolic_trace(non_traceable_model)
# Throws:
# torch.fx.proxy.TraceError:
# symbolically traced variables cannot be used as inputs to control flow
###### Option 2:
model = torch.compile(non_traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))
# Throws:
# torch._dynamo.exc.UserError:
# Dynamic control flow is not supported at the moment.
# Please use functorch.experimental.control_flow.cond to explicitly capture the control flow.
# For more information about this error, see:
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
Now we will modify it to be Tracable:
class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use torch operations instead of "if" statement
condition = x.sum() > 0
linear_output = self.linear(x)
linear_output = linear_output * condition
return linear_output
# both options will run successfully and therefore can be used by ACE
###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)
###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))
Example - non-data-dependent control flow operator
Note that model control flow can be handled without issue provided that the control flow is not data-dependent. Here the model contains conditional statements, but they are not dependent on any value of a tensor.
import torch
from torch import nn
class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.my_cond = True
self.num_loops = 5
def forward(self, x):
if self.my_cond:
for i in range(self.num_loops):
x = x + 3
return x
else:
return x
# both options will run successfully and therefore can be used by ACE
###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)
###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))