summaryrefslogtreecommitdiff
path: root/Userland/Libraries/LibSQL/AST/Insert.cpp
blob: 48ae68b73a96f931b92b32bd0e5a4e56ccd87200 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/*
 * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
 * Copyright (c) 2021, Mahmoud Mandour <ma.mandourr@gmail.com>
 *
 * SPDX-License-Identifier: BSD-2-Clause
 */

#include <LibSQL/AST/AST.h>
#include <LibSQL/Database.h>
#include <LibSQL/Meta.h>
#include <LibSQL/Row.h>

namespace SQL::AST {

static bool does_value_data_type_match(SQLType expected, SQLType actual)
{
    if (actual == SQLType::Null)
        return false;
    if (expected == SQLType::Integer)
        return actual == SQLType::Integer || actual == SQLType::Float;
    return expected == actual;
}

ResultOr<ResultSet> Insert::execute(ExecutionContext& context) const
{
    auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name));

    if (!table_def) {
        auto schema_name = m_schema_name.is_empty() ? String("default"sv) : m_schema_name;
        return Result { SQLCommand::Insert, SQLErrorCode::TableDoesNotExist, String::formatted("{}.{}", schema_name, m_table_name) };
    }

    Row row(table_def);
    for (auto& column : m_column_names) {
        if (!row.has(column))
            return Result { SQLCommand::Insert, SQLErrorCode::ColumnDoesNotExist, column };
    }

    ResultSet result { SQLCommand::Insert };
    TRY(result.try_ensure_capacity(m_chained_expressions.size()));

    for (auto& row_expr : m_chained_expressions) {
        for (auto& column_def : table_def->columns()) {
            if (!m_column_names.contains_slow(column_def.name()))
                row[column_def.name()] = column_def.default_value();
        }

        auto row_value = TRY(row_expr.evaluate(context));
        VERIFY(row_value.type() == SQLType::Tuple);
        auto values = row_value.to_vector().value();

        if (m_column_names.is_empty() && values.size() != row.size())
            return Result { SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, String::empty() };

        for (auto ix = 0u; ix < values.size(); ix++) {
            auto input_value_type = values[ix].type();
            auto& tuple_descriptor = *row.descriptor();
            // In case of having column names, this must succeed since we checked for every column name for existence in the table.
            auto element_index = m_column_names.is_empty() ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
            auto element_type = tuple_descriptor[element_index].type;

            if (!does_value_data_type_match(element_type, input_value_type))
                return Result { SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name() };

            row[element_index] = values[ix];
        }

        TRY(context.database->insert(row));
        result.insert_row(row, {});
    }

    return result;
}

}