Model requirements and limitations
To be able to use a torch.nn.Module as input for ACE, the model must satisfy a traceability requirement by either torch.dynamo
or torch.fx.symbolic_trace
.
The CLIKA SDK will attempt to trace the entire model using torch.dynamo
; in the event that this trace fails, the SDK will fall back to using torch.fx.symbolic_trace
.
Traceability
The input model must be traceable, meaning it must pass either of the following:
torch.fx.symbolic_trace(model)
. Or, for Hugging Face models:transformers.utils.fx.symbolic_trace(model)
The conditions above typically fail if the model contains data-dependent control flow operations.
In general, we believe torch.dynamo
to be the future of tracing; therefore we recommend it over using torch.fx.symbolic_trace
or
torch.jit.script
(to be deprecated: https://github.com/pytorch/pytorch/issues/103841)
In future releases, the CLIKA SDK will support torch.compile(model, backend="inductor", **fullgraph=False**, dynamic=True)
.
This implies that even if there are control flow Python operations in the model,
they will be supported without modification.
Data-dependent control flow operations
Control flow operators (if
, else
, while
, ...) can be used but not when they are data-dependent.
Data-dependent control flow refers to situations where the execution path of a program depends
on the values of tensors. Representing and capturing the divergence of computational paths is a hard problem.
In general, there is no problem capturing control flow operators when they help guide the tracing of the model.
Example - data-dependent if
statement
In this example, we'll refactor a simple non-traceable model into a traceable model:
import torch
from torch import nn
class NonTraceableModel(nn.Module):
def __init__(self):
super(NonTraceableModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use "if" statement which will NOT be traceable
if x.sum() > 0:
return self.linear(x)
else:
return torch.zeros(10)
# Both options will crash and therefore cannot be used by ACE
non_traceable_model = NonTraceableModel()
###### Option 1:
torch.fx.symbolic_trace(non_traceable_model)
# Throws:
# torch.fx.proxy.TraceError:
# symbolically traced variables cannot be used as inputs to control flow
###### Option 2:
model = torch.compile(non_traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))
# Throws:
# torch._dynamo.exc.UserError:
# Dynamic control flow is not supported at the moment.
# Please use functorch.experimental.control_flow.cond to explicitly capture the control flow.
# For more information about this error, see:
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
Now we will modify it to be Tracable:
class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use torch operations instead of "if" statement
condition = x.sum() > 0
linear_output = self.linear(x)
linear_output = linear_output * condition
return linear_output
# both options will run successfully and therefore can be used by ACE
###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)
###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))
Example - non-data-dependent control flow operator
Note that model control flow can be handled without issue provided that the control flow is not data-dependent. Here the model contains conditional statements, but they are not dependent on any value of a tensor.
import torch
from torch import nn
class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.my_cond = True
self.num_loops = 5
def forward(self, x):
if self.my_cond:
for i in range(self.num_loops):
x = x + 3
return x
else:
return x
# both options will run successfully and therefore can be used by ACE
###### Option 1:
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)
###### Option 2:
model = torch.compile(traceable_model, fullgraph=True, dynamic=True)
model(torch.rand(1, 10))