summaryrefslogtreecommitdiff
path: root/Userland
diff options
context:
space:
mode:
Diffstat (limited to 'Userland')
-rw-r--r--Userland/DevTools/SQLStudio/MainWidget.cpp11
-rw-r--r--Userland/Libraries/LibSQL/SQLClient.cpp20
-rw-r--r--Userland/Libraries/LibSQL/SQLClient.h6
-rw-r--r--Userland/Services/SQLServer/ConnectionFromClient.cpp6
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.cpp60
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.h11
-rw-r--r--Userland/Services/SQLServer/SQLClient.ipc3
-rw-r--r--Userland/Services/SQLServer/SQLServer.ipc4
-rw-r--r--Userland/Services/SQLServer/SQLStatement.cpp4
-rw-r--r--Userland/Utilities/sql.cpp49
10 files changed, 57 insertions, 117 deletions
diff --git a/Userland/DevTools/SQLStudio/MainWidget.cpp b/Userland/DevTools/SQLStudio/MainWidget.cpp
index 25f1f2f5b0..f47b5a8e9e 100644
--- a/Userland/DevTools/SQLStudio/MainWidget.cpp
+++ b/Userland/DevTools/SQLStudio/MainWidget.cpp
@@ -145,9 +145,16 @@ MainWidget::MainWidget()
m_run_script_action = GUI::Action::create("Run script", { Mod_Alt, Key_F9 }, Gfx::Bitmap::try_load_from_file("/res/icons/16x16/play.png"sv).release_value_but_fixme_should_propagate_errors(), [&](auto&) {
m_results.clear();
m_current_line_for_parsing = 0;
+
// TODO select the database to use in UI.
- m_connection_id = m_sql_client->connect("test");
- read_next_sql_statement_of_editor();
+ constexpr auto database_name = "Test"sv;
+
+ if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) {
+ m_connection_id = connection_id.release_value();
+ read_next_sql_statement_of_editor();
+ } else {
+ warnln("\033[33;1mCould not connect to:\033[0m {}", database_name);
+ }
});
auto& toolbar_container = add<GUI::ToolbarContainer>();
diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp
index 3b1bd35f44..4a5a2693ec 100644
--- a/Userland/Libraries/LibSQL/SQLClient.cpp
+++ b/Userland/Libraries/LibSQL/SQLClient.cpp
@@ -9,26 +9,6 @@
namespace SQL {
-void SQLClient::connected(u64 connection_id, DeprecatedString const& connected_to_database)
-{
- if (on_connected)
- on_connected(connection_id, connected_to_database);
-}
-
-void SQLClient::disconnected(u64 connection_id)
-{
- if (on_disconnected)
- on_disconnected(connection_id);
-}
-
-void SQLClient::connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message)
-{
- if (on_connection_error)
- on_connection_error(connection_id, code, message);
- else
- warnln("Connection error for connection_id {}: {} ({})", connection_id, message, to_underlying(code));
-}
-
void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message)
{
if (on_execution_error)
diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h
index 52ecb33375..ae45d852de 100644
--- a/Userland/Libraries/LibSQL/SQLClient.h
+++ b/Userland/Libraries/LibSQL/SQLClient.h
@@ -20,9 +20,6 @@ class SQLClient
IPC_CLIENT_CONNECTION(SQLClient, "/tmp/session/%sid/portal/sql"sv)
virtual ~SQLClient() = default;
- Function<void(u64, DeprecatedString const&)> on_connected;
- Function<void(u64)> on_disconnected;
- Function<void(u64, SQLErrorCode, DeprecatedString const&)> on_connection_error;
Function<void(u64, u64, SQLErrorCode, DeprecatedString const&)> on_execution_error;
Function<void(u64, u64, bool, size_t, size_t, size_t)> on_execution_success;
Function<void(u64, u64, Vector<DeprecatedString> const&)> on_next_result;
@@ -34,13 +31,10 @@ private:
{
}
- virtual void connected(u64 connection_id, DeprecatedString const& connected_to_database) override;
- virtual void connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message) override;
virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override;
virtual void next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const&) override;
virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override;
virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override;
- virtual void disconnected(u64 connection_id) override;
};
}
diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp
index 95d54622da..3c358dd9ae 100644
--- a/Userland/Services/SQLServer/ConnectionFromClient.cpp
+++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp
@@ -37,8 +37,10 @@ void ConnectionFromClient::die()
Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedString const& database_name)
{
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::connect(database_name: {})", database_name);
- auto database_connection = DatabaseConnection::construct(database_name, client_id());
- return { database_connection->connection_id() };
+
+ if (auto database_connection = DatabaseConnection::create(database_name, client_id()); !database_connection.is_error())
+ return { database_connection.value()->connection_id() };
+ return { {} };
}
void ConnectionFromClient::disconnect(u64 connection_id)
diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp
index ccea8d4061..46f3232000 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.cpp
+++ b/Userland/Services/SQLServer/DatabaseConnection.cpp
@@ -5,7 +5,6 @@
*/
#include <AK/LexicalPath.h>
-#include <SQLServer/ConnectionFromClient.h>
#include <SQLServer/DatabaseConnection.h>
#include <SQLServer/SQLStatement.h>
@@ -22,64 +21,41 @@ RefPtr<DatabaseConnection> DatabaseConnection::connection_for(u64 connection_id)
return nullptr;
}
-DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id)
+ErrorOr<NonnullRefPtr<DatabaseConnection>> DatabaseConnection::create(DeprecatedString database_name, int client_id)
+{
+ if (LexicalPath path(database_name); (path.title() != database_name) || (path.dirname() != "."))
+ return Error::from_string_view("Invalid database name"sv);
+
+ auto database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", database_name));
+ if (auto result = database->open(); result.is_error()) {
+ warnln("Could not open database: {}", result.error().error_string());
+ return Error::from_string_view("Could not open database"sv);
+ }
+
+ return adopt_nonnull_ref_or_enomem(new (nothrow) DatabaseConnection(move(database), move(database_name), client_id));
+}
+
+DatabaseConnection::DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id)
: Object()
+ , m_database(move(database))
, m_database_name(move(database_name))
, m_connection_id(s_next_connection_id++)
, m_client_id(client_id)
{
- if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) {
- auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
- client_connection->async_connection_error(m_connection_id, SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
- return;
- }
-
- dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name);
+ dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiatedconnection with database '{}'", connection_id(), m_database_name);
s_connections.set(m_connection_id, *this);
-
- deferred_invoke([this]() {
- m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name));
- auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
- if (auto maybe_error = m_database->open(); maybe_error.is_error()) {
- client_connection->async_connection_error(m_connection_id, maybe_error.error().error(), maybe_error.error().error_string());
- return;
- }
- m_accept_statements = true;
- if (client_connection)
- client_connection->async_connected(m_connection_id, m_database_name);
- else
- warnln("Cannot notify client of database connection. Client disconnected");
- });
}
void DatabaseConnection::disconnect()
{
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name);
- m_accept_statements = false;
- deferred_invoke([this]() {
- m_database = nullptr;
- s_connections.remove(m_connection_id);
- auto client_connection = ConnectionFromClient::client_connection_for(client_id());
- if (client_connection)
- client_connection->async_disconnected(m_connection_id);
- else
- warnln("Cannot notify client of database disconnection. Client disconnected");
- });
+ s_connections.remove(connection_id());
}
SQL::ResultOr<u64> DatabaseConnection::prepare_statement(StringView sql)
{
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
- if (!m_accept_statements)
- return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable };
-
- auto client_connection = ConnectionFromClient::client_connection_for(client_id());
- if (!client_connection) {
- warnln("Cannot notify client of database disconnection. Client disconnected");
- return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv };
- }
-
auto statement = TRY(SQLStatement::create(*this, sql));
return statement->statement_id();
}
diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h
index 3ded1919bc..eea9acbfe2 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.h
+++ b/Userland/Services/SQLServer/DatabaseConnection.h
@@ -6,6 +6,7 @@
#pragma once
+#include <AK/NonnullRefPtr.h>
#include <LibCore/Object.h>
#include <LibSQL/Database.h>
#include <LibSQL/Result.h>
@@ -14,26 +15,26 @@
namespace SQLServer {
class DatabaseConnection final : public Core::Object {
- C_OBJECT(DatabaseConnection)
+ C_OBJECT_ABSTRACT(DatabaseConnection)
public:
+ static ErrorOr<NonnullRefPtr<DatabaseConnection>> create(DeprecatedString database_name, int client_id);
~DatabaseConnection() override = default;
static RefPtr<DatabaseConnection> connection_for(u64 connection_id);
u64 connection_id() const { return m_connection_id; }
int client_id() const { return m_client_id; }
- RefPtr<SQL::Database> database() { return m_database; }
+ NonnullRefPtr<SQL::Database> database() { return m_database; }
void disconnect();
SQL::ResultOr<u64> prepare_statement(StringView sql);
private:
- DatabaseConnection(DeprecatedString database_name, int client_id);
+ DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id);
- RefPtr<SQL::Database> m_database { nullptr };
+ NonnullRefPtr<SQL::Database> m_database;
DeprecatedString m_database_name;
u64 m_connection_id { 0 };
int m_client_id { 0 };
- bool m_accept_statements { false };
};
}
diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc
index e5b87110f4..15d3533333 100644
--- a/Userland/Services/SQLServer/SQLClient.ipc
+++ b/Userland/Services/SQLServer/SQLClient.ipc
@@ -2,11 +2,8 @@
endpoint SQLClient
{
- connected(u64 connection_id, DeprecatedString connected_to_database) =|
- connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =|
execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =|
next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> row) =|
results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =|
execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =|
- disconnected(u64 connection_id) =|
}
diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc
index ba06d5eb39..2a01ae22f6 100644
--- a/Userland/Services/SQLServer/SQLServer.ipc
+++ b/Userland/Services/SQLServer/SQLServer.ipc
@@ -2,8 +2,8 @@
endpoint SQLServer
{
- connect(DeprecatedString name) => (u64 connection_id)
+ connect(DeprecatedString name) => (Optional<u64> connection_id)
prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id)
execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) => (Optional<u64> execution_id)
- disconnect(u64 connection_id) =|
+ disconnect(u64 connection_id) => ()
}
diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp
index cce1528f05..a1049bd835 100644
--- a/Userland/Services/SQLServer/SQLStatement.cpp
+++ b/Userland/Services/SQLServer/SQLStatement.cpp
@@ -74,9 +74,7 @@ Optional<u64> SQLStatement::execute(Vector<SQL::Value> placeholder_values)
m_ongoing_executions.set(execution_id);
deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] {
- VERIFY(!connection()->database().is_null());
-
- auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values);
+ auto execution_result = m_statement->execute(connection()->database(), placeholder_values);
m_ongoing_executions.remove(execution_id);
if (execution_result.is_error()) {
diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp
index a0c046c9ac..6020c3f01a 100644
--- a/Userland/Utilities/sql.cpp
+++ b/Userland/Utilities/sql.cpp
@@ -76,14 +76,6 @@ public:
m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors();
- m_sql_client->on_connected = [this](auto connection_id, auto const& connected_to_database) {
- outln("Connected to \033[33;1m{}\033[0m", connected_to_database);
- m_current_database = connected_to_database;
- m_pending_database = "";
- m_connection_id = connection_id;
- read_sql();
- };
-
m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto updated, auto created, auto deleted) {
if (updated != 0 || created != 0 || deleted != 0) {
outln("{} row(s) updated, {} created, {} deleted", updated, created, deleted);
@@ -104,27 +96,11 @@ public:
read_sql();
};
- m_sql_client->on_connection_error = [this](auto, auto code, auto const& message) {
- outln("\033[33;1mConnection error:\033[0m {}", message);
- m_loop.quit(to_underlying(code));
- };
-
m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) {
outln("\033[33;1mExecution error:\033[0m {}", message);
read_sql();
};
- m_sql_client->on_disconnected = [this](auto) {
- if (m_pending_database.is_empty()) {
- outln("Disconnected from \033[33;1m{}\033[0m and terminating", m_current_database);
- m_loop.quit(0);
- } else {
- outln("Disconnected from \033[33;1m{}\033[0m", m_current_database);
- m_current_database = "";
- m_sql_client->async_connect(m_pending_database);
- }
- };
-
if (!database_name.is_empty())
connect(database_name);
}
@@ -136,11 +112,18 @@ public:
void connect(DeprecatedString const& database_name)
{
- if (m_current_database.is_empty()) {
- m_sql_client->async_connect(database_name);
+ if (!m_database_name.is_empty()) {
+ m_sql_client->disconnect(m_connection_id);
+ m_database_name = {};
+ }
+
+ if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) {
+ outln("Connected to \033[33;1m{}\033[0m", database_name);
+ m_database_name = database_name;
+ m_connection_id = *connection_id;
} else {
- m_pending_database = database_name;
- m_sql_client->async_disconnect(m_connection_id);
+ warnln("\033[33;1mCould not connect to:\033[0m {}", database_name);
+ m_loop.quit(1);
}
}
@@ -158,6 +141,7 @@ public:
auto run()
{
+ read_sql();
return m_loop.exec();
}
@@ -166,8 +150,7 @@ private:
RefPtr<Line::Editor> m_editor { nullptr };
int m_repl_line_level { 0 };
bool m_keep_running { true };
- DeprecatedString m_pending_database {};
- DeprecatedString m_current_database {};
+ DeprecatedString m_database_name {};
AK::RefPtr<SQL::SQLClient> m_sql_client { nullptr };
u64 m_connection_id { 0 };
Core::EventLoop m_loop;
@@ -280,7 +263,8 @@ private:
// m_keep_running can be set to false when the file we are reading
// from is exhausted...
if (!m_keep_running) {
- m_sql_client->async_disconnect(m_connection_id);
+ m_sql_client->disconnect(m_connection_id);
+ m_loop.quit(0);
return;
}
@@ -296,7 +280,8 @@ private:
// ...But m_keep_running can also be set to false by a command handler.
if (!m_keep_running) {
- m_sql_client->async_disconnect(m_connection_id);
+ m_sql_client->disconnect(m_connection_id);
+ m_loop.quit(0);
return;
}
};