diff options
Diffstat (limited to 'Userland')
-rw-r--r-- | Userland/DevTools/SQLStudio/MainWidget.cpp | 11 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/SQLClient.cpp | 20 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/SQLClient.h | 6 | ||||
-rw-r--r-- | Userland/Services/SQLServer/ConnectionFromClient.cpp | 6 | ||||
-rw-r--r-- | Userland/Services/SQLServer/DatabaseConnection.cpp | 60 | ||||
-rw-r--r-- | Userland/Services/SQLServer/DatabaseConnection.h | 11 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLClient.ipc | 3 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLServer.ipc | 4 | ||||
-rw-r--r-- | Userland/Services/SQLServer/SQLStatement.cpp | 4 | ||||
-rw-r--r-- | Userland/Utilities/sql.cpp | 49 |
10 files changed, 57 insertions, 117 deletions
diff --git a/Userland/DevTools/SQLStudio/MainWidget.cpp b/Userland/DevTools/SQLStudio/MainWidget.cpp index 25f1f2f5b0..f47b5a8e9e 100644 --- a/Userland/DevTools/SQLStudio/MainWidget.cpp +++ b/Userland/DevTools/SQLStudio/MainWidget.cpp @@ -145,9 +145,16 @@ MainWidget::MainWidget() m_run_script_action = GUI::Action::create("Run script", { Mod_Alt, Key_F9 }, Gfx::Bitmap::try_load_from_file("/res/icons/16x16/play.png"sv).release_value_but_fixme_should_propagate_errors(), [&](auto&) { m_results.clear(); m_current_line_for_parsing = 0; + // TODO select the database to use in UI. - m_connection_id = m_sql_client->connect("test"); - read_next_sql_statement_of_editor(); + constexpr auto database_name = "Test"sv; + + if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) { + m_connection_id = connection_id.release_value(); + read_next_sql_statement_of_editor(); + } else { + warnln("\033[33;1mCould not connect to:\033[0m {}", database_name); + } }); auto& toolbar_container = add<GUI::ToolbarContainer>(); diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index 3b1bd35f44..4a5a2693ec 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -9,26 +9,6 @@ namespace SQL { -void SQLClient::connected(u64 connection_id, DeprecatedString const& connected_to_database) -{ - if (on_connected) - on_connected(connection_id, connected_to_database); -} - -void SQLClient::disconnected(u64 connection_id) -{ - if (on_disconnected) - on_disconnected(connection_id); -} - -void SQLClient::connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message) -{ - if (on_connection_error) - on_connection_error(connection_id, code, message); - else - warnln("Connection error for connection_id {}: {} ({})", connection_id, message, to_underlying(code)); -} - void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) { if (on_execution_error) diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h index 52ecb33375..ae45d852de 100644 --- a/Userland/Libraries/LibSQL/SQLClient.h +++ b/Userland/Libraries/LibSQL/SQLClient.h @@ -20,9 +20,6 @@ class SQLClient IPC_CLIENT_CONNECTION(SQLClient, "/tmp/session/%sid/portal/sql"sv) virtual ~SQLClient() = default; - Function<void(u64, DeprecatedString const&)> on_connected; - Function<void(u64)> on_disconnected; - Function<void(u64, SQLErrorCode, DeprecatedString const&)> on_connection_error; Function<void(u64, u64, SQLErrorCode, DeprecatedString const&)> on_execution_error; Function<void(u64, u64, bool, size_t, size_t, size_t)> on_execution_success; Function<void(u64, u64, Vector<DeprecatedString> const&)> on_next_result; @@ -34,13 +31,10 @@ private: { } - virtual void connected(u64 connection_id, DeprecatedString const& connected_to_database) override; - virtual void connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message) override; virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override; virtual void next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const&) override; virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override; virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; - virtual void disconnected(u64 connection_id) override; }; } diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 95d54622da..3c358dd9ae 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -37,8 +37,10 @@ void ConnectionFromClient::die() Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedString const& database_name) { dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::connect(database_name: {})", database_name); - auto database_connection = DatabaseConnection::construct(database_name, client_id()); - return { database_connection->connection_id() }; + + if (auto database_connection = DatabaseConnection::create(database_name, client_id()); !database_connection.is_error()) + return { database_connection.value()->connection_id() }; + return { {} }; } void ConnectionFromClient::disconnect(u64 connection_id) diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp index ccea8d4061..46f3232000 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.cpp +++ b/Userland/Services/SQLServer/DatabaseConnection.cpp @@ -5,7 +5,6 @@ */ #include <AK/LexicalPath.h> -#include <SQLServer/ConnectionFromClient.h> #include <SQLServer/DatabaseConnection.h> #include <SQLServer/SQLStatement.h> @@ -22,64 +21,41 @@ RefPtr<DatabaseConnection> DatabaseConnection::connection_for(u64 connection_id) return nullptr; } -DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id) +ErrorOr<NonnullRefPtr<DatabaseConnection>> DatabaseConnection::create(DeprecatedString database_name, int client_id) +{ + if (LexicalPath path(database_name); (path.title() != database_name) || (path.dirname() != ".")) + return Error::from_string_view("Invalid database name"sv); + + auto database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", database_name)); + if (auto result = database->open(); result.is_error()) { + warnln("Could not open database: {}", result.error().error_string()); + return Error::from_string_view("Could not open database"sv); + } + + return adopt_nonnull_ref_or_enomem(new (nothrow) DatabaseConnection(move(database), move(database_name), client_id)); +} + +DatabaseConnection::DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id) : Object() + , m_database(move(database)) , m_database_name(move(database_name)) , m_connection_id(s_next_connection_id++) , m_client_id(client_id) { - if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) { - auto client_connection = ConnectionFromClient::client_connection_for(m_client_id); - client_connection->async_connection_error(m_connection_id, SQL::SQLErrorCode::InvalidDatabaseName, m_database_name); - return; - } - - dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name); + dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiatedconnection with database '{}'", connection_id(), m_database_name); s_connections.set(m_connection_id, *this); - - deferred_invoke([this]() { - m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name)); - auto client_connection = ConnectionFromClient::client_connection_for(m_client_id); - if (auto maybe_error = m_database->open(); maybe_error.is_error()) { - client_connection->async_connection_error(m_connection_id, maybe_error.error().error(), maybe_error.error().error_string()); - return; - } - m_accept_statements = true; - if (client_connection) - client_connection->async_connected(m_connection_id, m_database_name); - else - warnln("Cannot notify client of database connection. Client disconnected"); - }); } void DatabaseConnection::disconnect() { dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name); - m_accept_statements = false; - deferred_invoke([this]() { - m_database = nullptr; - s_connections.remove(m_connection_id); - auto client_connection = ConnectionFromClient::client_connection_for(client_id()); - if (client_connection) - client_connection->async_disconnected(m_connection_id); - else - warnln("Cannot notify client of database disconnection. Client disconnected"); - }); + s_connections.remove(connection_id()); } SQL::ResultOr<u64> DatabaseConnection::prepare_statement(StringView sql) { dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql); - if (!m_accept_statements) - return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable }; - - auto client_connection = ConnectionFromClient::client_connection_for(client_id()); - if (!client_connection) { - warnln("Cannot notify client of database disconnection. Client disconnected"); - return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv }; - } - auto statement = TRY(SQLStatement::create(*this, sql)); return statement->statement_id(); } diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h index 3ded1919bc..eea9acbfe2 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.h +++ b/Userland/Services/SQLServer/DatabaseConnection.h @@ -6,6 +6,7 @@ #pragma once +#include <AK/NonnullRefPtr.h> #include <LibCore/Object.h> #include <LibSQL/Database.h> #include <LibSQL/Result.h> @@ -14,26 +15,26 @@ namespace SQLServer { class DatabaseConnection final : public Core::Object { - C_OBJECT(DatabaseConnection) + C_OBJECT_ABSTRACT(DatabaseConnection) public: + static ErrorOr<NonnullRefPtr<DatabaseConnection>> create(DeprecatedString database_name, int client_id); ~DatabaseConnection() override = default; static RefPtr<DatabaseConnection> connection_for(u64 connection_id); u64 connection_id() const { return m_connection_id; } int client_id() const { return m_client_id; } - RefPtr<SQL::Database> database() { return m_database; } + NonnullRefPtr<SQL::Database> database() { return m_database; } void disconnect(); SQL::ResultOr<u64> prepare_statement(StringView sql); private: - DatabaseConnection(DeprecatedString database_name, int client_id); + DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id); - RefPtr<SQL::Database> m_database { nullptr }; + NonnullRefPtr<SQL::Database> m_database; DeprecatedString m_database_name; u64 m_connection_id { 0 }; int m_client_id { 0 }; - bool m_accept_statements { false }; }; } diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc index e5b87110f4..15d3533333 100644 --- a/Userland/Services/SQLServer/SQLClient.ipc +++ b/Userland/Services/SQLServer/SQLClient.ipc @@ -2,11 +2,8 @@ endpoint SQLClient { - connected(u64 connection_id, DeprecatedString connected_to_database) =| - connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =| execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =| next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> row) =| results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =| execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =| - disconnected(u64 connection_id) =| } diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc index ba06d5eb39..2a01ae22f6 100644 --- a/Userland/Services/SQLServer/SQLServer.ipc +++ b/Userland/Services/SQLServer/SQLServer.ipc @@ -2,8 +2,8 @@ endpoint SQLServer { - connect(DeprecatedString name) => (u64 connection_id) + connect(DeprecatedString name) => (Optional<u64> connection_id) prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id) execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) => (Optional<u64> execution_id) - disconnect(u64 connection_id) =| + disconnect(u64 connection_id) => () } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index cce1528f05..a1049bd835 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -74,9 +74,7 @@ Optional<u64> SQLStatement::execute(Vector<SQL::Value> placeholder_values) m_ongoing_executions.set(execution_id); deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { - VERIFY(!connection()->database().is_null()); - - auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values); + auto execution_result = m_statement->execute(connection()->database(), placeholder_values); m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp index a0c046c9ac..6020c3f01a 100644 --- a/Userland/Utilities/sql.cpp +++ b/Userland/Utilities/sql.cpp @@ -76,14 +76,6 @@ public: m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors(); - m_sql_client->on_connected = [this](auto connection_id, auto const& connected_to_database) { - outln("Connected to \033[33;1m{}\033[0m", connected_to_database); - m_current_database = connected_to_database; - m_pending_database = ""; - m_connection_id = connection_id; - read_sql(); - }; - m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto updated, auto created, auto deleted) { if (updated != 0 || created != 0 || deleted != 0) { outln("{} row(s) updated, {} created, {} deleted", updated, created, deleted); @@ -104,27 +96,11 @@ public: read_sql(); }; - m_sql_client->on_connection_error = [this](auto, auto code, auto const& message) { - outln("\033[33;1mConnection error:\033[0m {}", message); - m_loop.quit(to_underlying(code)); - }; - m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) { outln("\033[33;1mExecution error:\033[0m {}", message); read_sql(); }; - m_sql_client->on_disconnected = [this](auto) { - if (m_pending_database.is_empty()) { - outln("Disconnected from \033[33;1m{}\033[0m and terminating", m_current_database); - m_loop.quit(0); - } else { - outln("Disconnected from \033[33;1m{}\033[0m", m_current_database); - m_current_database = ""; - m_sql_client->async_connect(m_pending_database); - } - }; - if (!database_name.is_empty()) connect(database_name); } @@ -136,11 +112,18 @@ public: void connect(DeprecatedString const& database_name) { - if (m_current_database.is_empty()) { - m_sql_client->async_connect(database_name); + if (!m_database_name.is_empty()) { + m_sql_client->disconnect(m_connection_id); + m_database_name = {}; + } + + if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) { + outln("Connected to \033[33;1m{}\033[0m", database_name); + m_database_name = database_name; + m_connection_id = *connection_id; } else { - m_pending_database = database_name; - m_sql_client->async_disconnect(m_connection_id); + warnln("\033[33;1mCould not connect to:\033[0m {}", database_name); + m_loop.quit(1); } } @@ -158,6 +141,7 @@ public: auto run() { + read_sql(); return m_loop.exec(); } @@ -166,8 +150,7 @@ private: RefPtr<Line::Editor> m_editor { nullptr }; int m_repl_line_level { 0 }; bool m_keep_running { true }; - DeprecatedString m_pending_database {}; - DeprecatedString m_current_database {}; + DeprecatedString m_database_name {}; AK::RefPtr<SQL::SQLClient> m_sql_client { nullptr }; u64 m_connection_id { 0 }; Core::EventLoop m_loop; @@ -280,7 +263,8 @@ private: // m_keep_running can be set to false when the file we are reading // from is exhausted... if (!m_keep_running) { - m_sql_client->async_disconnect(m_connection_id); + m_sql_client->disconnect(m_connection_id); + m_loop.quit(0); return; } @@ -296,7 +280,8 @@ private: // ...But m_keep_running can also be set to false by a command handler. if (!m_keep_running) { - m_sql_client->async_disconnect(m_connection_id); + m_sql_client->disconnect(m_connection_id); + m_loop.quit(0); return; } }; |