diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 447387be3c3..4b5d4a6de82 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -33,6 +33,26 @@ logger.setLevel(get_coreml_log_level(default_level=logging.INFO)) +_ARG_MIN_MAX_TARGETS = ( + torch.ops.aten.argmax.default, + torch.ops.aten.argmin.default, + exir_ops.edge.aten.argmax.default, + exir_ops.edge.aten.argmin.default, +) + + +def _is_arg_min_max_over_flattened_input(node: torch.fx.Node) -> bool: + """``argmin``/``argmax`` with ``dim=None`` reduces over the flattened input. + + CoreML doesn't support that reduction shape and intermittently crashes + the process at runtime — see pytorch/executorch#11715. + """ + if node.target not in _ARG_MIN_MAX_TARGETS: + return False + dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None) + return dim is None + + def _is_view_op(op: torch._ops.OpOverload) -> bool: schema = op._schema if len(schema.arguments) == 0: @@ -116,6 +136,13 @@ def should_override_support(self, node) -> bool: ) return True + if _is_arg_min_max_over_flattened_input(node): + self.log_once( + "torch.ops.aten.{argmax, argmin}.default with dim=None is " + "not supported by CoreML. Overriding op support." + ) + return True + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args # # in the placeholders due to partitioning, which CoreML does not support diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 4b38806ee63..28695302bd9 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -338,6 +338,53 @@ def forward(self, x): torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02) ) + def test_argmax_argmin_dim_none_is_skipped(self): + """ + Regression test for https://github.com/pytorch/executorch/issues/11715. + + argmax/argmin with dim=None reduces over the flattened tensor, which + CoreML does not support; the resulting model intermittently crashes + the process at runtime. The partitioner must reject these so they + fall back to the portable backend, while still delegating the + ordinary dim=int form. + """ + + class FlatModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=None, keepdim=False) + torch.argmin( + x, dim=None + ) + + ep = torch.export.export( + FlatModel().eval(), (torch.randn(10, 10),), strict=True + ) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("aten.argmax.default", op_names) + self.assertIn("aten.argmin.default", op_names) + + class DimModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=1) + + ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("executorch_call_delegate", op_names) + self.assertNotIn("aten.argmax.default", op_names) + def test_deprecation_warning_for_to_backend_workflow(self): """ Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.