diff options
Diffstat (limited to 'Userland/Libraries/LibSQL')
-rw-r--r-- | Userland/Libraries/LibSQL/AST.h | 72 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Forward.h | 1 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.cpp | 75 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.h | 1 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp | 73 |
5 files changed, 218 insertions, 4 deletions
diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h index 44c5134ba3..a0b8f09a3b 100644 --- a/Userland/Libraries/LibSQL/AST.h +++ b/Userland/Libraries/LibSQL/AST.h @@ -58,7 +58,7 @@ public: } const String& name() const { return m_name; } - const NonnullRefPtrVector<SignedNumber> signed_numbers() const { return m_signed_numbers; } + const NonnullRefPtrVector<SignedNumber>& signed_numbers() const { return m_signed_numbers; } private: String m_name; @@ -706,7 +706,7 @@ public: const RefPtr<Select>& select_statement() const { return m_select_statement; } bool has_columns() const { return !m_columns.is_empty(); } - const NonnullRefPtrVector<ColumnDefinition> columns() const { return m_columns; } + const NonnullRefPtrVector<ColumnDefinition>& columns() const { return m_columns; } bool is_temporary() const { return m_is_temporary; } bool is_error_if_table_exists() const { return m_is_error_if_table_exists; } @@ -739,6 +739,74 @@ private: bool m_is_error_if_table_does_not_exist; }; +enum class ConflictResolution { + Abort, + Fail, + Ignore, + Replace, + Rollback, +}; + +class Insert : public Statement { +public: + Insert(RefPtr<CommonTableExpressionList> common_table_expression_list, ConflictResolution conflict_resolution, String schema_name, String table_name, String alias, Vector<String> column_names, NonnullRefPtrVector<ChainedExpression> chained_expressions) + : m_common_table_expression_list(move(common_table_expression_list)) + , m_conflict_resolution(conflict_resolution) + , m_schema_name(move(schema_name)) + , m_table_name(move(table_name)) + , m_alias(move(alias)) + , m_column_names(move(column_names)) + , m_chained_expressions(move(chained_expressions)) + { + } + + Insert(RefPtr<CommonTableExpressionList> common_table_expression_list, ConflictResolution conflict_resolution, String schema_name, String table_name, String alias, Vector<String> column_names, RefPtr<Select> select_statement) + : m_common_table_expression_list(move(common_table_expression_list)) + , m_conflict_resolution(conflict_resolution) + , m_schema_name(move(schema_name)) + , m_table_name(move(table_name)) + , m_alias(move(alias)) + , m_column_names(move(column_names)) + , m_select_statement(move(select_statement)) + { + } + + Insert(RefPtr<CommonTableExpressionList> common_table_expression_list, ConflictResolution conflict_resolution, String schema_name, String table_name, String alias, Vector<String> column_names) + : m_common_table_expression_list(move(common_table_expression_list)) + , m_conflict_resolution(conflict_resolution) + , m_schema_name(move(schema_name)) + , m_table_name(move(table_name)) + , m_alias(move(alias)) + , m_column_names(move(column_names)) + { + } + + const RefPtr<CommonTableExpressionList>& common_table_expression_list() const { return m_common_table_expression_list; } + ConflictResolution conflict_resolution() const { return m_conflict_resolution; } + const String& schema_name() const { return m_schema_name; } + const String& table_name() const { return m_table_name; } + const String& alias() const { return m_alias; } + const Vector<String>& column_names() const { return m_column_names; } + + bool default_values() const { return !has_expressions() && !has_selection(); }; + + bool has_expressions() const { return !m_chained_expressions.is_empty(); } + const NonnullRefPtrVector<ChainedExpression>& chained_expressions() const { return m_chained_expressions; } + + bool has_selection() const { return !m_select_statement.is_null(); } + const RefPtr<Select>& select_statement() const { return m_select_statement; } + +private: + RefPtr<CommonTableExpressionList> m_common_table_expression_list; + ConflictResolution m_conflict_resolution; + String m_schema_name; + String m_table_name; + String m_alias; + Vector<String> m_column_names; + NonnullRefPtrVector<ChainedExpression> m_chained_expressions; + RefPtr<Select> m_select_statement; +}; + class Delete : public Statement { public: Delete(RefPtr<CommonTableExpressionList> common_table_expression_list, NonnullRefPtr<QualifiedTableName> qualified_table_name, RefPtr<Expression> where_clause, RefPtr<ReturningClause> returning_clause) diff --git a/Userland/Libraries/LibSQL/Forward.h b/Userland/Libraries/LibSQL/Forward.h index 2569fc65d0..ff2ee64f43 100644 --- a/Userland/Libraries/LibSQL/Forward.h +++ b/Userland/Libraries/LibSQL/Forward.h @@ -29,6 +29,7 @@ class Expression; class GroupByClause; class InChainedExpression; class InSelectionExpression; +class Insert; class InTableExpression; class InvertibleNestedDoubleExpression; class InvertibleNestedExpression; diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp index 3ac2527b6f..9e273b1de5 100644 --- a/Userland/Libraries/LibSQL/Parser.cpp +++ b/Userland/Libraries/LibSQL/Parser.cpp @@ -36,12 +36,14 @@ NonnullRefPtr<Statement> Parser::parse_statement() return parse_create_table_statement(); case TokenType::Drop: return parse_drop_table_statement(); + case TokenType::Insert: + return parse_insert_statement({}); case TokenType::Delete: return parse_delete_statement({}); case TokenType::Select: return parse_select_statement({}); default: - expected("CREATE, DROP, DELETE, or SELECT"); + expected("CREATE, DROP, INSERT, DELETE, or SELECT"); return create_ast_node<ErrorStatement>(); } } @@ -49,12 +51,14 @@ NonnullRefPtr<Statement> Parser::parse_statement() NonnullRefPtr<Statement> Parser::parse_statement_with_expression_list(RefPtr<CommonTableExpressionList> common_table_expression_list) { switch (m_parser_state.m_token.type()) { + case TokenType::Insert: + return parse_insert_statement(move(common_table_expression_list)); case TokenType::Delete: return parse_delete_statement(move(common_table_expression_list)); case TokenType::Select: return parse_select_statement(move(common_table_expression_list)); default: - expected("DELETE or SELECT"); + expected("INSERT, DELETE or SELECT"); return create_ast_node<ErrorStatement>(); } } @@ -113,6 +117,73 @@ NonnullRefPtr<DropTable> Parser::parse_drop_table_statement() return create_ast_node<DropTable>(move(schema_name), move(table_name), is_error_if_table_does_not_exist); } +NonnullRefPtr<Delete> Parser::parse_insert_statement(RefPtr<CommonTableExpressionList> common_table_expression_list) +{ + // https://sqlite.org/lang_insert.html + consume(TokenType::Insert); + + auto conflict_resolution = ConflictResolution::Abort; + if (consume_if(TokenType::Or)) { + // https://sqlite.org/lang_conflict.html + if (consume_if(TokenType::Abort)) + conflict_resolution = ConflictResolution::Abort; + else if (consume_if(TokenType::Fail)) + conflict_resolution = ConflictResolution::Fail; + else if (consume_if(TokenType::Ignore)) + conflict_resolution = ConflictResolution::Ignore; + else if (consume_if(TokenType::Replace)) + conflict_resolution = ConflictResolution::Replace; + else if (consume_if(TokenType::Rollback)) + conflict_resolution = ConflictResolution::Rollback; + else + expected("ABORT, FAIL, IGNORE, REPLACE, or ROLLBACK"); + } + + consume(TokenType::Into); + + String schema_name; + String table_name; + parse_schema_and_table_name(schema_name, table_name); + + String alias; + if (consume_if(TokenType::As)) + alias = consume(TokenType::Identifier).value(); + + Vector<String> column_names; + if (match(TokenType::ParenOpen)) + parse_comma_separated_list(true, [&]() { column_names.append(consume(TokenType::Identifier).value()); }); + + NonnullRefPtrVector<ChainedExpression> chained_expressions; + RefPtr<Select> select_statement; + + if (consume_if(TokenType::Values)) { + parse_comma_separated_list(false, [&]() { + if (auto chained_expression = parse_chained_expression(); chained_expression.has_value()) + chained_expressions.append(move(chained_expression.value())); + else + expected("Chained expression"); + }); + } else if (match(TokenType::Select)) { + select_statement = parse_select_statement({}); + } else { + consume(TokenType::Default); + consume(TokenType::Values); + } + + RefPtr<ReturningClause> returning_clause; + if (match(TokenType::Returning)) + returning_clause = parse_returning_clause(); + + // FIXME: Parse 'upsert-clause'. + + if (!chained_expressions.is_empty()) + return create_ast_node<Insert>(move(common_table_expression_list), conflict_resolution, move(schema_name), move(table_name), move(alias), move(column_names), move(chained_expressions)); + if (!select_statement.is_null()) + return create_ast_node<Insert>(move(common_table_expression_list), conflict_resolution, move(schema_name), move(table_name), move(alias), move(column_names), move(select_statement)); + + return create_ast_node<Insert>(move(common_table_expression_list), conflict_resolution, move(schema_name), move(table_name), move(alias), move(column_names)); +} + NonnullRefPtr<Delete> Parser::parse_delete_statement(RefPtr<CommonTableExpressionList> common_table_expression_list) { // https://sqlite.org/lang_delete.html diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h index 14b03bbf52..1a0e635e83 100644 --- a/Userland/Libraries/LibSQL/Parser.h +++ b/Userland/Libraries/LibSQL/Parser.h @@ -54,6 +54,7 @@ private: NonnullRefPtr<Statement> parse_statement_with_expression_list(RefPtr<CommonTableExpressionList>); NonnullRefPtr<CreateTable> parse_create_table_statement(); NonnullRefPtr<DropTable> parse_drop_table_statement(); + NonnullRefPtr<Delete> parse_insert_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<Delete> parse_delete_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<Select> parse_select_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<CommonTableExpressionList> parse_common_table_expression_list(); diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp index 946089a169..2a417a427b 100644 --- a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp +++ b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp @@ -150,6 +150,79 @@ TEST_CASE(drop_table) validate("DROP TABLE IF EXISTS test;", {}, "test", false); } +TEST_CASE(insert) +{ + EXPECT(parse("INSERT").is_error()); + EXPECT(parse("INSERT INTO").is_error()); + EXPECT(parse("INSERT INTO table").is_error()); + EXPECT(parse("INSERT INTO table (column)").is_error()); + EXPECT(parse("INSERT INTO table (column, ) DEFAULT VALUES;").is_error()); + EXPECT(parse("INSERT INTO table VALUES").is_error()); + EXPECT(parse("INSERT INTO table VALUES ();").is_error()); + EXPECT(parse("INSERT INTO table VALUES (1)").is_error()); + EXPECT(parse("INSERT INTO table SELECT").is_error()); + EXPECT(parse("INSERT INTO table SELECT * from table").is_error()); + EXPECT(parse("INSERT OR INTO table DEFAULT VALUES;").is_error()); + EXPECT(parse("INSERT OR foo INTO table DEFAULT VALUES;").is_error()); + + auto validate = [](StringView sql, SQL::ConflictResolution expected_conflict_resolution, StringView expected_schema, StringView expected_table, StringView expected_alias, Vector<StringView> expected_column_names, Vector<size_t> expected_chain_sizes, bool expect_select_statement) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto statement = result.release_value(); + EXPECT(is<SQL::Insert>(*statement)); + + const auto& insert = static_cast<const SQL::Insert&>(*statement); + EXPECT_EQ(insert.conflict_resolution(), expected_conflict_resolution); + EXPECT_EQ(insert.schema_name(), expected_schema); + EXPECT_EQ(insert.table_name(), expected_table); + EXPECT_EQ(insert.alias(), expected_alias); + + const auto& column_names = insert.column_names(); + EXPECT_EQ(column_names.size(), expected_column_names.size()); + for (size_t i = 0; i < column_names.size(); ++i) + EXPECT_EQ(column_names[i], expected_column_names[i]); + + EXPECT_EQ(insert.has_expressions(), !expected_chain_sizes.is_empty()); + if (insert.has_expressions()) { + const auto& chained_expressions = insert.chained_expressions(); + EXPECT_EQ(chained_expressions.size(), expected_chain_sizes.size()); + + for (size_t i = 0; i < chained_expressions.size(); ++i) { + const auto& chained_expression = chained_expressions[i]; + const auto& expressions = chained_expression.expressions(); + EXPECT_EQ(expressions.size(), expected_chain_sizes[i]); + + for (const auto& expression : expressions) + EXPECT(!is<SQL::ErrorExpression>(expression)); + } + } + + EXPECT_EQ(insert.has_selection(), expect_select_statement); + EXPECT_EQ(insert.default_values(), expected_chain_sizes.is_empty() && !expect_select_statement); + }; + + validate("INSERT OR ABORT INTO table DEFAULT VALUES;", SQL::ConflictResolution::Abort, {}, "table", {}, {}, {}, false); + validate("INSERT OR FAIL INTO table DEFAULT VALUES;", SQL::ConflictResolution::Fail, {}, "table", {}, {}, {}, false); + validate("INSERT OR IGNORE INTO table DEFAULT VALUES;", SQL::ConflictResolution::Ignore, {}, "table", {}, {}, {}, false); + validate("INSERT OR REPLACE INTO table DEFAULT VALUES;", SQL::ConflictResolution::Replace, {}, "table", {}, {}, {}, false); + validate("INSERT OR ROLLBACK INTO table DEFAULT VALUES;", SQL::ConflictResolution::Rollback, {}, "table", {}, {}, {}, false); + + auto resolution = SQL::ConflictResolution::Abort; + validate("INSERT INTO table DEFAULT VALUES;", resolution, {}, "table", {}, {}, {}, false); + validate("INSERT INTO schema.table DEFAULT VALUES;", resolution, "schema", "table", {}, {}, {}, false); + validate("INSERT INTO table AS foo DEFAULT VALUES;", resolution, {}, "table", "foo", {}, {}, false); + + validate("INSERT INTO table (column) DEFAULT VALUES;", resolution, {}, "table", {}, { "column" }, {}, false); + validate("INSERT INTO table (column1, column2) DEFAULT VALUES;", resolution, {}, "table", {}, { "column1", "column2" }, {}, false); + + validate("INSERT INTO table VALUES (1);", resolution, {}, "table", {}, {}, { 1 }, false); + validate("INSERT INTO table VALUES (1, 2);", resolution, {}, "table", {}, {}, { 2 }, false); + validate("INSERT INTO table VALUES (1, 2), (3, 4, 5);", resolution, {}, "table", {}, {}, { 2, 3 }, false); + + validate("INSERT INTO table SELECT * FROM table;", resolution, {}, "table", {}, {}, {}, true); +} + TEST_CASE(delete_) { EXPECT(parse("DELETE").is_error()); |