diff options
Diffstat (limited to 'Userland/Services/ChessEngine/MCTSTree.cpp')
-rw-r--r-- | Userland/Services/ChessEngine/MCTSTree.cpp | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/Userland/Services/ChessEngine/MCTSTree.cpp b/Userland/Services/ChessEngine/MCTSTree.cpp index 524b347104..bcd00747bc 100644 --- a/Userland/Services/ChessEngine/MCTSTree.cpp +++ b/Userland/Services/ChessEngine/MCTSTree.cpp @@ -16,6 +16,19 @@ MCTSTree::MCTSTree(Chess::Board const& board, MCTSTree* parent) { } +MCTSTree::MCTSTree(MCTSTree&& other) + : m_children(move(other.m_children)) + , m_parent(other.m_parent) + , m_white_points(other.m_white_points) + , m_simulations(other.m_simulations) + , m_board(move(other.m_board)) + , m_last_move(move(other.m_last_move)) + , m_turn(other.m_turn) + , m_moves_generated(other.m_moves_generated) +{ + other.m_parent = nullptr; +} + MCTSTree& MCTSTree::select_leaf() { if (!expanded() || m_children.size() == 0) @@ -117,22 +130,37 @@ void MCTSTree::do_round() node.apply_result(result); } -Chess::Move MCTSTree::best_move() const +Optional<MCTSTree&> MCTSTree::child_with_move(Chess::Move chess_move) +{ + for (auto& node : m_children) { + if (node.last_move() == chess_move) + return node; + } + return {}; +} + +MCTSTree& MCTSTree::best_node() { int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1; - Chess::Move best_move = { { 0, 0 }, { 0, 0 } }; + MCTSTree* best_node_ptr = nullptr; double best_score = -double(INFINITY); VERIFY(m_children.size()); for (auto& node : m_children) { double node_score = node.expected_value() * score_multiplier; if (node_score >= best_score) { - best_move = node.m_last_move.value(); + best_node_ptr = &node; best_score = node_score; } } + VERIFY(best_node_ptr); - return best_move; + return *best_node_ptr; +} + +Chess::Move MCTSTree::last_move() const +{ + return m_last_move.value(); } double MCTSTree::expected_value() const |