diff options
Diffstat (limited to 'Userland/Services')
-rw-r--r-- | Userland/Services/SQLServer/ConnectionFromClient.cpp | 6 | ||||
-rw-r--r-- | Userland/Services/SQLServer/DatabaseConnection.cpp | 60 | ||||
-rw-r--r-- | Userland/Services/SQLServer/DatabaseConnection.h | 11 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLClient.ipc | 3 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLServer.ipc | 4 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLStatement.cpp | 4 |
6 files changed, 31 insertions, 57 deletions
diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 95d54622da..3c358dd9ae 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -37,8 +37,10 @@ void ConnectionFromClient::die() Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedString const& database_name) { dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::connect(database_name: {})", database_name); - auto database_connection = DatabaseConnection::construct(database_name, client_id()); - return { database_connection->connection_id() }; + + if (auto database_connection = DatabaseConnection::create(database_name, client_id()); !database_connection.is_error()) + return { database_connection.value()->connection_id() }; + return { {} }; } void ConnectionFromClient::disconnect(u64 connection_id) diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp index ccea8d4061..46f3232000 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.cpp +++ b/Userland/Services/SQLServer/DatabaseConnection.cpp @@ -5,7 +5,6 @@ */ #include <AK/LexicalPath.h> -#include <SQLServer/ConnectionFromClient.h> #include <SQLServer/DatabaseConnection.h> #include <SQLServer/SQLStatement.h> @@ -22,64 +21,41 @@ RefPtr<DatabaseConnection> DatabaseConnection::connection_for(u64 connection_id) return nullptr; } -DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id) +ErrorOr<NonnullRefPtr<DatabaseConnection>> DatabaseConnection::create(DeprecatedString database_name, int client_id) +{ + if (LexicalPath path(database_name); (path.title() != database_name) || (path.dirname() != ".")) + return Error::from_string_view("Invalid database name"sv); + + auto database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", database_name)); + if (auto result = database->open(); result.is_error()) { + warnln("Could not open database: {}", result.error().error_string()); + return Error::from_string_view("Could not open database"sv); + } + + return adopt_nonnull_ref_or_enomem(new (nothrow) DatabaseConnection(move(database), move(database_name), client_id)); +} + +DatabaseConnection::DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id) : Object() + , m_database(move(database)) , m_database_name(move(database_name)) , m_connection_id(s_next_connection_id++) , m_client_id(client_id) { - if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) { - auto client_connection = ConnectionFromClient::client_connection_for(m_client_id); - client_connection->async_connection_error(m_connection_id, SQL::SQLErrorCode::InvalidDatabaseName, m_database_name); - return; - } - - dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name); + dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiatedconnection with database '{}'", connection_id(), m_database_name); s_connections.set(m_connection_id, *this); - - deferred_invoke([this]() { - m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name)); - auto client_connection = ConnectionFromClient::client_connection_for(m_client_id); - if (auto maybe_error = m_database->open(); maybe_error.is_error()) { - client_connection->async_connection_error(m_connection_id, maybe_error.error().error(), maybe_error.error().error_string()); - return; - } - m_accept_statements = true; - if (client_connection) - client_connection->async_connected(m_connection_id, m_database_name); - else - warnln("Cannot notify client of database connection. Client disconnected"); - }); } void DatabaseConnection::disconnect() { dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name); - m_accept_statements = false; - deferred_invoke([this]() { - m_database = nullptr; - s_connections.remove(m_connection_id); - auto client_connection = ConnectionFromClient::client_connection_for(client_id()); - if (client_connection) - client_connection->async_disconnected(m_connection_id); - else - warnln("Cannot notify client of database disconnection. Client disconnected"); - }); + s_connections.remove(connection_id()); } SQL::ResultOr<u64> DatabaseConnection::prepare_statement(StringView sql) { dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql); - if (!m_accept_statements) - return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable }; - - auto client_connection = ConnectionFromClient::client_connection_for(client_id()); - if (!client_connection) { - warnln("Cannot notify client of database disconnection. Client disconnected"); - return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv }; - } - auto statement = TRY(SQLStatement::create(*this, sql)); return statement->statement_id(); } diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h index 3ded1919bc..eea9acbfe2 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.h +++ b/Userland/Services/SQLServer/DatabaseConnection.h @@ -6,6 +6,7 @@ #pragma once +#include <AK/NonnullRefPtr.h> #include <LibCore/Object.h> #include <LibSQL/Database.h> #include <LibSQL/Result.h> @@ -14,26 +15,26 @@ namespace SQLServer { class DatabaseConnection final : public Core::Object { - C_OBJECT(DatabaseConnection) + C_OBJECT_ABSTRACT(DatabaseConnection) public: + static ErrorOr<NonnullRefPtr<DatabaseConnection>> create(DeprecatedString database_name, int client_id); ~DatabaseConnection() override = default; static RefPtr<DatabaseConnection> connection_for(u64 connection_id); u64 connection_id() const { return m_connection_id; } int client_id() const { return m_client_id; } - RefPtr<SQL::Database> database() { return m_database; } + NonnullRefPtr<SQL::Database> database() { return m_database; } void disconnect(); SQL::ResultOr<u64> prepare_statement(StringView sql); private: - DatabaseConnection(DeprecatedString database_name, int client_id); + DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id); - RefPtr<SQL::Database> m_database { nullptr }; + NonnullRefPtr<SQL::Database> m_database; DeprecatedString m_database_name; u64 m_connection_id { 0 }; int m_client_id { 0 }; - bool m_accept_statements { false }; }; } diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc index e5b87110f4..15d3533333 100644 --- a/Userland/Services/SQLServer/SQLClient.ipc +++ b/Userland/Services/SQLServer/SQLClient.ipc @@ -2,11 +2,8 @@ endpoint SQLClient { - connected(u64 connection_id, DeprecatedString connected_to_database) =| - connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =| execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =| next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> row) =| results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =| execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =| - disconnected(u64 connection_id) =| } diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc index ba06d5eb39..2a01ae22f6 100644 --- a/Userland/Services/SQLServer/SQLServer.ipc +++ b/Userland/Services/SQLServer/SQLServer.ipc @@ -2,8 +2,8 @@ endpoint SQLServer { - connect(DeprecatedString name) => (u64 connection_id) + connect(DeprecatedString name) => (Optional<u64> connection_id) prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id) execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) => (Optional<u64> execution_id) - disconnect(u64 connection_id) =| + disconnect(u64 connection_id) => () } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index cce1528f05..a1049bd835 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -74,9 +74,7 @@ Optional<u64> SQLStatement::execute(Vector<SQL::Value> placeholder_values) m_ongoing_executions.set(execution_id); deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { - VERIFY(!connection()->database().is_null()); - - auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values); + auto execution_result = m_statement->execute(connection()->database(), placeholder_values); m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { |