summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--AK/Function.h144
1 files changed, 131 insertions, 13 deletions
diff --git a/AK/Function.h b/AK/Function.h
index ca1ac2e1c8..5e4f639966 100644
--- a/AK/Function.h
+++ b/AK/Function.h
@@ -1,5 +1,6 @@
/*
* Copyright (C) 2016 Apple Inc. All rights reserved.
+ * Copyright (c) 2021, Gunnar Beutner <gbeutner@serenityos.org>
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -26,8 +27,12 @@
#pragma once
#include <AK/Assertions.h>
-#include <AK/OwnPtr.h>
+#include <AK/Atomic.h>
+#include <AK/BitCast.h>
+#include <AK/Noncopyable.h>
+#include <AK/ScopeGuard.h>
#include <AK/StdLibExtras.h>
+#include <AK/Types.h>
namespace AK {
@@ -36,46 +41,75 @@ class Function;
template<typename Out, typename... In>
class Function<Out(In...)> {
+ AK_MAKE_NONCOPYABLE(Function);
+
public:
Function() = default;
+ ~Function()
+ {
+ clear();
+ }
+
template<typename CallableType, class = typename EnableIf<!(IsPointer<CallableType> && IsFunction<RemovePointer<CallableType>>)&&IsRvalueReference<CallableType&&>>::Type>
Function(CallableType&& callable)
- : m_callable_wrapper(make<CallableWrapper<CallableType>>(move(callable)))
{
+ init_with_callable(move(callable));
}
template<typename FunctionType, class = typename EnableIf<IsPointer<FunctionType> && IsFunction<RemovePointer<FunctionType>>>::Type>
Function(FunctionType f)
- : m_callable_wrapper(make<CallableWrapper<FunctionType>>(move(f)))
{
+ init_with_callable(move(f));
+ }
+
+ Function(Function&& other)
+ {
+ move_from(move(other));
}
Out operator()(In... in) const
{
- VERIFY(m_callable_wrapper);
- return m_callable_wrapper->call(forward<In>(in)...);
+ auto* wrapper = callable_wrapper();
+ VERIFY(wrapper);
+ ++m_call_nesting_level;
+ ScopeGuard guard([this] {
+ --m_call_nesting_level;
+ });
+ return wrapper->call(forward<In>(in)...);
}
- explicit operator bool() const { return !!m_callable_wrapper; }
+ explicit operator bool() const { return !!callable_wrapper(); }
template<typename CallableType, class = typename EnableIf<!(IsPointer<CallableType> && IsFunction<RemovePointer<CallableType>>)&&IsRvalueReference<CallableType&&>>::Type>
Function& operator=(CallableType&& callable)
{
- m_callable_wrapper = make<CallableWrapper<CallableType>>(move(callable));
+ clear();
+ init_with_callable(move(callable));
return *this;
}
template<typename FunctionType, class = typename EnableIf<IsPointer<FunctionType> && IsFunction<RemovePointer<FunctionType>>>::Type>
Function& operator=(FunctionType f)
{
- m_callable_wrapper = make<CallableWrapper<FunctionType>>(move(f));
+ clear();
+ if (f)
+ init_with_callable(move(f));
return *this;
}
Function& operator=(std::nullptr_t)
{
- m_callable_wrapper = nullptr;
+ clear();
+ return *this;
+ }
+
+ Function& operator=(Function&& other)
+ {
+ if (this != &other) {
+ clear();
+ move_from(move(other));
+ }
return *this;
}
@@ -84,19 +118,21 @@ private:
public:
virtual ~CallableWrapperBase() = default;
virtual Out call(In...) const = 0;
+ virtual void destroy() = 0;
+ virtual void init_and_swap(u8*, size_t) = 0;
};
template<typename CallableType>
class CallableWrapper final : public CallableWrapperBase {
+ AK_MAKE_NONMOVABLE(CallableWrapper);
+ AK_MAKE_NONCOPYABLE(CallableWrapper);
+
public:
explicit CallableWrapper(CallableType&& callable)
: m_callable(move(callable))
{
}
- CallableWrapper(const CallableWrapper&) = delete;
- CallableWrapper& operator=(const CallableWrapper&) = delete;
-
Out call(In... in) const final override
{
if constexpr (requires { m_callable(forward<In>(in)...); }) {
@@ -110,11 +146,93 @@ private:
}
}
+ void destroy() final override
+ {
+ delete this;
+ }
+
+ void init_and_swap(u8* destination, size_t size) final override
+ {
+ VERIFY(size >= sizeof(CallableWrapper));
+ new (destination) CallableWrapper { move(m_callable) };
+ }
+
private:
CallableType m_callable;
};
- OwnPtr<CallableWrapperBase> m_callable_wrapper;
+ enum class FunctionKind {
+ NullPointer,
+ Inline,
+ Outline,
+ };
+
+ CallableWrapperBase* callable_wrapper() const
+ {
+ switch (m_kind) {
+ case FunctionKind::NullPointer:
+ return nullptr;
+ case FunctionKind::Inline:
+ return bit_cast<CallableWrapperBase*>(&m_storage);
+ case FunctionKind::Outline:
+ return *bit_cast<CallableWrapperBase**>(&m_storage);
+ default:
+ VERIFY_NOT_REACHED();
+ }
+ }
+
+ void clear()
+ {
+ auto* wrapper = callable_wrapper();
+ if (m_kind == FunctionKind::Inline) {
+ VERIFY(wrapper);
+ wrapper->~CallableWrapperBase();
+ } else if (m_kind == FunctionKind::Outline) {
+ VERIFY(wrapper);
+ wrapper->destroy();
+ }
+ m_kind = FunctionKind::NullPointer;
+ }
+
+ template<typename Callable>
+ void init_with_callable(Callable&& callable)
+ {
+ using WrapperType = CallableWrapper<Callable>;
+ if constexpr (sizeof(WrapperType) > inline_capacity) {
+ *bit_cast<CallableWrapperBase**>(&m_storage) = new WrapperType(move(callable));
+ m_kind = FunctionKind::Outline;
+ } else {
+ new (m_storage) WrapperType(move(callable));
+ m_kind = FunctionKind::Inline;
+ }
+ }
+
+ void move_from(Function&& other)
+ {
+ VERIFY(m_call_nesting_level == 0 && other.m_call_nesting_level == 0);
+ auto* other_wrapper = other.callable_wrapper();
+ switch (other.m_kind) {
+ case FunctionKind::NullPointer:
+ break;
+ case FunctionKind::Inline:
+ other_wrapper->init_and_swap(m_storage, inline_capacity);
+ m_kind = FunctionKind::Inline;
+ break;
+ case FunctionKind::Outline:
+ *bit_cast<CallableWrapperBase**>(&m_storage) = other_wrapper;
+ m_kind = FunctionKind::Outline;
+ break;
+ default:
+ VERIFY_NOT_REACHED();
+ }
+ other.m_kind = FunctionKind::NullPointer;
+ }
+
+ FunctionKind m_kind { FunctionKind::NullPointer };
+ mutable Atomic<u16> m_call_nesting_level { 0 };
+ // Empirically determined to fit most lambdas and functions.
+ static constexpr size_t inline_capacity = 4 * sizeof(void*);
+ alignas(max(alignof(CallableWrapperBase), alignof(CallableWrapperBase*))) u8 m_storage[inline_capacity];
};
}