diff options
Diffstat (limited to 'Userland/Libraries/LibSQL/AST')
-rw-r--r-- | Userland/Libraries/LibSQL/AST/AST.h | 24 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/CreateSchema.cpp | 6 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/CreateTable.cpp | 8 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Expression.cpp | 90 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Insert.cpp | 49 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Parser.cpp | 12 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/AST/Select.cpp | 37 |
7 files changed, 213 insertions, 13 deletions
diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index cd9001cf31..7a8ca428eb 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -295,7 +295,15 @@ private: // Expressions //================================================================================================== +struct ExecutionContext { + NonnullRefPtr<Database> database; + RefPtr<SQLResult> result { nullptr }; + Tuple current_row {}; +}; + class Expression : public ASTNode { +public: + virtual Value evaluate(ExecutionContext&) const; }; class ErrorExpression final : public Expression { @@ -309,6 +317,7 @@ public: } double value() const { return m_value; } + virtual Value evaluate(ExecutionContext&) const override; private: double m_value; @@ -322,6 +331,7 @@ public: } const String& value() const { return m_value; } + virtual Value evaluate(ExecutionContext&) const override; private: String m_value; @@ -341,11 +351,14 @@ private: }; class NullLiteral : public Expression { +public: + virtual Value evaluate(ExecutionContext&) const override; }; class NestedExpression : public Expression { public: const NonnullRefPtr<Expression>& expression() const { return m_expression; } + virtual Value evaluate(ExecutionContext&) const override; protected: explicit NestedExpression(NonnullRefPtr<Expression> expression) @@ -439,6 +452,7 @@ public: } UnaryOperator type() const { return m_type; } + virtual Value evaluate(ExecutionContext&) const override; private: UnaryOperator m_type; @@ -488,6 +502,7 @@ public: } const NonnullRefPtrVector<Expression>& expressions() const { return m_expressions; } + virtual Value evaluate(ExecutionContext&) const override; private: NonnullRefPtrVector<Expression> m_expressions; @@ -667,7 +682,7 @@ private: class Statement : public ASTNode { public: - virtual RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const { return nullptr; } + virtual RefPtr<SQLResult> execute(ExecutionContext&) const { return nullptr; } }; class ErrorStatement final : public Statement { @@ -684,7 +699,7 @@ public: const String& schema_name() const { return m_schema_name; } bool is_error_if_schema_exists() const { return m_is_error_if_schema_exists; } - RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override; + RefPtr<SQLResult> execute(ExecutionContext&) const override; private: String m_schema_name; @@ -723,7 +738,7 @@ public: bool is_temporary() const { return m_is_temporary; } bool is_error_if_table_exists() const { return m_is_error_if_table_exists; } - RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override; + RefPtr<SQLResult> execute(ExecutionContext&) const override; private: String m_schema_name; @@ -886,6 +901,8 @@ public: bool has_selection() const { return !m_select_statement.is_null(); } const RefPtr<Select>& select_statement() const { return m_select_statement; } + RefPtr<SQLResult> execute(ExecutionContext&) const; + private: RefPtr<CommonTableExpressionList> m_common_table_expression_list; ConflictResolution m_conflict_resolution; @@ -977,6 +994,7 @@ public: const RefPtr<GroupByClause>& group_by_clause() const { return m_group_by_clause; } const NonnullRefPtrVector<OrderingTerm>& ordering_term_list() const { return m_ordering_term_list; } const RefPtr<LimitClause>& limit_clause() const { return m_limit_clause; } + RefPtr<SQLResult> execute(ExecutionContext&) const override; private: RefPtr<CommonTableExpressionList> m_common_table_expression_list; diff --git a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp index 423fbf44b4..e387279506 100644 --- a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp +++ b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp @@ -10,9 +10,9 @@ namespace SQL::AST { -RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const +RefPtr<SQLResult> CreateSchema::execute(ExecutionContext& context) const { - auto schema_def = database->get_schema(m_schema_name); + auto schema_def = context.database->get_schema(m_schema_name); if (schema_def) { if (m_is_error_if_schema_exists) { return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaExists, m_schema_name); @@ -21,7 +21,7 @@ RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const } schema_def = SchemaDef::construct(m_schema_name); - database->add_schema(*schema_def); + context.database->add_schema(*schema_def); return SQLResult::construct(SQLCommand::Create, 0, 1); } diff --git a/Userland/Libraries/LibSQL/AST/CreateTable.cpp b/Userland/Libraries/LibSQL/AST/CreateTable.cpp index 56729a7da1..090cb7036d 100644 --- a/Userland/Libraries/LibSQL/AST/CreateTable.cpp +++ b/Userland/Libraries/LibSQL/AST/CreateTable.cpp @@ -9,13 +9,13 @@ namespace SQL::AST { -RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const +RefPtr<SQLResult> CreateTable::execute(ExecutionContext& context) const { auto schema_name = (!m_schema_name.is_null() && !m_schema_name.is_empty()) ? m_schema_name : "default"; - auto schema_def = database->get_schema(schema_name); + auto schema_def = context.database->get_schema(schema_name); if (!schema_def) return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaDoesNotExist, m_schema_name); - auto table_def = database->get_table(schema_name, m_table_name); + auto table_def = context.database->get_table(schema_name, m_table_name); if (table_def) { if (m_is_error_if_table_exists) { return SQLResult::construct(SQLCommand::Create, SQLErrorCode::TableExists, m_table_name); @@ -37,7 +37,7 @@ RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const } table_def->append_column(column.name(), type); } - database->add_table(*table_def); + context.database->add_table(*table_def); return SQLResult::construct(SQLCommand::Create, 0, 1); } diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp new file mode 100644 index 0000000000..a66257fb41 --- /dev/null +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021, Jan de Visser <jan@de-visser.net> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <LibSQL/AST/AST.h> +#include <LibSQL/Database.h> + +namespace SQL::AST { + +Value Expression::evaluate(ExecutionContext&) const +{ + return Value::null(); +} + +Value NumericLiteral::evaluate(ExecutionContext&) const +{ + Value ret(SQLType::Float); + ret = value(); + return ret; +} + +Value StringLiteral::evaluate(ExecutionContext&) const +{ + Value ret(SQLType::Text); + ret = value(); + return ret; +} + +Value NullLiteral::evaluate(ExecutionContext&) const +{ + return Value::null(); +} + +Value NestedExpression::evaluate(ExecutionContext& context) const +{ + return expression()->evaluate(context); +} + +Value ChainedExpression::evaluate(ExecutionContext& context) const +{ + Value ret(SQLType::Tuple); + Vector<Value> values; + for (auto& expression : expressions()) { + values.append(expression.evaluate(context)); + } + ret = values; + return ret; +} + +Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const +{ + Value expression_value = NestedExpression::evaluate(context); + switch (type()) { + case UnaryOperator::Plus: + if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float) + return expression_value; + // TODO: Error handling. + VERIFY_NOT_REACHED(); + case UnaryOperator::Minus: + if (expression_value.type() == SQLType::Integer) { + expression_value = -int(expression_value); + return expression_value; + } + if (expression_value.type() == SQLType::Float) { + expression_value = -double(expression_value); + return expression_value; + } + // TODO: Error handling. + VERIFY_NOT_REACHED(); + case UnaryOperator::Not: + if (expression_value.type() == SQLType::Boolean) { + expression_value = !bool(expression_value); + return expression_value; + } + // TODO: Error handling. + VERIFY_NOT_REACHED(); + case UnaryOperator::BitwiseNot: + if (expression_value.type() == SQLType::Integer) { + expression_value = ~u32(expression_value); + return expression_value; + } + // TODO: Error handling. + VERIFY_NOT_REACHED(); + } + VERIFY_NOT_REACHED(); +} + +} diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp new file mode 100644 index 0000000000..2f7edc2486 --- /dev/null +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021, Jan de Visser <jan@de-visser.net> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <LibSQL/AST/AST.h> +#include <LibSQL/Database.h> +#include <LibSQL/Meta.h> +#include <LibSQL/Row.h> + +namespace SQL::AST { + +RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const +{ + auto table_def = context.database->get_table(m_schema_name, m_table_name); + if (!table_def) { + auto schema_name = m_schema_name; + if (schema_name.is_null() || schema_name.is_empty()) + schema_name = "default"; + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::TableDoesNotExist, String::formatted("{}.{}", schema_name, m_table_name)); + } + + Row row(table_def); + for (auto& column : m_column_names) { + if (!row.has(column)) { + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::ColumnDoesNotExist, column); + } + } + + for (auto& row_expr : m_chained_expressions) { + for (auto& column_def : table_def->columns()) { + if (!m_column_names.contains_slow(column_def.name())) { + row[column_def.name()] = column_def.default_value(); + } + } + auto row_value = row_expr.evaluate(context); + VERIFY(row_value.type() == SQLType::Tuple); + auto values = row_value.to_vector().value(); + for (auto ix = 0u; ix < values.size(); ix++) { + auto& column_name = m_column_names[ix]; + row[column_name] = values[ix]; + } + context.database->insert(row); + } + return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0); +} + +} diff --git a/Userland/Libraries/LibSQL/AST/Parser.cpp b/Userland/Libraries/LibSQL/AST/Parser.cpp index e32b4c0c64..2c0f0366e2 100644 --- a/Userland/Libraries/LibSQL/AST/Parser.cpp +++ b/Userland/Libraries/LibSQL/AST/Parser.cpp @@ -205,10 +205,16 @@ NonnullRefPtr<Insert> Parser::parse_insert_statement(RefPtr<CommonTableExpressio if (consume_if(TokenType::Values)) { parse_comma_separated_list(false, [&]() { - if (auto chained_expression = parse_chained_expression(); chained_expression.has_value()) - chained_expressions.append(move(chained_expression.value())); - else + if (auto chained_expression = parse_chained_expression(); chained_expression.has_value()) { + auto chained_expr = dynamic_cast<ChainedExpression*>(chained_expression->ptr()); + if ((column_names.size() > 0) && (chained_expr->expressions().size() != column_names.size())) { + syntax_error("Number of expressions does not match number of columns"); + } else { + chained_expressions.append(move(chained_expression.value())); + } + } else { expected("Chained expression"); + } }); } else if (match(TokenType::Select)) { select_statement = parse_select_statement({}); diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp new file mode 100644 index 0000000000..4239497e49 --- /dev/null +++ b/Userland/Libraries/LibSQL/AST/Select.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021, Jan de Visser <jan@de-visser.net> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <LibSQL/AST/AST.h> +#include <LibSQL/Database.h> +#include <LibSQL/Meta.h> +#include <LibSQL/Row.h> + +namespace SQL::AST { + +RefPtr<SQLResult> Select::execute(ExecutionContext& context) const +{ + if (table_or_subquery_list().size() == 1 && table_or_subquery_list()[0].is_table()) { + if (result_column_list().size() == 1 && result_column_list()[0].type() == ResultType::All) { + auto table = context.database->get_table(table_or_subquery_list()[0].schema_name(), table_or_subquery_list()[0].table_name()); + if (!table) { + return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_or_subquery_list()[0].table_name()); + } + NonnullRefPtr<TupleDescriptor> descriptor = table->to_tuple_descriptor(); + context.result = SQLResult::construct(); + for (auto& row : context.database->select_all(*table)) { + Tuple tuple(descriptor); + for (auto ix = 0u; ix < descriptor->size(); ix++) { + tuple[ix] = row[ix]; + } + context.result->append(tuple); + } + return context.result; + } + } + return SQLResult::construct(); +} + +} |