diff options
author | Timothy Flynn <trflynn89@pm.me> | 2021-04-23 21:40:19 -0400 |
---|---|---|
committer | Linus Groh <mail@linusgroh.de> | 2021-04-24 14:22:08 +0200 |
commit | 0764a686160e8db4473d52051e0b5324593eddf1 (patch) | |
tree | d61ee2f275e66000ee6cdc8856c115ece417faa2 /Userland/Libraries | |
parent | 8d79b4a3e19c123c228539e10695f7251e7ebbe6 (diff) | |
download | serenity-0764a686160e8db4473d52051e0b5324593eddf1.zip |
LibSQL: Parse UPDATE statement
This also migrates parsing of conflict resolution to a helper method,
since both INSERT and UPDATE need it.
Diffstat (limited to 'Userland/Libraries')
-rw-r--r-- | Userland/Libraries/LibSQL/AST.h | 36 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Forward.h | 1 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.cpp | 86 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.h | 2 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp | 93 |
5 files changed, 198 insertions, 20 deletions
diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h index a0b8f09a3b..1e1f2714b6 100644 --- a/Userland/Libraries/LibSQL/AST.h +++ b/Userland/Libraries/LibSQL/AST.h @@ -807,6 +807,42 @@ private: RefPtr<Select> m_select_statement; }; +class Update : public Statement { +public: + struct UpdateColumns { + Vector<String> column_names; + NonnullRefPtr<Expression> expression; + }; + + Update(RefPtr<CommonTableExpressionList> common_table_expression_list, ConflictResolution conflict_resolution, NonnullRefPtr<QualifiedTableName> qualified_table_name, Vector<UpdateColumns> update_columns, NonnullRefPtrVector<TableOrSubquery> table_or_subquery_list, RefPtr<Expression> where_clause, RefPtr<ReturningClause> returning_clause) + : m_common_table_expression_list(move(common_table_expression_list)) + , m_conflict_resolution(conflict_resolution) + , m_qualified_table_name(move(qualified_table_name)) + , m_update_columns(move(update_columns)) + , m_table_or_subquery_list(move(table_or_subquery_list)) + , m_where_clause(move(where_clause)) + , m_returning_clause(move(returning_clause)) + { + } + + const RefPtr<CommonTableExpressionList>& common_table_expression_list() const { return m_common_table_expression_list; } + ConflictResolution conflict_resolution() const { return m_conflict_resolution; } + const NonnullRefPtr<QualifiedTableName>& qualified_table_name() const { return m_qualified_table_name; } + const Vector<UpdateColumns>& update_columns() const { return m_update_columns; } + const NonnullRefPtrVector<TableOrSubquery>& table_or_subquery_list() const { return m_table_or_subquery_list; } + const RefPtr<Expression>& where_clause() const { return m_where_clause; } + const RefPtr<ReturningClause>& returning_clause() const { return m_returning_clause; } + +private: + RefPtr<CommonTableExpressionList> m_common_table_expression_list; + ConflictResolution m_conflict_resolution; + NonnullRefPtr<QualifiedTableName> m_qualified_table_name; + Vector<UpdateColumns> m_update_columns; + NonnullRefPtrVector<TableOrSubquery> m_table_or_subquery_list; + RefPtr<Expression> m_where_clause; + RefPtr<ReturningClause> m_returning_clause; +}; + class Delete : public Statement { public: Delete(RefPtr<CommonTableExpressionList> common_table_expression_list, NonnullRefPtr<QualifiedTableName> qualified_table_name, RefPtr<Expression> where_clause, RefPtr<ReturningClause> returning_clause) diff --git a/Userland/Libraries/LibSQL/Forward.h b/Userland/Libraries/LibSQL/Forward.h index ff2ee64f43..3f83078e40 100644 --- a/Userland/Libraries/LibSQL/Forward.h +++ b/Userland/Libraries/LibSQL/Forward.h @@ -55,4 +55,5 @@ class TableOrSubquery; class Token; class TypeName; class UnaryOperatorExpression; +class Update; } diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp index 9e273b1de5..de96359776 100644 --- a/Userland/Libraries/LibSQL/Parser.cpp +++ b/Userland/Libraries/LibSQL/Parser.cpp @@ -38,12 +38,14 @@ NonnullRefPtr<Statement> Parser::parse_statement() return parse_drop_table_statement(); case TokenType::Insert: return parse_insert_statement({}); + case TokenType::Update: + return parse_update_statement({}); case TokenType::Delete: return parse_delete_statement({}); case TokenType::Select: return parse_select_statement({}); default: - expected("CREATE, DROP, INSERT, DELETE, or SELECT"); + expected("CREATE, DROP, INSERT, UPDATE, DELETE, or SELECT"); return create_ast_node<ErrorStatement>(); } } @@ -53,12 +55,14 @@ NonnullRefPtr<Statement> Parser::parse_statement_with_expression_list(RefPtr<Com switch (m_parser_state.m_token.type()) { case TokenType::Insert: return parse_insert_statement(move(common_table_expression_list)); + case TokenType::Update: + return parse_update_statement(move(common_table_expression_list)); case TokenType::Delete: return parse_delete_statement(move(common_table_expression_list)); case TokenType::Select: return parse_select_statement(move(common_table_expression_list)); default: - expected("INSERT, DELETE or SELECT"); + expected("INSERT, UPDATE, DELETE or SELECT"); return create_ast_node<ErrorStatement>(); } } @@ -121,24 +125,7 @@ NonnullRefPtr<Delete> Parser::parse_insert_statement(RefPtr<CommonTableExpressio { // https://sqlite.org/lang_insert.html consume(TokenType::Insert); - - auto conflict_resolution = ConflictResolution::Abort; - if (consume_if(TokenType::Or)) { - // https://sqlite.org/lang_conflict.html - if (consume_if(TokenType::Abort)) - conflict_resolution = ConflictResolution::Abort; - else if (consume_if(TokenType::Fail)) - conflict_resolution = ConflictResolution::Fail; - else if (consume_if(TokenType::Ignore)) - conflict_resolution = ConflictResolution::Ignore; - else if (consume_if(TokenType::Replace)) - conflict_resolution = ConflictResolution::Replace; - else if (consume_if(TokenType::Rollback)) - conflict_resolution = ConflictResolution::Rollback; - else - expected("ABORT, FAIL, IGNORE, REPLACE, or ROLLBACK"); - } - + auto conflict_resolution = parse_conflict_resolution(); consume(TokenType::Into); String schema_name; @@ -184,6 +171,44 @@ NonnullRefPtr<Delete> Parser::parse_insert_statement(RefPtr<CommonTableExpressio return create_ast_node<Insert>(move(common_table_expression_list), conflict_resolution, move(schema_name), move(table_name), move(alias), move(column_names)); } +NonnullRefPtr<Delete> Parser::parse_update_statement(RefPtr<CommonTableExpressionList> common_table_expression_list) +{ + // https://sqlite.org/lang_update.html + consume(TokenType::Update); + auto conflict_resolution = parse_conflict_resolution(); + auto qualified_table_name = parse_qualified_table_name(); + consume(TokenType::Set); + + Vector<Update::UpdateColumns> update_columns; + parse_comma_separated_list(false, [&]() { + Vector<String> column_names; + if (match(TokenType::ParenOpen)) { + parse_comma_separated_list(true, [&]() { column_names.append(consume(TokenType::Identifier).value()); }); + } else { + column_names.append(consume(TokenType::Identifier).value()); + } + + consume(TokenType::Equals); + update_columns.append({ move(column_names), parse_expression() }); + }); + + NonnullRefPtrVector<TableOrSubquery> table_or_subquery_list; + if (consume_if(TokenType::From)) { + // FIXME: Parse join-clause. + parse_comma_separated_list(false, [&]() { table_or_subquery_list.append(parse_table_or_subquery()); }); + } + + RefPtr<Expression> where_clause; + if (consume_if(TokenType::Where)) + where_clause = parse_expression(); + + RefPtr<ReturningClause> returning_clause; + if (match(TokenType::Returning)) + returning_clause = parse_returning_clause(); + + return create_ast_node<Update>(move(common_table_expression_list), conflict_resolution, move(qualified_table_name), move(update_columns), move(table_or_subquery_list), move(where_clause), move(returning_clause)); +} + NonnullRefPtr<Delete> Parser::parse_delete_statement(RefPtr<CommonTableExpressionList> common_table_expression_list) { // https://sqlite.org/lang_delete.html @@ -932,6 +957,27 @@ void Parser::parse_schema_and_table_name(String& schema_name, String& table_name } } +ConflictResolution Parser::parse_conflict_resolution() +{ + // https://sqlite.org/lang_conflict.html + if (consume_if(TokenType::Or)) { + if (consume_if(TokenType::Abort)) + return ConflictResolution::Abort; + if (consume_if(TokenType::Fail)) + return ConflictResolution::Fail; + if (consume_if(TokenType::Ignore)) + return ConflictResolution::Ignore; + if (consume_if(TokenType::Replace)) + return ConflictResolution::Replace; + if (consume_if(TokenType::Rollback)) + return ConflictResolution::Rollback; + + expected("ABORT, FAIL, IGNORE, REPLACE, or ROLLBACK"); + } + + return ConflictResolution::Abort; +} + Token Parser::consume() { auto old_token = m_parser_state.m_token; diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h index 1a0e635e83..43463508ad 100644 --- a/Userland/Libraries/LibSQL/Parser.h +++ b/Userland/Libraries/LibSQL/Parser.h @@ -55,6 +55,7 @@ private: NonnullRefPtr<CreateTable> parse_create_table_statement(); NonnullRefPtr<DropTable> parse_drop_table_statement(); NonnullRefPtr<Delete> parse_insert_statement(RefPtr<CommonTableExpressionList>); + NonnullRefPtr<Delete> parse_update_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<Delete> parse_delete_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<Select> parse_select_statement(RefPtr<CommonTableExpressionList>); NonnullRefPtr<CommonTableExpressionList> parse_common_table_expression_list(); @@ -87,6 +88,7 @@ private: NonnullRefPtr<TableOrSubquery> parse_table_or_subquery(); NonnullRefPtr<OrderingTerm> parse_ordering_term(); void parse_schema_and_table_name(String& schema_name, String& table_name); + ConflictResolution parse_conflict_resolution(); template<typename ParseCallback> void parse_comma_separated_list(bool surrounded_by_parentheses, ParseCallback&& parse_callback) diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp index 2a417a427b..a7456d3444 100644 --- a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp +++ b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp @@ -223,6 +223,99 @@ TEST_CASE(insert) validate("INSERT INTO table SELECT * FROM table;", resolution, {}, "table", {}, {}, {}, true); } +TEST_CASE(update) +{ + EXPECT(parse("UPDATE").is_error()); + EXPECT(parse("UPDATE table").is_error()); + EXPECT(parse("UPDATE table SET").is_error()); + EXPECT(parse("UPDATE table SET column").is_error()); + EXPECT(parse("UPDATE table SET column=4").is_error()); + EXPECT(parse("UPDATE table SET column=4, ;").is_error()); + EXPECT(parse("UPDATE table SET (column)=4").is_error()); + EXPECT(parse("UPDATE table SET (column)=4, ;").is_error()); + EXPECT(parse("UPDATE table SET (column, )=4;").is_error()); + EXPECT(parse("UPDATE table SET column=4 FROM").is_error()); + EXPECT(parse("UPDATE table SET column=4 FROM table").is_error()); + EXPECT(parse("UPDATE table SET column=4 WHERE").is_error()); + EXPECT(parse("UPDATE table SET column=4 WHERE 1==1").is_error()); + EXPECT(parse("UPDATE table SET column=4 RETURNING").is_error()); + EXPECT(parse("UPDATE table SET column=4 RETURNING *").is_error()); + EXPECT(parse("UPDATE table SET column=4 RETURNING column").is_error()); + EXPECT(parse("UPDATE table SET column=4 RETURNING column AS").is_error()); + EXPECT(parse("UPDATE OR table SET column=4;").is_error()); + EXPECT(parse("UPDATE OR foo table SET column=4;").is_error()); + + auto validate = [](StringView sql, SQL::ConflictResolution expected_conflict_resolution, StringView expected_schema, StringView expected_table, StringView expected_alias, Vector<Vector<String>> expected_update_columns, bool expect_where_clause, bool expect_returning_clause, Vector<StringView> expected_returned_column_aliases) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto statement = result.release_value(); + EXPECT(is<SQL::Update>(*statement)); + + const auto& update = static_cast<const SQL::Update&>(*statement); + EXPECT_EQ(update.conflict_resolution(), expected_conflict_resolution); + + const auto& qualified_table_name = update.qualified_table_name(); + EXPECT_EQ(qualified_table_name->schema_name(), expected_schema); + EXPECT_EQ(qualified_table_name->table_name(), expected_table); + EXPECT_EQ(qualified_table_name->alias(), expected_alias); + + const auto& update_columns = update.update_columns(); + EXPECT_EQ(update_columns.size(), expected_update_columns.size()); + for (size_t i = 0; i < update_columns.size(); ++i) { + const auto& update_column = update_columns[i]; + const auto& expected_update_column = expected_update_columns[i]; + EXPECT_EQ(update_column.column_names.size(), expected_update_column.size()); + EXPECT(!is<SQL::ErrorExpression>(*update_column.expression)); + + for (size_t j = 0; j < update_column.column_names.size(); ++j) + EXPECT_EQ(update_column.column_names[j], expected_update_column[j]); + } + + const auto& where_clause = update.where_clause(); + EXPECT_EQ(where_clause.is_null(), !expect_where_clause); + if (where_clause) + EXPECT(!is<SQL::ErrorExpression>(*where_clause)); + + const auto& returning_clause = update.returning_clause(); + EXPECT_EQ(returning_clause.is_null(), !expect_returning_clause); + if (returning_clause) { + EXPECT_EQ(returning_clause->columns().size(), expected_returned_column_aliases.size()); + + for (size_t i = 0; i < returning_clause->columns().size(); ++i) { + const auto& column = returning_clause->columns()[i]; + const auto& expected_column_alias = expected_returned_column_aliases[i]; + + EXPECT(!is<SQL::ErrorExpression>(*column.expression)); + EXPECT_EQ(column.column_alias, expected_column_alias); + } + } + }; + + Vector<Vector<String>> update_columns { { "column" } }; + validate("UPDATE OR ABORT table SET column=1;", SQL::ConflictResolution::Abort, {}, "table", {}, update_columns, false, false, {}); + validate("UPDATE OR FAIL table SET column=1;", SQL::ConflictResolution::Fail, {}, "table", {}, update_columns, false, false, {}); + validate("UPDATE OR IGNORE table SET column=1;", SQL::ConflictResolution::Ignore, {}, "table", {}, update_columns, false, false, {}); + validate("UPDATE OR REPLACE table SET column=1;", SQL::ConflictResolution::Replace, {}, "table", {}, update_columns, false, false, {}); + validate("UPDATE OR ROLLBACK table SET column=1;", SQL::ConflictResolution::Rollback, {}, "table", {}, update_columns, false, false, {}); + + auto resolution = SQL::ConflictResolution::Abort; + validate("UPDATE table SET column=1;", resolution, {}, "table", {}, update_columns, false, false, {}); + validate("UPDATE schema.table SET column=1;", resolution, "schema", "table", {}, update_columns, false, false, {}); + validate("UPDATE table AS foo SET column=1;", resolution, {}, "table", "foo", update_columns, false, false, {}); + + validate("UPDATE table SET column=1;", resolution, {}, "table", {}, { { "column" } }, false, false, {}); + validate("UPDATE table SET column1=1, column2=2;", resolution, {}, "table", {}, { { "column1" }, { "column2" } }, false, false, {}); + validate("UPDATE table SET (column1, column2)=1, column3=2;", resolution, {}, "table", {}, { { "column1", "column2" }, { "column3" } }, false, false, {}); + + validate("UPDATE table SET column=1 WHERE 1==1;", resolution, {}, "table", {}, update_columns, true, false, {}); + + validate("UPDATE table SET column=1 RETURNING *;", resolution, {}, "table", {}, update_columns, false, true, {}); + validate("UPDATE table SET column=1 RETURNING column;", resolution, {}, "table", {}, update_columns, false, true, { {} }); + validate("UPDATE table SET column=1 RETURNING column AS alias;", resolution, {}, "table", {}, update_columns, false, true, { "alias" }); + validate("UPDATE table SET column=1 RETURNING column1 AS alias1, column2 AS alias2;", resolution, {}, "table", {}, update_columns, false, true, { "alias1", "alias2" }); +} + TEST_CASE(delete_) { EXPECT(parse("DELETE").is_error()); |