summaryrefslogtreecommitdiff
path: root/Userland
diff options
context:
space:
mode:
authorTimothy Flynn <trflynn89@pm.me>2021-04-21 17:12:06 -0400
committerAndreas Kling <kling@serenityos.org>2021-04-22 18:08:15 +0200
commit9331293e4417e1dc3813396defa2c51793f19570 (patch)
treec0b2053653d849328d135e7ff9f608e1958d177f /Userland
parent6a7d7624a7175ec52864172b872e08c89c0dc4c3 (diff)
downloadserenity-9331293e4417e1dc3813396defa2c51793f19570.zip
LibSQL: Separate parsing of common-table-expression list
Statements like SELECT, INSERT, and UPDATE also optionally include this list, so move its parsing out of parse_delete_statement(). Since it will appear before the actual statement, parse it first in next_statement(); then only parse for statements that are allowed to include the list.
Diffstat (limited to 'Userland')
-rw-r--r--Userland/Libraries/LibSQL/Parser.cpp60
-rw-r--r--Userland/Libraries/LibSQL/Parser.h5
-rw-r--r--Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp2
3 files changed, 46 insertions, 21 deletions
diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp
index c215fcedc9..3208ef9f78 100644
--- a/Userland/Libraries/LibSQL/Parser.cpp
+++ b/Userland/Libraries/LibSQL/Parser.cpp
@@ -16,20 +16,40 @@ Parser::Parser(Lexer lexer)
NonnullRefPtr<Statement> Parser::next_statement()
{
+ if (match(TokenType::With)) {
+ auto common_table_expression_list = parse_common_table_expression_list();
+ return parse_statement_with_expression_list(move(common_table_expression_list));
+ }
+
+ return parse_statement();
+}
+
+NonnullRefPtr<Statement> Parser::parse_statement()
+{
switch (m_parser_state.m_token.type()) {
case TokenType::Create:
return parse_create_table_statement();
case TokenType::Drop:
return parse_drop_table_statement();
case TokenType::Delete:
- case TokenType::With:
- return parse_delete_statement();
+ return parse_delete_statement({});
default:
expected("CREATE, DROP, or DELETE");
return create_ast_node<ErrorStatement>();
}
}
+NonnullRefPtr<Statement> Parser::parse_statement_with_expression_list(RefPtr<CommonTableExpressionList> common_table_expression_list)
+{
+ switch (m_parser_state.m_token.type()) {
+ case TokenType::Delete:
+ return parse_delete_statement(move(common_table_expression_list));
+ default:
+ expected("DELETE");
+ return create_ast_node<ErrorStatement>();
+ }
+}
+
NonnullRefPtr<CreateTable> Parser::parse_create_table_statement()
{
// https://sqlite.org/lang_createtable.html
@@ -108,26 +128,9 @@ NonnullRefPtr<DropTable> Parser::parse_drop_table_statement()
return create_ast_node<DropTable>(move(schema_name), move(table_name), is_error_if_table_does_not_exist);
}
-NonnullRefPtr<Delete> Parser::parse_delete_statement()
+NonnullRefPtr<Delete> Parser::parse_delete_statement(RefPtr<CommonTableExpressionList> common_table_expression_list)
{
// https://sqlite.org/lang_delete.html
-
- RefPtr<CommonTableExpressionList> common_table_expression_list;
- if (consume_if(TokenType::With)) {
- NonnullRefPtrVector<CommonTableExpression> common_table_expression;
- bool recursive = consume_if(TokenType::Recursive);
-
- do {
- common_table_expression.append(parse_common_table_expression());
- if (!match(TokenType::Comma))
- break;
-
- consume(TokenType::Comma);
- } while (!match(TokenType::Eof));
-
- common_table_expression_list = create_ast_node<CommonTableExpressionList>(recursive, move(common_table_expression));
- }
-
consume(TokenType::Delete);
consume(TokenType::From);
auto qualified_table_name = parse_qualified_table_name();
@@ -145,6 +148,23 @@ NonnullRefPtr<Delete> Parser::parse_delete_statement()
return create_ast_node<Delete>(move(common_table_expression_list), move(qualified_table_name), move(where_clause), move(returning_clause));
}
+NonnullRefPtr<CommonTableExpressionList> Parser::parse_common_table_expression_list()
+{
+ consume(TokenType::With);
+ bool recursive = consume_if(TokenType::Recursive);
+
+ NonnullRefPtrVector<CommonTableExpression> common_table_expression;
+ do {
+ common_table_expression.append(parse_common_table_expression());
+ if (!match(TokenType::Comma))
+ break;
+
+ consume(TokenType::Comma);
+ } while (!match(TokenType::Eof));
+
+ return create_ast_node<CommonTableExpressionList>(recursive, move(common_table_expression));
+}
+
NonnullRefPtr<Expression> Parser::parse_expression()
{
// https://sqlite.org/lang_expr.html
diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h
index 953fa9abd2..1a4796171a 100644
--- a/Userland/Libraries/LibSQL/Parser.h
+++ b/Userland/Libraries/LibSQL/Parser.h
@@ -50,9 +50,12 @@ private:
Vector<Error> m_errors;
};
+ NonnullRefPtr<Statement> parse_statement();
+ NonnullRefPtr<Statement> parse_statement_with_expression_list(RefPtr<CommonTableExpressionList>);
NonnullRefPtr<CreateTable> parse_create_table_statement();
NonnullRefPtr<DropTable> parse_drop_table_statement();
- NonnullRefPtr<Delete> parse_delete_statement();
+ NonnullRefPtr<Delete> parse_delete_statement(RefPtr<CommonTableExpressionList>);
+ NonnullRefPtr<CommonTableExpressionList> parse_common_table_expression_list();
NonnullRefPtr<Expression> parse_primary_expression();
NonnullRefPtr<Expression> parse_secondary_expression(NonnullRefPtr<Expression> primary);
diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp
index 6b45a252c3..5dc40b8311 100644
--- a/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp
+++ b/Userland/Libraries/LibSQL/Tests/TestSqlStatementParser.cpp
@@ -54,6 +54,7 @@ TEST_CASE(create_table)
EXPECT(parse("CREATE TABLE test ( column1 varchar(.abc) )").is_error());
EXPECT(parse("CREATE TABLE test ( column1 varchar(0x) )").is_error());
EXPECT(parse("CREATE TABLE test ( column1 varchar(0xzzz) )").is_error());
+ EXPECT(parse("WITH table AS () CREATE TABLE test ( column1 );").is_error());
struct Column {
StringView name;
@@ -118,6 +119,7 @@ TEST_CASE(drop_table)
EXPECT(parse("DROP TABLE").is_error());
EXPECT(parse("DROP TABLE test").is_error());
EXPECT(parse("DROP TABLE IF test;").is_error());
+ EXPECT(parse("WITH table AS () DROP TABLE test;").is_error());
auto validate = [](StringView sql, StringView expected_schema, StringView expected_table, bool expected_is_error_if_table_does_not_exist = true) {
auto result = parse(sql);