summaryrefslogtreecommitdiff
path: root/Userland/Libraries/LibSQL/AST
diff options
context:
space:
mode:
Diffstat (limited to 'Userland/Libraries/LibSQL/AST')
-rw-r--r--Userland/Libraries/LibSQL/AST/AST.h24
-rw-r--r--Userland/Libraries/LibSQL/AST/CreateSchema.cpp6
-rw-r--r--Userland/Libraries/LibSQL/AST/CreateTable.cpp8
-rw-r--r--Userland/Libraries/LibSQL/AST/Expression.cpp90
-rw-r--r--Userland/Libraries/LibSQL/AST/Insert.cpp49
-rw-r--r--Userland/Libraries/LibSQL/AST/Parser.cpp12
-rw-r--r--Userland/Libraries/LibSQL/AST/Select.cpp37
7 files changed, 213 insertions, 13 deletions
diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h
index cd9001cf31..7a8ca428eb 100644
--- a/Userland/Libraries/LibSQL/AST/AST.h
+++ b/Userland/Libraries/LibSQL/AST/AST.h
@@ -295,7 +295,15 @@ private:
// Expressions
//==================================================================================================
+struct ExecutionContext {
+ NonnullRefPtr<Database> database;
+ RefPtr<SQLResult> result { nullptr };
+ Tuple current_row {};
+};
+
class Expression : public ASTNode {
+public:
+ virtual Value evaluate(ExecutionContext&) const;
};
class ErrorExpression final : public Expression {
@@ -309,6 +317,7 @@ public:
}
double value() const { return m_value; }
+ virtual Value evaluate(ExecutionContext&) const override;
private:
double m_value;
@@ -322,6 +331,7 @@ public:
}
const String& value() const { return m_value; }
+ virtual Value evaluate(ExecutionContext&) const override;
private:
String m_value;
@@ -341,11 +351,14 @@ private:
};
class NullLiteral : public Expression {
+public:
+ virtual Value evaluate(ExecutionContext&) const override;
};
class NestedExpression : public Expression {
public:
const NonnullRefPtr<Expression>& expression() const { return m_expression; }
+ virtual Value evaluate(ExecutionContext&) const override;
protected:
explicit NestedExpression(NonnullRefPtr<Expression> expression)
@@ -439,6 +452,7 @@ public:
}
UnaryOperator type() const { return m_type; }
+ virtual Value evaluate(ExecutionContext&) const override;
private:
UnaryOperator m_type;
@@ -488,6 +502,7 @@ public:
}
const NonnullRefPtrVector<Expression>& expressions() const { return m_expressions; }
+ virtual Value evaluate(ExecutionContext&) const override;
private:
NonnullRefPtrVector<Expression> m_expressions;
@@ -667,7 +682,7 @@ private:
class Statement : public ASTNode {
public:
- virtual RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const { return nullptr; }
+ virtual RefPtr<SQLResult> execute(ExecutionContext&) const { return nullptr; }
};
class ErrorStatement final : public Statement {
@@ -684,7 +699,7 @@ public:
const String& schema_name() const { return m_schema_name; }
bool is_error_if_schema_exists() const { return m_is_error_if_schema_exists; }
- RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override;
+ RefPtr<SQLResult> execute(ExecutionContext&) const override;
private:
String m_schema_name;
@@ -723,7 +738,7 @@ public:
bool is_temporary() const { return m_is_temporary; }
bool is_error_if_table_exists() const { return m_is_error_if_table_exists; }
- RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override;
+ RefPtr<SQLResult> execute(ExecutionContext&) const override;
private:
String m_schema_name;
@@ -886,6 +901,8 @@ public:
bool has_selection() const { return !m_select_statement.is_null(); }
const RefPtr<Select>& select_statement() const { return m_select_statement; }
+ RefPtr<SQLResult> execute(ExecutionContext&) const;
+
private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
ConflictResolution m_conflict_resolution;
@@ -977,6 +994,7 @@ public:
const RefPtr<GroupByClause>& group_by_clause() const { return m_group_by_clause; }
const NonnullRefPtrVector<OrderingTerm>& ordering_term_list() const { return m_ordering_term_list; }
const RefPtr<LimitClause>& limit_clause() const { return m_limit_clause; }
+ RefPtr<SQLResult> execute(ExecutionContext&) const override;
private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
diff --git a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp
index 423fbf44b4..e387279506 100644
--- a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp
+++ b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp
@@ -10,9 +10,9 @@
namespace SQL::AST {
-RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const
+RefPtr<SQLResult> CreateSchema::execute(ExecutionContext& context) const
{
- auto schema_def = database->get_schema(m_schema_name);
+ auto schema_def = context.database->get_schema(m_schema_name);
if (schema_def) {
if (m_is_error_if_schema_exists) {
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaExists, m_schema_name);
@@ -21,7 +21,7 @@ RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const
}
schema_def = SchemaDef::construct(m_schema_name);
- database->add_schema(*schema_def);
+ context.database->add_schema(*schema_def);
return SQLResult::construct(SQLCommand::Create, 0, 1);
}
diff --git a/Userland/Libraries/LibSQL/AST/CreateTable.cpp b/Userland/Libraries/LibSQL/AST/CreateTable.cpp
index 56729a7da1..090cb7036d 100644
--- a/Userland/Libraries/LibSQL/AST/CreateTable.cpp
+++ b/Userland/Libraries/LibSQL/AST/CreateTable.cpp
@@ -9,13 +9,13 @@
namespace SQL::AST {
-RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const
+RefPtr<SQLResult> CreateTable::execute(ExecutionContext& context) const
{
auto schema_name = (!m_schema_name.is_null() && !m_schema_name.is_empty()) ? m_schema_name : "default";
- auto schema_def = database->get_schema(schema_name);
+ auto schema_def = context.database->get_schema(schema_name);
if (!schema_def)
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaDoesNotExist, m_schema_name);
- auto table_def = database->get_table(schema_name, m_table_name);
+ auto table_def = context.database->get_table(schema_name, m_table_name);
if (table_def) {
if (m_is_error_if_table_exists) {
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::TableExists, m_table_name);
@@ -37,7 +37,7 @@ RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const
}
table_def->append_column(column.name(), type);
}
- database->add_table(*table_def);
+ context.database->add_table(*table_def);
return SQLResult::construct(SQLCommand::Create, 0, 1);
}
diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp
new file mode 100644
index 0000000000..a66257fb41
--- /dev/null
+++ b/Userland/Libraries/LibSQL/AST/Expression.cpp
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibSQL/AST/AST.h>
+#include <LibSQL/Database.h>
+
+namespace SQL::AST {
+
+Value Expression::evaluate(ExecutionContext&) const
+{
+ return Value::null();
+}
+
+Value NumericLiteral::evaluate(ExecutionContext&) const
+{
+ Value ret(SQLType::Float);
+ ret = value();
+ return ret;
+}
+
+Value StringLiteral::evaluate(ExecutionContext&) const
+{
+ Value ret(SQLType::Text);
+ ret = value();
+ return ret;
+}
+
+Value NullLiteral::evaluate(ExecutionContext&) const
+{
+ return Value::null();
+}
+
+Value NestedExpression::evaluate(ExecutionContext& context) const
+{
+ return expression()->evaluate(context);
+}
+
+Value ChainedExpression::evaluate(ExecutionContext& context) const
+{
+ Value ret(SQLType::Tuple);
+ Vector<Value> values;
+ for (auto& expression : expressions()) {
+ values.append(expression.evaluate(context));
+ }
+ ret = values;
+ return ret;
+}
+
+Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
+{
+ Value expression_value = NestedExpression::evaluate(context);
+ switch (type()) {
+ case UnaryOperator::Plus:
+ if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
+ return expression_value;
+ // TODO: Error handling.
+ VERIFY_NOT_REACHED();
+ case UnaryOperator::Minus:
+ if (expression_value.type() == SQLType::Integer) {
+ expression_value = -int(expression_value);
+ return expression_value;
+ }
+ if (expression_value.type() == SQLType::Float) {
+ expression_value = -double(expression_value);
+ return expression_value;
+ }
+ // TODO: Error handling.
+ VERIFY_NOT_REACHED();
+ case UnaryOperator::Not:
+ if (expression_value.type() == SQLType::Boolean) {
+ expression_value = !bool(expression_value);
+ return expression_value;
+ }
+ // TODO: Error handling.
+ VERIFY_NOT_REACHED();
+ case UnaryOperator::BitwiseNot:
+ if (expression_value.type() == SQLType::Integer) {
+ expression_value = ~u32(expression_value);
+ return expression_value;
+ }
+ // TODO: Error handling.
+ VERIFY_NOT_REACHED();
+ }
+ VERIFY_NOT_REACHED();
+}
+
+}
diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp
new file mode 100644
index 0000000000..2f7edc2486
--- /dev/null
+++ b/Userland/Libraries/LibSQL/AST/Insert.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibSQL/AST/AST.h>
+#include <LibSQL/Database.h>
+#include <LibSQL/Meta.h>
+#include <LibSQL/Row.h>
+
+namespace SQL::AST {
+
+RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
+{
+ auto table_def = context.database->get_table(m_schema_name, m_table_name);
+ if (!table_def) {
+ auto schema_name = m_schema_name;
+ if (schema_name.is_null() || schema_name.is_empty())
+ schema_name = "default";
+ return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::TableDoesNotExist, String::formatted("{}.{}", schema_name, m_table_name));
+ }
+
+ Row row(table_def);
+ for (auto& column : m_column_names) {
+ if (!row.has(column)) {
+ return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::ColumnDoesNotExist, column);
+ }
+ }
+
+ for (auto& row_expr : m_chained_expressions) {
+ for (auto& column_def : table_def->columns()) {
+ if (!m_column_names.contains_slow(column_def.name())) {
+ row[column_def.name()] = column_def.default_value();
+ }
+ }
+ auto row_value = row_expr.evaluate(context);
+ VERIFY(row_value.type() == SQLType::Tuple);
+ auto values = row_value.to_vector().value();
+ for (auto ix = 0u; ix < values.size(); ix++) {
+ auto& column_name = m_column_names[ix];
+ row[column_name] = values[ix];
+ }
+ context.database->insert(row);
+ }
+ return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0);
+}
+
+}
diff --git a/Userland/Libraries/LibSQL/AST/Parser.cpp b/Userland/Libraries/LibSQL/AST/Parser.cpp
index e32b4c0c64..2c0f0366e2 100644
--- a/Userland/Libraries/LibSQL/AST/Parser.cpp
+++ b/Userland/Libraries/LibSQL/AST/Parser.cpp
@@ -205,10 +205,16 @@ NonnullRefPtr<Insert> Parser::parse_insert_statement(RefPtr<CommonTableExpressio
if (consume_if(TokenType::Values)) {
parse_comma_separated_list(false, [&]() {
- if (auto chained_expression = parse_chained_expression(); chained_expression.has_value())
- chained_expressions.append(move(chained_expression.value()));
- else
+ if (auto chained_expression = parse_chained_expression(); chained_expression.has_value()) {
+ auto chained_expr = dynamic_cast<ChainedExpression*>(chained_expression->ptr());
+ if ((column_names.size() > 0) && (chained_expr->expressions().size() != column_names.size())) {
+ syntax_error("Number of expressions does not match number of columns");
+ } else {
+ chained_expressions.append(move(chained_expression.value()));
+ }
+ } else {
expected("Chained expression");
+ }
});
} else if (match(TokenType::Select)) {
select_statement = parse_select_statement({});
diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp
new file mode 100644
index 0000000000..4239497e49
--- /dev/null
+++ b/Userland/Libraries/LibSQL/AST/Select.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#include <LibSQL/AST/AST.h>
+#include <LibSQL/Database.h>
+#include <LibSQL/Meta.h>
+#include <LibSQL/Row.h>
+
+namespace SQL::AST {
+
+RefPtr<SQLResult> Select::execute(ExecutionContext& context) const
+{
+ if (table_or_subquery_list().size() == 1 && table_or_subquery_list()[0].is_table()) {
+ if (result_column_list().size() == 1 && result_column_list()[0].type() == ResultType::All) {
+ auto table = context.database->get_table(table_or_subquery_list()[0].schema_name(), table_or_subquery_list()[0].table_name());
+ if (!table) {
+ return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_or_subquery_list()[0].table_name());
+ }
+ NonnullRefPtr<TupleDescriptor> descriptor = table->to_tuple_descriptor();
+ context.result = SQLResult::construct();
+ for (auto& row : context.database->select_all(*table)) {
+ Tuple tuple(descriptor);
+ for (auto ix = 0u; ix < descriptor->size(); ix++) {
+ tuple[ix] = row[ix];
+ }
+ context.result->append(tuple);
+ }
+ return context.result;
+ }
+ }
+ return SQLResult::construct();
+}
+
+}