Skip to main content
Version: 25.12.0

Compile API

clika_compile

Compiles a PyTorch model using the CLIKA optimization backend.

For more information about torch.dynamo, see:

This function serves as a high-level interface to torch.compile, specifically configured to utilize the 'clika' backend. It applies advanced model optimization techniques such as quantization, pruning, and LoRA adaptation based on the provided settings.

def clika_compile(
module_or_func: ModelArgType,
*,
# clika related:
tracing_inputs: Optional[Iterable[Any]] = None,
calibration_inputs: Optional[Iterable[Any]] = None,
deployment_settings: Optional[BaseDeploymentSettings] = None,
data_transform_fn: Optional[AnyCallable] = lambda x: x,
forward_fn: Optional[AnyCallable] = None,
quantization_settings: Optional[QuantizationSettings] = None,
layer_quantization_settings: Optional[Mapping[str, LayerQuantizationSettings]] = None,
logs_dir: Optional[Union[Path, str]] = None,
clika_compile_options: Optional[ClikaCompileOptions] = None,
discard_input_model: bool = True,
is_model_pretrained: bool = True,
is_dry_run: bool = False,
skip_eval_set_mode: bool = False,
graph_name: Optional[str] = None,
# dynamo related:
fullgraph: bool = True,
dynamic: bool = True,
disable: bool = False,
**extra_options: dict,
) -> Union[torch.nn.Module, ClikaModule]:

Arguments

module_or_func

Type: ModelArgType The PyTorch model (e.g., torch.nn.Module) or callable (e.g., my_model.infer) to be compiled and optimized.

tracing_inputs

Type: Optional[Iterable[Any]] (Default: None)

A list of sample inputs used to trace the model's computational graph. This is essential for operations like ONNX export or torch.compile (Dynamo).

  • Content vs. Structure: The actual data values in the inputs does not matter (e.g., torch.rand is fine). However, the inputs must be valid in their type, shape, and structure (e.g., a dict with correct keys) so that the model's forward method can execute without error.
  • Input Format: Each element in the iterable represents a single, complete input for one call to the model's forward method. (e.g., [tensor1, tensor2] or [(ids1, mask1), (ids2, mask2)] or [(ids1, mask1, {'attn_mask': tensor}), (ids2, mask2, {'attn_mask': tensor})]).

Handling Dynamic Shapes

To export a model that supports dynamic input shapes (e.g., variable batch size, sequence length, or image dimensions), you must provide multiple tracing_inputs. The tracer analyzes the differences between these inputs to identify and enable dynamic axes.

For more information about torch.dynamo, see:

Rule of Thumb: You must provide at least two examples that differ along the dimension(s) you wish to make dynamic. The tracer compares all provided inputs to determine the full set of dynamic axes.

Examples:

  • Dynamic Batch Size: Provide inputs with different batch sizes.
    • tracing_inputs shapes: [(1, 3, 224, 224), (2, 3, 224, 224)]
  • Dynamic Spatial Dimensions: Provide inputs with different heights/widths.
    • tracing_inputs shapes: [(1, 3, 224, 224), (1, 3, 320, 640)]
  • Dynamic Batch & Spatial: Provide inputs covering the range of variability.
    • tracing_inputs shapes: [(1, 3, 224, 224), (2, 3, 320, 640)]
  • LLM Sequence Length: Provide inputs with different sequence lengths to handle both "prefill" (long) and "decode" (short) steps.
    • tracing_inputs shapes (conceptual): [{"input_ids": tensor(1, 1)}, {"input_ids": tensor(1, 256)}]

How it works (conceptual): The tracer finds differences between all inputs to mark axes as dynamic.

  • Compares (1, 3, 224, 224) and (2, 3, 224, 224) -> Marks dim 0 as dynamic.
  • Compares (1, 3, 224, 224) and (1, 3, 320, 224) -> Marks dim 2 as dynamic.
  • ...and so on.

deployment_settings

Type: Optional[BaseDeploymentSettings] (Default: None)

Configuration object specifying target deployment constraints and settings required by the CLIKA backend.

  • Mandatory: This is mandatory if quantization_settings is provided, as it's required by the algorithm initialization step.
  • If quantization_settings is None, this can be omitted, but clika_model.clika_initialize_algorithms() must be called manually later.

calibration_inputs

Type: Optional[Iterable[Any]] (Default: None)

A list of sample inputs used for model calibration.

  • Mandatory: This is mandatory if quantization_settings is provided, as it's required by the algorithm initialization step (e.g., for PTQ).

  • Requirement: These inputs should be representative of the actual data your model will encounter in production (e.g., real images, tokenized text samples).

  • Behavior: If calibration_inputs is provided but tracing_inputs is None, this list will be used for both calibration and tracing.

  • Input Format: Each element in the iterable represents a single, complete input for one call to the model's forward method. The structure of the elements must match what forward expects.

    • For def forward(self, x: torch.Tensor): inputs_list = [tensor1, tensor2, ...]
    • For def forward(self, input_ids, attention_mask): inputs_list = [(ids1, mask1), (ids2, mask2), ...]
    • For def forward(self, batch: Dict[str, torch.Tensor]): inputs_list = [{"input_ids": ..., "attention_mask": ...}, ...]

data_transform_fn

Type: Optional[AnyCallable] (Default: lambda x: x)

A function applied to each data sample (from calibration_inputs or tracing_inputs) before it is passed to forward_fn. This is for preprocessing, e.g., lambda x: x.to('cuda').

forward_fn

Type: Optional[AnyCallable] (Default: None)

A custom function defining how data is fed to the model during internal calibration or tracing steps. This is useful if your model's forward method has a complex signature or requires additional static arguments.

CLIKA Backend takes care of this automatically, however in some very exotic use-cases, this gives a more fine-control over the mechanism.

  • NOTE: The data will already be on the Model's device, so avoid moving the data to/from devices. The function should focus on the arguments, keyword-arguments necessary.

The forward_fn takes the model and one input element (inp) from the given calibration_inputs or tracing_inputs after being packed into args (tuple), kwargs (dict).

  • Default: None, taken care of automatically by the CLIKA backend.
  • Example: If forward is def forward(self, x, use_cache=True, offset=6): Your inputs list could be: inputs_list = [{"x": tensor1}, {"x": tensor2}] And your forward_fn might be: forward_fn = lambda model, args, kwargs: model(*args, **kwargs, use_cache=True, offset=6)

quantization_settings

Type: Optional[QuantizationSettings] (Default: None) Global settings for quantization applied by the CLIKA backend.

layer_quantization_settings

Type: Optional[Mapping[str, LayerQuantizationSettings]] (Default: None) Per-layer overrides for quantization settings, allowing fine-grained control. Keys are layer names/paths, values are layer-specific settings.

logs_dir

Type: Optional[Union[Path, str]] (Default: None) Directory path for saving optimization logs generated by the CLIKA backend. If None, logs might not be saved or saved to a default location.

clika_compile_options

Type: Optional[ClikaCompileOptions] (Default: None) A dataclass that controls certain aspects of the compression, model initialization, and the choice of algorithms or methods that are applied.

discard_input_model

Type: bool (Default: True) If True, the original input model object might be modified or discarded after compilation to potentially save memory. Set to False to ensure the original model remains unchanged.

is_model_pretrained

Type: bool (Default: True) Flag indicating if the input model uses pre-trained weights. This might influence optimization strategies within the backend.

is_dry_run

Type: bool (Default: False) If True, performs a dry run of the compilation and optimization process (e.g., for validation or analysis) without generating a fully optimized artifact.

skip_eval_set_mode

Type: bool (Default: False) If True, prevents the function from automatically calling model.eval() before processing. Useful if the model's mode is managed externally.

fullgraph

Type: bool (Default: True) Argument passed to torch.compile. If True, attempts to compile the entire model into a single graph.

dynamic

Type: bool (Default: True) Argument passed to torch.compile. Enables support for dynamic shapes in the model graph.

disable

Type: bool (Default: False) Argument passed to torch.compile. If True, disables the torch.compile call entirely, returning the original model without attempting compilation.

**extra_options

Type: dict Additional keyword arguments that are passed directly as backend-specific options to the 'clika' backend.


Returns

Type: Union[torch.nn.Module, ClikaModule]

The compiled and potentially optimized model. The specific type might be a standard torch.nn.Module or a specialized ClikaModule depending on the optimizations applied by the backend.

Raises

  • ValueError: If quantization_settings is set but deployment_settings or calibration_inputs are not provided.