From a53f5e18966be1655a669244a118a4262760005d Mon Sep 17 00:00:00 2001 From: Kristof Roomp Date: Tue, 19 Nov 2024 19:38:38 +0100 Subject: [PATCH] Added widening multiply for u16x8, u32x4, u32x8, i32x4 and i32x8 (#182) * add widening unsigned mul * added wasm * fix ordering of lanes in neon * try again * fix for avx2 * improve tests * added plain widen_mul * add mul_keep_high * fix dumb stuff * added comments * make u32x8 version generic * optimized avx2 of mul_keep_high * make calling convention more consistant * fix arm * comments * add WASM simd implementations... turns out they are easy --- src/f64x2_.rs | 6 +- src/i32x4_.rs | 56 ++++++++++++++-- src/i32x8_.rs | 8 +-- src/i64x4_.rs | 4 +- src/i8x16_.rs | 6 +- src/i8x32_.rs | 27 ++++---- src/u16x8_.rs | 38 +++++++++++ src/u32x4_.rs | 109 ++++++++++++++++++++++++++++++-- src/u32x8_.rs | 33 ++++++++-- tests/all_tests/main.rs | 3 +- tests/all_tests/t_i16x8.rs | 23 ++++++- tests/all_tests/t_i32x4.rs | 19 ++++++ tests/all_tests/t_u16x8.rs | 25 ++++++++ tests/all_tests/t_u32x4.rs | 40 ++++++++++++ tests/all_tests/t_u32x8.rs | 8 +++ tests/all_tests/t_usefulness.rs | 17 +---- 16 files changed, 367 insertions(+), 55 deletions(-) diff --git a/src/f64x2_.rs b/src/f64x2_.rs index 429fadf..f300bb9 100644 --- a/src/f64x2_.rs +++ b/src/f64x2_.rs @@ -1598,7 +1598,8 @@ impl f64x2 { cast_mut(self) } - /// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the higher two `i32` lanes) + /// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the + /// higher two `i32` lanes) #[inline] pub fn from_i32x4_lower2(v: i32x4) -> Self { pick! { @@ -1619,7 +1620,8 @@ impl f64x2 { } impl From for f64x2 { - /// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the higher two `i32` lanes) + /// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the + /// higher two `i32` lanes) #[inline] fn from(v: i32x4) -> Self { Self::from_i32x4_lower2(v) diff --git a/src/i32x4_.rs b/src/i32x4_.rs index 0a6c238..e8053bc 100644 --- a/src/i32x4_.rs +++ b/src/i32x4_.rs @@ -327,8 +327,8 @@ impl_shr_t_for_i32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shr`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shr`) impl Shr for i32x4 { type Output = Self; @@ -364,8 +364,8 @@ impl Shr for i32x4 { /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shl`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shl`) impl Shl for i32x4 { type Output = Self; @@ -490,6 +490,54 @@ impl i32x4 { } } } + + /// Multiplies corresponding 32 bit lanes and returns the 64 bit result + /// on the corresponding lanes. + /// + /// Effectively does two multiplies on 128 bit platforms, but is easier + /// to use than wrapping mul_widen_i32_odd_m128i individually. + #[inline] + #[must_use] + pub fn mul_widen(self, rhs: Self) -> i64x4 { + pick! { + if #[cfg(target_feature="avx2")] { + let a = convert_to_i64_m256i_from_i32_m128i(self.sse); + let b = convert_to_i64_m256i_from_i32_m128i(rhs.sse); + cast(mul_i64_low_bits_m256i(a, b)) + } else if #[cfg(target_feature="sse4.1")] { + let evenp = mul_widen_i32_odd_m128i(self.sse, rhs.sse); + + let oddp = mul_widen_i32_odd_m128i( + shr_imm_u64_m128i::<32>(self.sse), + shr_imm_u64_m128i::<32>(rhs.sse)); + + i64x4 { + a: i64x2 { sse: unpack_low_i64_m128i(evenp, oddp)}, + b: i64x2 { sse: unpack_high_i64_m128i(evenp, oddp)} + } + } else if #[cfg(target_feature="simd128")] { + i64x4 { + a: i64x2 { simd: i64x2_extmul_low_i32x4(self.simd, rhs.simd) }, + b: i64x2 { simd: i64x2_extmul_high_i32x4(self.simd, rhs.simd) }, + } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] { + unsafe { + i64x4 { a: i64x2 { neon: vmull_s32(vget_low_s32(self.neon), vget_low_s32(rhs.neon)) }, + b: i64x2 { neon: vmull_s32(vget_high_s32(self.neon), vget_high_s32(rhs.neon)) } } + } + } else { + let a: [i32; 4] = cast(self); + let b: [i32; 4] = cast(rhs); + cast([ + i64::from(a[0]) * i64::from(b[0]), + i64::from(a[1]) * i64::from(b[1]), + i64::from(a[2]) * i64::from(b[2]), + i64::from(a[3]) * i64::from(b[3]), + ]) + } + } + } + #[inline] #[must_use] pub fn abs(self) -> Self { diff --git a/src/i32x8_.rs b/src/i32x8_.rs index 30835da..8297d7b 100644 --- a/src/i32x8_.rs +++ b/src/i32x8_.rs @@ -232,8 +232,8 @@ impl_shr_t_for_i32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shr`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shr`) impl Shr for i32x8 { type Output = Self; @@ -258,8 +258,8 @@ impl Shr for i32x8 { /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shl`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shl`) impl Shl for i32x8 { type Output = Self; diff --git a/src/i64x4_.rs b/src/i64x4_.rs index d9f6636..faaf038 100644 --- a/src/i64x4_.rs +++ b/src/i64x4_.rs @@ -4,11 +4,11 @@ pick! { if #[cfg(target_feature="avx2")] { #[derive(Default, Clone, Copy, PartialEq, Eq)] #[repr(C, align(32))] - pub struct i64x4 { avx2: m256i } + pub struct i64x4 { pub(crate) avx2: m256i } } else { #[derive(Default, Clone, Copy, PartialEq, Eq)] #[repr(C, align(32))] - pub struct i64x4 { a : i64x2, b : i64x2 } + pub struct i64x4 { pub(crate) a : i64x2, pub(crate) b : i64x2 } } } diff --git a/src/i8x16_.rs b/src/i8x16_.rs index 504956c..023b382 100644 --- a/src/i8x16_.rs +++ b/src/i8x16_.rs @@ -700,7 +700,8 @@ impl i8x16 { /// `rhs`. /// /// * Index values in the range `[0, 15]` select the i-th element of `self`. - /// * Index values that are out of range will cause that output lane to be `0`. + /// * Index values that are out of range will cause that output lane to be + /// `0`. #[inline] pub fn swizzle(self, rhs: i8x16) -> i8x16 { pick! { @@ -727,7 +728,8 @@ impl i8x16 { } } - /// Works like [`swizzle`](Self::swizzle) with the following additional details + /// Works like [`swizzle`](Self::swizzle) with the following additional + /// details /// /// * Indices in the range `[0, 15]` will select the i-th element of `self`. /// * If the high bit of any index is set (meaning that the index is diff --git a/src/i8x32_.rs b/src/i8x32_.rs index b2b72b0..ebc1034 100644 --- a/src/i8x32_.rs +++ b/src/i8x32_.rs @@ -331,13 +331,14 @@ impl i8x32 { !self.any() } - /// Returns a new vector with lanes selected from the lanes of the first input vector - /// a specified in the second input vector `rhs`. - /// The indices i in range `[0, 15]` select the i-th element of `self`. For indices - /// outside of the range the resulting lane is `0`. + /// Returns a new vector with lanes selected from the lanes of the first input + /// vector a specified in the second input vector `rhs`. + /// The indices i in range `[0, 15]` select the i-th element of `self`. For + /// indices outside of the range the resulting lane is `0`. /// - /// This note that is the equivalent of two parallel swizzle operations on the two halves of the vector, - /// and the indexes each refer to the corresponding half. + /// This note that is the equivalent of two parallel swizzle operations on the + /// two halves of the vector, and the indexes each refer to the + /// corresponding half. #[inline] pub fn swizzle_half(self, rhs: i8x32) -> i8x32 { pick! { @@ -352,13 +353,15 @@ impl i8x32 { } } - /// Indices in the range `[0, 15]` will select the i-th element of `self`. If the high bit - /// of any element of `rhs` is set (negative) then the corresponding output - /// lane is guaranteed to be zero. Otherwise if the element of `rhs` is within the range `[32, 127]` - /// then the output lane is either `0` or `self[rhs[i] % 16]` depending on the implementation. + /// Indices in the range `[0, 15]` will select the i-th element of `self`. If + /// the high bit of any element of `rhs` is set (negative) then the + /// corresponding output lane is guaranteed to be zero. Otherwise if the + /// element of `rhs` is within the range `[32, 127]` then the output lane is + /// either `0` or `self[rhs[i] % 16]` depending on the implementation. /// - /// This is the equivalent to two parallel swizzle operations on the two halves of the vector, - /// and the indexes each refer to their corresponding half. + /// This is the equivalent to two parallel swizzle operations on the two + /// halves of the vector, and the indexes each refer to their corresponding + /// half. #[inline] pub fn swizzle_half_relaxed(self, rhs: i8x32) -> i8x32 { pick! { diff --git a/src/u16x8_.rs b/src/u16x8_.rs index c1e0d64..64931c4 100644 --- a/src/u16x8_.rs +++ b/src/u16x8_.rs @@ -591,6 +591,44 @@ impl u16x8 { } } + /// Multiples two `u16x8` and return the high part of intermediate `u32x8` + #[inline] + #[must_use] + pub fn mul_keep_high(self, rhs: Self) -> Self { + pick! { + if #[cfg(target_feature="sse2")] { + Self { sse: mul_u16_keep_high_m128i(self.sse, rhs.sse) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] { + let lhs_low = unsafe { vget_low_u16(self.neon) }; + let rhs_low = unsafe { vget_low_u16(rhs.neon) }; + + let lhs_high = unsafe { vget_high_u16(self.neon) }; + let rhs_high = unsafe { vget_high_u16(rhs.neon) }; + + let low = unsafe { vmull_u16(lhs_low, rhs_low) }; + let high = unsafe { vmull_u16(lhs_high, rhs_high) }; + + u16x8 { neon: unsafe { vuzpq_u16(vreinterpretq_u16_u32(low), vreinterpretq_u16_u32(high)).1 } } + } else if #[cfg(target_feature="simd128")] { + let low = u32x4_extmul_low_u16x8(self.simd, rhs.simd); + let high = u32x4_extmul_high_u16x8(self.simd, rhs.simd); + + Self { simd: u16x8_shuffle::<1, 3, 5, 7, 9, 11, 13, 15>(low, high) } + } else { + u16x8::new([ + ((u32::from(rhs.as_array_ref()[0]) * u32::from(self.as_array_ref()[0])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[1]) * u32::from(self.as_array_ref()[1])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[2]) * u32::from(self.as_array_ref()[2])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[3]) * u32::from(self.as_array_ref()[3])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[4]) * u32::from(self.as_array_ref()[4])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[5]) * u32::from(self.as_array_ref()[5])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[6]) * u32::from(self.as_array_ref()[6])) >> 16) as u16, + ((u32::from(rhs.as_array_ref()[7]) * u32::from(self.as_array_ref()[7])) >> 16) as u16, + ]) + } + } + } + #[inline] pub fn to_array(self) -> [u16; 8] { cast(self) diff --git a/src/u32x4_.rs b/src/u32x4_.rs index e85b0a8..c860545 100644 --- a/src/u32x4_.rs +++ b/src/u32x4_.rs @@ -327,8 +327,8 @@ impl_shr_t_for_u32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shr`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shr`) impl Shr for u32x4 { type Output = Self; #[inline] @@ -363,8 +363,8 @@ impl Shr for u32x4 { /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shl`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shl`) impl Shl for u32x4 { type Output = Self; #[inline] @@ -431,7 +431,7 @@ impl u32x4 { Self { sse: cmp_gt_mask_i32_m128i((self ^ h).sse, (rhs ^ h).sse) } } else if #[cfg(target_feature="simd128")] { Self { simd: u32x4_gt(self.simd, rhs.simd) } - } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] { unsafe {Self { neon: vcgtq_u32(self.neon, rhs.neon) }} } else { Self { arr: [ @@ -450,6 +450,105 @@ impl u32x4 { rhs.cmp_gt(self) } + /// Multiplies 32x32 bit to 64 bit and then only keeps the high 32 bits of the + /// result. Useful for implementing divide constant value (see t_usefulness + /// example) + #[inline] + #[must_use] + pub fn mul_keep_high(self, rhs: Self) -> Self { + pick! { + if #[cfg(target_feature="avx2")] { + let a = convert_to_i64_m256i_from_u32_m128i(self.sse); + let b = convert_to_i64_m256i_from_u32_m128i(rhs.sse); + let r = mul_u64_low_bits_m256i(a, b); + + // the compiler does a good job shuffling the lanes around + let b : [u32;8] = cast(r); + cast([b[1],b[3],b[5],b[7]]) + } else if #[cfg(target_feature="sse2")] { + let evenp = mul_widen_u32_odd_m128i(self.sse, rhs.sse); + + let oddp = mul_widen_u32_odd_m128i( + shr_imm_u64_m128i::<32>(self.sse), + shr_imm_u64_m128i::<32>(rhs.sse)); + + // the compiler does a good job shuffling the lanes around + let a : [u32;4]= cast(evenp); + let b : [u32;4]= cast(oddp); + cast([a[1],b[1],a[3],b[3]]) + + } else if #[cfg(target_feature="simd128")] { + let low = u64x2_extmul_low_u32x4(self.simd, rhs.simd); + let high = u64x2_extmul_high_u32x4(self.simd, rhs.simd); + + Self { simd: u32x4_shuffle::<1, 3, 5, 7>(low, high) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] { + unsafe { + let l = vmull_u32(vget_low_u32(self.neon), vget_low_u32(rhs.neon)); + let h = vmull_u32(vget_high_u32(self.neon), vget_high_u32(rhs.neon)); + u32x4 { neon: vcombine_u32(vshrn_n_u64(l,32), vshrn_n_u64(h,32)) } + } + } else { + let a: [u32; 4] = cast(self); + let b: [u32; 4] = cast(rhs); + cast([ + ((u64::from(a[0]) * u64::from(b[0])) >> 32) as u32, + ((u64::from(a[1]) * u64::from(b[1])) >> 32) as u32, + ((u64::from(a[2]) * u64::from(b[2])) >> 32) as u32, + ((u64::from(a[3]) * u64::from(b[3])) >> 32) as u32, + ]) + } + } + } + + /// Multiplies corresponding 32 bit lanes and returns the 64 bit result + /// on the corresponding lanes. + /// + /// Effectively does two multiplies on 128 bit platforms, but is easier + /// to use than wrapping mul_widen_u32_odd_m128i individually. + #[inline] + #[must_use] + pub fn mul_widen(self, rhs: Self) -> u64x4 { + pick! { + if #[cfg(target_feature="avx2")] { + // ok to sign extend since we are throwing away the high half of the result anyway + let a = convert_to_i64_m256i_from_i32_m128i(self.sse); + let b = convert_to_i64_m256i_from_i32_m128i(rhs.sse); + cast(mul_u64_low_bits_m256i(a, b)) + } else if #[cfg(target_feature="sse2")] { + let evenp = mul_widen_u32_odd_m128i(self.sse, rhs.sse); + + let oddp = mul_widen_u32_odd_m128i( + shr_imm_u64_m128i::<32>(self.sse), + shr_imm_u64_m128i::<32>(rhs.sse)); + + u64x4 { + a: u64x2 { sse: unpack_low_i64_m128i(evenp, oddp)}, + b: u64x2 { sse: unpack_high_i64_m128i(evenp, oddp)} + } + } else if #[cfg(target_feature="simd128")] { + u64x4 { + a: u64x2 { simd: u64x2_extmul_low_u32x4(self.simd, rhs.simd) }, + b: u64x2 { simd: u64x2_extmul_high_u32x4(self.simd, rhs.simd) }, + } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] { + unsafe { + u64x4 { a: u64x2 { neon: vmull_u32(vget_low_u32(self.neon), vget_low_u32(rhs.neon)) }, + b: u64x2 { neon: vmull_u32(vget_high_u32(self.neon), vget_high_u32(rhs.neon)) } } + } + } else { + let a: [u32; 4] = cast(self); + let b: [u32; 4] = cast(rhs); + cast([ + u64::from(a[0]) * u64::from(b[0]), + u64::from(a[1]) * u64::from(b[1]), + u64::from(a[2]) * u64::from(b[2]), + u64::from(a[3]) * u64::from(b[3]), + ]) + } + } + } + #[inline] #[must_use] pub fn blend(self, t: Self, f: Self) -> Self { diff --git a/src/u32x8_.rs b/src/u32x8_.rs index 6def6df..376f8be 100644 --- a/src/u32x8_.rs +++ b/src/u32x8_.rs @@ -208,8 +208,8 @@ impl_shr_t_for_u32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shr`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shr`) impl Shr for u32x8 { type Output = Self; @@ -234,8 +234,8 @@ impl Shr for u32x8 { /// Shifts lanes by the corresponding lane. /// /// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any -/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of -/// the type. (same as `wrapping_shl`) +/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth +/// of the type. (same as `wrapping_shl`) impl Shl for u32x8 { type Output = Self; @@ -301,6 +301,31 @@ impl u32x8 { rhs.cmp_gt(self) } + /// Multiplies 32x32 bit to 64 bit and then only keeps the high 32 bits of the + /// result. Useful for implementing divide constant value (see t_usefulness + /// example) + #[inline] + #[must_use] + pub fn mul_keep_high(self, rhs: u32x8) -> u32x8 { + pick! { + if #[cfg(target_feature="avx2")] { + let a : [u32;8]= cast(self); + let b : [u32;8]= cast(rhs); + + // let the compiler shuffle the values around, it does the right thing + let r1 : [u32;8] = cast(mul_u64_low_bits_m256i(cast([a[0], 0, a[1], 0, a[2], 0, a[3], 0]), cast([b[0], 0, b[1], 0, b[2], 0, b[3], 0]))); + let r2 : [u32;8] = cast(mul_u64_low_bits_m256i(cast([a[4], 0, a[5], 0, a[6], 0, a[7], 0]), cast([b[4], 0, b[5], 0, b[6], 0, b[7], 0]))); + + cast([r1[1], r1[3], r1[5], r1[7], r2[1], r2[3], r2[5], r2[7]]) + } else { + Self { + a : self.a.mul_keep_high(rhs.a), + b : self.b.mul_keep_high(rhs.b), + } + } + } + } + #[inline] #[must_use] pub fn blend(self, t: Self, f: Self) -> Self { diff --git a/tests/all_tests/main.rs b/tests/all_tests/main.rs index 81833f7..98fa169 100644 --- a/tests/all_tests/main.rs +++ b/tests/all_tests/main.rs @@ -32,7 +32,8 @@ fn next_rand_u64(state: &mut u64) -> u64 { const A: u64 = 6364136223846793005; const C: u64 = 1442695040888963407; - // Update the state and calculate the next number (rotate to avoid lack of randomness in low bits) + // Update the state and calculate the next number (rotate to avoid lack of + // randomness in low bits) *state = state.wrapping_mul(A).wrapping_add(C).rotate_left(31); *state diff --git a/tests/all_tests/t_i16x8.rs b/tests/all_tests/t_i16x8.rs index f6a1ec0..5f8375d 100644 --- a/tests/all_tests/t_i16x8.rs +++ b/tests/all_tests/t_i16x8.rs @@ -369,10 +369,27 @@ fn impl_i16x8_reduce_max() { #[test] fn impl_mul_keep_high() { - let a = i16x8::from([1, 200, 300, 4568, -1, -2, -3, -4]); - let b = i16x8::from([5, 600, 700, 8910, -15, -26, -37, 48]); + let a = i16x8::from([i16::MAX, 200, 300, 4568, -1, -2, -3, -4]); + let b = i16x8::from([i16::MIN, 600, 700, 8910, -15, -26, -37, 48]); let c: [i16; 8] = i16x8::mul_keep_high(a, b).into(); - assert_eq!(c, [0, 1, 3, 621, 0, 0, 0, -1]); + assert_eq!( + c, + [ + (i32::from(i16::MAX) * i32::from(i16::MIN) >> 16) as i16, + 1, + 3, + 621, + 0, + 0, + 0, + -1 + ] + ); + + crate::test_random_vector_vs_scalar( + |a: i16x8, b| i16x8::mul_keep_high(a, b), + |a, b| ((i32::from(a) * i32::from(b)) >> 16) as i16, + ); } #[test] diff --git a/tests/all_tests/t_i32x4.rs b/tests/all_tests/t_i32x4.rs index fe839b0..3f588d6 100644 --- a/tests/all_tests/t_i32x4.rs +++ b/tests/all_tests/t_i32x4.rs @@ -266,3 +266,22 @@ fn impl_i32x4_shl_each() { |a, b| a.wrapping_shl(b as u32), ); } + +#[test] +fn impl_i32x4_mul_widen() { + let a = i32x4::from([1, 2, 3 * -1000000, i32::MAX]); + let b = i32x4::from([5, 6, 7 * -1000000, i32::MIN]); + let expected = i64x4::from([ + 1 * 5, + 2 * 6, + 3 * 7 * 1000000 * 1000000, + i32::MIN as i64 * i32::MAX as i64, + ]); + let actual = a.mul_widen(b); + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: i32x4, b| a.mul_widen(b), + |a, b| a as i64 * b as i64, + ); +} diff --git a/tests/all_tests/t_u16x8.rs b/tests/all_tests/t_u16x8.rs index 81b1358..69ef206 100644 --- a/tests/all_tests/t_u16x8.rs +++ b/tests/all_tests/t_u16x8.rs @@ -218,6 +218,31 @@ fn impl_u16x8_from_u8x16_high() { assert_eq!(expected, actual); } +#[test] +fn impl_u16x8_mul_keep_high() { + let a = u16x8::from([u16::MAX, 200, 300, 4568, 1, 2, 3, 200]); + let b = u16x8::from([u16::MAX, 600, 700, 8910, 15, 26, 37, 600]); + let c: [u16; 8] = u16x8::mul_keep_high(a, b).into(); + assert_eq!( + c, + [ + (u32::from(u16::MAX) * u32::from(u16::MAX) >> 16) as u16, + 1, + 3, + 621, + 0, + 0, + 0, + 1 + ] + ); + + crate::test_random_vector_vs_scalar( + |a: u16x8, b| u16x8::mul_keep_high(a, b), + |a, b| ((u32::from(a) * u32::from(b)) >> 16) as u16, + ); +} + #[test] fn impl_u16x8_mul_widen() { let a = u16x8::from([1, 2, 3, 4, 5, 6, i16::MAX as u16, u16::MAX]); diff --git a/tests/all_tests/t_u32x4.rs b/tests/all_tests/t_u32x4.rs index 1f10c31..f943fd7 100644 --- a/tests/all_tests/t_u32x4.rs +++ b/tests/all_tests/t_u32x4.rs @@ -1,4 +1,5 @@ use std::num::Wrapping; + use wide::*; #[test] @@ -234,3 +235,42 @@ fn test_u32x4_none() { let a = u32x4::from([0; 4]); assert!(a.none()); } + +#[test] +fn impl_u32x4_mul_widen() { + let a = u32x4::from([1, 2, 3 * 1000000, u32::MAX]); + let b = u32x4::from([5, 6, 7 * 1000000, u32::MAX]); + let expected = u64x4::from([ + 1 * 5, + 2 * 6, + 3 * 7 * 1000000 * 1000000, + u32::MAX as u64 * u32::MAX as u64, + ]); + let actual = a.mul_widen(b); + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x4, b| a.mul_widen(b), + |a, b| u64::from(a) * u64::from(b), + ); +} + +#[test] +fn impl_u32x4_mul_keep_high() { + let mul_high = |a: u32, b: u32| ((u64::from(a) * u64::from(b)) >> 32) as u32; + let a = u32x4::from([1, 2 * 10000000, 3 * 1000000, u32::MAX]); + let b = u32x4::from([5, 6 * 100, 7 * 1000000, u32::MAX]); + let expected = u32x4::from([ + mul_high(1, 5), + mul_high(2 * 10000000, 6 * 100), + mul_high(3 * 1000000, 7 * 1000000), + mul_high(u32::MAX, u32::MAX), + ]); + let actual = a.mul_keep_high(b); + assert_eq!(expected, actual); + + crate::test_random_vector_vs_scalar( + |a: u32x4, b| a.mul_keep_high(b), + |a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32, + ); +} diff --git a/tests/all_tests/t_u32x8.rs b/tests/all_tests/t_u32x8.rs index da6599a..72a5d85 100644 --- a/tests/all_tests/t_u32x8.rs +++ b/tests/all_tests/t_u32x8.rs @@ -295,3 +295,11 @@ fn test_u32x8_none() { let a = u32x8::from([0; 8]); assert!(a.none()); } + +#[test] +fn impl_u32x8_mul_keep_high() { + crate::test_random_vector_vs_scalar( + |a: u32x8, b| u32x8::mul_keep_high(a, b), + |a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32, + ); +} diff --git a/tests/all_tests/t_usefulness.rs b/tests/all_tests/t_usefulness.rs index 311bce2..506e6d5 100644 --- a/tests/all_tests/t_usefulness.rs +++ b/tests/all_tests/t_usefulness.rs @@ -391,22 +391,7 @@ fn generate_branch_free_divide_magic_shift(denom: u32x8) -> (u32x8, u32x8) { // using the previously generated magic and shift, calculate the division fn branch_free_divide(numerator: u32x8, magic: u32x8, shift: u32x8) -> u32x8 { - // Returns 32 high bits of the 64 bit result of multiplication of two u32s - let mul_hi = |a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32; - - let a = numerator.as_array_ref(); - let b = magic.as_array_ref(); - - let q = u32x8::from([ - mul_hi(a[0], b[0]), - mul_hi(a[1], b[1]), - mul_hi(a[2], b[2]), - mul_hi(a[3], b[3]), - mul_hi(a[4], b[4]), - mul_hi(a[5], b[5]), - mul_hi(a[6], b[6]), - mul_hi(a[7], b[7]), - ]); + let q = u32x8::mul_keep_high(numerator, magic); let t = ((numerator - q) >> 1) + q; t >> shift