diff --git a/backends/transforms/channels_last_ops.py b/backends/transforms/channels_last_ops.py new file mode 100644 index 00000000000..cbb182ccb02 --- /dev/null +++ b/backends/transforms/channels_last_ops.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""The ``channels_last`` operator dialect. + +Operators in this dialect interpret their activation input/output as channels-last +``(N, H, W, C)`` with contiguous strides and a fixed (identity) dim-order, as +opposed to the implicit dim-order handling used elsewhere. They let layout-handling +passes (see RFC #19299) make channels-last regions explicit in the graph. + +Efficiency is a non-goal: kernels are implemented as ``permute -> aten op -> permute``. +Importing this module registers the dialect. +""" + +import torch +from torch.library import Library, register_fake + +lib = Library("channels_last", "DEF") + + +def _conv( + input, weight, bias, stride, padding, dilation, transposed, output_padding, groups +): + nchw = input.permute(0, 3, 1, 2) + out = torch.ops.aten.convolution( + nchw, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + return out.permute(0, 2, 3, 1).contiguous() + + +def _avg_pool2d( + input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override +): + nchw = input.permute(0, 3, 1, 2) + out = torch.ops.aten.avg_pool2d( + nchw, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + return out.permute(0, 2, 3, 1).contiguous() + + +def _permute_copy(input, dims): + return torch.ops.aten.permute_copy(input, dims).contiguous() + + +lib.define( + "convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, " + "int[] padding, int[] dilation, bool transposed, int[] output_padding, " + "int groups) -> Tensor" +) +lib.impl("convolution", _conv, "CompositeExplicitAutograd") +register_fake("channels_last::convolution", _conv, lib=lib) + +lib.define( + "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride, int[2] padding, " + "bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor" +) +lib.impl("avg_pool2d", _avg_pool2d, "CompositeExplicitAutograd") +register_fake("channels_last::avg_pool2d", _avg_pool2d, lib=lib) + +lib.define("permute_copy(Tensor input, int[] dims) -> Tensor") +lib.impl("permute_copy", _permute_copy, "CompositeExplicitAutograd") +register_fake("channels_last::permute_copy", _permute_copy, lib=lib) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 36466ec4aa0..d978e601359 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -191,6 +191,19 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "channels_last_ops", + srcs = [ + "channels_last_ops.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + ], + ) + runtime.python_library( name = "rank_0_to_rank_1", srcs = [ @@ -269,6 +282,19 @@ def define_common_targets(): ], ) + runtime.python_test( + name = "test_channels_last_ops", + srcs = [ + "test/test_channels_last_ops.py", + ], + deps = [ + "//caffe2:torch", + ":channels_last_ops", + "//executorch/exir:lib", + "fbsource//third-party/pypi/pytest:pytest", + ], + ) + runtime.python_test( name = "test_rank_0_to_rank_1", srcs = [ diff --git a/backends/transforms/test/test_channels_last_ops.py b/backends/transforms/test/test_channels_last_ops.py new file mode 100644 index 00000000000..b06a2773c59 --- /dev/null +++ b/backends/transforms/test/test_channels_last_ops.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Importing the module registers the channels_last operator dialect. +import executorch.backends.transforms.channels_last_ops # noqa: F401 +import pytest +import torch +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops + + +def _to_nhwc(nchw: torch.Tensor) -> torch.Tensor: + return nchw.permute(0, 2, 3, 1).contiguous() + + +def _find(graph_module: torch.fx.GraphModule, target): + nodes = [ + n + for n in graph_module.graph.nodes + if n.op == "call_function" and n.target == target + ] + assert len(nodes) == 1, f"expected exactly one {target}, found {len(nodes)}" + return nodes[0] + + +_CONV_CASES = [ + # (N, C_in, H, W, C_out, kernel, stride, padding, dilation, groups, bias) + (2, 3, 8, 8, 4, 3, 1, 0, 1, 1, True), + (2, 3, 8, 8, 4, 3, 1, 0, 1, 1, False), + (1, 4, 10, 10, 6, 3, 2, 1, 1, 1, True), + (1, 4, 7, 7, 4, 3, 1, 1, 1, 4, True), # depthwise (groups == C_in == C_out) +] + + +@pytest.mark.parametrize("n,cin,h,w,cout,k,stride,pad,dil,groups,bias", _CONV_CASES) +def test_convolution_matches_aten( + n, cin, h, w, cout, k, stride, pad, dil, groups, bias +): + torch.manual_seed(0) + nchw = torch.randn(n, cin, h, w) + weight = torch.randn(cout, cin // groups, k, k) + bias_t = torch.randn(cout) if bias else None + nhwc = _to_nhwc(nchw) + + expected = _to_nhwc( + torch.ops.aten.convolution( + nchw, + weight, + bias_t, + [stride, stride], + [pad, pad], + [dil, dil], + False, + [0, 0], + groups, + ) + ) + actual = torch.ops.channels_last.convolution( + nhwc, + weight, + bias_t, + [stride, stride], + [pad, pad], + [dil, dil], + False, + [0, 0], + groups, + ) + + assert actual.shape == expected.shape + assert torch.allclose(actual, expected, atol=1e-5) + + +@pytest.mark.parametrize( + "kernel,stride,pad,ceil_mode,count_include_pad", + [ + (2, 2, 0, False, True), + (3, 1, 1, False, True), + (3, 2, 1, True, False), + ], +) +def test_avg_pool2d_matches_aten(kernel, stride, pad, ceil_mode, count_include_pad): + torch.manual_seed(0) + nchw = torch.randn(2, 3, 9, 9) + nhwc = _to_nhwc(nchw) + + expected = _to_nhwc( + torch.ops.aten.avg_pool2d( + nchw, + [kernel, kernel], + [stride, stride], + [pad, pad], + ceil_mode, + count_include_pad, + None, + ) + ) + actual = torch.ops.channels_last.avg_pool2d( + nhwc, + [kernel, kernel], + [stride, stride], + [pad, pad], + ceil_mode, + count_include_pad, + None, + ) + + assert actual.shape == expected.shape + assert torch.allclose(actual, expected, atol=1e-5) + + +@pytest.mark.parametrize("dims", [(0, 3, 1, 2), (0, 2, 3, 1), (3, 2, 1, 0)]) +def test_permute_copy_moves_data(dims): + torch.manual_seed(0) + x = torch.randn(2, 4, 5, 3) + + expected = torch.ops.aten.permute_copy(x, list(dims)) + actual = torch.ops.channels_last.permute_copy(x, list(dims)) + + assert actual.shape == expected.shape + assert torch.equal(actual, expected) + assert actual.is_contiguous() + + +def test_convolution_lowers_to_edge_dialect(): + class M(torch.nn.Module): + def forward(self, x, w, b): + return torch.ops.channels_last.convolution( + x, w, b, [1, 1], [0, 0], [1, 1], False, [0, 0], 1 + ) + + nhwc = torch.randn(2, 8, 8, 3) + weight = torch.randn(4, 3, 3, 3) + bias = torch.randn(4) + + ep = torch.export.export(M().eval(), (nhwc, weight, bias), strict=True) + edge = to_edge(ep) + + node = _find( + edge.exported_program().graph_module, + exir_ops.edge.channels_last.convolution.default, + ) + # Fake kernel must yield the correct NHWC output shape (N, H_out, W_out, C_out). + assert tuple(node.meta["val"].shape) == (2, 6, 6, 4)