summaryrefslogtreecommitdiff
path: root/Userland/Services/SQLServer
diff options
context:
space:
mode:
authorTimothy Flynn <trflynn89@pm.me>2022-12-02 16:25:27 -0500
committerAndreas Kling <kling@serenityos.org>2022-12-07 13:09:00 +0100
commite2f71d280817408e93fc79c652e2de29cdc82660 (patch)
tree0958b10c247e23fe93d34ee531ae6c74ccd122b0 /Userland/Services/SQLServer
parent3a915483b022f78be5c2547786ad55832d9fb028 (diff)
downloadserenity-e2f71d280817408e93fc79c652e2de29cdc82660.zip
LibSQL+SQLServer+SQLStudio+sql: Use proper types for SQL IPC and IDs
When storing IDs and sending values over IPC, this changes SQLServer to: 1. Stop using -1 as a nominal "bad" ID. Store the IDs as unsigned, and use Optional in the one place that the IPC needs to indicate an ID was not allocated. 2. Let LibIPC encode/decode enumerations (SQLErrorCode) on our behalf. 3. Use size_t for array sizes.
Diffstat (limited to 'Userland/Services/SQLServer')
-rw-r--r--Userland/Services/SQLServer/ConnectionFromClient.cpp12
-rw-r--r--Userland/Services/SQLServer/ConnectionFromClient.h6
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.cpp14
-rw-r--r--Userland/Services/SQLServer/DatabaseConnection.h10
-rw-r--r--Userland/Services/SQLServer/SQLClient.ipc16
-rw-r--r--Userland/Services/SQLServer/SQLServer.ipc8
-rw-r--r--Userland/Services/SQLServer/SQLStatement.cpp11
-rw-r--r--Userland/Services/SQLServer/SQLStatement.h6
8 files changed, 42 insertions, 41 deletions
diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp
index c22045d610..cb4fd185c4 100644
--- a/Userland/Services/SQLServer/ConnectionFromClient.cpp
+++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp
@@ -41,7 +41,7 @@ Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedStr
return { database_connection->connection_id() };
}
-void ConnectionFromClient::disconnect(int connection_id)
+void ConnectionFromClient::disconnect(u64 connection_id)
{
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::disconnect(connection_id: {})", connection_id);
auto database_connection = DatabaseConnection::connection_for(connection_id);
@@ -51,27 +51,27 @@ void ConnectionFromClient::disconnect(int connection_id)
dbgln("Database connection has disappeared");
}
-Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_statement(int connection_id, DeprecatedString const& sql)
+Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_statement(u64 connection_id, DeprecatedString const& sql)
{
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement(connection_id: {}, sql: '{}')", connection_id, sql);
auto database_connection = DatabaseConnection::connection_for(connection_id);
if (!database_connection) {
dbgln("Database connection has disappeared");
- return { -1 };
+ return { {} };
}
auto result = database_connection->prepare_statement(sql);
if (result.is_error()) {
dbgln_if(SQLSERVER_DEBUG, "Could not parse SQL statement: {}", result.error().error_string());
- return { -1 };
+ return { {} };
}
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement -> statement_id = {}", result.value());
return { result.value() };
}
-void ConnectionFromClient::execute_statement(int statement_id, Vector<SQL::Value> const& placeholder_values)
+void ConnectionFromClient::execute_statement(u64 statement_id, Vector<SQL::Value> const& placeholder_values)
{
dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::execute_query_statement(statement_id: {})", statement_id);
auto statement = SQLStatement::statement_for(statement_id);
@@ -80,7 +80,7 @@ void ConnectionFromClient::execute_statement(int statement_id, Vector<SQL::Value
statement->execute(move(const_cast<Vector<SQL::Value>&>(placeholder_values)));
} else {
dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared");
- async_execution_error(statement_id, (int)SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id));
+ async_execution_error(statement_id, SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id));
}
}
diff --git a/Userland/Services/SQLServer/ConnectionFromClient.h b/Userland/Services/SQLServer/ConnectionFromClient.h
index 1c6f612d83..05e2186e03 100644
--- a/Userland/Services/SQLServer/ConnectionFromClient.h
+++ b/Userland/Services/SQLServer/ConnectionFromClient.h
@@ -29,9 +29,9 @@ private:
explicit ConnectionFromClient(NonnullOwnPtr<Core::Stream::LocalSocket>, int client_id);
virtual Messages::SQLServer::ConnectResponse connect(DeprecatedString const&) override;
- virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(int, DeprecatedString const&) override;
- virtual void execute_statement(int, Vector<SQL::Value> const& placeholder_values) override;
- virtual void disconnect(int) override;
+ virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(u64, DeprecatedString const&) override;
+ virtual void execute_statement(u64, Vector<SQL::Value> const& placeholder_values) override;
+ virtual void disconnect(u64) override;
};
}
diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp
index 0f3bebf224..ccea8d4061 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.cpp
+++ b/Userland/Services/SQLServer/DatabaseConnection.cpp
@@ -11,9 +11,10 @@
namespace SQLServer {
-static HashMap<int, NonnullRefPtr<DatabaseConnection>> s_connections;
+static HashMap<u64, NonnullRefPtr<DatabaseConnection>> s_connections;
+static u64 s_next_connection_id = 0;
-RefPtr<DatabaseConnection> DatabaseConnection::connection_for(int connection_id)
+RefPtr<DatabaseConnection> DatabaseConnection::connection_for(u64 connection_id)
{
if (s_connections.contains(connection_id))
return *s_connections.get(connection_id).value();
@@ -21,8 +22,6 @@ RefPtr<DatabaseConnection> DatabaseConnection::connection_for(int connection_id)
return nullptr;
}
-static int s_next_connection_id = 0;
-
DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int client_id)
: Object()
, m_database_name(move(database_name))
@@ -31,17 +30,18 @@ DatabaseConnection::DatabaseConnection(DeprecatedString database_name, int clien
{
if (LexicalPath path(m_database_name); (path.title() != m_database_name) || (path.dirname() != ".")) {
auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
- client_connection->async_connection_error(m_connection_id, (int)SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
+ client_connection->async_connection_error(m_connection_id, SQL::SQLErrorCode::InvalidDatabaseName, m_database_name);
return;
}
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection {} initiating connection with database '{}'", connection_id(), m_database_name);
s_connections.set(m_connection_id, *this);
+
deferred_invoke([this]() {
m_database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", m_database_name));
auto client_connection = ConnectionFromClient::client_connection_for(m_client_id);
if (auto maybe_error = m_database->open(); maybe_error.is_error()) {
- client_connection->async_connection_error(m_connection_id, to_underlying(maybe_error.error().error()), maybe_error.error().error_string());
+ client_connection->async_connection_error(m_connection_id, maybe_error.error().error(), maybe_error.error().error_string());
return;
}
m_accept_statements = true;
@@ -67,7 +67,7 @@ void DatabaseConnection::disconnect()
});
}
-SQL::ResultOr<int> DatabaseConnection::prepare_statement(StringView sql)
+SQL::ResultOr<u64> DatabaseConnection::prepare_statement(StringView sql)
{
dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql);
diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h
index 77632f0938..3ded1919bc 100644
--- a/Userland/Services/SQLServer/DatabaseConnection.h
+++ b/Userland/Services/SQLServer/DatabaseConnection.h
@@ -19,20 +19,20 @@ class DatabaseConnection final : public Core::Object {
public:
~DatabaseConnection() override = default;
- static RefPtr<DatabaseConnection> connection_for(int connection_id);
- int connection_id() const { return m_connection_id; }
+ static RefPtr<DatabaseConnection> connection_for(u64 connection_id);
+ u64 connection_id() const { return m_connection_id; }
int client_id() const { return m_client_id; }
RefPtr<SQL::Database> database() { return m_database; }
void disconnect();
- SQL::ResultOr<int> prepare_statement(StringView sql);
+ SQL::ResultOr<u64> prepare_statement(StringView sql);
private:
DatabaseConnection(DeprecatedString database_name, int client_id);
RefPtr<SQL::Database> m_database { nullptr };
DeprecatedString m_database_name;
- int m_connection_id;
- int m_client_id;
+ u64 m_connection_id { 0 };
+ int m_client_id { 0 };
bool m_accept_statements { false };
};
diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc
index c33e4fff3d..7228851b6e 100644
--- a/Userland/Services/SQLServer/SQLClient.ipc
+++ b/Userland/Services/SQLServer/SQLClient.ipc
@@ -1,10 +1,12 @@
+#include <LibSQL/Result.h>
+
endpoint SQLClient
{
- connected(int connection_id, DeprecatedString connected_to_database) =|
- connection_error(int connection_id, int code, DeprecatedString message) =|
- execution_success(int statement_id, bool has_results, int created, int updated, int deleted) =|
- next_result(int statement_id, Vector<DeprecatedString> row) =|
- results_exhausted(int statement_id, int total_rows) =|
- execution_error(int statement_id, int code, DeprecatedString message) =|
- disconnected(int connection_id) =|
+ connected(u64 connection_id, DeprecatedString connected_to_database) =|
+ connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =|
+ execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) =|
+ next_result(u64 statement_id, Vector<DeprecatedString> row) =|
+ results_exhausted(u64 statement_id, size_t total_rows) =|
+ execution_error(u64 statement_id, SQL::SQLErrorCode code, DeprecatedString message) =|
+ disconnected(u64 connection_id) =|
}
diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc
index 9c67f134d8..89fb7b8392 100644
--- a/Userland/Services/SQLServer/SQLServer.ipc
+++ b/Userland/Services/SQLServer/SQLServer.ipc
@@ -2,8 +2,8 @@
endpoint SQLServer
{
- connect(DeprecatedString name) => (int connection_id)
- prepare_statement(int connection_id, DeprecatedString statement) => (int statement_id)
- execute_statement(int statement_id, Vector<SQL::Value> placeholder_values) =|
- disconnect(int connection_id) =|
+ connect(DeprecatedString name) => (u64 connection_id)
+ prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional<u64> statement_id)
+ execute_statement(u64 statement_id, Vector<SQL::Value> placeholder_values) =|
+ disconnect(u64 connection_id) =|
}
diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp
index 3bd03dc89b..d4b9f583fc 100644
--- a/Userland/Services/SQLServer/SQLStatement.cpp
+++ b/Userland/Services/SQLServer/SQLStatement.cpp
@@ -12,9 +12,10 @@
namespace SQLServer {
-static HashMap<int, NonnullRefPtr<SQLStatement>> s_statements;
+static HashMap<u64, NonnullRefPtr<SQLStatement>> s_statements;
+static u64 s_next_statement_id = 0;
-RefPtr<SQLStatement> SQLStatement::statement_for(int statement_id)
+RefPtr<SQLStatement> SQLStatement::statement_for(u64 statement_id)
{
if (s_statements.contains(statement_id))
return *s_statements.get(statement_id).value();
@@ -22,8 +23,6 @@ RefPtr<SQLStatement> SQLStatement::statement_for(int statement_id)
return nullptr;
}
-static int s_next_statement_id = 0;
-
SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnection& connection, StringView sql)
{
auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
@@ -54,7 +53,7 @@ void SQLStatement::report_error(SQL::Result result)
remove_from_parent();
if (client_connection)
- client_connection->async_execution_error(statement_id(), (int)result.error(), result.error_string());
+ client_connection->async_execution_error(statement_id(), result.error(), result.error_string());
else
warnln("Cannot return execution error. Client disconnected");
@@ -129,7 +128,7 @@ void SQLStatement::next()
next();
});
} else {
- client_connection->async_results_exhausted(statement_id(), (int)m_index);
+ client_connection->async_results_exhausted(statement_id(), m_index);
}
}
diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h
index c4658599f3..19e97d0718 100644
--- a/Userland/Services/SQLServer/SQLStatement.h
+++ b/Userland/Services/SQLServer/SQLStatement.h
@@ -25,8 +25,8 @@ public:
static SQL::ResultOr<NonnullRefPtr<SQLStatement>> create(DatabaseConnection&, StringView sql);
~SQLStatement() override = default;
- static RefPtr<SQLStatement> statement_for(int statement_id);
- int statement_id() const { return m_statement_id; }
+ static RefPtr<SQLStatement> statement_for(u64 statement_id);
+ u64 statement_id() const { return m_statement_id; }
DatabaseConnection* connection() { return dynamic_cast<DatabaseConnection*>(parent()); }
void execute(Vector<SQL::Value> placeholder_values);
@@ -37,7 +37,7 @@ private:
void next();
void report_error(SQL::Result);
- int m_statement_id;
+ u64 m_statement_id { 0 };
size_t m_index { 0 };
NonnullRefPtr<SQL::AST::Statement> m_statement;
Optional<SQL::ResultSet> m_result {};