diff options
Diffstat (limited to 'Userland/Libraries/LibSQL/AST')
-rw-r--r-- | Userland/Libraries/LibSQL/AST/AST.h | 2 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Insert.cpp | 12 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Update.cpp | 63 |
3 files changed, 66 insertions, 11 deletions
diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index e1d89956ab..4d64c9d48a 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -992,6 +992,8 @@ public: RefPtr<Expression> const& where_clause() const { return m_where_clause; } RefPtr<ReturningClause> const& returning_clause() const { return m_returning_clause; } + virtual ResultOr<ResultSet> execute(ExecutionContext&) const override; + private: RefPtr<CommonTableExpressionList> m_common_table_expression_list; ConflictResolution m_conflict_resolution; diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index bda6c6fa7d..76be3de763 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -12,15 +12,6 @@ namespace SQL::AST { -static bool does_value_data_type_match(SQLType expected, SQLType actual) -{ - if (actual == SQLType::Null) - return false; - if (expected == SQLType::Integer) - return actual == SQLType::Integer || actual == SQLType::Float; - return expected == actual; -} - ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const { auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name)); @@ -49,13 +40,12 @@ ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, DeprecatedString::empty() }; for (auto ix = 0u; ix < values.size(); ix++) { - auto input_value_type = values[ix].type(); auto& tuple_descriptor = *row.descriptor(); // In case of having column names, this must succeed since we checked for every column name for existence in the table. auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index(); auto element_type = tuple_descriptor[element_index].type; - if (!does_value_data_type_match(element_type, input_value_type)) + if (!values[ix].is_type_compatible_with(element_type)) return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() }; row[element_index] = move(values[ix]); diff --git a/Userland/Libraries/LibSQL/AST/Update.cpp b/Userland/Libraries/LibSQL/AST/Update.cpp new file mode 100644 index 0000000000..de65951f41 --- /dev/null +++ b/Userland/Libraries/LibSQL/AST/Update.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022, Tim Flynn <trflynn89@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <LibSQL/AST/AST.h> +#include <LibSQL/Database.h> +#include <LibSQL/Meta.h> +#include <LibSQL/Row.h> + +namespace SQL::AST { + +ResultOr<ResultSet> Update::execute(ExecutionContext& context) const +{ + auto const& schema_name = m_qualified_table_name->schema_name(); + auto const& table_name = m_qualified_table_name->table_name(); + auto table_def = TRY(context.database->get_table(schema_name, table_name)); + + Vector<Row> matched_rows; + + for (auto& table_row : TRY(context.database->select_all(*table_def))) { + context.current_row = &table_row; + + if (auto const& where_clause = this->where_clause()) { + auto where_result = TRY(where_clause->evaluate(context)).to_bool(); + if (!where_result.has_value() || !where_result.value()) + continue; + } + + TRY(matched_rows.try_append(move(table_row))); + } + + ResultSet result { SQLCommand::Update }; + + for (auto& update_column : m_update_columns) { + auto row_value = TRY(update_column.expression->evaluate(context)); + + for (auto& table_row : matched_rows) { + auto& row_descriptor = *table_row.descriptor(); + + for (auto const& column_name : update_column.column_names) { + if (!table_row.has(column_name)) + return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name }; + + auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index(); + auto column_type = row_descriptor[column_index].type; + + if (!row_value.is_type_compatible_with(column_type)) + return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name }; + + table_row[column_index] = row_value; + } + + TRY(context.database->update(table_row)); + result.insert_row(table_row, {}); + } + } + + return result; +} + +} |