Skip to content

Commit

Permalink
Added widening multiply for u16x8, u32x4, u32x8, i32x4 and i32x8 (#182)
Browse files Browse the repository at this point in the history
* add widening unsigned mul

* added wasm

* fix ordering of lanes in neon

* try again

* fix for avx2

* improve tests

* added plain widen_mul

* add mul_keep_high

* fix dumb stuff

* added comments

* make u32x8 version generic

* optimized avx2 of mul_keep_high

* make calling convention more consistant

* fix arm

* comments

* add WASM simd implementations... turns out they are easy
  • Loading branch information
mcroomp authored Nov 19, 2024
1 parent 37d7dca commit a53f5e1
Show file tree
Hide file tree
Showing 16 changed files with 367 additions and 55 deletions.
6 changes: 4 additions & 2 deletions src/f64x2_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,8 @@ impl f64x2 {
cast_mut(self)
}

/// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the higher two `i32` lanes)
/// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the
/// higher two `i32` lanes)
#[inline]
pub fn from_i32x4_lower2(v: i32x4) -> Self {
pick! {
Expand All @@ -1619,7 +1620,8 @@ impl f64x2 {
}

impl From<i32x4> for f64x2 {
/// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the higher two `i32` lanes)
/// Converts the lower two `i32` lanes to two `f64` lanes (and dropping the
/// higher two `i32` lanes)
#[inline]
fn from(v: i32x4) -> Self {
Self::from_i32x4_lower2(v)
Expand Down
56 changes: 52 additions & 4 deletions src/i32x4_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ impl_shr_t_for_i32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shr`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shr`)
impl Shr<i32x4> for i32x4 {
type Output = Self;

Expand Down Expand Up @@ -364,8 +364,8 @@ impl Shr<i32x4> for i32x4 {
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shl`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shl`)
impl Shl<i32x4> for i32x4 {
type Output = Self;

Expand Down Expand Up @@ -490,6 +490,54 @@ impl i32x4 {
}
}
}

/// Multiplies corresponding 32 bit lanes and returns the 64 bit result
/// on the corresponding lanes.
///
/// Effectively does two multiplies on 128 bit platforms, but is easier
/// to use than wrapping mul_widen_i32_odd_m128i individually.
#[inline]
#[must_use]
pub fn mul_widen(self, rhs: Self) -> i64x4 {
pick! {
if #[cfg(target_feature="avx2")] {
let a = convert_to_i64_m256i_from_i32_m128i(self.sse);
let b = convert_to_i64_m256i_from_i32_m128i(rhs.sse);
cast(mul_i64_low_bits_m256i(a, b))
} else if #[cfg(target_feature="sse4.1")] {
let evenp = mul_widen_i32_odd_m128i(self.sse, rhs.sse);

let oddp = mul_widen_i32_odd_m128i(
shr_imm_u64_m128i::<32>(self.sse),
shr_imm_u64_m128i::<32>(rhs.sse));

i64x4 {
a: i64x2 { sse: unpack_low_i64_m128i(evenp, oddp)},
b: i64x2 { sse: unpack_high_i64_m128i(evenp, oddp)}
}
} else if #[cfg(target_feature="simd128")] {
i64x4 {
a: i64x2 { simd: i64x2_extmul_low_i32x4(self.simd, rhs.simd) },
b: i64x2 { simd: i64x2_extmul_high_i32x4(self.simd, rhs.simd) },
}
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] {
unsafe {
i64x4 { a: i64x2 { neon: vmull_s32(vget_low_s32(self.neon), vget_low_s32(rhs.neon)) },
b: i64x2 { neon: vmull_s32(vget_high_s32(self.neon), vget_high_s32(rhs.neon)) } }
}
} else {
let a: [i32; 4] = cast(self);
let b: [i32; 4] = cast(rhs);
cast([
i64::from(a[0]) * i64::from(b[0]),
i64::from(a[1]) * i64::from(b[1]),
i64::from(a[2]) * i64::from(b[2]),
i64::from(a[3]) * i64::from(b[3]),
])
}
}
}

#[inline]
#[must_use]
pub fn abs(self) -> Self {
Expand Down
8 changes: 4 additions & 4 deletions src/i32x8_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ impl_shr_t_for_i32x8!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shr`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shr`)
impl Shr<i32x8> for i32x8 {
type Output = Self;

Expand All @@ -258,8 +258,8 @@ impl Shr<i32x8> for i32x8 {
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shl`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shl`)
impl Shl<i32x8> for i32x8 {
type Output = Self;

Expand Down
4 changes: 2 additions & 2 deletions src/i64x4_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ pick! {
if #[cfg(target_feature="avx2")] {
#[derive(Default, Clone, Copy, PartialEq, Eq)]
#[repr(C, align(32))]
pub struct i64x4 { avx2: m256i }
pub struct i64x4 { pub(crate) avx2: m256i }
} else {
#[derive(Default, Clone, Copy, PartialEq, Eq)]
#[repr(C, align(32))]
pub struct i64x4 { a : i64x2, b : i64x2 }
pub struct i64x4 { pub(crate) a : i64x2, pub(crate) b : i64x2 }
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/i8x16_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ impl i8x16 {
/// `rhs`.
///
/// * Index values in the range `[0, 15]` select the i-th element of `self`.
/// * Index values that are out of range will cause that output lane to be `0`.
/// * Index values that are out of range will cause that output lane to be
/// `0`.
#[inline]
pub fn swizzle(self, rhs: i8x16) -> i8x16 {
pick! {
Expand All @@ -727,7 +728,8 @@ impl i8x16 {
}
}

/// Works like [`swizzle`](Self::swizzle) with the following additional details
/// Works like [`swizzle`](Self::swizzle) with the following additional
/// details
///
/// * Indices in the range `[0, 15]` will select the i-th element of `self`.
/// * If the high bit of any index is set (meaning that the index is
Expand Down
27 changes: 15 additions & 12 deletions src/i8x32_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,14 @@ impl i8x32 {
!self.any()
}

/// Returns a new vector with lanes selected from the lanes of the first input vector
/// a specified in the second input vector `rhs`.
/// The indices i in range `[0, 15]` select the i-th element of `self`. For indices
/// outside of the range the resulting lane is `0`.
/// Returns a new vector with lanes selected from the lanes of the first input
/// vector a specified in the second input vector `rhs`.
/// The indices i in range `[0, 15]` select the i-th element of `self`. For
/// indices outside of the range the resulting lane is `0`.
///
/// This note that is the equivalent of two parallel swizzle operations on the two halves of the vector,
/// and the indexes each refer to the corresponding half.
/// This note that is the equivalent of two parallel swizzle operations on the
/// two halves of the vector, and the indexes each refer to the
/// corresponding half.
#[inline]
pub fn swizzle_half(self, rhs: i8x32) -> i8x32 {
pick! {
Expand All @@ -352,13 +353,15 @@ impl i8x32 {
}
}

/// Indices in the range `[0, 15]` will select the i-th element of `self`. If the high bit
/// of any element of `rhs` is set (negative) then the corresponding output
/// lane is guaranteed to be zero. Otherwise if the element of `rhs` is within the range `[32, 127]`
/// then the output lane is either `0` or `self[rhs[i] % 16]` depending on the implementation.
/// Indices in the range `[0, 15]` will select the i-th element of `self`. If
/// the high bit of any element of `rhs` is set (negative) then the
/// corresponding output lane is guaranteed to be zero. Otherwise if the
/// element of `rhs` is within the range `[32, 127]` then the output lane is
/// either `0` or `self[rhs[i] % 16]` depending on the implementation.
///
/// This is the equivalent to two parallel swizzle operations on the two halves of the vector,
/// and the indexes each refer to their corresponding half.
/// This is the equivalent to two parallel swizzle operations on the two
/// halves of the vector, and the indexes each refer to their corresponding
/// half.
#[inline]
pub fn swizzle_half_relaxed(self, rhs: i8x32) -> i8x32 {
pick! {
Expand Down
38 changes: 38 additions & 0 deletions src/u16x8_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,44 @@ impl u16x8 {
}
}

/// Multiples two `u16x8` and return the high part of intermediate `u32x8`
#[inline]
#[must_use]
pub fn mul_keep_high(self, rhs: Self) -> Self {
pick! {
if #[cfg(target_feature="sse2")] {
Self { sse: mul_u16_keep_high_m128i(self.sse, rhs.sse) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] {
let lhs_low = unsafe { vget_low_u16(self.neon) };
let rhs_low = unsafe { vget_low_u16(rhs.neon) };

let lhs_high = unsafe { vget_high_u16(self.neon) };
let rhs_high = unsafe { vget_high_u16(rhs.neon) };

let low = unsafe { vmull_u16(lhs_low, rhs_low) };
let high = unsafe { vmull_u16(lhs_high, rhs_high) };

u16x8 { neon: unsafe { vuzpq_u16(vreinterpretq_u16_u32(low), vreinterpretq_u16_u32(high)).1 } }
} else if #[cfg(target_feature="simd128")] {
let low = u32x4_extmul_low_u16x8(self.simd, rhs.simd);
let high = u32x4_extmul_high_u16x8(self.simd, rhs.simd);

Self { simd: u16x8_shuffle::<1, 3, 5, 7, 9, 11, 13, 15>(low, high) }
} else {
u16x8::new([
((u32::from(rhs.as_array_ref()[0]) * u32::from(self.as_array_ref()[0])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[1]) * u32::from(self.as_array_ref()[1])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[2]) * u32::from(self.as_array_ref()[2])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[3]) * u32::from(self.as_array_ref()[3])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[4]) * u32::from(self.as_array_ref()[4])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[5]) * u32::from(self.as_array_ref()[5])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[6]) * u32::from(self.as_array_ref()[6])) >> 16) as u16,
((u32::from(rhs.as_array_ref()[7]) * u32::from(self.as_array_ref()[7])) >> 16) as u16,
])
}
}
}

#[inline]
pub fn to_array(self) -> [u16; 8] {
cast(self)
Expand Down
109 changes: 104 additions & 5 deletions src/u32x4_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ impl_shr_t_for_u32x4!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-right; yields `self >> mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shr`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shr`)
impl Shr<u32x4> for u32x4 {
type Output = Self;
#[inline]
Expand Down Expand Up @@ -363,8 +363,8 @@ impl Shr<u32x4> for u32x4 {
/// Shifts lanes by the corresponding lane.
///
/// Bitwise shift-left; yields `self << mask(rhs)`, where mask removes any
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth of
/// the type. (same as `wrapping_shl`)
/// high-order bits of `rhs` that would cause the shift to exceed the bitwidth
/// of the type. (same as `wrapping_shl`)
impl Shl<u32x4> for u32x4 {
type Output = Self;
#[inline]
Expand Down Expand Up @@ -431,7 +431,7 @@ impl u32x4 {
Self { sse: cmp_gt_mask_i32_m128i((self ^ h).sse, (rhs ^ h).sse) }
} else if #[cfg(target_feature="simd128")] {
Self { simd: u32x4_gt(self.simd, rhs.simd) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] {
unsafe {Self { neon: vcgtq_u32(self.neon, rhs.neon) }}
} else {
Self { arr: [
Expand All @@ -450,6 +450,105 @@ impl u32x4 {
rhs.cmp_gt(self)
}

/// Multiplies 32x32 bit to 64 bit and then only keeps the high 32 bits of the
/// result. Useful for implementing divide constant value (see t_usefulness
/// example)
#[inline]
#[must_use]
pub fn mul_keep_high(self, rhs: Self) -> Self {
pick! {
if #[cfg(target_feature="avx2")] {
let a = convert_to_i64_m256i_from_u32_m128i(self.sse);
let b = convert_to_i64_m256i_from_u32_m128i(rhs.sse);
let r = mul_u64_low_bits_m256i(a, b);

// the compiler does a good job shuffling the lanes around
let b : [u32;8] = cast(r);
cast([b[1],b[3],b[5],b[7]])
} else if #[cfg(target_feature="sse2")] {
let evenp = mul_widen_u32_odd_m128i(self.sse, rhs.sse);

let oddp = mul_widen_u32_odd_m128i(
shr_imm_u64_m128i::<32>(self.sse),
shr_imm_u64_m128i::<32>(rhs.sse));

// the compiler does a good job shuffling the lanes around
let a : [u32;4]= cast(evenp);
let b : [u32;4]= cast(oddp);
cast([a[1],b[1],a[3],b[3]])

} else if #[cfg(target_feature="simd128")] {
let low = u64x2_extmul_low_u32x4(self.simd, rhs.simd);
let high = u64x2_extmul_high_u32x4(self.simd, rhs.simd);

Self { simd: u32x4_shuffle::<1, 3, 5, 7>(low, high) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] {
unsafe {
let l = vmull_u32(vget_low_u32(self.neon), vget_low_u32(rhs.neon));
let h = vmull_u32(vget_high_u32(self.neon), vget_high_u32(rhs.neon));
u32x4 { neon: vcombine_u32(vshrn_n_u64(l,32), vshrn_n_u64(h,32)) }
}
} else {
let a: [u32; 4] = cast(self);
let b: [u32; 4] = cast(rhs);
cast([
((u64::from(a[0]) * u64::from(b[0])) >> 32) as u32,
((u64::from(a[1]) * u64::from(b[1])) >> 32) as u32,
((u64::from(a[2]) * u64::from(b[2])) >> 32) as u32,
((u64::from(a[3]) * u64::from(b[3])) >> 32) as u32,
])
}
}
}

/// Multiplies corresponding 32 bit lanes and returns the 64 bit result
/// on the corresponding lanes.
///
/// Effectively does two multiplies on 128 bit platforms, but is easier
/// to use than wrapping mul_widen_u32_odd_m128i individually.
#[inline]
#[must_use]
pub fn mul_widen(self, rhs: Self) -> u64x4 {
pick! {
if #[cfg(target_feature="avx2")] {
// ok to sign extend since we are throwing away the high half of the result anyway
let a = convert_to_i64_m256i_from_i32_m128i(self.sse);
let b = convert_to_i64_m256i_from_i32_m128i(rhs.sse);
cast(mul_u64_low_bits_m256i(a, b))
} else if #[cfg(target_feature="sse2")] {
let evenp = mul_widen_u32_odd_m128i(self.sse, rhs.sse);

let oddp = mul_widen_u32_odd_m128i(
shr_imm_u64_m128i::<32>(self.sse),
shr_imm_u64_m128i::<32>(rhs.sse));

u64x4 {
a: u64x2 { sse: unpack_low_i64_m128i(evenp, oddp)},
b: u64x2 { sse: unpack_high_i64_m128i(evenp, oddp)}
}
} else if #[cfg(target_feature="simd128")] {
u64x4 {
a: u64x2 { simd: u64x2_extmul_low_u32x4(self.simd, rhs.simd) },
b: u64x2 { simd: u64x2_extmul_high_u32x4(self.simd, rhs.simd) },
}
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))] {
unsafe {
u64x4 { a: u64x2 { neon: vmull_u32(vget_low_u32(self.neon), vget_low_u32(rhs.neon)) },
b: u64x2 { neon: vmull_u32(vget_high_u32(self.neon), vget_high_u32(rhs.neon)) } }
}
} else {
let a: [u32; 4] = cast(self);
let b: [u32; 4] = cast(rhs);
cast([
u64::from(a[0]) * u64::from(b[0]),
u64::from(a[1]) * u64::from(b[1]),
u64::from(a[2]) * u64::from(b[2]),
u64::from(a[3]) * u64::from(b[3]),
])
}
}
}

#[inline]
#[must_use]
pub fn blend(self, t: Self, f: Self) -> Self {
Expand Down
Loading

0 comments on commit a53f5e1

Please sign in to comment.