diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index 26b97e5a7a2..b9cbab4c1ef 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include // std::array #include // PRId64 @@ -932,7 +933,13 @@ inline size_t getLeadingDims( ssize_t(tensor.dim())); size_t dims = 1; for (const auto i : c10::irange(dim)) { - dims *= static_cast(tensor.size(i)); + size_t next_dims; + ET_CHECK_MSG( + !c10::mul_overflows( + dims, static_cast(tensor.size(i)), &next_dims), + "Overflow computing leading dims at dimension %zd", + (ssize_t)i); + dims = next_dims; } return dims; } @@ -949,7 +956,13 @@ inline size_t getTrailingDims( ssize_t(tensor.dim())); size_t dims = 1; for (size_t i = dim + 1; i < static_cast(tensor.dim()); ++i) { - dims *= static_cast(tensor.size(i)); + size_t next_dims; + ET_CHECK_MSG( + !c10::mul_overflows( + dims, static_cast(tensor.size(i)), &next_dims), + "Overflow computing trailing dims at dimension %zu", + i); + dims = next_dims; } return dims; }