summaryrefslogtreecommitdiff
path: root/Userland/Services/ChessEngine/MCTSTree.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Userland/Services/ChessEngine/MCTSTree.cpp')
-rw-r--r--Userland/Services/ChessEngine/MCTSTree.cpp37
1 files changed, 18 insertions, 19 deletions
diff --git a/Userland/Services/ChessEngine/MCTSTree.cpp b/Userland/Services/ChessEngine/MCTSTree.cpp
index 9e5284100a..e331784a9d 100644
--- a/Userland/Services/ChessEngine/MCTSTree.cpp
+++ b/Userland/Services/ChessEngine/MCTSTree.cpp
@@ -8,13 +8,12 @@
#include <AK/String.h>
#include <stdlib.h>
-MCTSTree::MCTSTree(const Chess::Board& board, double exploration_parameter, MCTSTree* parent)
+MCTSTree::MCTSTree(const Chess::Board& board, MCTSTree* parent)
: m_parent(parent)
- , m_exploration_parameter(exploration_parameter)
- , m_board(board)
+ , m_board(make<Chess::Board>(board))
+ , m_last_move(board.last_move())
+ , m_turn(board.turn())
{
- if (m_parent)
- m_eval_method = m_parent->eval_method();
}
MCTSTree& MCTSTree::select_leaf()
@@ -25,7 +24,7 @@ MCTSTree& MCTSTree::select_leaf()
MCTSTree* node = nullptr;
double max_uct = -double(INFINITY);
for (auto& child : m_children) {
- double uct = child.uct(m_board.turn());
+ double uct = child.uct(m_turn);
if (uct >= max_uct) {
max_uct = uct;
node = &child;
@@ -40,13 +39,15 @@ MCTSTree& MCTSTree::expand()
VERIFY(!expanded() || m_children.size() == 0);
if (!m_moves_generated) {
- m_board.generate_moves([&](Chess::Move move) {
- Chess::Board clone = m_board;
+ m_board->generate_moves([&](Chess::Move move) {
+ Chess::Board clone = *m_board;
clone.apply_move(move);
- m_children.append(make<MCTSTree>(clone, m_exploration_parameter, this));
+ m_children.append(make<MCTSTree>(clone, this));
return IterationDecision::Continue;
});
m_moves_generated = true;
+ if (m_children.size() != 0)
+ m_board = nullptr; // Release the board to save memory.
}
if (m_children.size() == 0) {
@@ -63,8 +64,7 @@ MCTSTree& MCTSTree::expand()
int MCTSTree::simulate_game() const
{
- VERIFY_NOT_REACHED();
- Chess::Board clone = m_board;
+ Chess::Board clone = *m_board;
while (!clone.game_finished()) {
clone.apply_move(clone.random_move());
}
@@ -73,10 +73,10 @@ int MCTSTree::simulate_game() const
int MCTSTree::heuristic() const
{
- if (m_board.game_finished())
- return m_board.game_score();
+ if (m_board->game_finished())
+ return m_board->game_score();
- double winchance = max(min(double(m_board.material_imbalance()) / 6, 1.0), -1.0);
+ double winchance = max(min(double(m_board->material_imbalance()) / 6, 1.0), -1.0);
double random = double(rand()) / RAND_MAX;
if (winchance >= random)
@@ -101,7 +101,7 @@ void MCTSTree::do_round()
auto& node = select_leaf().expand();
int result;
- if (m_eval_method == EvalMethod::Simulation) {
+ if constexpr (s_eval_method == EvalMethod::Simulation) {
result = node.simulate_game();
} else {
result = node.heuristic();
@@ -111,7 +111,7 @@ void MCTSTree::do_round()
Chess::Move MCTSTree::best_move() const
{
- int score_multiplier = (m_board.turn() == Chess::Color::White) ? 1 : -1;
+ int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1;
Chess::Move best_move = { { 0, 0 }, { 0, 0 } };
double best_score = -double(INFINITY);
@@ -119,8 +119,7 @@ Chess::Move MCTSTree::best_move() const
for (auto& node : m_children) {
double node_score = node.expected_value() * score_multiplier;
if (node_score >= best_score) {
- // The best move is the last move made in the child.
- best_move = node.m_board.moves()[node.m_board.moves().size() - 1];
+ best_move = node.m_last_move.value();
best_score = node_score;
}
}
@@ -143,7 +142,7 @@ double MCTSTree::uct(Chess::Color color) const
// Fun fact: Szepesvári was my data structures professor.
double expected = expected_value() * ((color == Chess::Color::White) ? 1 : -1);
- return expected + m_exploration_parameter * sqrt(log(m_parent->m_simulations) / m_simulations);
+ return expected + s_exploration_parameter * sqrt(log(m_parent->m_simulations) / m_simulations);
}
bool MCTSTree::expanded() const