diff options
Diffstat (limited to 'Userland')
-rw-r--r-- | Userland/Libraries/LibSQL/AST.h | 16 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.cpp | 35 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.h | 1 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp | 33 |
4 files changed, 80 insertions, 5 deletions
diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h index 922432b48a..96f0e970dc 100644 --- a/Userland/Libraries/LibSQL/AST.h +++ b/Userland/Libraries/LibSQL/AST.h @@ -537,6 +537,22 @@ private: RefPtr<Expression> m_else_expression; }; +class ExistsExpression : public Expression { +public: + ExistsExpression(NonnullRefPtr<Select> select_statement, bool invert_expression) + : m_select_statement(move(select_statement)) + , m_invert_expression(invert_expression) + { + } + + const NonnullRefPtr<Select>& select_statement() const { return m_select_statement; } + bool invert_expression() const { return m_invert_expression; } + +private: + NonnullRefPtr<Select> m_select_statement; + bool m_invert_expression; +}; + class CollateExpression : public NestedExpression { public: CollateExpression(NonnullRefPtr<Expression> expression, String collation_name) diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp index 69b30e37ad..8775c2d3b3 100644 --- a/Userland/Libraries/LibSQL/Parser.cpp +++ b/Userland/Libraries/LibSQL/Parser.cpp @@ -215,7 +215,6 @@ NonnullRefPtr<Expression> Parser::parse_expression() // FIXME: Parse 'bind-parameter'. // FIXME: Parse 'function-name'. - // FIXME: Parse 'exists'. // FIXME: Parse 'raise-function'. return expression; @@ -241,6 +240,9 @@ NonnullRefPtr<Expression> Parser::parse_primary_expression() if (auto expression = parse_case_expression(); expression.has_value()) return move(expression.value()); + if (auto expression = parse_exists_expression(false); expression.has_value()) + return move(expression.value()); + expected("Primary Expression"); consume(); @@ -381,8 +383,12 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_unary_operator_expression() if (consume_if(TokenType::Tilde)) return create_ast_node<UnaryOperatorExpression>(UnaryOperator::BitwiseNot, parse_expression()); - if (consume_if(TokenType::Not)) - return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Not, parse_expression()); + if (consume_if(TokenType::Not)) { + if (match(TokenType::Exists)) + return parse_exists_expression(true); + else + return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Not, parse_expression()); + } return {}; } @@ -448,11 +454,15 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_binary_operator_expression(Non Optional<NonnullRefPtr<Expression>> Parser::parse_chained_expression() { - if (!match(TokenType::ParenOpen)) + if (!consume_if(TokenType::ParenOpen)) return {}; + if (match(TokenType::Select)) + return parse_exists_expression(false, TokenType::Select); + NonnullRefPtrVector<Expression> expressions; - parse_comma_separated_list(true, [&]() { expressions.append(parse_expression()); }); + parse_comma_separated_list(false, [&]() { expressions.append(parse_expression()); }); + consume(TokenType::ParenClose); return create_ast_node<ChainedExpression>(move(expressions)); } @@ -506,6 +516,21 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_case_expression() return create_ast_node<CaseExpression>(move(case_expression), move(when_then_clauses), move(else_expression)); } +Optional<NonnullRefPtr<Expression>> Parser::parse_exists_expression(bool invert_expression, TokenType opening_token) +{ + VERIFY((opening_token == TokenType::Exists) || (opening_token == TokenType::Select)); + + if ((opening_token == TokenType::Exists) && !consume_if(TokenType::Exists)) + return {}; + + if (opening_token == TokenType::Exists) + consume(TokenType::ParenOpen); + auto select_statement = parse_select_statement({}); + consume(TokenType::ParenClose); + + return create_ast_node<ExistsExpression>(move(select_statement), invert_expression); +} + Optional<NonnullRefPtr<Expression>> Parser::parse_collate_expression(NonnullRefPtr<Expression> expression) { if (!match(TokenType::Collate)) diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h index 6d6cf91591..14b03bbf52 100644 --- a/Userland/Libraries/LibSQL/Parser.h +++ b/Userland/Libraries/LibSQL/Parser.h @@ -68,6 +68,7 @@ private: Optional<NonnullRefPtr<Expression>> parse_chained_expression(); Optional<NonnullRefPtr<Expression>> parse_cast_expression(); Optional<NonnullRefPtr<Expression>> parse_case_expression(); + Optional<NonnullRefPtr<Expression>> parse_exists_expression(bool invert_expression, TokenType opening_token = TokenType::Exists); Optional<NonnullRefPtr<Expression>> parse_collate_expression(NonnullRefPtr<Expression> expression); Optional<NonnullRefPtr<Expression>> parse_is_expression(NonnullRefPtr<Expression> expression); Optional<NonnullRefPtr<Expression>> parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression); diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp index 4d8952dc14..1e7a2ef187 100644 --- a/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp +++ b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp @@ -338,6 +338,39 @@ TEST_CASE(case_expression) validate("CASE 15 WHEN 16 THEN 17 WHEN 18 THEN 19 ELSE 20 END", true, 2, true); } +TEST_CASE(exists_expression) +{ + EXPECT(parse("EXISTS").is_error()); + EXPECT(parse("EXISTS (").is_error()); + EXPECT(parse("EXISTS (SELECT").is_error()); + EXPECT(parse("EXISTS (SELECT)").is_error()); + EXPECT(parse("EXISTS (SELECT * FROM table").is_error()); + EXPECT(parse("NOT EXISTS").is_error()); + EXPECT(parse("NOT EXISTS (").is_error()); + EXPECT(parse("NOT EXISTS (SELECT").is_error()); + EXPECT(parse("NOT EXISTS (SELECT)").is_error()); + EXPECT(parse("NOT EXISTS (SELECT * FROM table").is_error()); + EXPECT(parse("(").is_error()); + EXPECT(parse("(SELECT").is_error()); + EXPECT(parse("(SELECT)").is_error()); + EXPECT(parse("(SELECT * FROM table").is_error()); + + auto validate = [](StringView sql, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::ExistsExpression>(*expression)); + + const auto& exists = static_cast<const SQL::ExistsExpression&>(*expression); + EXPECT_EQ(exists.invert_expression(), expected_invert_expression); + }; + + validate("EXISTS (SELECT * FROM table)", false); + validate("NOT EXISTS (SELECT * FROM table)", true); + validate("(SELECT * FROM table)", false); +} + TEST_CASE(collate_expression) { EXPECT(parse("COLLATE").is_error()); |