diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 29c7120feb7..7f6abe980e1 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -28,6 +28,21 @@ from executorch.exir.dim_order_utils import get_memory_format +_IOS18_QUANT_HINT = ( + "ExecuTorch hint: pass `compile_specs=CoreMLBackend.generate_compile_specs(" + "minimum_deployment_target=ct.target.iOS18)` (or higher) to " + "`CoreMLPartitioner` when lowering models that use `quantize_(...)`." +) + + +def _raise_with_executorch_hint(err: Exception) -> "BaseException": + """Re-raise a coremltools quantization error with ExecuTorch-specific guidance.""" + msg = str(err) + if "iOS18" in msg or "iOS 18" in msg: + raise ValueError(f"{msg}\n{_IOS18_QUANT_HINT}") from err + raise err + + # https://github.com/apple/coremltools/pull/2563 @register_torch_op(override=False) def split_copy(context, node): @@ -159,12 +174,15 @@ def dequantize_affine(context, node): f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization." ) - output = _utils._construct_constexpr_dequant_op( - int_data.astype(quantized_np_dtype), - zero_point, - scale, - name=node.name, - ) + try: + output = _utils._construct_constexpr_dequant_op( + int_data.astype(quantized_np_dtype), + zero_point, + scale, + name=node.name, + ) + except ValueError as e: + _raise_with_executorch_hint(e) context.add(output, node.name) @@ -211,9 +229,12 @@ def dequantize_codebook(context, node): f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." ) - output = _utils._construct_constexpr_lut_op( - codes.astype(np.int8), - codebook, - name=node.name, - ) + try: + output = _utils._construct_constexpr_lut_op( + codes.astype(np.int8), + codebook, + name=node.name, + ) + except ValueError as e: + _raise_with_executorch_hint(e) context.add(output, node.name) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index de54b684ee7..10c3f01a585 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -317,6 +317,33 @@ def forward(self, x): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) + def test_dequantize_affine_below_ios18_raises_with_hint(self): + """ + Regression test for https://github.com/pytorch/executorch/issues/13122. + + `quantize_(...)` with blockwise / int4 configurations requires iOS18. + coremltools raises a ValueError that does not mention how to fix the + deployment target on the ExecuTorch side; we wrap it to add the + partitioner-level guidance. + """ + model = torch.nn.Linear(64, 64) + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), + ) + ep = torch.export.export(model.eval(), (torch.randn(1, 64),), strict=True) + with self.assertRaises(ValueError) as cm: + executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[ + self._coreml_partitioner(minimum_deployment_target=ct.target.iOS17) + ], + ) + msg = str(cm.exception) + self.assertIn("iOS18", msg) + self.assertIn("CoreMLPartitioner", msg) + self.assertIn("minimum_deployment_target", msg) + if __name__ == "__main__": test_runner = TestTorchOps()