summaryrefslogtreecommitdiff
path: root/Userland/Libraries/LibSQL/AST
diff options
context:
space:
mode:
authorTimothy Flynn <trflynn89@pm.me>2022-12-05 07:55:21 -0500
committerAndreas Kling <kling@serenityos.org>2022-12-07 13:09:00 +0100
commit53f8d62ea4442f31add38a06bbdd2bfa3037a0a5 (patch)
tree62cac31ca1b945514b962a7a97258ae543c2b8f2 /Userland/Libraries/LibSQL/AST
parent1574f2c3f60808d53ba5d2f785694ec9541239d2 (diff)
downloadserenity-53f8d62ea4442f31add38a06bbdd2bfa3037a0a5.zip
LibSQL: Partially implement the UPDATE command
This implements enough to update rows filtered by a WHERE clause.
Diffstat (limited to 'Userland/Libraries/LibSQL/AST')
-rw-r--r--Userland/Libraries/LibSQL/AST/AST.h2
-rw-r--r--Userland/Libraries/LibSQL/AST/Insert.cpp12
-rw-r--r--Userland/Libraries/LibSQL/AST/Update.cpp63
3 files changed, 66 insertions, 11 deletions
diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h
index e1d89956ab..4d64c9d48a 100644
--- a/Userland/Libraries/LibSQL/AST/AST.h
+++ b/Userland/Libraries/LibSQL/AST/AST.h
@@ -992,6 +992,8 @@ public:
RefPtr<Expression> const& where_clause() const { return m_where_clause; }
RefPtr<ReturningClause> const& returning_clause() const { return m_returning_clause; }
+ virtual ResultOr<ResultSet> execute(ExecutionContext&) const override;
+
private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
ConflictResolution m_conflict_resolution;
diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp
index bda6c6fa7d..76be3de763 100644
--- a/Userland/Libraries/LibSQL/AST/Insert.cpp
+++ b/Userland/Libraries/LibSQL/AST/Insert.cpp
@@ -12,15 +12,6 @@
namespace SQL::AST {
-static bool does_value_data_type_match(SQLType expected, SQLType actual)
-{
- if (actual == SQLType::Null)
- return false;
- if (expected == SQLType::Integer)
- return actual == SQLType::Integer || actual == SQLType::Float;
- return expected == actual;
-}
-
ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
{
auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name));
@@ -49,13 +40,12 @@ ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() };
for (auto ix = 0u; ix < values.size(); ix++) {
- auto input_value_type = values[ix].type();
auto& tuple_descriptor = *row.descriptor();
// In case of having column names, this must succeed since we checked for every column name for existence in the table.
auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
auto element_type = tuple_descriptor[element_index].type;
- if (!does_value_data_type_match(element_type, input_value_type))
+ if (!values[ix].is_type_compatible_with(element_type))
return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() };
row[element_index] = move(values[ix]);
diff --git a/Userland/Libraries/LibSQL/AST/Update.cpp b/Userland/Libraries/LibSQL/AST/Update.cpp
new file mode 100644
index 0000000000..de65951f41
--- /dev/null
+++ b/Userland/Libraries/LibSQL/AST/Update.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2022, Tim Flynn <trflynn89@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibSQL/AST/AST.h>
+#include <LibSQL/Database.h>
+#include <LibSQL/Meta.h>
+#include <LibSQL/Row.h>
+
+namespace SQL::AST {
+
+ResultOr<ResultSet> Update::execute(ExecutionContext& context) const
+{
+ auto const& schema_name = m_qualified_table_name->schema_name();
+ auto const& table_name = m_qualified_table_name->table_name();
+ auto table_def = TRY(context.database->get_table(schema_name, table_name));
+
+ Vector<Row> matched_rows;
+
+ for (auto& table_row : TRY(context.database->select_all(*table_def))) {
+ context.current_row = &table_row;
+
+ if (auto const& where_clause = this->where_clause()) {
+ auto where_result = TRY(where_clause->evaluate(context)).to_bool();
+ if (!where_result.has_value() || !where_result.value())
+ continue;
+ }
+
+ TRY(matched_rows.try_append(move(table_row)));
+ }
+
+ ResultSet result { SQLCommand::Update };
+
+ for (auto& update_column : m_update_columns) {
+ auto row_value = TRY(update_column.expression->evaluate(context));
+
+ for (auto& table_row : matched_rows) {
+ auto& row_descriptor = *table_row.descriptor();
+
+ for (auto const& column_name : update_column.column_names) {
+ if (!table_row.has(column_name))
+ return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name };
+
+ auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index();
+ auto column_type = row_descriptor[column_index].type;
+
+ if (!row_value.is_type_compatible_with(column_type))
+ return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name };
+
+ table_row[column_index] = row_value;
+ }
+
+ TRY(context.database->update(table_row));
+ result.insert_row(table_row, {});
+ }
+ }
+
+ return result;
+}
+
+}