/* * Copyright (c) 2021, Jan de Visser * Copyright (c) 2022, Tim Flynn * * SPDX-License-Identifier: BSD-2-Clause */ #include #include #include #include #include #include #include namespace SQL { // We use the upper 4 bits of the encoded type to store extra information about the type. This // includes if the value is null, and the encoded size of any integer type. Of course, this encoding // only works if the SQL type itself fits in the lower 4 bits. enum class SQLTypeWithCount { #undef __ENUMERATE_SQL_TYPE #define __ENUMERATE_SQL_TYPE(name, type) type, ENUMERATE_SQL_TYPES(__ENUMERATE_SQL_TYPE) #undef __ENUMERATE_SQL_TYPE Count, }; static_assert(to_underlying(SQLTypeWithCount::Count) <= 0x0f, "Too many SQL types for current encoding"); // Adding to this list is fine, but changing the order of any value here will result in LibSQL // becoming unable to read existing .db files. If the order must absolutely be changed, be sure // to bump Heap::VERSION. enum class TypeData : u8 { Null = 1 << 4, Int8 = 2 << 4, Int16 = 3 << 4, Int32 = 4 << 4, Int64 = 5 << 4, Uint8 = 6 << 4, Uint16 = 7 << 4, Uint32 = 8 << 4, Uint64 = 9 << 4, }; template static decltype(auto) downsize_integer(Integer auto value, Callback&& callback) { if constexpr (IsSigned) { if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Int8); if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Int16); if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Int32); return callback(value, TypeData::Int64); } else { if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Uint8); if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Uint16); if (AK::is_within_range(value)) return callback(static_cast(value), TypeData::Uint32); return callback(value, TypeData::Uint64); } } template static decltype(auto) downsize_integer(Value const& value, Callback&& callback) { VERIFY(value.is_int()); if (value.value().has()) return downsize_integer(value.value().get(), forward(callback)); return downsize_integer(value.value().get(), forward(callback)); } template static ResultOr perform_integer_operation(Value const& lhs, Value const& rhs, Callback&& callback) { VERIFY(lhs.is_int()); VERIFY(rhs.is_int()); if (lhs.value().has()) { if (auto rhs_value = rhs.to_int(); rhs_value.has_value()) return callback(lhs.to_int().value(), rhs_value.value()); } else { if (auto rhs_value = rhs.to_int(); rhs_value.has_value()) return callback(lhs.to_int().value(), rhs_value.value()); } return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; } Value::Value(SQLType type) : m_type(type) { } Value::Value(DeprecatedString value) : m_type(SQLType::Text) , m_value(move(value)) { } Value::Value(double value) { if (trunc(value) == value) { if (AK::is_within_range(value)) { m_type = SQLType::Integer; m_value = static_cast(value); return; } if (AK::is_within_range(value)) { m_type = SQLType::Integer; m_value = static_cast(value); return; } } m_type = SQLType::Float; m_value = value; } Value::Value(NonnullRefPtr descriptor, Vector values) : m_type(SQLType::Tuple) , m_value(TupleValue { move(descriptor), move(values) }) { } Value::Value(Value const& other) : m_type(other.m_type) , m_value(other.m_value) { } Value::Value(Value&& other) : m_type(other.m_type) , m_value(move(other.m_value)) { } Value::Value(Duration duration) : m_type(SQLType::Integer) , m_value(duration.to_milliseconds()) { } Value::Value(UnixDateTime time) : Value(time.offset_to_epoch()) { } Value::~Value() = default; ResultOr Value::create_tuple(NonnullRefPtr descriptor) { Vector values; TRY(values.try_resize(descriptor->size())); for (size_t i = 0; i < descriptor->size(); ++i) values[i].m_type = descriptor->at(i).type; return Value { move(descriptor), move(values) }; } ResultOr Value::create_tuple(Vector values) { auto descriptor = TRY(infer_tuple_descriptor(values)); return Value { move(descriptor), move(values) }; } SQLType Value::type() const { return m_type; } StringView Value::type_name() const { switch (type()) { #undef __ENUMERATE_SQL_TYPE #define __ENUMERATE_SQL_TYPE(name, type) \ case SQLType::type: \ return name##sv; ENUMERATE_SQL_TYPES(__ENUMERATE_SQL_TYPE) #undef __ENUMERATE_SQL_TYPE default: VERIFY_NOT_REACHED(); } } bool Value::is_type_compatible_with(SQLType other_type) const { switch (type()) { case SQLType::Null: return false; case SQLType::Integer: case SQLType::Float: return other_type == SQLType::Integer || other_type == SQLType::Float; default: break; } return type() == other_type; } bool Value::is_null() const { return !m_value.has_value(); } bool Value::is_int() const { return m_value.has_value() && (m_value->has() || m_value->has()); } DeprecatedString Value::to_deprecated_string() const { if (is_null()) return "(null)"sv; return m_value->visit( [](DeprecatedString const& value) -> DeprecatedString { return value; }, [](Integer auto value) -> DeprecatedString { return DeprecatedString::number(value); }, [](double value) -> DeprecatedString { return DeprecatedString::number(value); }, [](bool value) -> DeprecatedString { return value ? "true"sv : "false"sv; }, [](TupleValue const& value) -> DeprecatedString { StringBuilder builder; builder.append('('); builder.join(',', value.values); builder.append(')'); return builder.to_deprecated_string(); }); } Optional Value::to_double() const { if (is_null()) return {}; return m_value->visit( [](DeprecatedString const& value) -> Optional { return value.to_double(); }, [](Integer auto value) -> Optional { return static_cast(value); }, [](double value) -> Optional { return value; }, [](bool value) -> Optional { return static_cast(value); }, [](TupleValue const&) -> Optional { return {}; }); } Optional Value::to_bool() const { if (is_null()) return {}; return m_value->visit( [](DeprecatedString const& value) -> Optional { if (value.equals_ignoring_ascii_case("true"sv) || value.equals_ignoring_ascii_case("t"sv)) return true; if (value.equals_ignoring_ascii_case("false"sv) || value.equals_ignoring_ascii_case("f"sv)) return false; return {}; }, [](Integer auto value) -> Optional { return static_cast(value); }, [](double value) -> Optional { return fabs(value) > NumericLimits::epsilon(); }, [](bool value) -> Optional { return value; }, [](TupleValue const& value) -> Optional { for (auto const& element : value.values) { auto as_bool = element.to_bool(); if (!as_bool.has_value()) return {}; if (!as_bool.value()) return false; } return true; }); } Optional> Value::to_vector() const { if (is_null() || (type() != SQLType::Tuple)) return {}; auto const& tuple = m_value->get(); return tuple.values; } Value& Value::operator=(Value value) { m_type = value.m_type; m_value = move(value.m_value); return *this; } Value& Value::operator=(DeprecatedString value) { m_type = SQLType::Text; m_value = move(value); return *this; } Value& Value::operator=(double value) { m_type = SQLType::Float; m_value = value; return *this; } ResultOr Value::assign_tuple(NonnullRefPtr descriptor) { Vector values; TRY(values.try_resize(descriptor->size())); for (size_t i = 0; i < descriptor->size(); ++i) values[i].m_type = descriptor->at(i).type; m_type = SQLType::Tuple; m_value = TupleValue { move(descriptor), move(values) }; return {}; } ResultOr Value::assign_tuple(Vector values) { if (is_null() || (type() != SQLType::Tuple)) { auto descriptor = TRY(infer_tuple_descriptor(values)); m_type = SQLType::Tuple; m_value = TupleValue { move(descriptor), move(values) }; return {}; } auto& tuple = m_value->get(); if (values.size() > tuple.descriptor->size()) return Result { SQLCommand::Unknown, SQLErrorCode::InvalidNumberOfValues }; for (size_t i = 0; i < values.size(); ++i) { if (values[i].type() != tuple.descriptor->at(i).type) return Result { SQLCommand::Unknown, SQLErrorCode::InvalidType, SQLType_name(values[i].type()) }; } if (values.size() < tuple.descriptor->size()) { size_t original_size = values.size(); MUST(values.try_resize(tuple.descriptor->size())); for (size_t i = original_size; i < values.size(); ++i) values[i].m_type = tuple.descriptor->at(i).type; } m_value = TupleValue { move(tuple.descriptor), move(values) }; return {}; } size_t Value::length() const { if (is_null()) return 0; // FIXME: This seems to be more of an encoded byte size rather than a length. return m_value->visit( [](DeprecatedString const& value) -> size_t { return sizeof(u32) + value.length(); }, [](Integer auto value) -> size_t { return downsize_integer(value, [](auto integer, auto) { return sizeof(integer); }); }, [](double value) -> size_t { return sizeof(value); }, [](bool value) -> size_t { return sizeof(value); }, [](TupleValue const& value) -> size_t { auto size = value.descriptor->length() + sizeof(u32); for (auto const& element : value.values) size += element.length(); return size; }); } u32 Value::hash() const { if (is_null()) return 0; return m_value->visit( [](DeprecatedString const& value) -> u32 { return value.hash(); }, [](Integer auto value) -> u32 { return downsize_integer(value, [](auto integer, auto) { if constexpr (sizeof(decltype(integer)) == 8) return u64_hash(integer); else return int_hash(integer); }); }, [](double) -> u32 { VERIFY_NOT_REACHED(); }, [](bool value) -> u32 { return int_hash(value); }, [](TupleValue const& value) -> u32 { u32 hash = 0; for (auto const& element : value.values) { if (hash == 0) hash = element.hash(); else hash = pair_int_hash(hash, element.hash()); } return hash; }); } int Value::compare(Value const& other) const { if (is_null()) return -1; if (other.is_null()) return 1; return m_value->visit( [&](DeprecatedString const& value) -> int { return value.view().compare(other.to_deprecated_string()); }, [&](Integer auto value) -> int { auto casted = other.to_int>(); if (!casted.has_value()) return 1; if (value == *casted) return 0; return value < *casted ? -1 : 1; }, [&](double value) -> int { auto casted = other.to_double(); if (!casted.has_value()) return 1; auto diff = value - *casted; if (fabs(diff) < NumericLimits::epsilon()) return 0; return diff < 0 ? -1 : 1; }, [&](bool value) -> int { auto casted = other.to_bool(); if (!casted.has_value()) return 1; return value ^ *casted; }, [&](TupleValue const& value) -> int { if (other.is_null() || (other.type() != SQLType::Tuple)) { if (value.values.size() == 1) return value.values[0].compare(other); return 1; } auto const& other_value = other.m_value->get(); if (auto result = value.descriptor->compare_ignoring_names(*other_value.descriptor); result != 0) return 1; if (value.values.size() != other_value.values.size()) return value.values.size() < other_value.values.size() ? -1 : 1; for (size_t i = 0; i < value.values.size(); ++i) { auto result = value.values[i].compare(other_value.values[i]); if (result == 0) continue; if (value.descriptor->at(i).order == Order::Descending) result = -result; return result; } return 0; }); } bool Value::operator==(Value const& value) const { return compare(value) == 0; } bool Value::operator==(StringView value) const { return to_deprecated_string() == value; } bool Value::operator==(double value) const { return to_double() == value; } bool Value::operator!=(Value const& value) const { return compare(value) != 0; } bool Value::operator<(Value const& value) const { return compare(value) < 0; } bool Value::operator<=(Value const& value) const { return compare(value) <= 0; } bool Value::operator>(Value const& value) const { return compare(value) > 0; } bool Value::operator>=(Value const& value) const { return compare(value) >= 0; } template static Result invalid_type_for_numeric_operator(Operator op) { if constexpr (IsSame) return { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, BinaryOperator_name(op) }; else if constexpr (IsSame) return { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(op) }; else static_assert(DependentFalse); } ResultOr Value::add(Value const& other) const { if (is_int() && other.is_int()) { return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { Checked result { lhs }; result.add(rhs); if (result.has_overflow()) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { result.value_unchecked() }; }); } auto lhs = to_double(); auto rhs = other.to_double(); if (!lhs.has_value() || !rhs.has_value()) return invalid_type_for_numeric_operator(AST::BinaryOperator::Plus); return Value { lhs.value() + rhs.value() }; } ResultOr Value::subtract(Value const& other) const { if (is_int() && other.is_int()) { return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { Checked result { lhs }; result.sub(rhs); if (result.has_overflow()) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { result.value_unchecked() }; }); } auto lhs = to_double(); auto rhs = other.to_double(); if (!lhs.has_value() || !rhs.has_value()) return invalid_type_for_numeric_operator(AST::BinaryOperator::Minus); return Value { lhs.value() - rhs.value() }; } ResultOr Value::multiply(Value const& other) const { if (is_int() && other.is_int()) { return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { Checked result { lhs }; result.mul(rhs); if (result.has_overflow()) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { result.value_unchecked() }; }); } auto lhs = to_double(); auto rhs = other.to_double(); if (!lhs.has_value() || !rhs.has_value()) return invalid_type_for_numeric_operator(AST::BinaryOperator::Multiplication); return Value { lhs.value() * rhs.value() }; } ResultOr Value::divide(Value const& other) const { auto lhs = to_double(); auto rhs = other.to_double(); if (!lhs.has_value() || !rhs.has_value()) return invalid_type_for_numeric_operator(AST::BinaryOperator::Division); if (rhs == 0.0) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { lhs.value() / rhs.value() }; } ResultOr Value::modulo(Value const& other) const { if (!is_int() || !other.is_int()) return invalid_type_for_numeric_operator(AST::BinaryOperator::Modulo); return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { Checked result { lhs }; result.mod(rhs); if (result.has_overflow()) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { result.value_unchecked() }; }); } ResultOr Value::negate() const { if (type() == SQLType::Integer) { auto value = to_int(); if (!value.has_value()) return invalid_type_for_numeric_operator(AST::UnaryOperator::Minus); return Value { value.value() * -1 }; } if (type() == SQLType::Float) return Value { -to_double().value() }; return invalid_type_for_numeric_operator(AST::UnaryOperator::Minus); } ResultOr Value::shift_left(Value const& other) const { if (!is_int() || !other.is_int()) return invalid_type_for_numeric_operator(AST::BinaryOperator::ShiftLeft); return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { using LHS = decltype(lhs); using RHS = decltype(rhs); static constexpr auto max_shift = static_cast(sizeof(LHS) * 8); if (rhs < 0 || rhs >= max_shift) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { lhs << rhs }; }); } ResultOr Value::shift_right(Value const& other) const { if (!is_int() || !other.is_int()) return invalid_type_for_numeric_operator(AST::BinaryOperator::ShiftRight); return perform_integer_operation(*this, other, [](auto lhs, auto rhs) -> ResultOr { using LHS = decltype(lhs); using RHS = decltype(rhs); static constexpr auto max_shift = static_cast(sizeof(LHS) * 8); if (rhs < 0 || rhs >= max_shift) return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOverflow }; return Value { lhs >> rhs }; }); } ResultOr Value::bitwise_or(Value const& other) const { if (!is_int() || !other.is_int()) return invalid_type_for_numeric_operator(AST::BinaryOperator::BitwiseOr); return perform_integer_operation(*this, other, [](auto lhs, auto rhs) { return Value { lhs | rhs }; }); } ResultOr Value::bitwise_and(Value const& other) const { if (!is_int() || !other.is_int()) return invalid_type_for_numeric_operator(AST::BinaryOperator::BitwiseAnd); return perform_integer_operation(*this, other, [](auto lhs, auto rhs) { return Value { lhs & rhs }; }); } ResultOr Value::bitwise_not() const { if (!is_int()) return invalid_type_for_numeric_operator(AST::UnaryOperator::BitwiseNot); return downsize_integer(*this, [](auto value, auto) { return Value { ~value }; }); } static u8 encode_type_flags(Value const& value) { auto type_flags = to_underlying(value.type()); if (value.is_null()) { type_flags |= to_underlying(TypeData::Null); } else if (value.is_int()) { downsize_integer(value, [&](auto, auto type_data) { type_flags |= to_underlying(type_data); }); } return type_flags; } void Value::serialize(Serializer& serializer) const { auto type_flags = encode_type_flags(*this); serializer.serialize(type_flags); if (is_null()) return; if (is_int()) { downsize_integer(*this, [&](auto integer, auto) { serializer.serialize(integer); }); return; } m_value->visit( [&](TupleValue const& value) { serializer.serialize(*value.descriptor); serializer.serialize(static_cast(value.values.size())); for (auto const& element : value.values) serializer.serialize(element); }, [&](auto const& value) { serializer.serialize(value); }); } void Value::deserialize(Serializer& serializer) { auto type_flags = serializer.deserialize(); auto type_data = static_cast(type_flags & 0xf0); m_type = static_cast(type_flags & 0x0f); if (type_data == TypeData::Null) return; switch (m_type) { case SQLType::Null: VERIFY_NOT_REACHED(); case SQLType::Text: m_value = serializer.deserialize(); break; case SQLType::Integer: switch (type_data) { case TypeData::Int8: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Int16: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Int32: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Int64: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Uint8: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Uint16: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Uint32: m_value = static_cast(serializer.deserialize(0)); break; case TypeData::Uint64: m_value = static_cast(serializer.deserialize(0)); break; default: VERIFY_NOT_REACHED(); } break; case SQLType::Float: m_value = serializer.deserialize(0.0); break; case SQLType::Boolean: m_value = serializer.deserialize(false); break; case SQLType::Tuple: { auto descriptor = serializer.adopt_and_deserialize(); auto size = serializer.deserialize(); Vector values; values.ensure_capacity(size); for (size_t i = 0; i < size; ++i) values.unchecked_append(serializer.deserialize()); m_value = TupleValue { move(descriptor), move(values) }; break; } } } TupleElementDescriptor Value::descriptor() const { return { "", "", "", type(), Order::Ascending }; } ResultOr> Value::infer_tuple_descriptor(Vector const& values) { auto descriptor = TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQL::TupleDescriptor)); TRY(descriptor->try_ensure_capacity(values.size())); for (auto const& element : values) descriptor->unchecked_append({ ""sv, ""sv, ""sv, element.type(), Order::Ascending }); return descriptor; } } template<> ErrorOr IPC::encode(Encoder& encoder, SQL::Value const& value) { auto type_flags = encode_type_flags(value); TRY(encoder.encode(type_flags)); if (value.is_null()) return {}; switch (value.type()) { case SQL::SQLType::Null: return {}; case SQL::SQLType::Text: return encoder.encode(value.to_deprecated_string()); case SQL::SQLType::Integer: return SQL::downsize_integer(value, [&](auto integer, auto) { return encoder.encode(integer); }); case SQL::SQLType::Float: return encoder.encode(value.to_double().value()); case SQL::SQLType::Boolean: return encoder.encode(value.to_bool().value()); case SQL::SQLType::Tuple: return encoder.encode(value.to_vector().value()); } VERIFY_NOT_REACHED(); } template<> ErrorOr IPC::decode(Decoder& decoder) { auto type_flags = TRY(decoder.decode()); auto type_data = static_cast(type_flags & 0xf0); auto type = static_cast(type_flags & 0x0f); if (type_data == SQL::TypeData::Null) return SQL::Value { type }; switch (type) { case SQL::SQLType::Null: return SQL::Value {}; case SQL::SQLType::Text: return SQL::Value { TRY(decoder.decode()) }; case SQL::SQLType::Integer: switch (type_data) { case SQL::TypeData::Int8: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Int16: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Int32: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Int64: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Uint8: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Uint16: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Uint32: return SQL::Value { TRY(decoder.decode()) }; case SQL::TypeData::Uint64: return SQL::Value { TRY(decoder.decode()) }; default: break; } break; case SQL::SQLType::Float: return SQL::Value { TRY(decoder.decode()) }; case SQL::SQLType::Boolean: return SQL::Value { TRY(decoder.decode()) }; case SQL::SQLType::Tuple: { auto tuple = TRY(decoder.decode>()); auto value = SQL::Value::create_tuple(move(tuple)); if (value.is_error()) return Error::from_errno(to_underlying(value.error().error())); return value.release_value(); } } VERIFY_NOT_REACHED(); }