Compile API
clika_compile
Compiles a PyTorch model using the CLIKA optimization backend.
For more information about torch.dynamo, see:
This function serves as a high-level interface to torch.compile, specifically configured to utilize the 'clika' backend. It applies advanced model optimization techniques such as quantization, pruning, and LoRA adaptation based on the provided settings.
def clika_compile(
module_or_func: ModelArgType,
*,
# clika related:
tracing_inputs: Optional[Iterable[Any]] = None,
calibration_inputs: Optional[Iterable[Any]] = None,
deployment_settings: Optional[BaseDeploymentSettings] = None,
data_transform_fn: Optional[AnyCallable] = lambda x: x,
forward_fn: Optional[AnyCallable] = None,
quantization_settings: Optional[QuantizationSettings] = None,
layer_quantization_settings: Optional[Mapping[str, LayerQuantizationSettings]] = None,
logs_dir: Optional[Union[Path, str]] = None,
clika_compile_options: Optional[ClikaCompileOptions] = None,
discard_input_model: bool = True,
is_model_pretrained: bool = True,
is_dry_run: bool = False,
skip_eval_set_mode: bool = False,
graph_name: Optional[str] = None,
# dynamo related:
fullgraph: bool = True,
dynamic: bool = True,
disable: bool = False,
**extra_options: dict,
) -> Union[torch.nn.Module, ClikaModule]:
Arguments
module_or_func
Type: ModelArgType
The PyTorch model (e.g., torch.nn.Module) or callable (e.g., my_model.infer) to be compiled and optimized.
tracing_inputs
Type: Optional[Iterable[Any]] (Default: None)
A list of sample inputs used to trace the model's computational graph. This is essential for operations like ONNX export or torch.compile (Dynamo).
- Content vs. Structure: The actual data values in the inputs does not matter (e.g.,
torch.randis fine). However, the inputs must be valid in their type, shape, and structure (e.g., adictwith correct keys) so that the model'sforwardmethod can execute without error. - Input Format: Each element in the iterable represents a single, complete input for one call to the model's
forwardmethod. (e.g.,[tensor1, tensor2]or[(ids1, mask1), (ids2, mask2)]or[(ids1, mask1, {'attn_mask': tensor}), (ids2, mask2, {'attn_mask': tensor})]).
Handling Dynamic Shapes
To export a model that supports dynamic input shapes (e.g., variable batch size, sequence length, or image dimensions), you must provide multiple tracing_inputs. The tracer analyzes the differences between these inputs to identify and enable dynamic axes.
For more information about torch.dynamo, see:
Rule of Thumb: You must provide at least two examples that differ along the dimension(s) you wish to make dynamic. The tracer compares all provided inputs to determine the full set of dynamic axes.
Examples:
- Dynamic Batch Size: Provide inputs with different batch sizes.
tracing_inputsshapes:[(1, 3, 224, 224), (2, 3, 224, 224)]
- Dynamic Spatial Dimensions: Provide inputs with different heights/widths.
tracing_inputsshapes:[(1, 3, 224, 224), (1, 3, 320, 640)]
- Dynamic Batch & Spatial: Provide inputs covering the range of variability.
tracing_inputsshapes:[(1, 3, 224, 224), (2, 3, 320, 640)]
- LLM Sequence Length: Provide inputs with different sequence lengths to handle both "prefill" (long) and "decode" (short) steps.
tracing_inputsshapes (conceptual):[{"input_ids": tensor(1, 1)}, {"input_ids": tensor(1, 256)}]
How it works (conceptual): The tracer finds differences between all inputs to mark axes as dynamic.
- Compares
(1, 3, 224, 224)and(2, 3, 224, 224)-> Marks dim 0 as dynamic. - Compares
(1, 3, 224, 224)and(1, 3, 320, 224)-> Marks dim 2 as dynamic. - ...and so on.
deployment_settings
Type: Optional[BaseDeploymentSettings] (Default: None)
Configuration object specifying target deployment constraints and settings required by the CLIKA backend.
- Mandatory: This is mandatory if
quantization_settingsis provided, as it's required by the algorithm initialization step. - If
quantization_settingsisNone, this can be omitted, butclika_model.clika_initialize_algorithms()must be called manually later.
calibration_inputs
Type: Optional[Iterable[Any]] (Default: None)
A list of sample inputs used for model calibration.
-
Mandatory: This is mandatory if
quantization_settingsis provided, as it's required by the algorithm initialization step (e.g., for PTQ). -
Requirement: These inputs should be representative of the actual data your model will encounter in production (e.g., real images, tokenized text samples).
-
Behavior: If
calibration_inputsis provided buttracing_inputsisNone, this list will be used for both calibration and tracing. -
Input Format: Each element in the iterable represents a single, complete input for one call to the model's
forwardmethod. The structure of the elements must match whatforwardexpects.- For
def forward(self, x: torch.Tensor):inputs_list = [tensor1, tensor2, ...] - For
def forward(self, input_ids, attention_mask):inputs_list = [(ids1, mask1), (ids2, mask2), ...] - For
def forward(self, batch: Dict[str, torch.Tensor]):inputs_list = [{"input_ids": ..., "attention_mask": ...}, ...]
- For
data_transform_fn
Type: Optional[AnyCallable] (Default: lambda x: x)
A function applied to each data sample (from calibration_inputs or tracing_inputs) before it is passed to forward_fn. This is for preprocessing, e.g., lambda x: x.to('cuda').
forward_fn
Type: Optional[AnyCallable] (Default: None)
A custom function defining how data is fed to the model during internal calibration or tracing steps. This is useful if your model's forward method has a complex signature or requires additional static arguments.
CLIKA Backend takes care of this automatically, however in some very exotic use-cases, this gives a more fine-control over the mechanism.
- NOTE: The data will already be on the Model's device, so avoid moving the data to/from devices. The function should focus on the arguments, keyword-arguments necessary.
The forward_fn takes the model and one input element (inp) from the given calibration_inputs or tracing_inputs after being packed into args (tuple), kwargs (dict).
- Default:
None, taken care of automatically by the CLIKA backend. - Example:
If
forwardisdef forward(self, x, use_cache=True, offset=6):Your inputs list could be:inputs_list = [{"x": tensor1}, {"x": tensor2}]And yourforward_fnmight be:forward_fn = lambda model, args, kwargs: model(*args, **kwargs, use_cache=True, offset=6)
quantization_settings
Type: Optional[QuantizationSettings] (Default: None)
Global settings for quantization applied by the CLIKA backend.
layer_quantization_settings
Type: Optional[Mapping[str, LayerQuantizationSettings]] (Default: None)
Per-layer overrides for quantization settings, allowing fine-grained control. Keys are layer names/paths, values are layer-specific settings.
logs_dir
Type: Optional[Union[Path, str]] (Default: None)
Directory path for saving optimization logs generated by the CLIKA backend. If None, logs might not be saved or saved to a default location.
clika_compile_options
Type: Optional[ClikaCompileOptions] (Default: None)
A dataclass that controls certain aspects of the compression, model initialization, and the choice of algorithms or methods that are applied.
discard_input_model
Type: bool (Default: True)
If True, the original input model object might be modified or discarded after compilation to potentially save memory. Set to False to ensure the original model remains unchanged.
is_model_pretrained
Type: bool (Default: True)
Flag indicating if the input model uses pre-trained weights. This might influence optimization strategies within the backend.
is_dry_run
Type: bool (Default: False)
If True, performs a dry run of the compilation and optimization process (e.g., for validation or analysis) without generating a fully optimized artifact.
skip_eval_set_mode
Type: bool (Default: False)
If True, prevents the function from automatically calling model.eval() before processing. Useful if the model's mode is managed externally.
fullgraph
Type: bool (Default: True)
Argument passed to torch.compile. If True, attempts to compile the entire model into a single graph.
dynamic
Type: bool (Default: True)
Argument passed to torch.compile. Enables support for dynamic shapes in the model graph.
disable
Type: bool (Default: False)
Argument passed to torch.compile. If True, disables the torch.compile call entirely, returning the original model without attempting compilation.
**extra_options
Type: dict
Additional keyword arguments that are passed directly as backend-specific options to the 'clika' backend.
Returns
Type: Union[torch.nn.Module, ClikaModule]
The compiled and potentially optimized model. The specific type might be a standard torch.nn.Module or a specialized ClikaModule depending on the optimizations applied by the backend.
Raises
- ValueError: If
quantization_settingsis set butdeployment_settingsorcalibration_inputsare not provided.