From 9cec903c0028da590b51fb58c04d196e7a26a723 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 14 May 2026 03:06:45 +0800 Subject: [PATCH] fix(tests): generate integer inputs portably --- tests/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 982d05aec..3bfa8e93c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -75,9 +75,21 @@ def rand_strided(shape, strides, *, dtype=None, device=None): def randint_strided(low, high, shape, strides, *, dtype=None, device=None): output = empty_strided(shape, strides, dtype=dtype, device=device) - output.as_strided( + flat = output.as_strided( (output.untyped_storage().size() // output.element_size(),), (1,) - ).random_(low, high) + ) + + try: + flat.random_(low, high) + except RuntimeError as exc: + if "random_" not in str(exc) or "not implemented" not in str(exc): + raise + + values = torch.randint(low, high, flat.shape, dtype=torch.int64).to( + dtype=output.dtype, + device=output.device, + ) + flat.copy_(values) return output