diff --git a/src/uint/boxed/mul_mod.rs b/src/uint/boxed/mul_mod.rs index 9e5adccd..e965a3be 100644 --- a/src/uint/boxed/mul_mod.rs +++ b/src/uint/boxed/mul_mod.rs @@ -22,6 +22,10 @@ impl BoxedUint { pub fn mul_mod_special(&self, rhs: &Self, c: Limb) -> Self { debug_assert_eq!(self.bits_precision(), rhs.bits_precision()); + if c.is_zero().into() { + return self.wrapping_mul(rhs); + } + if self.nlimbs() == 1 { let reduced = mul_rem( self.limbs[0], @@ -93,7 +97,7 @@ fn mac_by_limb(a: &UintRef, b: &UintRef, c: Limb, carry: Limb) -> (BoxedUint, Li #[cfg(all(test, feature = "rand_core"))] mod tests { - use crate::{BoxedUint, ConcatenatingMul, Limb, NonZero, Random, RandomMod}; + use crate::{BoxedUint, ConcatenatingMul, Limb, NonZero, Random, RandomMod, Resize}; use rand_core::SeedableRng; #[test] @@ -148,4 +152,18 @@ mod tests { } } } + + #[test] + fn mul_mod_special_zero_c_is_wrapping_multiplication() { + for bits in [Limb::BITS, 2 * Limb::BITS, 4 * Limb::BITS] { + let a = BoxedUint::from(0x1234_5678u32).resize(bits); + let b = BoxedUint::from(0xfedc_ba91u32).resize(bits); + + assert_eq!( + a.mul_mod_special(&b, Limb::ZERO), + a.wrapping_mul(&b), + "c = 0 represents the power-of-two modulus" + ); + } + } } diff --git a/src/uint/mul_mod.rs b/src/uint/mul_mod.rs index 58e3a1a9..24d32225 100644 --- a/src/uint/mul_mod.rs +++ b/src/uint/mul_mod.rs @@ -25,6 +25,10 @@ impl Uint { /// and S. Vanstone, CRC Press, 1996. #[must_use] pub const fn mul_mod_special(&self, rhs: &Self, c: Limb) -> Self { + if c.is_zero().to_bool_vartime() { + return self.wrapping_mul(rhs); + } + // We implicitly assume `LIMBS > 0`, because `Uint<0>` doesn't compile. // Still the case `LIMBS == 1` needs special handling. if LIMBS == 1 { @@ -163,4 +167,25 @@ mod tests { test_size::<16>(); } } + + #[test] + fn mul_mod_special_zero_c_is_wrapping_multiplication() { + let a = Uint::<1>::from_u32(0x1234_5678); + let b = Uint::<1>::from_u32(0xfedc_ba91); + + assert_eq!( + a.mul_mod_special(&b, Limb::ZERO), + a.wrapping_mul(&b), + "c = 0 represents the power-of-two modulus" + ); + + let a = Uint::<2>::from_u32(0x1234_5678); + let b = Uint::<2>::from_u32(0xfedc_ba91); + + assert_eq!( + a.mul_mod_special(&b, Limb::ZERO), + a.wrapping_mul(&b), + "c = 0 represents the power-of-two modulus" + ); + } }