diff options
Diffstat (limited to 'Userland/Services/ChessEngine')
-rw-r--r-- | Userland/Services/ChessEngine/ChessEngine.cpp | 14 | ||||
-rw-r--r-- | Userland/Services/ChessEngine/ChessEngine.h | 2 | ||||
-rw-r--r-- | Userland/Services/ChessEngine/MCTSTree.cpp | 36 | ||||
-rw-r--r-- | Userland/Services/ChessEngine/MCTSTree.h | 7 |
4 files changed, 52 insertions, 7 deletions
diff --git a/Userland/Services/ChessEngine/ChessEngine.cpp b/Userland/Services/ChessEngine/ChessEngine.cpp index 88b0aa89bc..4c92f4c221 100644 --- a/Userland/Services/ChessEngine/ChessEngine.cpp +++ b/Userland/Services/ChessEngine/ChessEngine.cpp @@ -38,7 +38,14 @@ void ChessEngine::handle_go(GoCommand const& command) auto elapsed_time = Core::ElapsedTimer::start_new(); - MCTSTree mcts(m_board); + auto mcts = [this]() -> MCTSTree { + if (!m_last_tree.has_value()) + return { m_board }; + auto x = m_last_tree.value().child_with_move(m_board.last_move().value()); + if (x.has_value()) + return move(x.value()); + return { m_board }; + }(); int rounds = 0; while (elapsed_time.elapsed() <= command.movetime.value()) { @@ -47,7 +54,10 @@ void ChessEngine::handle_go(GoCommand const& command) } dbgln("MCTS finished {} rounds.", rounds); dbgln("MCTS evaluation {}", mcts.expected_value()); - auto best_move = mcts.best_move(); + auto& best_node = mcts.best_node(); + auto const& best_move = best_node.last_move(); dbgln("MCTS best move {}", best_move.to_long_algebraic()); send_command(BestMoveCommand(best_move)); + + m_last_tree = move(best_node); } diff --git a/Userland/Services/ChessEngine/ChessEngine.h b/Userland/Services/ChessEngine/ChessEngine.h index 94c61a86ae..cf2d2a2e49 100644 --- a/Userland/Services/ChessEngine/ChessEngine.h +++ b/Userland/Services/ChessEngine/ChessEngine.h @@ -6,6 +6,7 @@ #pragma once +#include "MCTSTree.h" #include <LibChess/Chess.h> #include <LibChess/UCIEndpoint.h> @@ -26,4 +27,5 @@ private: } Chess::Board m_board; + Optional<MCTSTree> m_last_tree; }; diff --git a/Userland/Services/ChessEngine/MCTSTree.cpp b/Userland/Services/ChessEngine/MCTSTree.cpp index 524b347104..bcd00747bc 100644 --- a/Userland/Services/ChessEngine/MCTSTree.cpp +++ b/Userland/Services/ChessEngine/MCTSTree.cpp @@ -16,6 +16,19 @@ MCTSTree::MCTSTree(Chess::Board const& board, MCTSTree* parent) { } +MCTSTree::MCTSTree(MCTSTree&& other) + : m_children(move(other.m_children)) + , m_parent(other.m_parent) + , m_white_points(other.m_white_points) + , m_simulations(other.m_simulations) + , m_board(move(other.m_board)) + , m_last_move(move(other.m_last_move)) + , m_turn(other.m_turn) + , m_moves_generated(other.m_moves_generated) +{ + other.m_parent = nullptr; +} + MCTSTree& MCTSTree::select_leaf() { if (!expanded() || m_children.size() == 0) @@ -117,22 +130,37 @@ void MCTSTree::do_round() node.apply_result(result); } -Chess::Move MCTSTree::best_move() const +Optional<MCTSTree&> MCTSTree::child_with_move(Chess::Move chess_move) +{ + for (auto& node : m_children) { + if (node.last_move() == chess_move) + return node; + } + return {}; +} + +MCTSTree& MCTSTree::best_node() { int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1; - Chess::Move best_move = { { 0, 0 }, { 0, 0 } }; + MCTSTree* best_node_ptr = nullptr; double best_score = -double(INFINITY); VERIFY(m_children.size()); for (auto& node : m_children) { double node_score = node.expected_value() * score_multiplier; if (node_score >= best_score) { - best_move = node.m_last_move.value(); + best_node_ptr = &node; best_score = node_score; } } + VERIFY(best_node_ptr); - return best_move; + return *best_node_ptr; +} + +Chess::Move MCTSTree::last_move() const +{ + return m_last_move.value(); } double MCTSTree::expected_value() const diff --git a/Userland/Services/ChessEngine/MCTSTree.h b/Userland/Services/ChessEngine/MCTSTree.h index 15ce2943c8..2bf48d6abc 100644 --- a/Userland/Services/ChessEngine/MCTSTree.h +++ b/Userland/Services/ChessEngine/MCTSTree.h @@ -20,6 +20,7 @@ public: }; MCTSTree(Chess::Board const& board, MCTSTree* parent = nullptr); + MCTSTree(MCTSTree&&); MCTSTree& select_leaf(); MCTSTree& expand(); @@ -28,7 +29,11 @@ public: void apply_result(int game_score); void do_round(); - Chess::Move best_move() const; + Optional<MCTSTree&> child_with_move(Chess::Move); + + MCTSTree& best_node(); + + Chess::Move last_move() const; double expected_value() const; double uct(Chess::Color color) const; bool expanded() const; |