summaryrefslogtreecommitdiff
path: root/Userland/Libraries/LibSQL
diff options
context:
space:
mode:
Diffstat (limited to 'Userland/Libraries/LibSQL')
-rw-r--r--Userland/Libraries/LibSQL/AST.h16
-rw-r--r--Userland/Libraries/LibSQL/Parser.cpp35
-rw-r--r--Userland/Libraries/LibSQL/Parser.h1
-rw-r--r--Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp33
4 files changed, 80 insertions, 5 deletions
diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h
index 922432b48a..96f0e970dc 100644
--- a/Userland/Libraries/LibSQL/AST.h
+++ b/Userland/Libraries/LibSQL/AST.h
@@ -537,6 +537,22 @@ private:
RefPtr<Expression> m_else_expression;
};
+class ExistsExpression : public Expression {
+public:
+ ExistsExpression(NonnullRefPtr<Select> select_statement, bool invert_expression)
+ : m_select_statement(move(select_statement))
+ , m_invert_expression(invert_expression)
+ {
+ }
+
+ const NonnullRefPtr<Select>& select_statement() const { return m_select_statement; }
+ bool invert_expression() const { return m_invert_expression; }
+
+private:
+ NonnullRefPtr<Select> m_select_statement;
+ bool m_invert_expression;
+};
+
class CollateExpression : public NestedExpression {
public:
CollateExpression(NonnullRefPtr<Expression> expression, String collation_name)
diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp
index 69b30e37ad..8775c2d3b3 100644
--- a/Userland/Libraries/LibSQL/Parser.cpp
+++ b/Userland/Libraries/LibSQL/Parser.cpp
@@ -215,7 +215,6 @@ NonnullRefPtr<Expression> Parser::parse_expression()
// FIXME: Parse 'bind-parameter'.
// FIXME: Parse 'function-name'.
- // FIXME: Parse 'exists'.
// FIXME: Parse 'raise-function'.
return expression;
@@ -241,6 +240,9 @@ NonnullRefPtr<Expression> Parser::parse_primary_expression()
if (auto expression = parse_case_expression(); expression.has_value())
return move(expression.value());
+ if (auto expression = parse_exists_expression(false); expression.has_value())
+ return move(expression.value());
+
expected("Primary Expression");
consume();
@@ -381,8 +383,12 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_unary_operator_expression()
if (consume_if(TokenType::Tilde))
return create_ast_node<UnaryOperatorExpression>(UnaryOperator::BitwiseNot, parse_expression());
- if (consume_if(TokenType::Not))
- return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Not, parse_expression());
+ if (consume_if(TokenType::Not)) {
+ if (match(TokenType::Exists))
+ return parse_exists_expression(true);
+ else
+ return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Not, parse_expression());
+ }
return {};
}
@@ -448,11 +454,15 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_binary_operator_expression(Non
Optional<NonnullRefPtr<Expression>> Parser::parse_chained_expression()
{
- if (!match(TokenType::ParenOpen))
+ if (!consume_if(TokenType::ParenOpen))
return {};
+ if (match(TokenType::Select))
+ return parse_exists_expression(false, TokenType::Select);
+
NonnullRefPtrVector<Expression> expressions;
- parse_comma_separated_list(true, [&]() { expressions.append(parse_expression()); });
+ parse_comma_separated_list(false, [&]() { expressions.append(parse_expression()); });
+ consume(TokenType::ParenClose);
return create_ast_node<ChainedExpression>(move(expressions));
}
@@ -506,6 +516,21 @@ Optional<NonnullRefPtr<Expression>> Parser::parse_case_expression()
return create_ast_node<CaseExpression>(move(case_expression), move(when_then_clauses), move(else_expression));
}
+Optional<NonnullRefPtr<Expression>> Parser::parse_exists_expression(bool invert_expression, TokenType opening_token)
+{
+ VERIFY((opening_token == TokenType::Exists) || (opening_token == TokenType::Select));
+
+ if ((opening_token == TokenType::Exists) && !consume_if(TokenType::Exists))
+ return {};
+
+ if (opening_token == TokenType::Exists)
+ consume(TokenType::ParenOpen);
+ auto select_statement = parse_select_statement({});
+ consume(TokenType::ParenClose);
+
+ return create_ast_node<ExistsExpression>(move(select_statement), invert_expression);
+}
+
Optional<NonnullRefPtr<Expression>> Parser::parse_collate_expression(NonnullRefPtr<Expression> expression)
{
if (!match(TokenType::Collate))
diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h
index 6d6cf91591..14b03bbf52 100644
--- a/Userland/Libraries/LibSQL/Parser.h
+++ b/Userland/Libraries/LibSQL/Parser.h
@@ -68,6 +68,7 @@ private:
Optional<NonnullRefPtr<Expression>> parse_chained_expression();
Optional<NonnullRefPtr<Expression>> parse_cast_expression();
Optional<NonnullRefPtr<Expression>> parse_case_expression();
+ Optional<NonnullRefPtr<Expression>> parse_exists_expression(bool invert_expression, TokenType opening_token = TokenType::Exists);
Optional<NonnullRefPtr<Expression>> parse_collate_expression(NonnullRefPtr<Expression> expression);
Optional<NonnullRefPtr<Expression>> parse_is_expression(NonnullRefPtr<Expression> expression);
Optional<NonnullRefPtr<Expression>> parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression);
diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp
index 4d8952dc14..1e7a2ef187 100644
--- a/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp
+++ b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp
@@ -338,6 +338,39 @@ TEST_CASE(case_expression)
validate("CASE 15 WHEN 16 THEN 17 WHEN 18 THEN 19 ELSE 20 END", true, 2, true);
}
+TEST_CASE(exists_expression)
+{
+ EXPECT(parse("EXISTS").is_error());
+ EXPECT(parse("EXISTS (").is_error());
+ EXPECT(parse("EXISTS (SELECT").is_error());
+ EXPECT(parse("EXISTS (SELECT)").is_error());
+ EXPECT(parse("EXISTS (SELECT * FROM table").is_error());
+ EXPECT(parse("NOT EXISTS").is_error());
+ EXPECT(parse("NOT EXISTS (").is_error());
+ EXPECT(parse("NOT EXISTS (SELECT").is_error());
+ EXPECT(parse("NOT EXISTS (SELECT)").is_error());
+ EXPECT(parse("NOT EXISTS (SELECT * FROM table").is_error());
+ EXPECT(parse("(").is_error());
+ EXPECT(parse("(SELECT").is_error());
+ EXPECT(parse("(SELECT)").is_error());
+ EXPECT(parse("(SELECT * FROM table").is_error());
+
+ auto validate = [](StringView sql, bool expected_invert_expression) {
+ auto result = parse(sql);
+ EXPECT(!result.is_error());
+
+ auto expression = result.release_value();
+ EXPECT(is<SQL::ExistsExpression>(*expression));
+
+ const auto& exists = static_cast<const SQL::ExistsExpression&>(*expression);
+ EXPECT_EQ(exists.invert_expression(), expected_invert_expression);
+ };
+
+ validate("EXISTS (SELECT * FROM table)", false);
+ validate("NOT EXISTS (SELECT * FROM table)", true);
+ validate("(SELECT * FROM table)", false);
+}
+
TEST_CASE(collate_expression)
{
EXPECT(parse("COLLATE").is_error());