diff options
author | Stephan Unverwerth <s.unverwerth@serenityos.org> | 2021-12-31 00:38:38 +0100 |
---|---|---|
committer | Ali Mohammad Pur <Ali.mpfard@gmail.com> | 2022-01-09 16:21:13 +0330 |
commit | 7adcdecc7bfbaf416f709645a01ebe45e0884705 (patch) | |
tree | 8ff989e550b17357711ce16c9af4df036473d3c3 /AK/SIMDExtras.h | |
parent | 75e31a4749d448c0468e2b7fc8197c384dfdca0f (diff) | |
download | serenity-7adcdecc7bfbaf416f709645a01ebe45e0884705.zip |
AK: Add SIMDExtras.h with SIMD related functions
Adds a header to AK with helper functions for writing vectorized code.
Co-authored-by: Hendiadyoin <leon2002.la@gmail.com>
Diffstat (limited to 'AK/SIMDExtras.h')
-rw-r--r-- | AK/SIMDExtras.h | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/AK/SIMDExtras.h b/AK/SIMDExtras.h new file mode 100644 index 0000000000..4b4a406116 --- /dev/null +++ b/AK/SIMDExtras.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021, Stephan Unverwerth <s.unverwerth@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include <AK/SIMD.h> + +// Returning a vector on i686 target generates warning "psabi". +// This prevents the CI, treating this as an error, from running to completion. +#pragma GCC diagnostic push +#pragma GCC diagnostic warning "-Wpsabi" + +namespace AK::SIMD { + +// SIMD Vector Expansion + +ALWAYS_INLINE static constexpr f32x4 expand4(float f) +{ + return f32x4 { f, f, f, f }; +} + +ALWAYS_INLINE static constexpr i32x4 expand4(i32 i) +{ + return i32x4 { i, i, i, i }; +} + +ALWAYS_INLINE static constexpr u32x4 expand4(u32 u) +{ + return u32x4 { u, u, u, u }; +} + +// Casting + +template<typename TSrc> +ALWAYS_INLINE static u32x4 to_u32x4(TSrc v) +{ + return __builtin_convertvector(v, u32x4); +} + +template<typename TSrc> +ALWAYS_INLINE static i32x4 to_i32x4(TSrc v) +{ + return __builtin_convertvector(v, i32x4); +} + +template<typename TSrc> +ALWAYS_INLINE static f32x4 to_f32x4(TSrc v) +{ + return __builtin_convertvector(v, f32x4); +} + +// Masking + +ALWAYS_INLINE static i32 maskbits(i32x4 mask) +{ +#if defined(__SSE__) + return __builtin_ia32_movmskps((f32x4)mask); +#else + return ((mask[0] & 0x80000000) >> 31) | ((mask[1] & 0x80000000) >> 30) | ((mask[2] & 0x80000000) >> 29) | ((mask[3] & 0x80000000) >> 28); +#endif +} + +ALWAYS_INLINE static bool all(i32x4 mask) +{ + return maskbits(mask) == 15; +} + +ALWAYS_INLINE static bool any(i32x4 mask) +{ + return maskbits(mask) != 0; +} + +ALWAYS_INLINE static bool none(i32x4 mask) +{ + return maskbits(mask) == 0; +} + +ALWAYS_INLINE static int maskcount(i32x4 mask) +{ + constexpr static int count_lut[16] { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 }; + return count_lut[maskbits(mask)]; +} + +// Load / Store + +ALWAYS_INLINE static f32x4 load4(float const* a, float const* b, float const* c, float const* d) +{ + return f32x4 { *a, *b, *c, *d }; +} + +ALWAYS_INLINE static u32x4 load4(u32 const* a, u32 const* b, u32 const* c, u32 const* d) +{ + return u32x4 { *a, *b, *c, *d }; +} + +ALWAYS_INLINE static f32x4 load4_masked(float const* a, float const* b, float const* c, float const* d, i32x4 mask) +{ + int bits = maskbits(mask); + return f32x4 { + bits & 1 ? *a : 0.f, + bits & 2 ? *b : 0.f, + bits & 4 ? *c : 0.f, + bits & 8 ? *d : 0.f, + }; +} + +ALWAYS_INLINE static u32x4 load4_masked(u32 const* a, u32 const* b, u32 const* c, u32 const* d, i32x4 mask) +{ + int bits = maskbits(mask); + return u32x4 { + bits & 1 ? *a : 0u, + bits & 2 ? *b : 0u, + bits & 4 ? *c : 0u, + bits & 8 ? *d : 0u, + }; +} + +template<typename VectorType, typename UnderlyingType = decltype(declval<VectorType>()[0])> +ALWAYS_INLINE static void store4(VectorType v, UnderlyingType* a, UnderlyingType* b, UnderlyingType* c, UnderlyingType* d) +{ + *a = v[0]; + *b = v[1]; + *c = v[2]; + *d = v[3]; +} + +template<typename VectorType, typename UnderlyingType = decltype(declval<VectorType>()[0])> +ALWAYS_INLINE static void store4_masked(VectorType v, UnderlyingType* a, UnderlyingType* b, UnderlyingType* c, UnderlyingType* d, i32x4 mask) +{ + int bits = maskbits(mask); + if (bits & 1) + *a = v[0]; + if (bits & 2) + *b = v[1]; + if (bits & 4) + *c = v[2]; + if (bits & 8) + *d = v[3]; +} + +#pragma GCC diagnostic pop + +} |