From 182ae1d600b15857ace70cf93731e36cdcef7c92 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 13:48:13 +0900 Subject: [PATCH 1/4] Skip argmin/argmax with dim=None in CoreML partitioner argmax/argmin with dim=None reduces over the flattened input, which CoreML does not support and which intermittently crashes the process at runtime. Reject these in the partitioner so they fall back to the portable backend; the ordinary dim=int form is still delegated. Fixes #11715. --- .../coreml/partition/coreml_partitioner.py | 17 +++++++ .../coreml/test/test_coreml_partitioner.py | 45 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 447387be3c3..622dc95973f 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -116,6 +116,23 @@ def should_override_support(self, node) -> bool: ) return True + # https://github.com/pytorch/executorch/issues/11715 + # argmin/argmax with dim=None reduces over the flattened input, which + # CoreML does not support and causes intermittent process crashes. + if node.target in [ + torch.ops.aten.argmax.default, + torch.ops.aten.argmin.default, + exir_ops.edge.aten.argmax.default, + exir_ops.edge.aten.argmin.default, + ]: + dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None) + if dim is None: + self.log_once( + "torch.ops.aten.{argmax, argmin}.default with dim=None is " + "not supported by CoreML. Overriding op support." + ) + return True + # TODO: enable this after bugs in ExecuTorch's partitioner are fixed # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args # # in the placeholders due to partitioning, which CoreML does not support diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 4b38806ee63..3bc4832550f 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -338,6 +338,51 @@ def forward(self, x): torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02) ) + def test_argmax_argmin_dim_none_is_skipped(self): + """ + Regression test for https://github.com/pytorch/executorch/issues/11715. + + argmax/argmin with dim=None reduces over the flattened tensor, which + CoreML does not support; the resulting model intermittently crashes + the process at runtime. The partitioner must reject these so they + fall back to the portable backend, while still delegating the + ordinary dim=int form. + """ + + class FlatModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=None, keepdim=False) + torch.argmin( + x, dim=None + ) + + ep = torch.export.export(FlatModel().eval(), (torch.randn(10, 10),), strict=True) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("aten.argmax.default", op_names) + self.assertIn("aten.argmin.default", op_names) + + class DimModel(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=1) + + ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True) + edge = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[CoreMLPartitioner()] + ) + op_names = [ + n.target.__name__ + for n in edge.exported_program().graph.nodes + if n.op == "call_function" + ] + self.assertIn("executorch_call_delegate", op_names) + self.assertNotIn("aten.argmax.default", op_names) + def test_deprecation_warning_for_to_backend_workflow(self): """ Test that the deprecated to_edge + to_backend workflow shows a deprecation warning. From 2fe428f2536175f7b1bb541c95fd65eb48537449 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 14:44:31 +0900 Subject: [PATCH 2/4] ci: trigger Meta CLA re-check after signature From d950f597bdda3a502da4f4c9842917874722dd4e Mon Sep 17 00:00:00 2001 From: john-rocky Date: Tue, 5 May 2026 05:36:29 +0900 Subject: [PATCH 3/4] review: factor argmin/argmax dim=None check into helper Per @metascroy on #19247. --- .../coreml/partition/coreml_partitioner.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 622dc95973f..4b5d4a6de82 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -33,6 +33,26 @@ logger.setLevel(get_coreml_log_level(default_level=logging.INFO)) +_ARG_MIN_MAX_TARGETS = ( + torch.ops.aten.argmax.default, + torch.ops.aten.argmin.default, + exir_ops.edge.aten.argmax.default, + exir_ops.edge.aten.argmin.default, +) + + +def _is_arg_min_max_over_flattened_input(node: torch.fx.Node) -> bool: + """``argmin``/``argmax`` with ``dim=None`` reduces over the flattened input. + + CoreML doesn't support that reduction shape and intermittently crashes + the process at runtime — see pytorch/executorch#11715. + """ + if node.target not in _ARG_MIN_MAX_TARGETS: + return False + dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None) + return dim is None + + def _is_view_op(op: torch._ops.OpOverload) -> bool: schema = op._schema if len(schema.arguments) == 0: @@ -116,22 +136,12 @@ def should_override_support(self, node) -> bool: ) return True - # https://github.com/pytorch/executorch/issues/11715 - # argmin/argmax with dim=None reduces over the flattened input, which - # CoreML does not support and causes intermittent process crashes. - if node.target in [ - torch.ops.aten.argmax.default, - torch.ops.aten.argmin.default, - exir_ops.edge.aten.argmax.default, - exir_ops.edge.aten.argmin.default, - ]: - dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None) - if dim is None: - self.log_once( - "torch.ops.aten.{argmax, argmin}.default with dim=None is " - "not supported by CoreML. Overriding op support." - ) - return True + if _is_arg_min_max_over_flattened_input(node): + self.log_once( + "torch.ops.aten.{argmax, argmin}.default with dim=None is " + "not supported by CoreML. Overriding op support." + ) + return True # TODO: enable this after bugs in ExecuTorch's partitioner are fixed # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args From a01c061f92db20342b0e2239d7555201955effe0 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Wed, 6 May 2026 04:39:28 +0900 Subject: [PATCH 4/4] review: lintrunner -a (line wrap on test_coreml_partitioner.py) --- backends/apple/coreml/test/test_coreml_partitioner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 3bc4832550f..28695302bd9 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -355,7 +355,9 @@ def forward(self, x): x, dim=None ) - ep = torch.export.export(FlatModel().eval(), (torch.randn(10, 10),), strict=True) + ep = torch.export.export( + FlatModel().eval(), (torch.randn(10, 10),), strict=True + ) edge = executorch.exir.to_edge_transform_and_lower( ep, partitioner=[CoreMLPartitioner()] )