summaryrefslogtreecommitdiff
path: root/AK/Trie.h
diff options
context:
space:
mode:
authorAli Mohammad Pur <ali.mpfard@gmail.com>2022-02-14 16:49:53 +0330
committerIdan Horowitz <idan.horowitz@gmail.com>2022-02-15 18:03:02 +0200
commita1cb2c371a72af9556a629fb4f89acee92626708 (patch)
tree9eaeb5057f5f4d2036879ac43c53b897406ae0a7 /AK/Trie.h
parent80e61985632b59acb10a9cf384c7942b2d9bd27d (diff)
downloadserenity-a1cb2c371a72af9556a629fb4f89acee92626708.zip
AK+Kernel: OOM-harden most parts of Trie
The only part of Unveil that can't handle OOM gracefully is the String::formatted() use in the node metadata.
Diffstat (limited to 'AK/Trie.h')
-rw-r--r--AK/Trie.h171
1 files changed, 97 insertions, 74 deletions
diff --git a/AK/Trie.h b/AK/Trie.h
index 6436d7375c..8c60e2f724 100644
--- a/AK/Trie.h
+++ b/AK/Trie.h
@@ -6,6 +6,7 @@
#pragma once
+#include <AK/Concepts.h>
#include <AK/Forward.h>
#include <AK/HashMap.h>
#include <AK/OwnPtr.h>
@@ -29,62 +30,6 @@ template<typename DeclaredBaseType, typename DefaultBaseType, typename ValueType
class Trie {
using BaseType = typename SubstituteIfVoid<DeclaredBaseType, DefaultBaseType>::Type;
- class ConstIterator {
-
- public:
- static ConstIterator end() { return {}; }
-
- bool operator==(const ConstIterator& other) const { return m_current_node == other.m_current_node; }
-
- const BaseType& operator*() const { return static_cast<const BaseType&>(*m_current_node); }
- const BaseType* operator->() const { return static_cast<const BaseType*>(m_current_node); }
- void operator++() { skip_to_next(); }
-
- explicit ConstIterator(const Trie& node)
- {
- m_current_node = &node;
- // FIXME: Figure out how to OOM harden this iterator.
- MUST(m_state.try_empend(false, node.m_children.begin(), node.m_children.end()));
- }
-
- private:
- void skip_to_next()
- {
- auto& current_state = m_state.last();
- if (current_state.did_generate_root)
- ++current_state.it;
- else
- current_state.did_generate_root = true;
- if (current_state.it == current_state.end)
- return pop_and_get_next();
-
- m_current_node = &*(*current_state.it).value;
-
- // FIXME: Figure out how to OOM harden this iterator.
- MUST(m_state.try_empend(false, m_current_node->m_children.begin(), m_current_node->m_children.end()));
- }
- void pop_and_get_next()
- {
- m_state.take_last();
- if (m_state.is_empty()) {
- m_current_node = nullptr;
- return;
- }
-
- skip_to_next();
- }
-
- ConstIterator() = default;
-
- struct State {
- bool did_generate_root { false };
- typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType it;
- typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType end;
- };
- Vector<State> m_state;
- const Trie* m_current_node { nullptr };
- };
-
public:
using MetadataType = MetadataT;
@@ -127,48 +72,56 @@ public:
Optional<MetadataType> metadata() const requires(!IsNullPointer<MetadataType>) { return m_metadata; }
void set_metadata(MetadataType metadata) requires(!IsNullPointer<MetadataType>) { m_metadata = move(metadata); }
const MetadataType& metadata_value() const requires(!IsNullPointer<MetadataType>) { return m_metadata.value(); }
+ MetadataType& metadata_value() requires(!IsNullPointer<MetadataType>) { return m_metadata.value(); }
const ValueType& value() const { return m_value; }
ValueType& value() { return m_value; }
- Trie& ensure_child(ValueType value, Optional<MetadataType> metadata = {})
+ ErrorOr<Trie*> ensure_child(ValueType value, Optional<MetadataType> metadata = {})
{
auto it = m_children.find(value);
if (it == m_children.end()) {
- auto node = adopt_nonnull_own_or_enomem(new (nothrow) Trie(value, move(metadata))).release_value_but_fixme_should_propagate_errors();
+ auto node = TRY(adopt_nonnull_own_or_enomem(new (nothrow) Trie(value, move(metadata))));
auto& node_ref = *node;
- m_children.set(move(value), move(node));
- return static_cast<BaseType&>(node_ref);
+ TRY(m_children.try_set(move(value), move(node)));
+ return &static_cast<BaseType&>(node_ref);
}
auto& node_ref = *it->value;
if (metadata.has_value())
node_ref.m_metadata = move(metadata);
- return static_cast<BaseType&>(node_ref);
+ return &static_cast<BaseType&>(node_ref);
}
template<typename It, typename ProvideMetadataFunction>
- BaseType& insert(
+ ErrorOr<BaseType*> insert(
It& it, const It& end, MetadataType metadata, ProvideMetadataFunction provide_missing_metadata) requires(!IsNullPointer<MetadataType>)
{
Trie* last_root_node = &traverse_until_last_accessible_node(it, end);
+ auto invoke_provide_missing_metadata = [&]<typename... Ts>(Ts && ... args)->ErrorOr<Optional<MetadataType>>
+ {
+ if constexpr (SameAs<MetadataType, decltype(provide_missing_metadata(forward<Ts>(args)...))>)
+ return Optional<MetadataType>(provide_missing_metadata(forward<Ts>(args)...));
+ else
+ return provide_missing_metadata(forward<Ts>(args)...);
+ };
for (; it != end; ++it)
- last_root_node = static_cast<Trie*>(&last_root_node->ensure_child(*it, provide_missing_metadata(static_cast<BaseType&>(*last_root_node), it)));
+ last_root_node = static_cast<Trie*>(TRY(last_root_node->ensure_child(*it, TRY(invoke_provide_missing_metadata(static_cast<BaseType&>(*last_root_node), it)))));
last_root_node->set_metadata(move(metadata));
- return static_cast<BaseType&>(*last_root_node);
+ return static_cast<BaseType*>(last_root_node);
}
template<typename It>
- BaseType& insert(It& it, const It& end) requires(IsNullPointer<MetadataType>)
+ ErrorOr<BaseType*> insert(It& it, const It& end) requires(IsNullPointer<MetadataType>)
{
Trie* last_root_node = &traverse_until_last_accessible_node(it, end);
for (; it != end; ++it)
- last_root_node = static_cast<Trie*>(&last_root_node->ensure_child(*it, {}));
- return static_cast<BaseType&>(*last_root_node);
+ last_root_node = static_cast<Trie*>(TRY(last_root_node->ensure_child(*it, {})));
+ return static_cast<BaseType*>(last_root_node);
}
template<typename It, typename ProvideMetadataFunction>
- BaseType& insert(
+ ErrorOr<BaseType*> insert(
const It& begin, const It& end, MetadataType metadata, ProvideMetadataFunction provide_missing_metadata) requires(!IsNullPointer<MetadataType>)
{
auto it = begin;
@@ -176,7 +129,7 @@ public:
}
template<typename It>
- BaseType& insert(const It& begin, const It& end) requires(IsNullPointer<MetadataType>)
+ ErrorOr<BaseType*> insert(const It& begin, const It& end) requires(IsNullPointer<MetadataType>)
{
auto it = begin;
return insert(it, end);
@@ -185,21 +138,91 @@ public:
HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>& children() { return m_children; }
HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits> const& children() const { return m_children; }
- ConstIterator begin() const { return ConstIterator(*this); }
- ConstIterator end() const { return ConstIterator::end(); }
+ template<typename Fn>
+ ErrorOr<void> for_each_node_in_tree_order(Fn callback) const
+ {
+ struct State {
+ bool did_generate_root { false };
+ typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType it;
+ typename HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits>::ConstIteratorType end;
+ };
+ Vector<State> state;
+ TRY(state.try_empend(false, m_children.begin(), m_children.end()));
+
+ auto invoke = [&](auto& current_node) -> ErrorOr<IterationDecision> {
+ if constexpr (VoidFunction<Fn, const BaseType&>) {
+ callback(static_cast<const BaseType&>(current_node));
+ return IterationDecision::Continue;
+ } else if constexpr (IsSpecializationOf<decltype(callback(declval<const BaseType&>())), ErrorOr>) {
+ return callback(static_cast<const BaseType&>(current_node));
+ } else if constexpr (IteratorFunction<Fn, const BaseType&>) {
+ return callback(static_cast<const BaseType&>(current_node));
+ } else {
+ static_assert(DependentFalse<Fn>, "Invalid iterator function type signature");
+ }
+ return IterationDecision::Continue;
+ };
+
+ for (auto* current_node = this; current_node != nullptr;) {
+ if (TRY(invoke(*current_node)) == IterationDecision::Break)
+ break;
+ TRY(skip_to_next_iterator(state, current_node));
+ }
+ return {};
+ }
[[nodiscard]] bool is_empty() const { return m_children.is_empty(); }
void clear() { m_children.clear(); }
- BaseType deep_copy()
+ ErrorOr<BaseType> deep_copy()
{
- Trie root(m_value, m_metadata);
+ Trie root(m_value, TRY(copy_metadata(m_metadata)));
for (auto& it : m_children)
- root.m_children.set(it.key, adopt_nonnull_own_or_enomem(new (nothrow) Trie(it.value->deep_copy())).release_value_but_fixme_should_propagate_errors());
+ TRY(root.m_children.try_set(it.key, TRY(adopt_nonnull_own_or_enomem(new (nothrow) Trie(TRY(it.value->deep_copy()))))));
return static_cast<BaseType&&>(move(root));
}
private:
+ static ErrorOr<Optional<MetadataType>> copy_metadata(Optional<MetadataType> const& metadata)
+ {
+ if (!metadata.has_value())
+ return Optional<MetadataType> {};
+
+ if constexpr (requires(MetadataType t) { { t.copy() } -> SpecializationOf<ErrorOr>; })
+ return Optional<MetadataType> { TRY(metadata->copy()) };
+#ifndef KERNEL
+ else
+ return Optional<MetadataType> { MetadataType(metadata.value()) };
+#endif
+ }
+
+ static ErrorOr<void> skip_to_next_iterator(auto& state, auto& current_node)
+ {
+ auto& current_state = state.last();
+ if (current_state.did_generate_root)
+ ++current_state.it;
+ else
+ current_state.did_generate_root = true;
+
+ if (current_state.it == current_state.end)
+ return pop_and_get_next_iterator(state, current_node);
+
+ current_node = &*(*current_state.it).value;
+
+ TRY(state.try_empend(false, current_node->m_children.begin(), current_node->m_children.end()));
+ return {};
+ }
+
+ static ErrorOr<void> pop_and_get_next_iterator(auto& state, auto& current_node)
+ {
+ state.take_last();
+ if (state.is_empty()) {
+ current_node = nullptr;
+ return {};
+ }
+ return skip_to_next_iterator(state, current_node);
+ }
+
ValueType m_value;
Optional<MetadataType> m_metadata;
HashMap<ValueType, NonnullOwnPtr<Trie>, ValueTraits> m_children;