Skip to main content
Version: 25.4.0

Compile API

clika_compile / torch.compile(..., backend="clika", ...)

Compiles a PyTorch model using the CLIKA optimization backend.

This function serves as a high-level interface to torch.compile, specifically configured to utilize the 'clika' backend. It applies advanced model optimization techniques such as quantization, pruning, and LoRA adaptation based on the provided settings.

Args:

model

Type: (ModelArgType):

The PyTorch model (e.g., torch.nn.Module) to be compiled and optimized.

calibration_inputs

Type: (Union[torch.Tensor, Iterable[object]]):

Representative data samples required by the CLIKA backend for calibration, particularly crucial for quantization.

tracing_inputs

Type: (Optional[Union[torch.Tensor, Iterable[object]]]):

Input data used for tracing the model graph if needed by the backend. Defaults to calibration_inputs. Note that the tracing_inputs are passed to torch.dynamo to help specialize the model to different Input Shapes. Torch Dynamo does not specialize Dimensions of 1 properly, so make sure to pass inputs with batch_size > 1 if you plan to run the Model in a batched-fasion later. Same remark for different input shapes, whether it is spatial dimensions for image inputs or different sequence lengths.

deployment_settings

Type: (BaseDeploymentSettings):

Mandatory. Configuration object specifying target deployment constraints and settings required by the CLIKA backend.

apply_on_data_fn

Type: (Optional[Callable[[Any], Any]]):

A function applied to each data sample (from calibration_inputs or tracing_inputs) before feeding it to the model during internal processing. Defaults to an identity function (lambda x: x).

forward_fn

Type: (Optional[Callable[[torch.nn.Module, Any], Any]]):

A custom function defining how data is fed to the model during internal calibration or tracing steps. Defaults to clika_feed_data_to_model.

quantization_settings

Type: (Optional[QuantizationSettings]):

Global settings for quantization applied by the CLIKA backend. Defaults to None (no global quantization).

layer_quantization_settings

Type: (Optional[Mapping[str, LayerQuantizationSettings]]):

Per-layer overrides for quantization settings, allowing fine-grained control. Keys are layer names/paths, values are layer-specific settings. Defaults to None.

pruning_settings

Type: (Optional[PruningSettings]):

Global settings for model pruning applied by the CLIKA backend. Defaults to None (no global pruning).

layer_pruning_settings

Type: (Optional[Mapping[str, LayerPruningSettings]]):

Per-layer overrides for pruning settings. Keys are layer names/paths, values are layer-specific settings. Defaults to None.

lora_settings

Type: (Optional[LoraSettings]):

Settings for applying LoRA (Low-Rank Adaptation) to the model. Defaults to None.

logs_dir

Type: (Optional[Union[Path, str]]):

Directory path for saving optimization logs generated by the CLIKA backend. If None, logs might not be saved or saved to a default location. Defaults to None.

discard_input_model

Type: (bool, optional):

If True, the original input model object might be modified or discarded after compilation to potentially save memory. Set to False to ensure the original model remains unchanged. Defaults to True.

is_model_pretrained

Type: (bool, optional):

Flag indicating if the input model uses pre-trained weights. This might influence optimization strategies within the backend. Defaults to True.

is_dry_run

Type: (bool, optional):

If True, performs a dry run of the compilation and optimization process (e.g., for validation or analysis) without generating a fully optimized artifact. Defaults to False.

skip_eval_set_mode

Type: (bool, optional):

If True, prevents the function from automatically calling model.eval() before processing. Useful if the model's mode is managed externally. Defaults to False.

fullgraph

Type: (bool, optional):

Argument passed to torch.compile. If True, attempts to compile the entire model into a single graph. Defaults to True.

dynamic

Type: (bool, optional):

Argument passed to torch.compile. Enables support for dynamic shapes in the model graph. Defaults to True.

disable

Type: (bool, optional):

Argument passed to torch.compile. If True, disables the torch.compile call entirely, returning the original model without attempting compilation. Defaults to False.

**extra_options (dict):

Additional keyword arguments that are passed directly as backend-specific options to the 'clika' backend via torch.compile.

Returns:

Union[torch.nn.Module, ClikaModule] The compiled and potentially optimized model. The specific type might be a standard torch.nn.Module or a specialized ClikaModule depending on the optimizations applied by the backend.

Raises: ValueError: If deployment_settings is not provided (is None).