summaryrefslogtreecommitdiff
path: root/Userland/Services/ChessEngine
diff options
context:
space:
mode:
Diffstat (limited to 'Userland/Services/ChessEngine')
-rw-r--r--Userland/Services/ChessEngine/ChessEngine.cpp14
-rw-r--r--Userland/Services/ChessEngine/ChessEngine.h2
-rw-r--r--Userland/Services/ChessEngine/MCTSTree.cpp36
-rw-r--r--Userland/Services/ChessEngine/MCTSTree.h7
4 files changed, 52 insertions, 7 deletions
diff --git a/Userland/Services/ChessEngine/ChessEngine.cpp b/Userland/Services/ChessEngine/ChessEngine.cpp
index 88b0aa89bc..4c92f4c221 100644
--- a/Userland/Services/ChessEngine/ChessEngine.cpp
+++ b/Userland/Services/ChessEngine/ChessEngine.cpp
@@ -38,7 +38,14 @@ void ChessEngine::handle_go(GoCommand const& command)
auto elapsed_time = Core::ElapsedTimer::start_new();
- MCTSTree mcts(m_board);
+ auto mcts = [this]() -> MCTSTree {
+ if (!m_last_tree.has_value())
+ return { m_board };
+ auto x = m_last_tree.value().child_with_move(m_board.last_move().value());
+ if (x.has_value())
+ return move(x.value());
+ return { m_board };
+ }();
int rounds = 0;
while (elapsed_time.elapsed() <= command.movetime.value()) {
@@ -47,7 +54,10 @@ void ChessEngine::handle_go(GoCommand const& command)
}
dbgln("MCTS finished {} rounds.", rounds);
dbgln("MCTS evaluation {}", mcts.expected_value());
- auto best_move = mcts.best_move();
+ auto& best_node = mcts.best_node();
+ auto const& best_move = best_node.last_move();
dbgln("MCTS best move {}", best_move.to_long_algebraic());
send_command(BestMoveCommand(best_move));
+
+ m_last_tree = move(best_node);
}
diff --git a/Userland/Services/ChessEngine/ChessEngine.h b/Userland/Services/ChessEngine/ChessEngine.h
index 94c61a86ae..cf2d2a2e49 100644
--- a/Userland/Services/ChessEngine/ChessEngine.h
+++ b/Userland/Services/ChessEngine/ChessEngine.h
@@ -6,6 +6,7 @@
#pragma once
+#include "MCTSTree.h"
#include <LibChess/Chess.h>
#include <LibChess/UCIEndpoint.h>
@@ -26,4 +27,5 @@ private:
}
Chess::Board m_board;
+ Optional<MCTSTree> m_last_tree;
};
diff --git a/Userland/Services/ChessEngine/MCTSTree.cpp b/Userland/Services/ChessEngine/MCTSTree.cpp
index 524b347104..bcd00747bc 100644
--- a/Userland/Services/ChessEngine/MCTSTree.cpp
+++ b/Userland/Services/ChessEngine/MCTSTree.cpp
@@ -16,6 +16,19 @@ MCTSTree::MCTSTree(Chess::Board const& board, MCTSTree* parent)
{
}
+MCTSTree::MCTSTree(MCTSTree&& other)
+ : m_children(move(other.m_children))
+ , m_parent(other.m_parent)
+ , m_white_points(other.m_white_points)
+ , m_simulations(other.m_simulations)
+ , m_board(move(other.m_board))
+ , m_last_move(move(other.m_last_move))
+ , m_turn(other.m_turn)
+ , m_moves_generated(other.m_moves_generated)
+{
+ other.m_parent = nullptr;
+}
+
MCTSTree& MCTSTree::select_leaf()
{
if (!expanded() || m_children.size() == 0)
@@ -117,22 +130,37 @@ void MCTSTree::do_round()
node.apply_result(result);
}
-Chess::Move MCTSTree::best_move() const
+Optional<MCTSTree&> MCTSTree::child_with_move(Chess::Move chess_move)
+{
+ for (auto& node : m_children) {
+ if (node.last_move() == chess_move)
+ return node;
+ }
+ return {};
+}
+
+MCTSTree& MCTSTree::best_node()
{
int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1;
- Chess::Move best_move = { { 0, 0 }, { 0, 0 } };
+ MCTSTree* best_node_ptr = nullptr;
double best_score = -double(INFINITY);
VERIFY(m_children.size());
for (auto& node : m_children) {
double node_score = node.expected_value() * score_multiplier;
if (node_score >= best_score) {
- best_move = node.m_last_move.value();
+ best_node_ptr = &node;
best_score = node_score;
}
}
+ VERIFY(best_node_ptr);
- return best_move;
+ return *best_node_ptr;
+}
+
+Chess::Move MCTSTree::last_move() const
+{
+ return m_last_move.value();
}
double MCTSTree::expected_value() const
diff --git a/Userland/Services/ChessEngine/MCTSTree.h b/Userland/Services/ChessEngine/MCTSTree.h
index 15ce2943c8..2bf48d6abc 100644
--- a/Userland/Services/ChessEngine/MCTSTree.h
+++ b/Userland/Services/ChessEngine/MCTSTree.h
@@ -20,6 +20,7 @@ public:
};
MCTSTree(Chess::Board const& board, MCTSTree* parent = nullptr);
+ MCTSTree(MCTSTree&&);
MCTSTree& select_leaf();
MCTSTree& expand();
@@ -28,7 +29,11 @@ public:
void apply_result(int game_score);
void do_round();
- Chess::Move best_move() const;
+ Optional<MCTSTree&> child_with_move(Chess::Move);
+
+ MCTSTree& best_node();
+
+ Chess::Move last_move() const;
double expected_value() const;
double uct(Chess::Color color) const;
bool expanded() const;