diff options
author | Ali Mohammad Pur <ali.mpfard@gmail.com> | 2022-02-14 16:49:53 +0330 |
---|---|---|
committer | Idan Horowitz <idan.horowitz@gmail.com> | 2022-02-15 18:03:02 +0200 |
commit | a1cb2c371a72af9556a629fb4f89acee92626708 (patch) | |
tree | 9eaeb5057f5f4d2036879ac43c53b897406ae0a7 /AK/Trie.h | |
parent | 80e61985632b59acb10a9cf384c7942b2d9bd27d (diff) | |
download | serenity-a1cb2c371a72af9556a629fb4f89acee92626708.zip |
AK+Kernel: OOM-harden most parts of Trie
The only part of Unveil that can't handle OOM gracefully is the
String::formatted() use in the node metadata.
Diffstat (limited to 'AK/Trie.h')
-rw-r--r-- | AK/Trie.h | 171 |
1 files changed, 97 insertions, 74 deletions
@@ -6,6 +6,7 @@ #pragma once +#include <AK/Concepts.h> #include <AK/Forward.h> #include <AK/HashMap.h> #include <AK/OwnPtr.h> @@ -29,62 +30,6 @@ template<typename DeclaredBaseType, typename DefaultBaseType, typename ValueType class Trie { using BaseType = typename SubstituteIfVoid<DeclaredBaseType, DefaultBaseType>::Type; - class ConstIterator { - - public: - static ConstIterator end() { return {}; } - - bool operator==(const ConstIterator& other) const { return m_current_node == other.m_current_node; } - - const BaseType& operator*() const { return static_cast<const BaseType&>(*m_current_node); } - const BaseType* operator->() const { return static_cast<const BaseType*>(m_current_node); } - void operator++() { skip_to_next(); } - - explicit ConstIterator(const Trie& node) - { - m_current_node = &node; - // FIXME: Figure out how to OOM harden this iterator. - MUST(m_state.try_empend(false, node.m_children.begin(), node.m_children.end())); - } - - private: - void skip_to_next() - { - auto& current_state = m_state.last(); - if (current_state.did_generate_root) - ++current_state.it; - else - current_state.did_generate_root = true; - if (current_state.it == current_state.end) - return pop_and_get_next(); - - m_current_node = &*(*current_state.it).value; - - // FIXME: Figure out how to OOM harden this iterator. - MUST(m_state.try_empend(false, m_current_node->m_children.begin(), m_current_node->m_children.end())); - } - void pop_and_get_next() - { - m_state.take_last(); - if (m_state.is_empty()) { - m_current_node = nullptr; - return; - } - - skip_to_next(); - } - - ConstIterator() = default; - - struct State { - bool did_generate_root { false }; - typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType it; - typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType end; - }; - Vector<State> m_state; - const Trie* m_current_node { nullptr }; - }; - public: using MetadataType = MetadataT; @@ -127,48 +72,56 @@ public: Optional<MetadataType> metadata() const requires(!IsNullPointer<MetadataType>) { return m_metadata; } void set_metadata(MetadataType metadata) requires(!IsNullPointer<MetadataType>) { m_metadata = move(metadata); } const MetadataType& metadata_value() const requires(!IsNullPointer<MetadataType>) { return m_metadata.value(); } + MetadataType& metadata_value() requires(!IsNullPointer<MetadataType>) { return m_metadata.value(); } const ValueType& value() const { return m_value; } ValueType& value() { return m_value; } - Trie& ensure_child(ValueType value, Optional<MetadataType> metadata = {}) + ErrorOr<Trie*> ensure_child(ValueType value, Optional<MetadataType> metadata = {}) { auto it = m_children.find(value); if (it == m_children.end()) { - auto node = adopt_nonnull_own_or_enomem(new (nothrow) Trie(value, move(metadata))).release_value_but_fixme_should_propagate_errors(); + auto node = TRY(adopt_nonnull_own_or_enomem(new (nothrow) Trie(value, move(metadata)))); auto& node_ref = *node; - m_children.set(move(value), move(node)); - return static_cast<BaseType&>(node_ref); + TRY(m_children.try_set(move(value), move(node))); + return &static_cast<BaseType&>(node_ref); } auto& node_ref = *it->value; if (metadata.has_value()) node_ref.m_metadata = move(metadata); - return static_cast<BaseType&>(node_ref); + return &static_cast<BaseType&>(node_ref); } template<typename It, typename ProvideMetadataFunction> - BaseType& insert( + ErrorOr<BaseType*> insert( It& it, const It& end, MetadataType metadata, ProvideMetadataFunction provide_missing_metadata) requires(!IsNullPointer<MetadataType>) { Trie* last_root_node = &traverse_until_last_accessible_node(it, end); + auto invoke_provide_missing_metadata = [&]<typename... Ts>(Ts && ... args)->ErrorOr<Optional<MetadataType>> + { + if constexpr (SameAs<MetadataType, decltype(provide_missing_metadata(forward<Ts>(args)...))>) + return Optional<MetadataType>(provide_missing_metadata(forward<Ts>(args)...)); + else + return provide_missing_metadata(forward<Ts>(args)...); + }; for (; it != end; ++it) - last_root_node = static_cast<Trie*>(&last_root_node->ensure_child(*it, provide_missing_metadata(static_cast<BaseType&>(*last_root_node), it))); + last_root_node = static_cast<Trie*>(TRY(last_root_node->ensure_child(*it, TRY(invoke_provide_missing_metadata(static_cast<BaseType&>(*last_root_node), it))))); last_root_node->set_metadata(move(metadata)); - return static_cast<BaseType&>(*last_root_node); + return static_cast<BaseType*>(last_root_node); } template<typename It> - BaseType& insert(It& it, const It& end) requires(IsNullPointer<MetadataType>) + ErrorOr<BaseType*> insert(It& it, const It& end) requires(IsNullPointer<MetadataType>) { Trie* last_root_node = &traverse_until_last_accessible_node(it, end); for (; it != end; ++it) - last_root_node = static_cast<Trie*>(&last_root_node->ensure_child(*it, {})); - return static_cast<BaseType&>(*last_root_node); + last_root_node = static_cast<Trie*>(TRY(last_root_node->ensure_child(*it, {}))); + return static_cast<BaseType*>(last_root_node); } template<typename It, typename ProvideMetadataFunction> - BaseType& insert( + ErrorOr<BaseType*> insert( const It& begin, const It& end, MetadataType metadata, ProvideMetadataFunction provide_missing_metadata) requires(!IsNullPointer<MetadataType>) { auto it = begin; @@ -176,7 +129,7 @@ public: } template<typename It> - BaseType& insert(const It& begin, const It& end) requires(IsNullPointer<MetadataType>) + ErrorOr<BaseType*> insert(const It& begin, const It& end) requires(IsNullPointer<MetadataType>) { auto it = begin; return insert(it, end); @@ -185,21 +138,91 @@ public: HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>& children() { return m_children; } HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits> const& children() const { return m_children; } - ConstIterator begin() const { return ConstIterator(*this); } - ConstIterator end() const { return ConstIterator::end(); } + template<typename Fn> + ErrorOr<void> for_each_node_in_tree_order(Fn callback) const + { + struct State { + bool did_generate_root { false }; + typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType it; + typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType end; + }; + Vector<State> state; + TRY(state.try_empend(false, m_children.begin(), m_children.end())); + + auto invoke = [&](auto& current_node) -> ErrorOr<IterationDecision> { + if constexpr (VoidFunction<Fn, const BaseType&>) { + callback(static_cast<const BaseType&>(current_node)); + return IterationDecision::Continue; + } else if constexpr (IsSpecializationOf<decltype(callback(declval<const BaseType&>())), ErrorOr>) { + return callback(static_cast<const BaseType&>(current_node)); + } else if constexpr (IteratorFunction<Fn, const BaseType&>) { + return callback(static_cast<const BaseType&>(current_node)); + } else { + static_assert(DependentFalse<Fn>, "Invalid iterator function type signature"); + } + return IterationDecision::Continue; + }; + + for (auto* current_node = this; current_node != nullptr;) { + if (TRY(invoke(*current_node)) == IterationDecision::Break) + break; + TRY(skip_to_next_iterator(state, current_node)); + } + return {}; + } [[nodiscard]] bool is_empty() const { return m_children.is_empty(); } void clear() { m_children.clear(); } - BaseType deep_copy() + ErrorOr<BaseType> deep_copy() { - Trie root(m_value, m_metadata); + Trie root(m_value, TRY(copy_metadata(m_metadata))); for (auto& it : m_children) - root.m_children.set(it.key, adopt_nonnull_own_or_enomem(new (nothrow) Trie(it.value->deep_copy())).release_value_but_fixme_should_propagate_errors()); + TRY(root.m_children.try_set(it.key, TRY(adopt_nonnull_own_or_enomem(new (nothrow) Trie(TRY(it.value->deep_copy())))))); return static_cast<BaseType&&>(move(root)); } private: + static ErrorOr<Optional<MetadataType>> copy_metadata(Optional<MetadataType> const& metadata) + { + if (!metadata.has_value()) + return Optional<MetadataType> {}; + + if constexpr (requires(MetadataType t) { { t.copy() } -> SpecializationOf<ErrorOr>; }) + return Optional<MetadataType> { TRY(metadata->copy()) }; +#ifndef KERNEL + else + return Optional<MetadataType> { MetadataType(metadata.value()) }; +#endif + } + + static ErrorOr<void> skip_to_next_iterator(auto& state, auto& current_node) + { + auto& current_state = state.last(); + if (current_state.did_generate_root) + ++current_state.it; + else + current_state.did_generate_root = true; + + if (current_state.it == current_state.end) + return pop_and_get_next_iterator(state, current_node); + + current_node = &*(*current_state.it).value; + + TRY(state.try_empend(false, current_node->m_children.begin(), current_node->m_children.end())); + return {}; + } + + static ErrorOr<void> pop_and_get_next_iterator(auto& state, auto& current_node) + { + state.take_last(); + if (state.is_empty()) { + current_node = nullptr; + return {}; + } + return skip_to_next_iterator(state, current_node); + } + ValueType m_value; Optional<MetadataType> m_metadata; HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits> m_children; |