/* * Copyright (c) 2021, Jan de Visser * * SPDX-License-Identifier: BSD-2-Clause */ #include #include #include #include #include namespace SQLServer { static HashMap> s_statements; static u64 s_next_statement_id = 0; RefPtr SQLStatement::statement_for(u64 statement_id) { if (s_statements.contains(statement_id)) return *s_statements.get(statement_id).value(); dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id); return nullptr; } SQL::ResultOr> SQLStatement::create(DatabaseConnection& connection, StringView sql) { auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql)); auto statement = parser.next_statement(); if (parser.has_errors()) return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() }; return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement)))); } SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr statement) : Core::Object(&connection) , m_statement_id(s_next_statement_id++) , m_statement(move(statement)) { dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id()); s_statements.set(m_statement_id, *this); } void SQLStatement::report_error(SQL::Result result, u64 execution_id) { dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string()); auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); s_statements.remove(statement_id()); remove_from_parent(); if (client_connection) client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string()); else warnln("Cannot return execution error. Client disconnected"); m_result = {}; } Optional SQLStatement::execute(Vector placeholder_values) { dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id()); auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); return {}; } auto execution_id = m_next_execution_id++; m_ongoing_executions.set(execution_id); deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { auto execution_result = m_statement->execute(connection()->database(), placeholder_values); m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { report_error(execution_result.release_error(), execution_id); return; } auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); if (!client_connection) { warnln("Cannot return statement execution results. Client disconnected"); return; } m_result = execution_result.release_value(); if (should_send_result_rows()) { client_connection->async_execution_success(statement_id(), execution_id, true, 0, 0, 0); m_index = 0; next(execution_id); } else { client_connection->async_execution_success(statement_id(), execution_id, false, 0, m_result->size(), 0); } }); return execution_id; } bool SQLStatement::should_send_result_rows() const { VERIFY(m_result.has_value()); if (m_result->is_empty()) return false; switch (m_result->command()) { case SQL::SQLCommand::Describe: case SQL::SQLCommand::Select: return true; default: return false; } } void SQLStatement::next(u64 execution_id) { VERIFY(!m_result->is_empty()); auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); return; } if (m_index < m_result->size()) { auto& tuple = m_result->at(m_index++).row; client_connection->async_next_result(statement_id(), execution_id, tuple.to_deprecated_string_vector()); deferred_invoke([this, execution_id]() { next(execution_id); }); } else { client_connection->async_results_exhausted(statement_id(), execution_id, m_index); } } }