ACE model requirements
To be able to use a torch.nn.Module as input for ACE, the model must satisfy a traceability requirement by either torch.dynamo
or torch.fx.symbolic_trace
.
The CLIKA SDK will attempt to trace the entire model using torch.dynamo
; in the event that this trace fails, the SDK will fall back to using torch.fx.symbolic_trace
.
Traceability
The input model must be traceable, meaning it must pass either of the following:
torch.fx.symbolic_trace(model)
. Or, for Hugging Face models:transformers.utils.fx.symbolic_trace(model
The conditions above typically fail if the model contains dynamic python operations like if
statements.
In future releases, the CLIKA SDK will support torch.compile(model, backend="inductor", **fullgraph=False**, dynamic=True)
.
This implies that even if there are control flow Python operations in the model,
they will be supported without modification.
Example - torch.fx.symbolic_trace
In this example we'll refactor a simple non-traceable model into a traceable model:
import torch
from torch import nn
class NonTraceableModel(nn.Module):
def __init__(self):
super(NonTraceableModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use "if" statement which will NOT be traceable
if x.sum() > 0:
return self.linear(x)
else:
return torch.zeros(10)
# will crash and therefore cannot be used by ACE
non_traceable_model = NonTraceableModel()
torch.fx.symbolic_trace(non_traceable_model)
class TraceableModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear = nn.Linear(10, 10)
def forward(self, x):
# use torch operations instead of "if" statement
condition = x.sum() > 0
linear_output = self.linear(x)
linear_output = linear_output * condition
return linear_output
# will run successfully and therefore can be used by ACE
traceable_model = TraceableModel()
torch.fx.symbolic_trace(traceable_model)