summaryrefslogtreecommitdiff
path: root/Userland/Services
diff options
context:
space:
mode:
Diffstat (limited to 'Userland/Services')
-rw-r--r--Userland/Services/SQLServer/ConnectionFromClient.cpp6
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.cpp60
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.h11
-rw-r--r--Userland/Services/SQLServer/SQLClient.ipc3
-rw-r--r--Userland/Services/SQLServer/SQLServer.ipc4
-rw-r--r--Userland/Services/SQLServer/SQLStatement.cpp4
6 files changed, 31 insertions, 57 deletions
diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp
index 95d54622da..3c358dd9ae 100644
--- a/Userland/Services/SQLServer/ConnectionFromClient.cpp
+++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp
@@ -37,8 +37,10 @@ void ConnectionFromClient::die()
Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedString const& database_name)
{
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::connect(database_name: {})", database_name);
- auto database_connection = DatabaseConnection::construct(database_name, client_id());
- return { database_connection->connection_id() };
+
+ if (auto database_connection = DatabaseConnection::create(database_name, client_id()); !database_connection.is_error())
+ return { database_connection.value()->connection_id() };
+ return { {} };
}
void ConnectionFromClient::disconnect(u64 connection_id)
diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp
index ccea8d4061..46f3232000 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.cpp
+++ b/Userland/Services/SQLServer/DatabaseConnection.cpp
@@ -5,7 +5,6 @@
*/
#include <AK/LexicalPath.h>
-#include <SQLServer/ConnectionFromClient.h>
#include <SQLServer/DatabaseConnection.h>
#include <SQLServer/SQLStatement.h>
@@ -22,64 +21,41 @@ RefPtr<DatabaseConnection> DatabaseConnection::connection_for(u64 connection_id)
return nullptr;
}
-DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id)
+ErrorOr<NonnullRefPtr<DatabaseConnection>> DatabaseConnection::create(DeprecatedString database_name, int client_id)
+{
+ if (LexicalPath path(database_name); (path.title() != database_name) || (path.dirname() != "."))
+ return Error::from_string_view("Invalid database name"sv);
+
+ auto database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", database_name));
+ if (auto result = database->open(); result.is_error()) {
+ warnln("Could not open database: {}", result.error().error_string());
+ return Error::from_string_view("Could not open database"sv);
+ }
+
+ return adopt_nonnull_ref_or_enomem(new (nothrow) DatabaseConnection(move(database), move(database_name), client_id));
+}
+
+DatabaseConnection::DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id)
: Object()
+ , m_database(move(database))
, m_database_name(move(database_name))
, m_connection_id(s_next_connection_id++)
, m_client_id(client_id)
{
- if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) {
- auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
- client_connection->async_connection_error(m_connection_id, SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
- return;
- }
-
- dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name);
+ dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiatedconnection with database '{}'", connection_id(), m_database_name);
s_connections.set(m_connection_id, *this);
-
- deferred_invoke([this]() {
- m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name));
- auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
- if (auto maybe_error = m_database->open(); maybe_error.is_error()) {
- client_connection->async_connection_error(m_connection_id, maybe_error.error().error(), maybe_error.error().error_string());
- return;
- }
- m_accept_statements = true;
- if (client_connection)
- client_connection->async_connected(m_connection_id, m_database_name);
- else
- warnln("Cannot notify client of database connection. Client disconnected");
- });
}
void DatabaseConnection::disconnect()
{
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::disconnect(connection_id {}, database '{}'", connection_id(), m_database_name);
- m_accept_statements = false;
- deferred_invoke([this]() {
- m_database = nullptr;
- s_connections.remove(m_connection_id);
- auto client_connection = ConnectionFromClient::client_connection_for(client_id());
- if (client_connection)
- client_connection->async_disconnected(m_connection_id);
- else
- warnln("Cannot notify client of database disconnection. Client disconnected");
- });
+ s_connections.remove(connection_id());
}
SQL::ResultOr<u64> DatabaseConnection::prepare_statement(StringView sql)
{
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
- if (!m_accept_statements)
- return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable };
-
- auto client_connection = ConnectionFromClient::client_connection_for(client_id());
- if (!client_connection) {
- warnln("Cannot notify client of database disconnection. Client disconnected");
- return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv };
- }
-
auto statement = TRY(SQLStatement::create(*this, sql));
return statement->statement_id();
}
diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h
index 3ded1919bc..eea9acbfe2 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.h
+++ b/Userland/Services/SQLServer/DatabaseConnection.h
@@ -6,6 +6,7 @@
#pragma once
+#include <AK/NonnullRefPtr.h>
#include <LibCore/Object.h>
#include <LibSQL/Database.h>
#include <LibSQL/Result.h>
@@ -14,26 +15,26 @@
namespace SQLServer {
class DatabaseConnection final : public Core::Object {
- C_OBJECT(DatabaseConnection)
+ C_OBJECT_ABSTRACT(DatabaseConnection)
public:
+ static ErrorOr<NonnullRefPtr<DatabaseConnection>> create(DeprecatedString database_name, int client_id);
~DatabaseConnection() override = default;
static RefPtr<DatabaseConnection> connection_for(u64 connection_id);
u64 connection_id() const { return m_connection_id; }
int client_id() const { return m_client_id; }
- RefPtr<SQL::Database> database() { return m_database; }
+ NonnullRefPtr<SQL::Database> database() { return m_database; }
void disconnect();
SQL::ResultOr<u64> prepare_statement(StringView sql);
private:
- DatabaseConnection(DeprecatedString database_name, int client_id);
+ DatabaseConnection(NonnullRefPtr<SQL::Database> database, DeprecatedString database_name, int client_id);
- RefPtr<SQL::Database> m_database { nullptr };
+ NonnullRefPtr<SQL::Database> m_database;
DeprecatedString m_database_name;
u64 m_connection_id { 0 };
int m_client_id { 0 };
- bool m_accept_statements { false };
};
}
diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc
index e5b87110f4..15d3533333 100644
--- a/Userland/Services/SQLServer/SQLClient.ipc
+++ b/Userland/Services/SQLServer/SQLClient.ipc
@@ -2,11 +2,8 @@
endpoint SQLClient
{
- connected(u64 connection_id, DeprecatedString connected_to_database) =|
- connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =|
execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =|
next_result(u64 statement_id, u64 execution_id, Vector<DeprecatedString> row) =|
results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =|
execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =|
- disconnected(u64 connection_id) =|
}
diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc
index ba06d5eb39..2a01ae22f6 100644
--- a/Userland/Services/SQLServer/SQLServer.ipc
+++ b/Userland/Services/SQLServer/SQLServer.ipc
@@ -2,8 +2,8 @@
endpoint SQLServer
{
- connect(DeprecatedString name) => (u64 connection_id)
+ connect(DeprecatedString name) => (Optional<u64> connection_id)
prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id)
execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) => (Optional<u64> execution_id)
- disconnect(u64 connection_id) =|
+ disconnect(u64 connection_id) => ()
}
diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp
index cce1528f05..a1049bd835 100644
--- a/Userland/Services/SQLServer/SQLStatement.cpp
+++ b/Userland/Services/SQLServer/SQLStatement.cpp
@@ -74,9 +74,7 @@ Optional<u64> SQLStatement::execute(Vector<SQL::Value> placeholder_values)
m_ongoing_executions.set(execution_id);
deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] {
- VERIFY(!connection()->database().is_null());
-
- auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values);
+ auto execution_result = m_statement->execute(connection()->database(), placeholder_values);
m_ongoing_executions.remove(execution_id);
if (execution_result.is_error()) {