Compile API
clika_compile
/ torch.compile(..., backend="clika", ...)
Compiles a PyTorch model using the CLIKA optimization backend.
This function serves as a high-level interface to
torch.compile
, specifically configured to utilize the 'clika' backend. It applies advanced model optimization techniques such as quantization, pruning, and LoRA adaptation based on the provided settings.Args:
model
Type: (ModelArgType):
The PyTorch model (e.g.,
torch.nn.Module
) to be compiled and optimized.
calibration_inputs
Type: (Union[torch.Tensor, Iterable[object]]):
Representative data samples required by the CLIKA backend for calibration, particularly crucial for quantization.
tracing_inputs
Type: (Optional[Union[torch.Tensor, Iterable[object]]]):
Input data used for tracing the model graph if needed by the backend. Defaults to
calibration_inputs
. Note that thetracing_inputs
are passed totorch.dynamo
to help specialize the model to different Input Shapes. Torch Dynamo does not specialize Dimensions of 1 properly, so make sure to pass inputs with batch_size > 1 if you plan to run the Model in a batched-fasion later. Same remark for different input shapes, whether it is spatial dimensions for image inputs or different sequence lengths.
deployment_settings
Type: (BaseDeploymentSettings):
Mandatory. Configuration object specifying target deployment constraints and settings required by the CLIKA backend.
apply_on_data_fn
Type: (Optional[Callable[[Any], Any]]):
A function applied to each data sample (from calibration_inputs or tracing_inputs) before feeding it to the model during internal processing. Defaults to an identity function (
lambda x: x
).
forward_fn
Type: (Optional[Callable[[torch.nn.Module, Any], Any]]):
A custom function defining how data is fed to the model during internal calibration or tracing steps. Defaults to
clika_feed_data_to_model
.
quantization_settings
Type: (Optional[QuantizationSettings]):
Global settings for quantization applied by the CLIKA backend. Defaults to None (no global quantization).
layer_quantization_settings
Type: (Optional[Mapping[str, LayerQuantizationSettings]]):
Per-layer overrides for quantization settings, allowing fine-grained control. Keys are layer names/paths, values are layer-specific settings. Defaults to None.
pruning_settings
Type: (Optional[PruningSettings]):
Global settings for model pruning applied by the CLIKA backend. Defaults to None (no global pruning).
layer_pruning_settings
Type: (Optional[Mapping[str, LayerPruningSettings]]):
Per-layer overrides for pruning settings. Keys are layer names/paths, values are layer-specific settings. Defaults to None.
lora_settings
Type: (Optional[LoraSettings]):
Settings for applying LoRA (Low-Rank Adaptation) to the model. Defaults to None.
logs_dir
Type: (Optional[Union[Path, str]]):
Directory path for saving optimization logs generated by the CLIKA backend. If None, logs might not be saved or saved to a default location. Defaults to None.
discard_input_model
Type: (bool, optional):
If True, the original input model object might be modified or discarded after compilation to potentially save memory. Set to False to ensure the original model remains unchanged. Defaults to True.
is_model_pretrained
Type: (bool, optional):
Flag indicating if the input model uses pre-trained weights. This might influence optimization strategies within the backend. Defaults to True.
is_dry_run
Type: (bool, optional):
If True, performs a dry run of the compilation and optimization process (e.g., for validation or analysis) without generating a fully optimized artifact. Defaults to False.
skip_eval_set_mode
Type: (bool, optional):
If True, prevents the function from automatically calling
model.eval()
before processing. Useful if the model's mode is managed externally. Defaults to False.
fullgraph
Type: (bool, optional):
Argument passed to
torch.compile
. If True, attempts to compile the entire model into a single graph. Defaults to True.
dynamic
Type: (bool, optional):
Argument passed to
torch.compile
. Enables support for dynamic shapes in the model graph. Defaults to True.
disable
Type: (bool, optional):
Argument passed to
torch.compile
. If True, disables thetorch.compile
call entirely, returning the original model without attempting compilation. Defaults to False.
**extra_options
(dict):
Additional keyword arguments that are passed directly as backend-specific options to the 'clika' backend via
torch.compile
.
Returns:
Union[torch.nn.Module, ClikaModule] The compiled and potentially optimized model. The specific type might be a standard
torch.nn.Module
or a specializedClikaModule
depending on the optimizations applied by the backend.Raises: ValueError: If
deployment_settings
is not provided (is None).