diff options
author | Timothy Flynn <trflynn89@pm.me> | 2021-04-20 17:49:26 -0400 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2021-04-21 21:37:55 +0200 |
commit | ce6c7ae18ac905e1ac5017f05e4c27345d323d68 (patch) | |
tree | 1fdb15fd1a6d512a681a7203261cdf064172041a /Userland/Libraries/LibSQL | |
parent | 8c8d611fb3bce515dcda9d1c4b82bdce28d2b12e (diff) | |
download | serenity-ce6c7ae18ac905e1ac5017f05e4c27345d323d68.zip |
LibSQL: Parse most language expressions
https://sqlite.org/lang_expr.html
The entry point to using expressions, parse_expression(), is not used
by SQL::Parser in this commit. But there's so much here that it's easier
to grok as its own commit.
Diffstat (limited to 'Userland/Libraries/LibSQL')
-rw-r--r-- | Userland/Libraries/LibSQL/AST.h | 361 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Forward.h | 23 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.cpp | 444 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Parser.h | 21 | ||||
-rw-r--r-- | Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp | 570 |
5 files changed, 1414 insertions, 5 deletions
diff --git a/Userland/Libraries/LibSQL/AST.h b/Userland/Libraries/LibSQL/AST.h index 74d2619eb7..e90fcb7777 100644 --- a/Userland/Libraries/LibSQL/AST.h +++ b/Userland/Libraries/LibSQL/AST.h @@ -29,7 +29,10 @@ #include <AK/NonnullRefPtr.h> #include <AK/NonnullRefPtrVector.h> #include <AK/RefCounted.h> +#include <AK/RefPtr.h> #include <AK/String.h> +#include <LibSQL/Forward.h> +#include <LibSQL/Token.h> namespace SQL { @@ -48,11 +51,9 @@ protected: ASTNode() = default; }; -class Statement : public ASTNode { -}; - -class ErrorStatement final : public Statement { -}; +//================================================================================================== +// Language types +//================================================================================================== class SignedNumber final : public ASTNode { public: @@ -100,6 +101,356 @@ private: NonnullRefPtr<TypeName> m_type_name; }; +//================================================================================================== +// Expressions +//================================================================================================== + +class Expression : public ASTNode { +}; + +class ErrorExpression final : public Expression { +}; + +class NumericLiteral : public Expression { +public: + explicit NumericLiteral(double value) + : m_value(value) + { + } + + double value() const { return m_value; } + +private: + double m_value; +}; + +class StringLiteral : public Expression { +public: + explicit StringLiteral(String value) + : m_value(move(value)) + { + } + + const String& value() const { return m_value; } + +private: + String m_value; +}; + +class BlobLiteral : public Expression { +public: + explicit BlobLiteral(String value) + : m_value(move(value)) + { + } + + const String& value() const { return m_value; } + +private: + String m_value; +}; + +class NullLiteral : public Expression { +}; + +class NestedExpression : public Expression { +public: + const NonnullRefPtr<Expression>& expression() const { return m_expression; } + +protected: + explicit NestedExpression(NonnullRefPtr<Expression> expression) + : m_expression(move(expression)) + { + } + +private: + NonnullRefPtr<Expression> m_expression; +}; + +class NestedDoubleExpression : public Expression { +public: + const NonnullRefPtr<Expression>& lhs() const { return m_lhs; } + const NonnullRefPtr<Expression>& rhs() const { return m_rhs; } + +protected: + NestedDoubleExpression(NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs) + : m_lhs(move(lhs)) + , m_rhs(move(rhs)) + { + } + +private: + NonnullRefPtr<Expression> m_lhs; + NonnullRefPtr<Expression> m_rhs; +}; + +class InvertibleNestedExpression : public NestedExpression { +public: + bool invert_expression() const { return m_invert_expression; } + +protected: + InvertibleNestedExpression(NonnullRefPtr<Expression> expression, bool invert_expression) + : NestedExpression(move(expression)) + , m_invert_expression(invert_expression) + { + } + +private: + bool m_invert_expression; +}; + +class InvertibleNestedDoubleExpression : public NestedDoubleExpression { +public: + bool invert_expression() const { return m_invert_expression; } + +protected: + InvertibleNestedDoubleExpression(NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs, bool invert_expression) + : NestedDoubleExpression(move(lhs), move(rhs)) + , m_invert_expression(invert_expression) + { + } + +private: + bool m_invert_expression; +}; + +class ColumnNameExpression : public Expression { +public: + ColumnNameExpression(String schema_name, String table_name, String column_name) + : m_schema_name(move(schema_name)) + , m_table_name(move(table_name)) + , m_column_name(move(column_name)) + { + } + + const String& schema_name() const { return m_schema_name; } + const String& table_name() const { return m_table_name; } + const String& column_name() const { return m_column_name; } + +private: + String m_schema_name; + String m_table_name; + String m_column_name; +}; + +enum class UnaryOperator { + Minus, + Plus, + BitwiseNot, + Not, +}; + +class UnaryOperatorExpression : public NestedExpression { +public: + UnaryOperatorExpression(UnaryOperator type, NonnullRefPtr<Expression> expression) + : NestedExpression(move(expression)) + , m_type(type) + { + } + + UnaryOperator type() const { return m_type; } + +private: + UnaryOperator m_type; +}; + +enum class BinaryOperator { + // Note: These are in order of highest-to-lowest operator precedence. + Concatenate, + Multiplication, + Division, + Modulo, + Plus, + Minus, + ShiftLeft, + ShiftRight, + BitwiseAnd, + BitwiseOr, + LessThan, + LessThanEquals, + GreaterThan, + GreaterThanEquals, + Equals, + NotEquals, + And, + Or, +}; + +class BinaryOperatorExpression : public NestedDoubleExpression { +public: + BinaryOperatorExpression(BinaryOperator type, NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs) + : NestedDoubleExpression(move(lhs), move(rhs)) + , m_type(type) + { + } + + BinaryOperator type() const { return m_type; } + +private: + BinaryOperator m_type; +}; + +class ChainedExpression : public Expression { +public: + explicit ChainedExpression(NonnullRefPtrVector<Expression> expressions) + : m_expressions(move(expressions)) + { + } + + const NonnullRefPtrVector<Expression>& expressions() const { return m_expressions; } + +private: + NonnullRefPtrVector<Expression> m_expressions; +}; + +class CastExpression : public NestedExpression { +public: + CastExpression(NonnullRefPtr<Expression> expression, NonnullRefPtr<TypeName> type_name) + : NestedExpression(move(expression)) + , m_type_name(move(type_name)) + { + } + + const NonnullRefPtr<TypeName>& type_name() const { return m_type_name; } + +private: + NonnullRefPtr<TypeName> m_type_name; +}; + +class CaseExpression : public Expression { +public: + struct WhenThenClause { + NonnullRefPtr<Expression> when; + NonnullRefPtr<Expression> then; + }; + + CaseExpression(RefPtr<Expression> case_expression, Vector<WhenThenClause> when_then_clauses, RefPtr<Expression> else_expression) + : m_case_expression(case_expression) + , m_when_then_clauses(when_then_clauses) + , m_else_expression(else_expression) + { + VERIFY(!m_when_then_clauses.is_empty()); + } + + const RefPtr<Expression>& case_expression() const { return m_case_expression; } + const Vector<WhenThenClause>& when_then_clauses() const { return m_when_then_clauses; } + const RefPtr<Expression>& else_expression() const { return m_else_expression; } + +private: + RefPtr<Expression> m_case_expression; + Vector<WhenThenClause> m_when_then_clauses; + RefPtr<Expression> m_else_expression; +}; + +class CollateExpression : public NestedExpression { +public: + CollateExpression(NonnullRefPtr<Expression> expression, String collation_name) + : NestedExpression(move(expression)) + , m_collation_name(move(collation_name)) + { + } + + const String& collation_name() const { return m_collation_name; } + +private: + String m_collation_name; +}; + +enum class MatchOperator { + Like, + Glob, + Match, + Regexp, +}; + +class MatchExpression : public InvertibleNestedDoubleExpression { +public: + MatchExpression(MatchOperator type, NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs, RefPtr<Expression> escape, bool invert_expression) + : InvertibleNestedDoubleExpression(move(lhs), move(rhs), invert_expression) + , m_type(type) + , m_escape(move(escape)) + { + } + + MatchOperator type() const { return m_type; } + const RefPtr<Expression>& escape() const { return m_escape; } + +private: + MatchOperator m_type; + RefPtr<Expression> m_escape; +}; + +class NullExpression : public InvertibleNestedExpression { +public: + NullExpression(NonnullRefPtr<Expression> expression, bool invert_expression) + : InvertibleNestedExpression(move(expression), invert_expression) + { + } +}; + +class IsExpression : public InvertibleNestedDoubleExpression { +public: + IsExpression(NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs, bool invert_expression) + : InvertibleNestedDoubleExpression(move(lhs), move(rhs), invert_expression) + { + } +}; + +class BetweenExpression : public InvertibleNestedDoubleExpression { +public: + BetweenExpression(NonnullRefPtr<Expression> expression, NonnullRefPtr<Expression> lhs, NonnullRefPtr<Expression> rhs, bool invert_expression) + : InvertibleNestedDoubleExpression(move(lhs), move(rhs), invert_expression) + , m_expression(move(expression)) + { + } + + const NonnullRefPtr<Expression>& expression() const { return m_expression; } + +private: + NonnullRefPtr<Expression> m_expression; +}; + +class InChainedExpression : public InvertibleNestedExpression { +public: + InChainedExpression(NonnullRefPtr<Expression> expression, NonnullRefPtr<ChainedExpression> expression_chain, bool invert_expression) + : InvertibleNestedExpression(move(expression), invert_expression) + , m_expression_chain(move(expression_chain)) + { + } + + const NonnullRefPtr<ChainedExpression>& expression_chain() const { return m_expression_chain; } + +private: + NonnullRefPtr<ChainedExpression> m_expression_chain; +}; + +class InTableExpression : public InvertibleNestedExpression { +public: + InTableExpression(NonnullRefPtr<Expression> expression, String schema_name, String table_name, bool invert_expression) + : InvertibleNestedExpression(move(expression), invert_expression) + , m_schema_name(move(schema_name)) + , m_table_name(move(table_name)) + { + } + + const String& schema_name() const { return m_schema_name; } + const String& table_name() const { return m_table_name; } + +private: + String m_schema_name; + String m_table_name; +}; + +//================================================================================================== +// Statements +//================================================================================================== + +class Statement : public ASTNode { +}; + +class ErrorStatement final : public Statement { +}; + class CreateTable : public Statement { public: CreateTable(String schema_name, String table_name, NonnullRefPtrVector<ColumnDefinition> columns, bool is_temporary, bool is_error_if_table_exists) diff --git a/Userland/Libraries/LibSQL/Forward.h b/Userland/Libraries/LibSQL/Forward.h index b27e83a158..0430a97bb5 100644 --- a/Userland/Libraries/LibSQL/Forward.h +++ b/Userland/Libraries/LibSQL/Forward.h @@ -28,14 +28,37 @@ namespace SQL { class ASTNode; +class BetweenExpression; +class BinaryOperatorExpression; +class BlobLiteral; +class CaseExpression; +class CastExpression; +class ChainedExpression; +class CollateExpression; class ColumnDefinition; +class ColumnNameExpression; class CreateTable; class DropTable; +class ErrorExpression; class ErrorStatement; +class Expression; +class InChainedExpression; +class InTableExpression; +class InvertibleNestedDoubleExpression; +class InvertibleNestedExpression; +class IsExpression; class Lexer; +class MatchExpression; +class NestedDoubleExpression; +class NestedExpression; +class NullExpression; +class NullLiteral; +class NumericLiteral; class Parser; class SignedNumber; class Statement; +class StringLiteral; class Token; class TypeName; +class UnaryOperatorExpression; } diff --git a/Userland/Libraries/LibSQL/Parser.cpp b/Userland/Libraries/LibSQL/Parser.cpp index fb25d1cddd..e35398db15 100644 --- a/Userland/Libraries/LibSQL/Parser.cpp +++ b/Userland/Libraries/LibSQL/Parser.cpp @@ -25,6 +25,7 @@ */ #include "Parser.h" +#include <AK/TypeCasts.h> namespace SQL { @@ -124,6 +125,449 @@ NonnullRefPtr<DropTable> Parser::parse_drop_table_statement() return create_ast_node<DropTable>(move(schema_name), move(table_name), is_error_if_table_does_not_exist); } +NonnullRefPtr<Expression> Parser::parse_expression() +{ + // https://sqlite.org/lang_expr.html + auto expression = parse_primary_expression(); + + if (match_secondary_expression()) + expression = parse_secondary_expression(move(expression)); + + // FIXME: Parse 'bind-parameter'. + // FIXME: Parse 'function-name'. + // FIXME: Parse 'exists'. + // FIXME: Parse 'raise-function'. + + return expression; +} + +NonnullRefPtr<Expression> Parser::parse_primary_expression() +{ + if (auto expression = parse_literal_value_expression(); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_column_name_expression(); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_unary_operator_expression(); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_chained_expression(); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_cast_expression(); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_case_expression(); expression.has_value()) + return move(expression.value()); + + expected("Primary Expression"); + consume(); + + return create_ast_node<ErrorExpression>(); +} + +NonnullRefPtr<Expression> Parser::parse_secondary_expression(NonnullRefPtr<Expression> primary) +{ + if (auto expression = parse_binary_operator_expression(primary); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_collate_expression(primary); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_is_expression(primary); expression.has_value()) + return move(expression.value()); + + bool invert_expression = false; + if (consume_if(TokenType::Not)) + invert_expression = true; + + if (auto expression = parse_match_expression(primary, invert_expression); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_null_expression(primary, invert_expression); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_between_expression(primary, invert_expression); expression.has_value()) + return move(expression.value()); + + if (auto expression = parse_in_expression(primary, invert_expression); expression.has_value()) + return move(expression.value()); + + expected("Secondary Expression"); + consume(); + + return create_ast_node<ErrorExpression>(); +} + +bool Parser::match_secondary_expression() const +{ + return match(TokenType::Not) + || match(TokenType::DoublePipe) + || match(TokenType::Asterisk) + || match(TokenType::Divide) + || match(TokenType::Modulus) + || match(TokenType::Plus) + || match(TokenType::Minus) + || match(TokenType::ShiftLeft) + || match(TokenType::ShiftRight) + || match(TokenType::Ampersand) + || match(TokenType::Pipe) + || match(TokenType::LessThan) + || match(TokenType::LessThanEquals) + || match(TokenType::GreaterThan) + || match(TokenType::GreaterThanEquals) + || match(TokenType::Equals) + || match(TokenType::EqualsEquals) + || match(TokenType::NotEquals1) + || match(TokenType::NotEquals2) + || match(TokenType::And) + || match(TokenType::Or) + || match(TokenType::Collate) + || match(TokenType::Is) + || match(TokenType::Like) + || match(TokenType::Glob) + || match(TokenType::Match) + || match(TokenType::Regexp) + || match(TokenType::Isnull) + || match(TokenType::Notnull) + || match(TokenType::Between) + || match(TokenType::In); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_literal_value_expression() +{ + if (match(TokenType::NumericLiteral)) { + auto value = consume().double_value(); + return create_ast_node<NumericLiteral>(value); + } + if (match(TokenType::StringLiteral)) { + // TODO: Should the surrounding ' ' be removed here? + auto value = consume().value(); + return create_ast_node<StringLiteral>(value); + } + if (match(TokenType::BlobLiteral)) { + // TODO: Should the surrounding x' ' be removed here? + auto value = consume().value(); + return create_ast_node<BlobLiteral>(value); + } + if (consume_if(TokenType::Null)) + return create_ast_node<NullLiteral>(); + + return {}; +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_column_name_expression() +{ + if (!match(TokenType::Identifier)) + return {}; + + String first_identifier = consume(TokenType::Identifier).value(); + String schema_name; + String table_name; + String column_name; + + if (consume_if(TokenType::Period)) { + String second_identifier = consume(TokenType::Identifier).value(); + + if (consume_if(TokenType::Period)) { + schema_name = move(first_identifier); + table_name = move(second_identifier); + column_name = consume(TokenType::Identifier).value(); + } else { + table_name = move(first_identifier); + column_name = move(second_identifier); + } + } else { + column_name = move(first_identifier); + } + + return create_ast_node<ColumnNameExpression>(move(schema_name), move(table_name), move(column_name)); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_unary_operator_expression() +{ + if (consume_if(TokenType::Minus)) + return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Minus, parse_expression()); + + if (consume_if(TokenType::Plus)) + return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Plus, parse_expression()); + + if (consume_if(TokenType::Tilde)) + return create_ast_node<UnaryOperatorExpression>(UnaryOperator::BitwiseNot, parse_expression()); + + if (consume_if(TokenType::Not)) + return create_ast_node<UnaryOperatorExpression>(UnaryOperator::Not, parse_expression()); + + return {}; +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_binary_operator_expression(NonnullRefPtr<Expression> lhs) +{ + if (consume_if(TokenType::DoublePipe)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Concatenate, move(lhs), parse_expression()); + + if (consume_if(TokenType::Asterisk)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Multiplication, move(lhs), parse_expression()); + + if (consume_if(TokenType::Divide)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Division, move(lhs), parse_expression()); + + if (consume_if(TokenType::Modulus)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Modulo, move(lhs), parse_expression()); + + if (consume_if(TokenType::Plus)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Plus, move(lhs), parse_expression()); + + if (consume_if(TokenType::Minus)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Minus, move(lhs), parse_expression()); + + if (consume_if(TokenType::ShiftLeft)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::ShiftLeft, move(lhs), parse_expression()); + + if (consume_if(TokenType::ShiftRight)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::ShiftRight, move(lhs), parse_expression()); + + if (consume_if(TokenType::Ampersand)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::BitwiseAnd, move(lhs), parse_expression()); + + if (consume_if(TokenType::Pipe)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::BitwiseOr, move(lhs), parse_expression()); + + if (consume_if(TokenType::LessThan)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::LessThan, move(lhs), parse_expression()); + + if (consume_if(TokenType::LessThanEquals)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::LessThanEquals, move(lhs), parse_expression()); + + if (consume_if(TokenType::GreaterThan)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::GreaterThan, move(lhs), parse_expression()); + + if (consume_if(TokenType::GreaterThanEquals)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::GreaterThanEquals, move(lhs), parse_expression()); + + if (consume_if(TokenType::Equals) || consume_if(TokenType::EqualsEquals)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Equals, move(lhs), parse_expression()); + + if (consume_if(TokenType::NotEquals1) || consume_if(TokenType::NotEquals2)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::NotEquals, move(lhs), parse_expression()); + + if (consume_if(TokenType::And)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::And, move(lhs), parse_expression()); + + if (consume_if(TokenType::Or)) + return create_ast_node<BinaryOperatorExpression>(BinaryOperator::Or, move(lhs), parse_expression()); + + return {}; +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_chained_expression() +{ + if (!match(TokenType::ParenOpen)) + return {}; + + NonnullRefPtrVector<Expression> expressions; + consume(TokenType::ParenOpen); + + do { + expressions.append(parse_expression()); + if (match(TokenType::ParenClose)) + break; + + consume(TokenType::Comma); + } while (!match(TokenType::Eof)); + + consume(TokenType::ParenClose); + + return create_ast_node<ChainedExpression>(move(expressions)); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_cast_expression() +{ + if (!match(TokenType::Cast)) + return {}; + + consume(TokenType::Cast); + consume(TokenType::ParenOpen); + auto expression = parse_expression(); + consume(TokenType::As); + auto type_name = parse_type_name(); + consume(TokenType::ParenClose); + + return create_ast_node<CastExpression>(move(expression), move(type_name)); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_case_expression() +{ + if (!match(TokenType::Case)) + return {}; + + consume(); + + RefPtr<Expression> case_expression; + if (!match(TokenType::When)) { + case_expression = parse_expression(); + } + + Vector<CaseExpression::WhenThenClause> when_then_clauses; + + do { + consume(TokenType::When); + auto when = parse_expression(); + consume(TokenType::Then); + auto then = parse_expression(); + + when_then_clauses.append({ move(when), move(then) }); + + if (!match(TokenType::When)) + break; + } while (!match(TokenType::Eof)); + + RefPtr<Expression> else_expression; + if (consume_if(TokenType::Else)) + else_expression = parse_expression(); + + consume(TokenType::End); + return create_ast_node<CaseExpression>(move(case_expression), move(when_then_clauses), move(else_expression)); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_collate_expression(NonnullRefPtr<Expression> expression) +{ + if (!match(TokenType::Collate)) + return {}; + + consume(); + String collation_name = consume(TokenType::Identifier).value(); + + return create_ast_node<CollateExpression>(move(expression), move(collation_name)); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_is_expression(NonnullRefPtr<Expression> expression) +{ + if (!match(TokenType::Is)) + return {}; + + consume(); + + bool invert_expression = false; + if (match(TokenType::Not)) { + consume(); + invert_expression = true; + } + + auto rhs = parse_expression(); + return create_ast_node<IsExpression>(move(expression), move(rhs), invert_expression); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression) +{ + auto parse_escape = [this]() { + RefPtr<Expression> escape; + if (consume_if(TokenType::Escape)) + escape = parse_expression(); + return escape; + }; + + if (consume_if(TokenType::Like)) + return create_ast_node<MatchExpression>(MatchOperator::Like, move(lhs), parse_expression(), parse_escape(), invert_expression); + + if (consume_if(TokenType::Glob)) + return create_ast_node<MatchExpression>(MatchOperator::Glob, move(lhs), parse_expression(), parse_escape(), invert_expression); + + if (consume_if(TokenType::Match)) + return create_ast_node<MatchExpression>(MatchOperator::Match, move(lhs), parse_expression(), parse_escape(), invert_expression); + + if (consume_if(TokenType::Regexp)) + return create_ast_node<MatchExpression>(MatchOperator::Regexp, move(lhs), parse_expression(), parse_escape(), invert_expression); + + return {}; +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_null_expression(NonnullRefPtr<Expression> expression, bool invert_expression) +{ + if (!match(TokenType::Isnull) && !match(TokenType::Notnull) && !(invert_expression && match(TokenType::Null))) + return {}; + + auto type = consume().type(); + invert_expression |= (type == TokenType::Notnull); + + return create_ast_node<NullExpression>(move(expression), invert_expression); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_between_expression(NonnullRefPtr<Expression> expression, bool invert_expression) +{ + if (!match(TokenType::Between)) + return {}; + + consume(); + + auto nested = parse_expression(); + if (!is<BinaryOperatorExpression>(*nested)) { + expected("Binary Expression"); + return create_ast_node<ErrorExpression>(); + } + + const auto& binary_expression = static_cast<const BinaryOperatorExpression&>(*nested); + if (binary_expression.type() != BinaryOperator::And) { + expected("AND Expression"); + return create_ast_node<ErrorExpression>(); + } + + return create_ast_node<BetweenExpression>(move(expression), binary_expression.lhs(), binary_expression.rhs(), invert_expression); +} + +Optional<NonnullRefPtr<Expression>> Parser::parse_in_expression(NonnullRefPtr<Expression> expression, bool invert_expression) +{ + if (!match(TokenType::In)) + return {}; + + consume(); + + if (consume_if(TokenType::ParenOpen)) { + if (match(TokenType::Select)) { + // FIXME: Parse "select-stmt". + return {}; + } + + // FIXME: Consolidate this with parse_chained_expression(). That method consumes the opening paren as + // well, and also requires at least one expression (whereas this allows for an empty chain). + NonnullRefPtrVector<Expression> expressions; + + if (!match(TokenType::ParenClose)) { + do { + expressions.append(parse_expression()); + if (match(TokenType::ParenClose)) + break; + + consume(TokenType::Comma); + } while (!match(TokenType::Eof)); + } + + consume(TokenType::ParenClose); + + auto chain = create_ast_node<ChainedExpression>(move(expressions)); + return create_ast_node<InChainedExpression>(move(expression), move(chain), invert_expression); + } + + String schema_or_table_name = consume(TokenType::Identifier).value(); + String schema_name; + String table_name; + + if (consume_if(TokenType::Period)) { + schema_name = move(schema_or_table_name); + table_name = consume(TokenType::Identifier).value(); + } else { + table_name = move(schema_or_table_name); + } + + if (match(TokenType::ParenOpen)) { + // FIXME: Parse "table-function". + return {}; + } + + return create_ast_node<InTableExpression>(move(expression), move(schema_name), move(table_name), invert_expression); +} + NonnullRefPtr<ColumnDefinition> Parser::parse_column_definition() { // https://sqlite.org/syntax/column-def.html diff --git a/Userland/Libraries/LibSQL/Parser.h b/Userland/Libraries/LibSQL/Parser.h index 26fb7d2a36..9b0ea5d2ab 100644 --- a/Userland/Libraries/LibSQL/Parser.h +++ b/Userland/Libraries/LibSQL/Parser.h @@ -58,6 +58,9 @@ public: bool has_errors() const { return m_parser_state.m_errors.size(); } const Vector<Error>& errors() const { return m_parser_state.m_errors; } +protected: + NonnullRefPtr<Expression> parse_expression(); // Protected for unit testing. + private: struct ParserState { explicit ParserState(Lexer); @@ -69,6 +72,24 @@ private: NonnullRefPtr<CreateTable> parse_create_table_statement(); NonnullRefPtr<DropTable> parse_drop_table_statement(); + + NonnullRefPtr<Expression> parse_primary_expression(); + NonnullRefPtr<Expression> parse_secondary_expression(NonnullRefPtr<Expression> primary); + bool match_secondary_expression() const; + Optional<NonnullRefPtr<Expression>> parse_literal_value_expression(); + Optional<NonnullRefPtr<Expression>> parse_column_name_expression(); + Optional<NonnullRefPtr<Expression>> parse_unary_operator_expression(); + Optional<NonnullRefPtr<Expression>> parse_binary_operator_expression(NonnullRefPtr<Expression> lhs); + Optional<NonnullRefPtr<Expression>> parse_chained_expression(); + Optional<NonnullRefPtr<Expression>> parse_cast_expression(); + Optional<NonnullRefPtr<Expression>> parse_case_expression(); + Optional<NonnullRefPtr<Expression>> parse_collate_expression(NonnullRefPtr<Expression> expression); + Optional<NonnullRefPtr<Expression>> parse_is_expression(NonnullRefPtr<Expression> expression); + Optional<NonnullRefPtr<Expression>> parse_match_expression(NonnullRefPtr<Expression> lhs, bool invert_expression); + Optional<NonnullRefPtr<Expression>> parse_null_expression(NonnullRefPtr<Expression> expression, bool invert_expression); + Optional<NonnullRefPtr<Expression>> parse_between_expression(NonnullRefPtr<Expression> expression, bool invert_expression); + Optional<NonnullRefPtr<Expression>> parse_in_expression(NonnullRefPtr<Expression> expression, bool invert_expression); + NonnullRefPtr<ColumnDefinition> parse_column_definition(); NonnullRefPtr<TypeName> parse_type_name(); NonnullRefPtr<SignedNumber> parse_signed_number(); diff --git a/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp new file mode 100644 index 0000000000..da3f78974f --- /dev/null +++ b/Userland/Libraries/LibSQL/Tests/TestSqlExpressionParser.cpp @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2021, Tim Flynn <trflynn89@pm.me> + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include <AK/TestSuite.h> + +#include <AK/HashMap.h> +#include <AK/Result.h> +#include <AK/String.h> +#include <AK/StringBuilder.h> +#include <AK/StringView.h> +#include <AK/TypeCasts.h> +#include <LibSQL/Lexer.h> +#include <LibSQL/Parser.h> + +namespace { + +class ExpressionParser : public SQL::Parser { +public: + explicit ExpressionParser(SQL::Lexer lexer) + : SQL::Parser(move(lexer)) + { + } + + NonnullRefPtr<SQL::Expression> parse() + { + return SQL::Parser::parse_expression(); + } +}; + +using ParseResult = AK::Result<NonnullRefPtr<SQL::Expression>, String>; + +ParseResult parse(StringView sql) +{ + auto parser = ExpressionParser(SQL::Lexer(sql)); + auto expression = parser.parse(); + + if (parser.has_errors()) { + return parser.errors()[0].to_string(); + } + + return expression; +} + +} + +TEST_CASE(numeric_literal) +{ + auto validate = [](StringView sql, double expected_value) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::NumericLiteral>(*expression)); + + const auto& literal = static_cast<const SQL::NumericLiteral&>(*expression); + EXPECT_EQ(literal.value(), expected_value); + }; + + validate("123", 123); + validate("3.14", 3.14); + validate("0xff", 255); + validate("1e3", 1000); +} + +TEST_CASE(string_literal) +{ + EXPECT(parse("'").is_error()); + EXPECT(parse("'unterminated").is_error()); + + auto validate = [](StringView sql, StringView expected_value) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::StringLiteral>(*expression)); + + const auto& literal = static_cast<const SQL::StringLiteral&>(*expression); + EXPECT_EQ(literal.value(), expected_value); + }; + + validate("''", "''"); + validate("'hello friends'", "'hello friends'"); +} + +TEST_CASE(blob_literal) +{ + EXPECT(parse("x'").is_error()); + EXPECT(parse("x'unterminated").is_error()); + + auto validate = [](StringView sql, StringView expected_value) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::BlobLiteral>(*expression)); + + const auto& literal = static_cast<const SQL::BlobLiteral&>(*expression); + EXPECT_EQ(literal.value(), expected_value); + }; + + validate("x''", "x''"); + validate("x'hello friends'", "x'hello friends'"); +} + +TEST_CASE(null_literal) +{ + auto validate = [](StringView sql) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::NullLiteral>(*expression)); + }; + + validate("NULL"); +} + +TEST_CASE(column_name) +{ + EXPECT(parse(".column").is_error()); + EXPECT(parse("table.").is_error()); + EXPECT(parse("schema.table.").is_error()); + + auto validate = [](StringView sql, StringView expected_schema, StringView expected_table, StringView expected_column) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::ColumnNameExpression>(*expression)); + + const auto& column = static_cast<const SQL::ColumnNameExpression&>(*expression); + EXPECT_EQ(column.schema_name(), expected_schema); + EXPECT_EQ(column.table_name(), expected_table); + EXPECT_EQ(column.column_name(), expected_column); + }; + + validate("column", {}, {}, "column"); + validate("table.column", {}, "table", "column"); + validate("schema.table.column", "schema", "table", "column"); +} + +TEST_CASE(unary_operator) +{ + EXPECT(parse("-").is_error()); + EXPECT(parse("--").is_error()); + EXPECT(parse("+").is_error()); + EXPECT(parse("++").is_error()); + EXPECT(parse("~").is_error()); + EXPECT(parse("~~").is_error()); + EXPECT(parse("NOT").is_error()); + + auto validate = [](StringView sql, SQL::UnaryOperator expected_operator) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::UnaryOperatorExpression>(*expression)); + + const auto& unary = static_cast<const SQL::UnaryOperatorExpression&>(*expression); + EXPECT_EQ(unary.type(), expected_operator); + + const auto& secondary_expression = unary.expression(); + EXPECT(!is<SQL::ErrorExpression>(*secondary_expression)); + }; + + validate("-15", SQL::UnaryOperator::Minus); + validate("+15", SQL::UnaryOperator::Plus); + validate("~15", SQL::UnaryOperator::BitwiseNot); + validate("NOT 15", SQL::UnaryOperator::Not); +} + +TEST_CASE(binary_operator) +{ + HashMap<StringView, SQL::BinaryOperator> operators { + { "||", SQL::BinaryOperator::Concatenate }, + { "*", SQL::BinaryOperator::Multiplication }, + { "/", SQL::BinaryOperator::Division }, + { "%", SQL::BinaryOperator::Modulo }, + { "+", SQL::BinaryOperator::Plus }, + { "-", SQL::BinaryOperator::Minus }, + { "<<", SQL::BinaryOperator::ShiftLeft }, + { ">>", SQL::BinaryOperator::ShiftRight }, + { "&", SQL::BinaryOperator::BitwiseAnd }, + { "|", SQL::BinaryOperator::BitwiseOr }, + { "<", SQL::BinaryOperator::LessThan }, + { "<=", SQL::BinaryOperator::LessThanEquals }, + { ">", SQL::BinaryOperator::GreaterThan }, + { ">=", SQL::BinaryOperator::GreaterThanEquals }, + { "=", SQL::BinaryOperator::Equals }, + { "==", SQL::BinaryOperator::Equals }, + { "!=", SQL::BinaryOperator::NotEquals }, + { "<>", SQL::BinaryOperator::NotEquals }, + { "AND", SQL::BinaryOperator::And }, + { "OR", SQL::BinaryOperator::Or }, + }; + + for (auto op : operators) { + EXPECT(parse(op.key).is_error()); + + StringBuilder builder; + builder.append("1 "); + builder.append(op.key); + EXPECT(parse(builder.build()).is_error()); + + builder.clear(); + + if (op.key != "+" && op.key != "-") { // "+1" and "-1" are fine (unary operator). + builder.append(op.key); + builder.append(" 1"); + EXPECT(parse(builder.build()).is_error()); + } + } + + auto validate = [](StringView sql, SQL::BinaryOperator expected_operator) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::BinaryOperatorExpression>(*expression)); + + const auto& binary = static_cast<const SQL::BinaryOperatorExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*binary.lhs())); + EXPECT(!is<SQL::ErrorExpression>(*binary.rhs())); + EXPECT_EQ(binary.type(), expected_operator); + }; + + for (auto op : operators) { + StringBuilder builder; + builder.append("1 "); + builder.append(op.key); + builder.append(" 1"); + validate(builder.build(), op.value); + } +} + +TEST_CASE(chained_expression) +{ + EXPECT(parse("()").is_error()); + EXPECT(parse("(,)").is_error()); + EXPECT(parse("(15,)").is_error()); + + auto validate = [](StringView sql, size_t expected_chain_size) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::ChainedExpression>(*expression)); + + const auto& chain = static_cast<const SQL::ChainedExpression&>(*expression).expressions(); + EXPECT_EQ(chain.size(), expected_chain_size); + + for (const auto& chained_expression : chain) + EXPECT(!is<SQL::ErrorExpression>(chained_expression)); + }; + + validate("(15)", 1); + validate("(15, 16)", 2); + validate("(15, 16, column)", 3); +} + +TEST_CASE(cast_expression) +{ + EXPECT(parse("CAST").is_error()); + EXPECT(parse("CAST (").is_error()); + EXPECT(parse("CAST ()").is_error()); + EXPECT(parse("CAST (15)").is_error()); + EXPECT(parse("CAST (15 AS").is_error()); + EXPECT(parse("CAST (15 AS)").is_error()); + EXPECT(parse("CAST (15 AS int").is_error()); + + auto validate = [](StringView sql, StringView expected_type_name) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::CastExpression>(*expression)); + + const auto& cast = static_cast<const SQL::CastExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*cast.expression())); + + const auto& type_name = cast.type_name(); + EXPECT_EQ(type_name->name(), expected_type_name); + }; + + validate("CAST (15 AS int)", "int"); + validate("CAST ('NULL' AS null)", "null"); + validate("CAST (15 AS varchar(255))", "varchar"); +} + +TEST_CASE(case_expression) +{ + EXPECT(parse("CASE").is_error()); + EXPECT(parse("CASE END").is_error()); + EXPECT(parse("CASE 15").is_error()); + EXPECT(parse("CASE 15 END").is_error()); + EXPECT(parse("CASE WHEN").is_error()); + EXPECT(parse("CASE WHEN THEN").is_error()); + EXPECT(parse("CASE WHEN 15 THEN 16").is_error()); + EXPECT(parse("CASE WHEN 15 THEN 16 ELSE").is_error()); + EXPECT(parse("CASE WHEN 15 THEN 16 ELSE END").is_error()); + + auto validate = [](StringView sql, bool expect_case_expression, size_t expected_when_then_size, bool expect_else_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::CaseExpression>(*expression)); + + const auto& case_ = static_cast<const SQL::CaseExpression&>(*expression); + + const auto& case_expression = case_.case_expression(); + EXPECT_EQ(case_expression.is_null(), !expect_case_expression); + if (case_expression) + EXPECT(!is<SQL::ErrorExpression>(*case_expression)); + + const auto& when_then_clauses = case_.when_then_clauses(); + EXPECT_EQ(when_then_clauses.size(), expected_when_then_size); + for (const auto& when_then_clause : when_then_clauses) { + EXPECT(!is<SQL::ErrorExpression>(*when_then_clause.when)); + EXPECT(!is<SQL::ErrorExpression>(*when_then_clause.then)); + } + + const auto& else_expression = case_.else_expression(); + EXPECT_EQ(else_expression.is_null(), !expect_else_expression); + if (else_expression) + EXPECT(!is<SQL::ErrorExpression>(*else_expression)); + }; + + validate("CASE WHEN 16 THEN 17 END", false, 1, false); + validate("CASE WHEN 16 THEN 17 WHEN 18 THEN 19 END", false, 2, false); + validate("CASE WHEN 16 THEN 17 WHEN 18 THEN 19 ELSE 20 END", false, 2, true); + + validate("CASE 15 WHEN 16 THEN 17 END", true, 1, false); + validate("CASE 15 WHEN 16 THEN 17 WHEN 18 THEN 19 END", true, 2, false); + validate("CASE 15 WHEN 16 THEN 17 WHEN 18 THEN 19 ELSE 20 END", true, 2, true); +} + +TEST_CASE(collate_expression) +{ + EXPECT(parse("COLLATE").is_error()); + EXPECT(parse("COLLATE name").is_error()); + EXPECT(parse("15 COLLATE").is_error()); + + auto validate = [](StringView sql, StringView expected_collation_name) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::CollateExpression>(*expression)); + + const auto& collate = static_cast<const SQL::CollateExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*collate.expression())); + EXPECT_EQ(collate.collation_name(), expected_collation_name); + }; + + validate("15 COLLATE fifteen", "fifteen"); + validate("(15, 16) COLLATE chain", "chain"); +} + +TEST_CASE(is_expression) +{ + EXPECT(parse("IS").is_error()); + EXPECT(parse("IS 1").is_error()); + EXPECT(parse("1 IS").is_error()); + EXPECT(parse("IS NOT").is_error()); + EXPECT(parse("IS NOT 1").is_error()); + EXPECT(parse("1 IS NOT").is_error()); + + auto validate = [](StringView sql, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::IsExpression>(*expression)); + + const auto& is_ = static_cast<const SQL::IsExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*is_.lhs())); + EXPECT(!is<SQL::ErrorExpression>(*is_.rhs())); + EXPECT_EQ(is_.invert_expression(), expected_invert_expression); + }; + + validate("1 IS NULL", false); + validate("1 IS NOT NULL", true); +} + +TEST_CASE(match_expression) +{ + HashMap<StringView, SQL::MatchOperator> operators { + { "LIKE", SQL::MatchOperator::Like }, + { "GLOB", SQL::MatchOperator::Glob }, + { "MATCH", SQL::MatchOperator::Match }, + { "REGEXP", SQL::MatchOperator::Regexp }, + }; + + for (auto op : operators) { + EXPECT(parse(op.key).is_error()); + + StringBuilder builder; + builder.append("1 "); + builder.append(op.key); + EXPECT(parse(builder.build()).is_error()); + + builder.clear(); + builder.append(op.key); + builder.append(" 1"); + EXPECT(parse(builder.build()).is_error()); + } + + auto validate = [](StringView sql, SQL::MatchOperator expected_operator, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::MatchExpression>(*expression)); + + const auto& match = static_cast<const SQL::MatchExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*match.lhs())); + EXPECT(!is<SQL::ErrorExpression>(*match.rhs())); + EXPECT_EQ(match.type(), expected_operator); + EXPECT_EQ(match.invert_expression(), expected_invert_expression); + }; + + for (auto op : operators) { + StringBuilder builder; + builder.append("1 "); + builder.append(op.key); + builder.append(" 1"); + validate(builder.build(), op.value, false); + + builder.clear(); + builder.append("1 NOT "); + builder.append(op.key); + builder.append(" 1"); + validate(builder.build(), op.value, true); + } +} + +TEST_CASE(null_expression) +{ + EXPECT(parse("ISNULL").is_error()); + EXPECT(parse("NOTNULL").is_error()); + EXPECT(parse("15 NOT").is_error()); + + auto validate = [](StringView sql, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::NullExpression>(*expression)); + + const auto& null = static_cast<const SQL::NullExpression&>(*expression); + EXPECT_EQ(null.invert_expression(), expected_invert_expression); + }; + + validate("15 ISNULL", false); + validate("15 NOTNULL", true); + validate("15 NOT NULL", true); +} + +TEST_CASE(between_expression) +{ + EXPECT(parse("BETWEEN").is_error()); + EXPECT(parse("NOT BETWEEN").is_error()); + EXPECT(parse("BETWEEN 10 AND 20").is_error()); + EXPECT(parse("NOT BETWEEN 10 AND 20").is_error()); + EXPECT(parse("15 BETWEEN 10").is_error()); + EXPECT(parse("15 BETWEEN 10 AND").is_error()); + EXPECT(parse("15 BETWEEN AND 20").is_error()); + EXPECT(parse("15 BETWEEN 10 OR 20").is_error()); + + auto validate = [](StringView sql, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::BetweenExpression>(*expression)); + + const auto& between = static_cast<const SQL::BetweenExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*between.expression())); + EXPECT(!is<SQL::ErrorExpression>(*between.lhs())); + EXPECT(!is<SQL::ErrorExpression>(*between.rhs())); + EXPECT_EQ(between.invert_expression(), expected_invert_expression); + }; + + validate("15 BETWEEN 10 AND 20", false); + validate("15 NOT BETWEEN 10 AND 20", true); +} + +TEST_CASE(in_table_expression) +{ + EXPECT(parse("IN").is_error()); + EXPECT(parse("IN table").is_error()); + EXPECT(parse("NOT IN").is_error()); + EXPECT(parse("NOT IN table").is_error()); + + auto validate = [](StringView sql, StringView expected_schema, StringView expected_table, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::InTableExpression>(*expression)); + + const auto& in = static_cast<const SQL::InTableExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*in.expression())); + EXPECT_EQ(in.schema_name(), expected_schema); + EXPECT_EQ(in.table_name(), expected_table); + EXPECT_EQ(in.invert_expression(), expected_invert_expression); + }; + + validate("15 IN table", {}, "table", false); + validate("15 IN schema.table", "schema", "table", false); + + validate("15 NOT IN table", {}, "table", true); + validate("15 NOT IN schema.table", "schema", "table", true); +} + +TEST_CASE(in_chained_expression) +{ + EXPECT(parse("IN ()").is_error()); + EXPECT(parse("NOT IN ()").is_error()); + + auto validate = [](StringView sql, size_t expected_chain_size, bool expected_invert_expression) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is<SQL::InChainedExpression>(*expression)); + + const auto& in = static_cast<const SQL::InChainedExpression&>(*expression); + EXPECT(!is<SQL::ErrorExpression>(*in.expression())); + EXPECT_EQ(in.expression_chain()->expressions().size(), expected_chain_size); + EXPECT_EQ(in.invert_expression(), expected_invert_expression); + + for (const auto& chained_expression : in.expression_chain()->expressions()) + EXPECT(!is<SQL::ErrorExpression>(chained_expression)); + }; + + validate("15 IN ()", 0, false); + validate("15 IN (15)", 1, false); + validate("15 IN (15, 16)", 2, false); + + validate("15 NOT IN ()", 0, true); + validate("15 NOT IN (15)", 1, true); + validate("15 NOT IN (15, 16)", 2, true); +} + +TEST_MAIN(SqlExpressionParser) |