summaryrefslogtreecommitdiff
path: root/AK/SIMDExtras.h
diff options
context:
space:
mode:
authorStephan Unverwerth <s.unverwerth@serenityos.org>2021-12-31 00:38:38 +0100
committerAli Mohammad Pur <Ali.mpfard@gmail.com>2022-01-09 16:21:13 +0330
commit7adcdecc7bfbaf416f709645a01ebe45e0884705 (patch)
tree8ff989e550b17357711ce16c9af4df036473d3c3 /AK/SIMDExtras.h
parent75e31a4749d448c0468e2b7fc8197c384dfdca0f (diff)
downloadserenity-7adcdecc7bfbaf416f709645a01ebe45e0884705.zip
AK: Add SIMDExtras.h with SIMD related functions
Adds a header to AK with helper functions for writing vectorized code. Co-authored-by: Hendiadyoin <leon2002.la@gmail.com>
Diffstat (limited to 'AK/SIMDExtras.h')
-rw-r--r--AK/SIMDExtras.h146
1 files changed, 146 insertions, 0 deletions
diff --git a/AK/SIMDExtras.h b/AK/SIMDExtras.h
new file mode 100644
index 0000000000..4b4a406116
--- /dev/null
+++ b/AK/SIMDExtras.h
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) 2021, Stephan Unverwerth <s.unverwerth@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#pragma once
+
+#include <AK/SIMD.h>
+
+// Returning a vector on i686 target generates warning "psabi".
+// This prevents the CI, treating this as an error, from running to completion.
+#pragma GCC diagnostic push
+#pragma GCC diagnostic warning "-Wpsabi"
+
+namespace AK::SIMD {
+
+// SIMD Vector Expansion
+
+ALWAYS_INLINE static constexpr f32x4 expand4(float f)
+{
+ return f32x4 { f, f, f, f };
+}
+
+ALWAYS_INLINE static constexpr i32x4 expand4(i32 i)
+{
+ return i32x4 { i, i, i, i };
+}
+
+ALWAYS_INLINE static constexpr u32x4 expand4(u32 u)
+{
+ return u32x4 { u, u, u, u };
+}
+
+// Casting
+
+template<typename TSrc>
+ALWAYS_INLINE static u32x4 to_u32x4(TSrc v)
+{
+ return __builtin_convertvector(v, u32x4);
+}
+
+template<typename TSrc>
+ALWAYS_INLINE static i32x4 to_i32x4(TSrc v)
+{
+ return __builtin_convertvector(v, i32x4);
+}
+
+template<typename TSrc>
+ALWAYS_INLINE static f32x4 to_f32x4(TSrc v)
+{
+ return __builtin_convertvector(v, f32x4);
+}
+
+// Masking
+
+ALWAYS_INLINE static i32 maskbits(i32x4 mask)
+{
+#if defined(__SSE__)
+ return __builtin_ia32_movmskps((f32x4)mask);
+#else
+ return ((mask[0] & 0x80000000) >> 31) | ((mask[1] & 0x80000000) >> 30) | ((mask[2] & 0x80000000) >> 29) | ((mask[3] & 0x80000000) >> 28);
+#endif
+}
+
+ALWAYS_INLINE static bool all(i32x4 mask)
+{
+ return maskbits(mask) == 15;
+}
+
+ALWAYS_INLINE static bool any(i32x4 mask)
+{
+ return maskbits(mask) != 0;
+}
+
+ALWAYS_INLINE static bool none(i32x4 mask)
+{
+ return maskbits(mask) == 0;
+}
+
+ALWAYS_INLINE static int maskcount(i32x4 mask)
+{
+ constexpr static int count_lut[16] { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 };
+ return count_lut[maskbits(mask)];
+}
+
+// Load / Store
+
+ALWAYS_INLINE static f32x4 load4(float const* a, float const* b, float const* c, float const* d)
+{
+ return f32x4 { *a, *b, *c, *d };
+}
+
+ALWAYS_INLINE static u32x4 load4(u32 const* a, u32 const* b, u32 const* c, u32 const* d)
+{
+ return u32x4 { *a, *b, *c, *d };
+}
+
+ALWAYS_INLINE static f32x4 load4_masked(float const* a, float const* b, float const* c, float const* d, i32x4 mask)
+{
+ int bits = maskbits(mask);
+ return f32x4 {
+ bits & 1 ? *a : 0.f,
+ bits & 2 ? *b : 0.f,
+ bits & 4 ? *c : 0.f,
+ bits & 8 ? *d : 0.f,
+ };
+}
+
+ALWAYS_INLINE static u32x4 load4_masked(u32 const* a, u32 const* b, u32 const* c, u32 const* d, i32x4 mask)
+{
+ int bits = maskbits(mask);
+ return u32x4 {
+ bits & 1 ? *a : 0u,
+ bits & 2 ? *b : 0u,
+ bits & 4 ? *c : 0u,
+ bits & 8 ? *d : 0u,
+ };
+}
+
+template<typename VectorType, typename UnderlyingType = decltype(declval<VectorType>()[0])>
+ALWAYS_INLINE static void store4(VectorType v, UnderlyingType* a, UnderlyingType* b, UnderlyingType* c, UnderlyingType* d)
+{
+ *a = v[0];
+ *b = v[1];
+ *c = v[2];
+ *d = v[3];
+}
+
+template<typename VectorType, typename UnderlyingType = decltype(declval<VectorType>()[0])>
+ALWAYS_INLINE static void store4_masked(VectorType v, UnderlyingType* a, UnderlyingType* b, UnderlyingType* c, UnderlyingType* d, i32x4 mask)
+{
+ int bits = maskbits(mask);
+ if (bits & 1)
+ *a = v[0];
+ if (bits & 2)
+ *b = v[1];
+ if (bits & 4)
+ *c = v[2];
+ if (bits & 8)
+ *d = v[3];
+}
+
+#pragma GCC diagnostic pop
+
+}