/* * Copyright (c) 2021, Jan de Visser * Copyright (c) 2022, the SerenityOS developers. * * SPDX-License-Identifier: BSD-2-Clause */ #include #include #include #if !defined(AK_OS_SERENITY) # include # include # include # include # include # include #endif namespace SQL { #if !defined(AK_OS_SERENITY) // This is heavily based on how SystemServer's Service creates its socket. static ErrorOr create_database_socket(DeprecatedString const& socket_path) { if (Core::DeprecatedFile::exists(socket_path)) TRY(Core::System::unlink(socket_path)); # ifdef SOCK_NONBLOCK auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); # else auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0)); int option = 1; TRY(Core::System::ioctl(socket_fd, FIONBIO, &option)); TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC)); # endif # if !defined(AK_OS_MACOS) && !defined(AK_OS_FREEBSD) && !defined(AK_OS_OPENBSD) TRY(Core::System::fchmod(socket_fd, 0600)); # endif auto socket_address = Core::SocketAddress::local(socket_path); auto socket_address_un = socket_address.to_sockaddr_un().release_value(); TRY(Core::System::bind(socket_fd, reinterpret_cast(&socket_address_un), sizeof(socket_address_un))); TRY(Core::System::listen(socket_fd, 16)); return socket_fd; } static ErrorOr launch_server(DeprecatedString const& socket_path, DeprecatedString const& pid_path, Vector candidate_server_paths) { auto server_fd_or_error = create_database_socket(socket_path); if (server_fd_or_error.is_error()) { warnln("Failed to create a database socket at {}: {}", socket_path, server_fd_or_error.error()); return server_fd_or_error.release_error(); } auto server_fd = server_fd_or_error.value(); auto server_pid = TRY(Core::System::fork()); if (server_pid == 0) { TRY(Core::System::setsid()); TRY(Core::System::signal(SIGCHLD, SIG_IGN)); server_pid = TRY(Core::System::fork()); if (server_pid != 0) { auto server_pid_file = TRY(Core::Stream::File::open(pid_path, Core::Stream::OpenMode::Write)); TRY(server_pid_file->write(DeprecatedString::number(server_pid).bytes())); TRY(Core::System::kill(getpid(), SIGTERM)); } server_fd = TRY(Core::System::dup(server_fd)); auto takeover_string = DeprecatedString::formatted("SQLServer:{}", server_fd); TRY(Core::System::setenv("SOCKET_TAKEOVER"sv, takeover_string, true)); ErrorOr result; for (auto const& server_path : candidate_server_paths) { auto arguments = Array { server_path.bytes_as_string_view(), "--pid-file"sv, pid_path, }; result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes); if (!result.is_error()) break; } if (result.is_error()) { warnln("Could not launch any of {}: {}", candidate_server_paths, result.error()); TRY(Core::System::unlink(pid_path)); } VERIFY_NOT_REACHED(); } TRY(Core::System::waitpid(server_pid)); return {}; } static ErrorOr should_launch_server(DeprecatedString const& pid_path) { if (!Core::DeprecatedFile::exists(pid_path)) return true; Optional pid; { auto server_pid_file = Core::Stream::File::open(pid_path, Core::Stream::OpenMode::Read); if (server_pid_file.is_error()) { warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error()); return server_pid_file.release_error(); } auto contents = server_pid_file.value()->read_until_eof(); if (contents.is_error()) { warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error()); return contents.release_error(); } pid = StringView { contents.value() }.to_int(); } if (!pid.has_value()) { warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path); TRY(Core::System::unlink(pid_path)); return true; } if (kill(*pid, 0) < 0) { warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid); TRY(Core::System::unlink(pid_path)); return true; } return false; } ErrorOr> SQLClient::launch_server_and_create_client(Vector candidate_server_paths) { auto runtime_directory = TRY(Core::StandardPaths::runtime_directory()); auto socket_path = DeprecatedString::formatted("{}/SQLServer.socket", runtime_directory); auto pid_path = DeprecatedString::formatted("{}/SQLServer.pid", runtime_directory); if (TRY(should_launch_server(pid_path))) TRY(launch_server(socket_path, pid_path, move(candidate_server_paths))); auto socket = TRY(Core::Stream::LocalSocket::connect(move(socket_path))); TRY(socket->set_blocking(true)); return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(move(socket))); } #endif void SQLClient::execution_success(u64 statement_id, u64 execution_id, Vector const& column_names, bool has_results, size_t created, size_t updated, size_t deleted) { if (!on_execution_success) { outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); return; } ExecutionSuccess success { .statement_id = statement_id, .execution_id = execution_id, .column_names = move(const_cast&>(column_names)), .has_results = has_results, .rows_created = created, .rows_updated = updated, .rows_deleted = deleted, }; on_execution_success(move(success)); } void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) { if (!on_execution_error) { warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); return; } ExecutionError error { .statement_id = statement_id, .execution_id = execution_id, .error_code = code, .error_message = move(const_cast(message)), }; on_execution_error(move(error)); } void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) { if (!on_next_result) { StringBuilder builder; builder.join(", "sv, row, "\"{}\""sv); outln("{}", builder.string_view()); return; } ExecutionResult result { .statement_id = statement_id, .execution_id = execution_id, .values = move(const_cast&>(row)), }; on_next_result(move(result)); } void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) { if (!on_results_exhausted) { outln("{} total row(s)", total_rows); return; } ExecutionComplete success { .statement_id = statement_id, .execution_id = execution_id, .total_rows = total_rows, }; on_results_exhausted(move(success)); } }