diff options
author | Ali Mohammad Pur <ali.mpfard@gmail.com> | 2022-02-02 19:21:55 +0330 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2022-02-06 13:10:10 +0100 |
commit | aafc451016b342886f7f1d0d51794f064f313e04 (patch) | |
tree | db1d8931ea335c87b508149e0f46c916bc6dbb3c /Userland | |
parent | 7a95c451a3bc4eb379e12c8243b878f73a3dd03d (diff) | |
download | serenity-aafc451016b342886f7f1d0d51794f064f313e04.zip |
Userland: Convert TLS::TLSv12 to a Core::Stream::Socket
This commit converts TLS::TLSv12 to a Core::Stream object, and in the
process allows TLS to now wrap other Core::Stream::Socket objects.
As a large part of LibHTTP and LibGemini depend on LibTLS's interface,
this also converts those to support Core::Stream, which leads to a
simplification of LibHTTP (as there's no need to care about the
underlying socket type anymore).
Note that RequestServer now controls the TLS socket options, which is a
better place anyway, as RS is the first receiver of the user-requested
options (though this is currently not particularly useful).
Diffstat (limited to 'Userland')
46 files changed, 808 insertions, 1113 deletions
diff --git a/Userland/Applications/Browser/Tab.h b/Userland/Applications/Browser/Tab.h index 02fbc9e3f9..919bf93325 100644 --- a/Userland/Applications/Browser/Tab.h +++ b/Userland/Applications/Browser/Tab.h @@ -11,7 +11,7 @@ #include <LibGUI/ActionGroup.h> #include <LibGUI/Widget.h> #include <LibGfx/ShareableBitmap.h> -#include <LibHTTP/HttpJob.h> +#include <LibHTTP/Job.h> #include <LibWeb/Forward.h> namespace Web { diff --git a/Userland/Libraries/LibCore/NetworkJob.cpp b/Userland/Libraries/LibCore/NetworkJob.cpp index 4dcc927ca0..03efac0069 100644 --- a/Userland/Libraries/LibCore/NetworkJob.cpp +++ b/Userland/Libraries/LibCore/NetworkJob.cpp @@ -10,7 +10,7 @@ namespace Core { -NetworkJob::NetworkJob(OutputStream& output_stream) +NetworkJob::NetworkJob(Core::Stream::Stream& output_stream) : m_output_stream(output_stream) { } @@ -19,7 +19,7 @@ NetworkJob::~NetworkJob() { } -void NetworkJob::start(NonnullRefPtr<Core::Socket>) +void NetworkJob::start(Core::Stream::Socket&) { } diff --git a/Userland/Libraries/LibCore/NetworkJob.h b/Userland/Libraries/LibCore/NetworkJob.h index e2eda0b407..e17f531796 100644 --- a/Userland/Libraries/LibCore/NetworkJob.h +++ b/Userland/Libraries/LibCore/NetworkJob.h @@ -8,7 +8,9 @@ #include <AK/Function.h> #include <AK/Stream.h> +#include <LibCore/Forward.h> #include <LibCore/Object.h> +#include <LibCore/Stream.h> namespace Core { @@ -39,8 +41,9 @@ public: DetachFromSocket, CloseSocket, }; - virtual void start(NonnullRefPtr<Core::Socket>) = 0; + virtual void start(Core::Stream::Socket&) = 0; virtual void shutdown(ShutdownMode) = 0; + virtual void fail(Error error) { did_fail(error); } void cancel() { @@ -49,16 +52,16 @@ public: } protected: - NetworkJob(OutputStream&); + NetworkJob(Core::Stream::Stream&); void did_finish(NonnullRefPtr<NetworkResponse>&&); void did_fail(Error); void did_progress(Optional<u32> total_size, u32 downloaded); - size_t do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); } + ErrorOr<size_t> do_write(ReadonlyBytes bytes) { return m_output_stream.write(bytes); } private: RefPtr<NetworkResponse> m_response; - OutputStream& m_output_stream; + Core::Stream::Stream& m_output_stream; Error m_error { Error::None }; }; diff --git a/Userland/Libraries/LibGemini/CMakeLists.txt b/Userland/Libraries/LibGemini/CMakeLists.txt index dbd453fed2..97fae1c763 100644 --- a/Userland/Libraries/LibGemini/CMakeLists.txt +++ b/Userland/Libraries/LibGemini/CMakeLists.txt @@ -1,6 +1,5 @@ set(SOURCES Document.cpp - GeminiJob.cpp GeminiRequest.cpp GeminiResponse.cpp Job.cpp diff --git a/Userland/Libraries/LibGemini/Forward.h b/Userland/Libraries/LibGemini/Forward.h index 390175473f..0a6cc72bca 100644 --- a/Userland/Libraries/LibGemini/Forward.h +++ b/Userland/Libraries/LibGemini/Forward.h @@ -11,7 +11,6 @@ namespace Gemini { class Document; class GeminiRequest; class GeminiResponse; -class GeminiJob; class Job; } diff --git a/Userland/Libraries/LibGemini/GeminiJob.cpp b/Userland/Libraries/LibGemini/GeminiJob.cpp deleted file mode 100644 index 57cde2260f..0000000000 --- a/Userland/Libraries/LibGemini/GeminiJob.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020, the SerenityOS developers. - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#include <AK/Debug.h> -#include <LibCore/EventLoop.h> -#include <LibGemini/GeminiJob.h> -#include <LibGemini/GeminiResponse.h> -#include <LibTLS/TLSv12.h> -#include <stdio.h> -#include <unistd.h> - -namespace Gemini { - -void GeminiJob::start(NonnullRefPtr<Core::Socket> socket) -{ - VERIFY(!m_socket); - VERIFY(is<TLS::TLSv12>(*socket)); - m_socket = static_ptr_cast<TLS::TLSv12>(socket); - m_socket->on_tls_error = [this](TLS::AlertDescription error) { - if (error == TLS::AlertDescription::HandshakeFailure) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ProtocolFailed); - }); - } else if (error == TLS::AlertDescription::DecryptError) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } else { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::TransmissionFailed); - }); - } - }; - m_socket->on_tls_finished = [this] { - finish_up(); - }; - m_socket->on_tls_certificate_request = [this](auto&) { - if (on_certificate_requested) - on_certificate_requested(*this); - }; - - m_socket->set_idle(false); - if (m_socket->is_established()) { - deferred_invoke([this] { on_socket_connected(); }); - } else { - m_socket->set_root_certificates(m_override_ca_certificates ? *m_override_ca_certificates : DefaultRootCACertificates::the().certificates()); - m_socket->on_tls_connected = [this] { - on_socket_connected(); - }; - bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port_or_default()); - if (!success) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } - } -} - -void GeminiJob::shutdown(ShutdownMode mode) -{ - if (!m_socket) - return; - if (mode == ShutdownMode::CloseSocket) { - m_socket->close(); - } else { - m_socket->on_tls_ready_to_read = nullptr; - m_socket->on_tls_connected = nullptr; - m_socket->set_idle(true); - m_socket = nullptr; - } -} - -void GeminiJob::read_while_data_available(Function<IterationDecision()> read) -{ - while (m_socket->can_read()) { - if (read() == IterationDecision::Break) - break; - } -} - -void GeminiJob::set_certificate(String certificate, String private_key) -{ - if (!m_socket->add_client_key(certificate.bytes(), private_key.bytes())) { - dbgln("LibGemini: Failed to set a client certificate"); - // FIXME: Do something about this failure - VERIFY_NOT_REACHED(); - } -} - -void GeminiJob::register_on_ready_to_read(Function<void()> callback) -{ - m_socket->on_tls_ready_to_read = [callback = move(callback)](auto&) { - callback(); - }; -} - -void GeminiJob::register_on_ready_to_write(Function<void()> callback) -{ - m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { - Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); - callback(); - }); -} - -bool GeminiJob::can_read_line() const -{ - return m_socket->can_read_line(); -} - -String GeminiJob::read_line(size_t size) -{ - return m_socket->read_line(size); -} - -ByteBuffer GeminiJob::receive(size_t size) -{ - return m_socket->read(size); -} - -bool GeminiJob::can_read() const -{ - return m_socket->can_read(); -} - -bool GeminiJob::eof() const -{ - return m_socket->eof(); -} - -bool GeminiJob::write(ReadonlyBytes bytes) -{ - return m_socket->write(bytes); -} - -} diff --git a/Userland/Libraries/LibGemini/GeminiJob.h b/Userland/Libraries/LibGemini/GeminiJob.h deleted file mode 100644 index ec554ee294..0000000000 --- a/Userland/Libraries/LibGemini/GeminiJob.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2020, the SerenityOS developers. - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#pragma once - -#include <LibCore/NetworkJob.h> -#include <LibGemini/GeminiRequest.h> -#include <LibGemini/GeminiResponse.h> -#include <LibGemini/Job.h> -#include <LibTLS/TLSv12.h> - -namespace Gemini { - -class GeminiJob final : public Job { - C_OBJECT(GeminiJob) -public: - virtual ~GeminiJob() override - { - } - - virtual void start(NonnullRefPtr<Core::Socket>) override; - virtual void shutdown(ShutdownMode) override; - void set_certificate(String certificate, String key); - - Core::Socket const* socket() const { return m_socket; } - URL url() const { return m_request.url(); } - - Function<void(GeminiJob&)> on_certificate_requested; - -protected: - virtual void register_on_ready_to_read(Function<void()>) override; - virtual void register_on_ready_to_write(Function<void()>) override; - virtual bool can_read_line() const override; - virtual String read_line(size_t) override; - virtual bool can_read() const override; - virtual ByteBuffer receive(size_t) override; - virtual bool eof() const override; - virtual bool write(ReadonlyBytes) override; - virtual bool is_established() const override { return m_socket->is_established(); } - virtual bool should_fail_on_empty_payload() const override { return false; } - virtual void read_while_data_available(Function<IterationDecision()>) override; - -private: - explicit GeminiJob(const GeminiRequest& request, OutputStream& output_stream, const Vector<Certificate>* override_certificates = nullptr) - : Job(request, output_stream) - , m_override_ca_certificates(override_certificates) - { - } - - RefPtr<TLS::TLSv12> m_socket; - const Vector<Certificate>* m_override_ca_certificates { nullptr }; -}; - -} diff --git a/Userland/Libraries/LibGemini/GeminiRequest.cpp b/Userland/Libraries/LibGemini/GeminiRequest.cpp index 3f219ca273..f1ffa98c63 100644 --- a/Userland/Libraries/LibGemini/GeminiRequest.cpp +++ b/Userland/Libraries/LibGemini/GeminiRequest.cpp @@ -6,7 +6,6 @@ #include <AK/StringBuilder.h> #include <AK/URL.h> -#include <LibGemini/GeminiJob.h> #include <LibGemini/GeminiRequest.h> namespace Gemini { diff --git a/Userland/Libraries/LibGemini/Job.cpp b/Userland/Libraries/LibGemini/Job.cpp index ed9aec3f89..9cb02f454b 100644 --- a/Userland/Libraries/LibGemini/Job.cpp +++ b/Userland/Libraries/LibGemini/Job.cpp @@ -5,14 +5,14 @@ */ #include <AK/Debug.h> +#include <LibCore/Stream.h> #include <LibGemini/GeminiResponse.h> #include <LibGemini/Job.h> -#include <stdio.h> #include <unistd.h> namespace Gemini { -Job::Job(const GeminiRequest& request, OutputStream& output_stream) +Job::Job(const GeminiRequest& request, Core::Stream::Stream& output_stream) : Core::NetworkJob(output_stream) , m_request(request) { @@ -22,12 +22,83 @@ Job::~Job() { } +void Job::start(Core::Stream::Socket& socket) +{ + VERIFY(!m_socket); + m_socket = verify_cast<Core::Stream::BufferedSocketBase>(&socket); + on_socket_connected(); +} + +void Job::shutdown(ShutdownMode mode) +{ + if (!m_socket) + return; + if (mode == ShutdownMode::CloseSocket) { + m_socket->close(); + } else { + m_socket->on_ready_to_read = nullptr; + m_socket = nullptr; + } +} + +void Job::register_on_ready_to_read(Function<void()> callback) +{ + m_socket->on_ready_to_read = [this, callback = move(callback)] { + callback(); + + while (can_read()) { + callback(); + } + }; +} + +bool Job::can_read_line() const +{ + return MUST(m_socket->can_read_line()); +} + +String Job::read_line(size_t size) +{ + ByteBuffer buffer = ByteBuffer::create_uninitialized(size).release_value_but_fixme_should_propagate_errors(); + auto nread = MUST(m_socket->read_until(buffer, "\r\n"sv)); + return String::copy(buffer.span().slice(0, nread)); +} + +ByteBuffer Job::receive(size_t size) +{ + ByteBuffer buffer = ByteBuffer::create_uninitialized(size).release_value_but_fixme_should_propagate_errors(); + auto nread = MUST(m_socket->read(buffer)); + return buffer.slice(0, nread); +} + +bool Job::can_read() const +{ + return MUST(m_socket->can_read_without_blocking()); +} + +bool Job::write(ReadonlyBytes bytes) +{ + return m_socket->write_or_error(bytes); +} + void Job::flush_received_buffers() { for (size_t i = 0; i < m_received_buffers.size(); ++i) { auto& payload = m_received_buffers[i]; - auto written = do_write(payload); - m_received_size -= written; + auto result = do_write(payload); + if (result.is_error()) { + if (!result.error().is_errno()) { + dbgln("Job: Failed to flush received buffers: {}", result.error()); + continue; + } + if (result.error().code() == EINTR) { + i--; + continue; + } + break; + } + auto written = result.release_value(); + m_buffered_size -= written; if (written == payload.size()) { // FIXME: Make this a take-first-friendly object? m_received_buffers.take_first(); @@ -41,20 +112,16 @@ void Job::flush_received_buffers() void Job::on_socket_connected() { - register_on_ready_to_write([this] { - if (m_sent_data) - return; - m_sent_data = true; - auto raw_request = m_request.to_raw_request(); + auto raw_request = m_request.to_raw_request(); + + if constexpr (JOB_DEBUG) { + dbgln("Job: raw_request:"); + dbgln("{}", String::copy(raw_request)); + } + bool success = write(raw_request); + if (!success) + deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); - if constexpr (JOB_DEBUG) { - dbgln("Job: raw_request:"); - dbgln("{}", String::copy(raw_request)); - } - bool success = write(raw_request); - if (!success) - deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); - }); register_on_ready_to_read([this] { if (is_cancelled()) return; @@ -65,19 +132,19 @@ void Job::on_socket_connected() auto line = read_line(PAGE_SIZE); if (line.is_null()) { - warnln("Job: Expected status line"); + dbgln("Job: Expected status line"); return deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); } auto parts = line.split_limit(' ', 2); if (parts.size() != 2) { - warnln("Job: Expected 2-part status line, got '{}'", line); + dbgln("Job: Expected 2-part status line, got '{}'", line); return deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); } auto status = parts[0].to_uint(); if (!status.has_value()) { - warnln("Job: Expected numeric status code"); + dbgln("Job: Expected numeric status code"); return deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); } @@ -97,41 +164,41 @@ void Job::on_socket_connected() } else if (m_status >= 60 && m_status < 70) { m_state = State::InBody; } else { - warnln("Job: Expected status between 10 and 69; instead got {}", m_status); + dbgln("Job: Expected status between 10 and 69; instead got {}", m_status); return deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); } - return; + if (!can_read()) { + dbgln("Can't read further :("); + return; + } } VERIFY(m_state == State::InBody || m_state == State::Finished); - read_while_data_available([&] { + while (MUST(m_socket->can_read_without_blocking())) { auto read_size = 64 * KiB; auto payload = receive(read_size); if (payload.is_empty()) { - if (eof()) { + if (m_socket->is_eof()) { finish_up(); - return IterationDecision::Break; - } - - if (should_fail_on_empty_payload()) { - deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); - return IterationDecision::Break; + break; } } m_received_size += payload.size(); + m_buffered_size += payload.size(); m_received_buffers.append(move(payload)); flush_received_buffers(); deferred_invoke([this] { did_progress({}, m_received_size); }); - return IterationDecision::Continue; - }); + if (m_socket->is_eof()) + break; + } - if (!is_established()) { + if (!m_socket->is_open() || m_socket->is_eof()) { dbgln_if(JOB_DEBUG, "Connection appears to have closed, finishing up"); finish_up(); } @@ -142,7 +209,7 @@ void Job::finish_up() { m_state = State::Finished; flush_received_buffers(); - if (m_received_size != 0) { + if (m_buffered_size != 0) { // We have to wait for the client to consume all the downloaded data // before we can actually call `did_finish`. in a normal flow, this should // never be hit since the client is reading as we are writing, unless there diff --git a/Userland/Libraries/LibGemini/Job.h b/Userland/Libraries/LibGemini/Job.h index d8e2167385..d58ab211b8 100644 --- a/Userland/Libraries/LibGemini/Job.h +++ b/Userland/Libraries/LibGemini/Job.h @@ -15,31 +15,31 @@ namespace Gemini { class Job : public Core::NetworkJob { + C_OBJECT(Job); + public: - explicit Job(const GeminiRequest&, OutputStream&); + explicit Job(const GeminiRequest&, Core::Stream::Stream&); virtual ~Job() override; - virtual void start(NonnullRefPtr<Core::Socket>) override = 0; - virtual void shutdown(ShutdownMode) override = 0; + virtual void start(Core::Stream::Socket&) override; + virtual void shutdown(ShutdownMode) override; GeminiResponse* response() { return static_cast<GeminiResponse*>(Core::NetworkJob::response()); } const GeminiResponse* response() const { return static_cast<const GeminiResponse*>(Core::NetworkJob::response()); } + const URL& url() const { return m_request.url(); } + Core::Stream::Socket const* socket() const { return m_socket; } + protected: void finish_up(); void on_socket_connected(); void flush_received_buffers(); - virtual void register_on_ready_to_read(Function<void()>) = 0; - virtual void register_on_ready_to_write(Function<void()>) = 0; - virtual bool can_read_line() const = 0; - virtual String read_line(size_t) = 0; - virtual bool can_read() const = 0; - virtual ByteBuffer receive(size_t) = 0; - virtual bool eof() const = 0; - virtual bool write(ReadonlyBytes) = 0; - virtual bool is_established() const = 0; - virtual bool should_fail_on_empty_payload() const { return false; } - virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; + void register_on_ready_to_read(Function<void()>); + bool can_read_line() const; + String read_line(size_t); + bool can_read() const; + ByteBuffer receive(size_t); + bool write(ReadonlyBytes); enum class State { InStatus, @@ -53,8 +53,8 @@ protected: String m_meta; Vector<ByteBuffer, 2> m_received_buffers; size_t m_received_size { 0 }; - bool m_sent_data { false }; - bool m_should_have_payload { false }; + size_t m_buffered_size { 0 }; + Core::Stream::BufferedSocketBase* m_socket { nullptr }; }; } diff --git a/Userland/Libraries/LibHTTP/CMakeLists.txt b/Userland/Libraries/LibHTTP/CMakeLists.txt index 42cd10fac3..6b729209e1 100644 --- a/Userland/Libraries/LibHTTP/CMakeLists.txt +++ b/Userland/Libraries/LibHTTP/CMakeLists.txt @@ -1,5 +1,4 @@ set(SOURCES - HttpJob.cpp HttpRequest.cpp HttpResponse.cpp HttpsJob.cpp diff --git a/Userland/Libraries/LibHTTP/Forward.h b/Userland/Libraries/LibHTTP/Forward.h index cb8cd3b83a..2c8614a8be 100644 --- a/Userland/Libraries/LibHTTP/Forward.h +++ b/Userland/Libraries/LibHTTP/Forward.h @@ -10,7 +10,6 @@ namespace HTTP { class HttpRequest; class HttpResponse; -class HttpJob; class HttpsJob; class Job; diff --git a/Userland/Libraries/LibHTTP/HttpJob.cpp b/Userland/Libraries/LibHTTP/HttpJob.cpp deleted file mode 100644 index 26c084c846..0000000000 --- a/Userland/Libraries/LibHTTP/HttpJob.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#include <AK/Debug.h> -#include <LibCore/TCPSocket.h> -#include <LibHTTP/HttpJob.h> -#include <LibHTTP/HttpResponse.h> -#include <stdio.h> -#include <unistd.h> - -namespace HTTP { -void HttpJob::start(NonnullRefPtr<Core::Socket> socket) -{ - VERIFY(!m_socket); - m_socket = move(socket); - m_socket->on_error = [this] { - dbgln_if(HTTPJOB_DEBUG, "HttpJob: on_error callback"); - deferred_invoke([this] { - did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - }; - m_socket->set_idle(false); - if (m_socket->is_connected()) { - dbgln_if(HTTPJOB_DEBUG, "Reusing previous connection for {}", url()); - deferred_invoke([this] { - dbgln_if(HTTPJOB_DEBUG, "HttpJob: on_connected callback"); - on_socket_connected(); - }); - } else { - dbgln_if(HTTPJOB_DEBUG, "Creating new connection for {}", url()); - m_socket->on_connected = [this] { - dbgln_if(HTTPJOB_DEBUG, "HttpJob: on_connected callback"); - on_socket_connected(); - }; - bool success = m_socket->connect(m_request.url().host(), m_request.url().port_or_default()); - if (!success) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } - }; -} - -void HttpJob::shutdown(ShutdownMode mode) -{ - if (!m_socket) - return; - if (mode == ShutdownMode::CloseSocket) { - m_socket->close(); - } else { - m_socket->on_ready_to_read = nullptr; - m_socket->on_connected = nullptr; - m_socket->set_idle(true); - m_socket = nullptr; - } -} - -void HttpJob::register_on_ready_to_read(Function<void()> callback) -{ - m_socket->on_ready_to_read = [callback = move(callback), this] { - callback(); - // As IODevice so graciously buffers everything, there's a possible - // scenario where it buffers the entire response, and we get stuck waiting - // for select() in the notifier (which will never return). - // So handle this case by exhausting the buffer here. - if (m_socket->can_read_only_from_buffer() && m_state != State::Finished && !has_error()) { - deferred_invoke([this] { - if (m_socket && m_socket->on_ready_to_read) - m_socket->on_ready_to_read(); - }); - } - }; -} - -void HttpJob::register_on_ready_to_write(Function<void()> callback) -{ - // There is no need to wait, the connection is already established - callback(); -} - -bool HttpJob::can_read_line() const -{ - return m_socket->can_read_line(); -} - -String HttpJob::read_line(size_t size) -{ - return m_socket->read_line(size); -} - -ByteBuffer HttpJob::receive(size_t size) -{ - return m_socket->receive(size); -} - -bool HttpJob::can_read() const -{ - return m_socket->can_read(); -} - -bool HttpJob::eof() const -{ - return m_socket->eof(); -} - -bool HttpJob::write(ReadonlyBytes bytes) -{ - return m_socket->write(bytes); -} - -} diff --git a/Userland/Libraries/LibHTTP/HttpJob.h b/Userland/Libraries/LibHTTP/HttpJob.h deleted file mode 100644 index 4f756dbe57..0000000000 --- a/Userland/Libraries/LibHTTP/HttpJob.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#pragma once - -#include <AK/HashMap.h> -#include <LibCore/NetworkJob.h> -#include <LibCore/TCPSocket.h> -#include <LibHTTP/HttpRequest.h> -#include <LibHTTP/HttpResponse.h> -#include <LibHTTP/Job.h> - -namespace HTTP { - -class HttpJob final : public Job { - C_OBJECT(HttpJob) -public: - virtual ~HttpJob() override - { - } - - virtual void start(NonnullRefPtr<Core::Socket>) override; - virtual void shutdown(ShutdownMode) override; - - Core::Socket const* socket() const { return m_socket; } - URL url() const { return m_request.url(); } - -protected: - virtual bool should_fail_on_empty_payload() const override { return false; } - virtual void register_on_ready_to_read(Function<void()>) override; - virtual void register_on_ready_to_write(Function<void()>) override; - virtual bool can_read_line() const override; - virtual String read_line(size_t) override; - virtual bool can_read() const override; - virtual ByteBuffer receive(size_t) override; - virtual bool eof() const override; - virtual bool write(ReadonlyBytes) override; - virtual bool is_established() const override { return true; } - -private: - explicit HttpJob(HttpRequest&& request, OutputStream& output_stream) - : Job(move(request), output_stream) - { - } - - RefPtr<Core::Socket> m_socket; -}; - -} diff --git a/Userland/Libraries/LibHTTP/HttpRequest.cpp b/Userland/Libraries/LibHTTP/HttpRequest.cpp index 034bc56766..bea1a98b18 100644 --- a/Userland/Libraries/LibHTTP/HttpRequest.cpp +++ b/Userland/Libraries/LibHTTP/HttpRequest.cpp @@ -6,8 +6,8 @@ #include <AK/Base64.h> #include <AK/StringBuilder.h> -#include <LibHTTP/HttpJob.h> #include <LibHTTP/HttpRequest.h> +#include <LibHTTP/Job.h> namespace HTTP { diff --git a/Userland/Libraries/LibHTTP/HttpResponse.cpp b/Userland/Libraries/LibHTTP/HttpResponse.cpp index c93f296ed0..fbea69cc39 100644 --- a/Userland/Libraries/LibHTTP/HttpResponse.cpp +++ b/Userland/Libraries/LibHTTP/HttpResponse.cpp @@ -8,9 +8,10 @@ namespace HTTP { -HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers) +HttpResponse::HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, size_t size) : m_code(code) , m_headers(move(headers)) + , m_downloaded_size(size) { } diff --git a/Userland/Libraries/LibHTTP/HttpResponse.h b/Userland/Libraries/LibHTTP/HttpResponse.h index d6e31eed03..c4ba77ec5a 100644 --- a/Userland/Libraries/LibHTTP/HttpResponse.h +++ b/Userland/Libraries/LibHTTP/HttpResponse.h @@ -15,22 +15,24 @@ namespace HTTP { class HttpResponse : public Core::NetworkResponse { public: virtual ~HttpResponse() override; - static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers) + static NonnullRefPtr<HttpResponse> create(int code, HashMap<String, String, CaseInsensitiveStringTraits>&& headers, size_t downloaded_size) { - return adopt_ref(*new HttpResponse(code, move(headers))); + return adopt_ref(*new HttpResponse(code, move(headers), downloaded_size)); } int code() const { return m_code; } + size_t downloaded_size() const { return m_downloaded_size; } StringView reason_phrase() const { return reason_phrase_for_code(m_code); } HashMap<String, String, CaseInsensitiveStringTraits> const& headers() const { return m_headers; } static StringView reason_phrase_for_code(int code); private: - HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&); + HttpResponse(int code, HashMap<String, String, CaseInsensitiveStringTraits>&&, size_t size); int m_code { 0 }; HashMap<String, String, CaseInsensitiveStringTraits> m_headers; + size_t m_downloaded_size { 0 }; }; } diff --git a/Userland/Libraries/LibHTTP/HttpsJob.cpp b/Userland/Libraries/LibHTTP/HttpsJob.cpp index 88a2137fc1..00c9437949 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.cpp +++ b/Userland/Libraries/LibHTTP/HttpsJob.cpp @@ -1,143 +1,17 @@ /* - * Copyright (c) 2020, the SerenityOS developers. + * Copyright (c) 2020-2022, the SerenityOS developers. * * SPDX-License-Identifier: BSD-2-Clause */ #include <AK/Debug.h> -#include <LibCore/EventLoop.h> -#include <LibHTTP/HttpResponse.h> #include <LibHTTP/HttpsJob.h> -#include <LibTLS/TLSv12.h> -#include <stdio.h> -#include <unistd.h> namespace HTTP { -void HttpsJob::start(NonnullRefPtr<Core::Socket> socket) +void HttpsJob::set_certificate(String certificate, String key) { - VERIFY(!m_socket); - VERIFY(is<TLS::TLSv12>(*socket)); - - m_socket = static_ptr_cast<TLS::TLSv12>(socket); - m_socket->on_tls_error = [&](TLS::AlertDescription error) { - if (error == TLS::AlertDescription::HandshakeFailure) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ProtocolFailed); - }); - } else if (error == TLS::AlertDescription::DecryptError) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } else { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::TransmissionFailed); - }); - } - }; - m_socket->on_tls_finished = [this] { - if (!m_has_scheduled_finish) - finish_up(); - }; - m_socket->on_tls_certificate_request = [this](auto&) { - if (on_certificate_requested) - on_certificate_requested(*this); - }; - m_socket->set_idle(false); - if (m_socket->is_established()) { - dbgln_if(HTTPSJOB_DEBUG, "Reusing previous connection for {}", url()); - deferred_invoke([this] { on_socket_connected(); }); - } else { - dbgln_if(HTTPSJOB_DEBUG, "Creating a new connection for {}", url()); - m_socket->set_root_certificates(m_override_ca_certificates ? *m_override_ca_certificates : DefaultRootCACertificates::the().certificates()); - m_socket->on_tls_connected = [this] { - dbgln_if(HTTPSJOB_DEBUG, "HttpsJob: on_connected callback"); - on_socket_connected(); - }; - bool success = ((TLS::TLSv12&)*m_socket).connect(m_request.url().host(), m_request.url().port_or_default()); - if (!success) { - deferred_invoke([this] { - return did_fail(Core::NetworkJob::Error::ConnectionFailed); - }); - } - } -} - -void HttpsJob::shutdown(ShutdownMode mode) -{ - if (!m_socket) - return; - if (mode == ShutdownMode::CloseSocket) { - m_socket->close(); - } else { - m_socket->on_tls_ready_to_read = nullptr; - m_socket->on_tls_connected = nullptr; - m_socket->set_on_tls_ready_to_write(nullptr); - m_socket->set_idle(true); - m_socket = nullptr; - } -} - -void HttpsJob::set_certificate(String certificate, String private_key) -{ - if (!m_socket->add_client_key(certificate.bytes(), private_key.bytes())) { - dbgln("LibHTTP: Failed to set a client certificate"); - // FIXME: Do something about this failure - VERIFY_NOT_REACHED(); - } -} - -void HttpsJob::read_while_data_available(Function<IterationDecision()> read) -{ - while (m_socket->can_read()) { - if (read() == IterationDecision::Break) - break; - } -} - -void HttpsJob::register_on_ready_to_read(Function<void()> callback) -{ - m_socket->on_tls_ready_to_read = [callback = move(callback)](auto&) { - callback(); - }; -} - -void HttpsJob::register_on_ready_to_write(Function<void()> callback) -{ - m_socket->set_on_tls_ready_to_write([callback = move(callback)](auto& tls) { - Core::deferred_invoke([&tls] { tls.set_on_tls_ready_to_write(nullptr); }); - callback(); - }); -} - -bool HttpsJob::can_read_line() const -{ - return m_socket->can_read_line(); -} - -String HttpsJob::read_line(size_t size) -{ - return m_socket->read_line(size); -} - -ByteBuffer HttpsJob::receive(size_t size) -{ - return m_socket->read(size); -} - -bool HttpsJob::can_read() const -{ - return m_socket->can_read(); -} - -bool HttpsJob::eof() const -{ - return m_socket->eof(); -} - -bool HttpsJob::write(ReadonlyBytes data) -{ - return m_socket->write(data); + m_received_client_certificates = TLS::TLSv12::parse_pem_certificate(certificate.bytes(), key.bytes()); } } diff --git a/Userland/Libraries/LibHTTP/HttpsJob.h b/Userland/Libraries/LibHTTP/HttpsJob.h index e277cd14c2..32a9bcc7b2 100644 --- a/Userland/Libraries/LibHTTP/HttpsJob.h +++ b/Userland/Libraries/LibHTTP/HttpsJob.h @@ -8,6 +8,7 @@ #include <AK/HashMap.h> #include <LibCore/NetworkJob.h> +#include <LibCore/Stream.h> #include <LibHTTP/HttpRequest.h> #include <LibHTTP/HttpResponse.h> #include <LibHTTP/Job.h> @@ -22,37 +23,20 @@ public: { } - virtual void start(NonnullRefPtr<Core::Socket>) override; - virtual void shutdown(ShutdownMode) override; - void set_certificate(String certificate, String key); - - Core::Socket const* socket() const { return m_socket; } - URL url() const { return m_request.url(); } + bool received_client_certificates() const { return m_received_client_certificates.has_value(); } + Vector<TLS::Certificate> take_client_certificates() const { return m_received_client_certificates.release_value(); } - Function<void(HttpsJob&)> on_certificate_requested; + void set_certificate(String certificate, String key); -protected: - virtual void register_on_ready_to_read(Function<void()>) override; - virtual void register_on_ready_to_write(Function<void()>) override; - virtual bool can_read_line() const override; - virtual String read_line(size_t) override; - virtual bool can_read() const override; - virtual ByteBuffer receive(size_t) override; - virtual bool eof() const override; - virtual bool write(ReadonlyBytes) override; - virtual bool is_established() const override { return m_socket->is_established(); } - virtual bool should_fail_on_empty_payload() const override { return false; } - virtual void read_while_data_available(Function<IterationDecision()>) override; + Function<Vector<TLS::Certificate>()> on_certificate_requested; private: - explicit HttpsJob(HttpRequest&& request, OutputStream& output_stream, const Vector<Certificate>* override_certs = nullptr) + explicit HttpsJob(HttpRequest&& request, Core::Stream::Stream& output_stream) : Job(move(request), output_stream) - , m_override_ca_certificates(override_certs) { } - RefPtr<TLS::TLSv12> m_socket; - const Vector<Certificate>* m_override_ca_certificates { nullptr }; + mutable Optional<Vector<TLS::Certificate>> m_received_client_certificates; }; } diff --git a/Userland/Libraries/LibHTTP/Job.cpp b/Userland/Libraries/LibHTTP/Job.cpp index 00828add3c..5895a121b3 100644 --- a/Userland/Libraries/LibHTTP/Job.cpp +++ b/Userland/Libraries/LibHTTP/Job.cpp @@ -72,7 +72,7 @@ static Optional<ByteBuffer> handle_content_encoding(const ByteBuffer& buf, const return buf; } -Job::Job(HttpRequest&& request, OutputStream& output_stream) +Job::Job(HttpRequest&& request, Core::Stream::Stream& output_stream) : Core::NetworkJob(output_stream) , m_request(move(request)) { @@ -82,6 +82,29 @@ Job::~Job() { } +void Job::start(Core::Stream::Socket& socket) +{ + VERIFY(!m_socket); + m_socket = static_cast<Core::Stream::BufferedSocketBase*>(&socket); + dbgln_if(HTTPJOB_DEBUG, "Reusing previous connection for {}", url()); + deferred_invoke([this] { + dbgln_if(HTTPJOB_DEBUG, "HttpJob: on_connected callback"); + on_socket_connected(); + }); +} + +void Job::shutdown(ShutdownMode mode) +{ + if (!m_socket) + return; + if (mode == ShutdownMode::CloseSocket) { + m_socket->close(); + } else { + m_socket->on_ready_to_read = nullptr; + m_socket = nullptr; + } +} + void Job::flush_received_buffers() { if (!m_can_stream_response || m_buffered_size == 0) @@ -89,7 +112,19 @@ void Job::flush_received_buffers() dbgln_if(JOB_DEBUG, "Job: Flushing received buffers: have {} bytes in {} buffers for {}", m_buffered_size, m_received_buffers.size(), m_request.url()); for (size_t i = 0; i < m_received_buffers.size(); ++i) { auto& payload = m_received_buffers[i]; - auto written = do_write(payload); + auto result = do_write(payload); + if (result.is_error()) { + if (!result.error().is_errno()) { + dbgln_if(JOB_DEBUG, "Job: Failed to flush received buffers: {}", result.error()); + continue; + } + if (result.error().code() == EINTR) { + i--; + continue; + } + break; + } + auto written = result.release_value(); m_buffered_size -= written; if (written == payload.size()) { // FIXME: Make this a take-first-friendly object? @@ -104,23 +139,63 @@ void Job::flush_received_buffers() dbgln_if(JOB_DEBUG, "Job: Flushing received buffers done: have {} bytes in {} buffers for {}", m_buffered_size, m_received_buffers.size(), m_request.url()); } -void Job::on_socket_connected() +void Job::register_on_ready_to_read(Function<void()> callback) { - register_on_ready_to_write([&] { - if (m_sent_data) - return; - m_sent_data = true; - auto raw_request = m_request.to_raw_request(); + m_socket->on_ready_to_read = [this, callback = move(callback)] { + callback(); + + // As `m_socket` is a buffered object, we might not get notifications for data in the buffer + // so exhaust the buffer to ensure we don't end up waiting forever. + if (MUST(m_socket->can_read_without_blocking()) && m_state != State::Finished && !has_error()) { + deferred_invoke([this] { + if (m_socket && m_socket->on_ready_to_read) + m_socket->on_ready_to_read(); + }); + } + }; +} - if constexpr (JOB_DEBUG) { - dbgln("Job: raw_request:"); - dbgln("{}", String::copy(raw_request)); +String Job::read_line(size_t size) +{ + auto buffer = ByteBuffer::create_uninitialized(size).release_value_but_fixme_should_propagate_errors(); + auto nread = m_socket->read_until(buffer, "\r\n"sv).release_value_but_fixme_should_propagate_errors(); + return String::copy(buffer.span().slice(0, nread)); +} + +ByteBuffer Job::receive(size_t size) +{ + if (size == 0) + return {}; + + auto buffer = ByteBuffer::create_uninitialized(size).release_value_but_fixme_should_propagate_errors(); + size_t nread; + do { + auto result = m_socket->read(buffer); + if (result.is_error() && result.error().is_errno() && result.error().code() == EINTR) + continue; + if (result.is_error()) { + dbgln_if(JOB_DEBUG, "Failed while reading: {}", result.error()); + VERIFY_NOT_REACHED(); } + nread = MUST(result); + break; + } while (true); + return buffer.slice(0, nread); +} + +void Job::on_socket_connected() +{ + auto raw_request = m_request.to_raw_request(); + + if constexpr (JOB_DEBUG) { + dbgln("Job: raw_request:"); + dbgln("{}", String::copy(raw_request)); + } + + bool success = m_socket->write_or_error(raw_request); + if (!success) + deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); - bool success = write(raw_request); - if (!success) - deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); - }); register_on_ready_to_read([&] { dbgln_if(JOB_DEBUG, "Ready to read for {}, state = {}, cancelled = {}", m_request.url(), to_underlying(m_state), is_cancelled()); if (is_cancelled()) @@ -133,12 +208,16 @@ void Job::on_socket_connected() return; } - if (eof()) + if (m_socket->is_eof()) { + dbgln_if(JOB_DEBUG, "Read failure: Actually EOF!"); return deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); + } - if (m_state == State::InStatus) { - if (!can_read_line()) { + while (m_state == State::InStatus) { + if (!MUST(m_socket->can_read_line())) { dbgln_if(JOB_DEBUG, "Job {} cannot read line", m_request.url()); + auto buf = receive(64); + dbgln_if(JOB_DEBUG, "{} bytes was read", buf.bytes().size()); return; } auto line = read_line(PAGE_SIZE); @@ -159,11 +238,14 @@ void Job::on_socket_connected() } m_code = code.value(); m_state = State::InHeaders; - return; + if (!MUST(m_socket->can_read_without_blocking())) + return; } - if (m_state == State::InHeaders || m_state == State::Trailers) { - if (!can_read_line()) + while (m_state == State::InHeaders || m_state == State::Trailers) { + if (!MUST(m_socket->can_read_line())) { + dbgln_if(JOB_DEBUG, "Can't read lines anymore :("); return; + } // There's no max limit defined on headers, but for our sanity, let's limit it to 32K. auto line = read_line(32 * KiB); if (line.is_null()) { @@ -179,14 +261,13 @@ void Job::on_socket_connected() if (line.is_empty()) { if (m_state == State::Trailers) { return finish_up(); - } else { - if (on_headers_received) { - if (!m_set_cookie_headers.is_empty()) - m_headers.set("Set-Cookie", JsonArray { m_set_cookie_headers }.to_string()); - on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {}); - } - m_state = State::InBody; } + if (on_headers_received) { + if (!m_set_cookie_headers.is_empty()) + m_headers.set("Set-Cookie", JsonArray { m_set_cookie_headers }.to_string()); + on_headers_received(m_headers, m_code > 0 ? m_code : Optional<u32> {}); + } + m_state = State::InBody; // We've reached the end of the headers, there's a possibility that the server // responds with nothing (content-length = 0 with normal encoding); if that's the case, @@ -195,7 +276,9 @@ void Job::on_socket_connected() if (result.value() == 0 && !m_headers.get("Transfer-Encoding"sv).value_or(""sv).view().trim_whitespace().equals_ignoring_case("chunked"sv)) return finish_up(); } - return; + if (!MUST(m_socket->can_read_line())) + return; + break; } auto parts = line.split_view(':'); if (parts.is_empty()) { @@ -223,9 +306,9 @@ void Job::on_socket_connected() if (name.equals_ignoring_case("Set-Cookie")) { dbgln_if(JOB_DEBUG, "Job: Received Set-Cookie header: '{}'", value); m_set_cookie_headers.append(move(value)); - return; - } - if (auto existing_value = m_headers.get(name); existing_value.has_value()) { + if (!MUST(m_socket->can_read_without_blocking())) + return; + } else if (auto existing_value = m_headers.get(name); existing_value.has_value()) { StringBuilder builder; builder.append(existing_value.value()); builder.append(','); @@ -244,12 +327,16 @@ void Job::on_socket_connected() m_content_length = length.value(); } dbgln_if(JOB_DEBUG, "Job: [{}] = '{}'", name, value); - return; + if (!MUST(m_socket->can_read_without_blocking())) { + dbgln_if(JOB_DEBUG, "Can't read headers anymore, byebye :("); + return; + } } VERIFY(m_state == State::InBody); - VERIFY(can_read()); + if (!MUST(m_socket->can_read_without_blocking())) + return; - read_while_data_available([&] { + while (MUST(m_socket->can_read_without_blocking())) { auto read_size = 64 * KiB; if (m_current_chunk_remaining_size.has_value()) { read_chunk_size:; @@ -260,16 +347,16 @@ void Job::on_socket_connected() if (m_should_read_chunk_ending_line) { VERIFY(size_data.is_empty()); m_should_read_chunk_ending_line = false; - return IterationDecision::Continue; + continue; } auto size_lines = size_data.view().lines(); dbgln_if(JOB_DEBUG, "Job: Received a chunk with size '{}'", size_data); if (size_lines.size() == 0) { - if (!eof()) - return AK::IterationDecision::Break; + if (!m_socket->is_eof()) + break; dbgln("Job: Reached end of stream"); finish_up(); - return IterationDecision::Break; + break; } else { auto chunk = size_lines[0].split_view(';', true); String size_string = chunk[0]; @@ -278,7 +365,7 @@ void Job::on_socket_connected() if (*endptr) { // invalid number deferred_invoke([this] { did_fail(Core::NetworkJob::Error::TransmissionFailed); }); - return IterationDecision::Break; + break; } if (size == 0) { // This is the last chunk @@ -323,19 +410,14 @@ void Job::on_socket_connected() } } + if (!MUST(m_socket->can_read_without_blocking())) + break; + dbgln_if(JOB_DEBUG, "Waiting for payload for {}", m_request.url()); auto payload = receive(read_size); - dbgln_if(JOB_DEBUG, "Received {} bytes of payload from {}", payload.size(), m_request.url()); - if (payload.is_empty()) { - if (eof()) { - finish_up(); - return IterationDecision::Break; - } - - if (should_fail_on_empty_payload()) { - deferred_invoke([this] { did_fail(Core::NetworkJob::Error::ProtocolFailed); }); - return IterationDecision::Break; - } + if (payload.is_empty() && m_socket->is_eof()) { + finish_up(); + break; } bool read_everything = false; @@ -357,7 +439,7 @@ void Job::on_socket_connected() if (read_everything) { VERIFY(m_received_size <= m_content_length.value()); finish_up(); - return IterationDecision::Break; + break; } if (m_current_chunk_remaining_size.has_value()) { @@ -369,12 +451,12 @@ void Job::on_socket_connected() if (m_current_chunk_total_size.value() == 0) { m_state = State::Trailers; - return IterationDecision::Break; + break; } // we've read everything, now let's get the next chunk size = -1; - if (can_read_line()) { + if (MUST(m_socket->can_read_line())) { auto line = read_line(PAGE_SIZE); VERIFY(line.is_empty()); } else { @@ -383,11 +465,9 @@ void Job::on_socket_connected() } m_current_chunk_remaining_size = size; } + } - return IterationDecision::Continue; - }); - - if (!is_established()) { + if (!m_socket->is_open()) { dbgln_if(JOB_DEBUG, "Connection appears to have closed, finishing up"); finish_up(); } @@ -443,7 +523,7 @@ void Job::finish_up() } m_has_scheduled_finish = true; - auto response = HttpResponse::create(m_code, move(m_headers)); + auto response = HttpResponse::create(m_code, move(m_headers), m_received_size); deferred_invoke([this, response = move(response)] { // If the server responded with "Connection: close", close the connection // as the server may or may not want to close the socket. diff --git a/Userland/Libraries/LibHTTP/Job.h b/Userland/Libraries/LibHTTP/Job.h index 1540ed9587..6c45fabae9 100644 --- a/Userland/Libraries/LibHTTP/Job.h +++ b/Userland/Libraries/LibHTTP/Job.h @@ -17,12 +17,17 @@ namespace HTTP { class Job : public Core::NetworkJob { + C_OBJECT(Job); + public: - explicit Job(HttpRequest&&, OutputStream&); + explicit Job(HttpRequest&&, Core::Stream::Stream&); virtual ~Job() override; - virtual void start(NonnullRefPtr<Core::Socket>) override = 0; - virtual void shutdown(ShutdownMode) override = 0; + virtual void start(Core::Stream::Socket&) override; + virtual void shutdown(ShutdownMode) override; + + Core::Stream::Socket const* socket() const { return m_socket; } + URL url() const { return m_request.url(); } HttpResponse* response() { return static_cast<HttpResponse*>(Core::NetworkJob::response()); } const HttpResponse* response() const { return static_cast<const HttpResponse*>(Core::NetworkJob::response()); } @@ -31,18 +36,10 @@ protected: void finish_up(); void on_socket_connected(); void flush_received_buffers(); - virtual void register_on_ready_to_read(Function<void()>) = 0; - virtual void register_on_ready_to_write(Function<void()>) = 0; - virtual bool can_read_line() const = 0; - virtual String read_line(size_t) = 0; - virtual bool can_read() const = 0; - virtual ByteBuffer receive(size_t) = 0; - virtual bool eof() const = 0; - virtual bool write(ReadonlyBytes) = 0; - virtual bool is_established() const = 0; - virtual bool should_fail_on_empty_payload() const { return true; } - virtual void read_while_data_available(Function<IterationDecision()> read) { read(); }; - virtual void timer_event(Core::TimerEvent&) override; + void register_on_ready_to_read(Function<void()>); + String read_line(size_t); + ByteBuffer receive(size_t); + void timer_event(Core::TimerEvent&) override; enum class State { InStatus, @@ -54,13 +51,13 @@ protected: HttpRequest m_request; State m_state { State::InStatus }; + Core::Stream::BufferedSocketBase* m_socket { nullptr }; int m_code { -1 }; HashMap<String, String, CaseInsensitiveStringTraits> m_headers; Vector<String> m_set_cookie_headers; Vector<ByteBuffer, 2> m_received_buffers; size_t m_buffered_size { 0 }; size_t m_received_size { 0 }; - bool m_sent_data { 0 }; Optional<u32> m_content_length; Optional<ssize_t> m_current_chunk_remaining_size; Optional<size_t> m_current_chunk_total_size; diff --git a/Userland/Libraries/LibIMAP/Client.cpp b/Userland/Libraries/LibIMAP/Client.cpp index d33429dc52..318239205e 100644 --- a/Userland/Libraries/LibIMAP/Client.cpp +++ b/Userland/Libraries/LibIMAP/Client.cpp @@ -10,20 +10,9 @@ namespace IMAP { -Client::Client(StringView host, u16 port, NonnullRefPtr<TLS::TLSv12> socket) - : m_host(host) - , m_port(port) - , m_tls(true) - , m_tls_socket(move(socket)) - , m_connect_pending(Promise<Empty>::construct()) -{ - setup_callbacks(); -} - Client::Client(StringView host, u16 port, NonnullOwnPtr<Core::Stream::Socket> socket) : m_host(host) , m_port(port) - , m_tls(false) , m_socket(move(socket)) , m_connect_pending(Promise<Empty>::construct()) { @@ -33,9 +22,7 @@ Client::Client(StringView host, u16 port, NonnullOwnPtr<Core::Stream::Socket> so Client::Client(Client&& other) : m_host(other.m_host) , m_port(other.m_port) - , m_tls(other.m_tls) , m_socket(move(other.m_socket)) - , m_tls_socket(move(other.m_tls_socket)) , m_connect_pending(move(other.m_connect_pending)) { setup_callbacks(); @@ -43,42 +30,21 @@ Client::Client(Client&& other) void Client::setup_callbacks() { - if (m_tls) { - m_tls_socket->on_tls_ready_to_read = [&](TLS::TLSv12&) { - auto maybe_error = on_tls_ready_to_receive(); - if (maybe_error.is_error()) { - dbgln("Error receiving from the socket: {}", maybe_error.error()); - close(); - } - }; - - } else { - m_socket->on_ready_to_read = [&] { - auto maybe_error = on_ready_to_receive(); - if (maybe_error.is_error()) { - dbgln("Error receiving from the socket: {}", maybe_error.error()); - close(); - } - }; - } + m_socket->on_ready_to_read = [&] { + auto maybe_error = on_ready_to_receive(); + if (maybe_error.is_error()) { + dbgln("Error receiving from the socket: {}", maybe_error.error()); + close(); + } + }; } ErrorOr<NonnullOwnPtr<Client>> Client::connect_tls(StringView host, u16 port) { - auto tls_socket = TLS::TLSv12::construct(nullptr); - tls_socket->set_root_certificates(DefaultRootCACertificates::the().certificates()); - - tls_socket->on_tls_error = [&](TLS::AlertDescription alert) { - dbgln("failed: {}", alert_name(alert)); - }; - tls_socket->on_tls_connected = [&] { - dbgln("connected"); - }; + auto tls_socket = TRY(TLS::TLSv12::connect(host, port)); + dbgln("connecting to {}:{}", host, port); - auto success = tls_socket->connect(host, port); - dbgln("connecting to {}:{} {}", host, port, success); - - return adopt_nonnull_own_or_enomem(new (nothrow) Client(host, port, tls_socket)); + return adopt_nonnull_own_or_enomem(new (nothrow) Client(host, port, move(tls_socket))); } ErrorOr<NonnullOwnPtr<Client>> Client::connect_plaintext(StringView host, u16 port) @@ -88,34 +54,6 @@ ErrorOr<NonnullOwnPtr<Client>> Client::connect_plaintext(StringView host, u16 po return adopt_nonnull_own_or_enomem(new (nothrow) Client(host, port, move(socket))); } -ErrorOr<void> Client::on_tls_ready_to_receive() -{ - if (!m_tls_socket->can_read()) - return {}; - auto data = m_tls_socket->read(); - // FIXME: Make TLSv12 return the actual error instead of returning a bogus - // one here. - if (!data.has_value()) - return Error::from_errno(EIO); - - // Once we get server hello we can start sending - if (m_connect_pending) { - m_connect_pending->resolve({}); - m_connect_pending.clear(); - return {}; - } - - m_buffer += data.value(); - if (m_buffer[m_buffer.size() - 1] == '\n') { - // Don't try parsing until we have a complete line. - auto response = m_parser.parse(move(m_buffer), m_expecting_response); - MUST(handle_parsed_response(move(response))); - m_buffer.clear(); - } - - return {}; -} - ErrorOr<void> Client::on_ready_to_receive() { if (!TRY(m_socket->can_read_without_blocking())) @@ -208,13 +146,8 @@ static ReadonlyBytes command_byte_buffer(CommandType command) ErrorOr<void> Client::send_raw(StringView data) { - if (m_tls) { - m_tls_socket->write(data.bytes()); - m_tls_socket->write("\r\n"sv.bytes()); - } else { - TRY(m_socket->write(data.bytes())); - TRY(m_socket->write("\r\n"sv.bytes())); - } + TRY(m_socket->write(data.bytes())); + TRY(m_socket->write("\r\n"sv.bytes())); return {}; } @@ -496,16 +429,12 @@ RefPtr<Promise<Optional<SolidResponse>>> Client::copy(Sequence sequence_set, Str void Client::close() { - if (m_tls) { - m_tls_socket->close(); - } else { - m_socket->close(); - } + m_socket->close(); } bool Client::is_open() { - return m_tls ? m_tls_socket->is_open() : m_socket->is_open(); + return m_socket->is_open(); } } diff --git a/Userland/Libraries/LibIMAP/Client.h b/Userland/Libraries/LibIMAP/Client.h index 9796dab86e..2e4096ad56 100644 --- a/Userland/Libraries/LibIMAP/Client.h +++ b/Userland/Libraries/LibIMAP/Client.h @@ -60,7 +60,6 @@ public: Function<void(ResponseData&&)> unrequested_response_callback; private: - Client(StringView host, u16 port, NonnullRefPtr<TLS::TLSv12>); Client(StringView host, u16 port, NonnullOwnPtr<Core::Stream::Socket>); void setup_callbacks(); @@ -73,11 +72,7 @@ private: StringView m_host; u16 m_port; - bool m_tls; - // FIXME: Convert this to a single `NonnullOwnPtr<Core::Stream::Socket>` - // once `TLS::TLSv12` is converted to a `Socket` as well. - OwnPtr<Core::Stream::Socket> m_socket; - RefPtr<TLS::TLSv12> m_tls_socket; + NonnullOwnPtr<Core::Stream::Socket> m_socket; RefPtr<Promise<Empty>> m_connect_pending {}; int m_current_command = 1; diff --git a/Userland/Libraries/LibProtocol/Request.cpp b/Userland/Libraries/LibProtocol/Request.cpp index 741ff6956b..35b7cbb2c7 100644 --- a/Userland/Libraries/LibProtocol/Request.cpp +++ b/Userland/Libraries/LibProtocol/Request.cpp @@ -20,72 +20,61 @@ bool Request::stop() return m_client->stop_request({}, *this); } -void Request::stream_into(OutputStream& stream) +template<typename T> +void Request::stream_into_impl(T& stream) { VERIFY(!m_internal_stream_data); - auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read); - - m_internal_stream_data = make<InternalStreamData>(fd()); - m_internal_stream_data->read_notifier = notifier; + m_internal_stream_data = make<InternalStreamData>(MUST(Core::Stream::File::adopt_fd(fd(), Core::Stream::OpenMode::Read))); + m_internal_stream_data->read_notifier = Core::Notifier::construct(fd(), Core::Notifier::Read); auto user_on_finish = move(on_finish); on_finish = [this](auto success, auto total_size) { m_internal_stream_data->success = success; m_internal_stream_data->total_size = total_size; m_internal_stream_data->request_done = true; + m_internal_stream_data->on_finish(); }; - notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] { - constexpr size_t buffer_size = 4096; - static char buf[buffer_size]; - auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size }); - if (!stream.write_or_error({ buf, nread })) { - // FIXME: What do we do here? - TODO(); + m_internal_stream_data->on_finish = [this, user_on_finish = move(user_on_finish)] { + if (!m_internal_stream_data->user_finish_called && m_internal_stream_data->read_stream->is_eof()) { + m_internal_stream_data->user_finish_called = true; + user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size); } - - if (m_internal_stream_data->read_stream.eof() && m_internal_stream_data->request_done) { + }; + m_internal_stream_data->read_notifier->on_ready_to_read = [this, &stream] { + constexpr size_t buffer_size = 16 * KiB; + static char buf[buffer_size]; + do { + auto result = m_internal_stream_data->read_stream->read({ buf, buffer_size }); + if (result.is_error() && (!result.error().is_errno() || (result.error().is_errno() && result.error().code() != EINTR))) + break; + if (result.is_error()) + continue; + auto nread = result.value(); + if (!stream.write_or_error({ buf, nread })) { + // FIXME: What do we do here? + TODO(); + } + if (nread == 0) + break; + } while (true); + + if (m_internal_stream_data->read_stream->is_eof() && m_internal_stream_data->request_done) { m_internal_stream_data->read_notifier->close(); - user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size); - } else { - m_internal_stream_data->read_stream.handle_any_error(); + m_internal_stream_data->on_finish(); } }; } void Request::stream_into(Core::Stream::Stream& stream) { - VERIFY(!m_internal_stream_data); - - auto notifier = Core::Notifier::construct(fd(), Core::Notifier::Read); - - m_internal_stream_data = make<InternalStreamData>(fd()); - m_internal_stream_data->read_notifier = notifier; - - auto user_on_finish = move(on_finish); - on_finish = [this](auto success, auto total_size) { - m_internal_stream_data->success = success; - m_internal_stream_data->total_size = total_size; - m_internal_stream_data->request_done = true; - }; - - notifier->on_ready_to_read = [this, &stream, user_on_finish = move(user_on_finish)] { - constexpr size_t buffer_size = 4096; - static char buf[buffer_size]; - auto nread = m_internal_stream_data->read_stream.read({ buf, buffer_size }); - if (!stream.write_or_error({ buf, nread })) { - // FIXME: What do we do here? - TODO(); - } + stream_into_impl(stream); +} - if (m_internal_stream_data->read_stream.eof() && m_internal_stream_data->request_done) { - m_internal_stream_data->read_notifier->close(); - user_on_finish(m_internal_stream_data->success, m_internal_stream_data->total_size); - } else { - m_internal_stream_data->read_stream.handle_any_error(); - } - }; +void Request::stream_into(OutputStream& stream) +{ + stream_into_impl(stream); } void Request::set_should_buffer_all_input(bool value) @@ -102,7 +91,7 @@ void Request::set_should_buffer_all_input(bool value) VERIFY(!m_internal_stream_data); VERIFY(!m_internal_buffered_data); VERIFY(on_buffered_request_finish); // Not having this set makes no sense. - m_internal_buffered_data = make<InternalBufferedData>(fd()); + m_internal_buffered_data = make<InternalBufferedData>(); m_should_buffer_all_input = true; on_headers_received = [this](auto& headers, auto response_code) { diff --git a/Userland/Libraries/LibProtocol/Request.h b/Userland/Libraries/LibProtocol/Request.h index cf5cf51262..b649a208ae 100644 --- a/Userland/Libraries/LibProtocol/Request.h +++ b/Userland/Libraries/LibProtocol/Request.h @@ -62,6 +62,9 @@ public: private: explicit Request(RequestClient&, i32 request_id); + template<typename T> + void stream_into_impl(T&); + WeakPtr<RequestClient> m_client; int m_request_id { -1 }; RefPtr<Core::Notifier> m_write_notifier; @@ -69,28 +72,24 @@ private: bool m_should_buffer_all_input { false }; struct InternalBufferedData { - InternalBufferedData(int fd) - : read_stream(fd) - { - } - - InputFileStream read_stream; DuplexMemoryStream payload_stream; HashMap<String, String, CaseInsensitiveStringTraits> response_headers; Optional<u32> response_code; }; struct InternalStreamData { - InternalStreamData(int fd) - : read_stream(fd) + InternalStreamData(NonnullOwnPtr<Core::Stream::Stream> stream) + : read_stream(move(stream)) { } - InputFileStream read_stream; + NonnullOwnPtr<Core::Stream::Stream> read_stream; RefPtr<Core::Notifier> read_notifier; bool success; u32 total_size { 0 }; bool request_done { false }; + Function<void()> on_finish {}; + bool user_finish_called { false }; }; OwnPtr<InternalBufferedData> m_internal_buffered_data; diff --git a/Userland/Libraries/LibTLS/Handshake.cpp b/Userland/Libraries/LibTLS/Handshake.cpp index df45a157ed..4fb065b677 100644 --- a/Userland/Libraries/LibTLS/Handshake.cpp +++ b/Userland/Libraries/LibTLS/Handshake.cpp @@ -202,8 +202,8 @@ ssize_t TLSv12::handle_handshake_finished(ReadonlyBytes buffer, WritePacketStage m_handshake_timeout_timer = nullptr; } - if (on_tls_ready_to_write) - on_tls_ready_to_write(*this); + if (on_connected) + on_connected(); return index + size; } diff --git a/Userland/Libraries/LibTLS/Record.cpp b/Userland/Libraries/LibTLS/Record.cpp index 77dfc71faa..2063e168d9 100644 --- a/Userland/Libraries/LibTLS/Record.cpp +++ b/Userland/Libraries/LibTLS/Record.cpp @@ -7,6 +7,7 @@ #include <AK/Debug.h> #include <AK/Endian.h> #include <AK/MemoryStream.h> +#include <LibCore/EventLoop.h> #include <LibCore/Timer.h> #include <LibCrypto/PK/Code/EMSA_PSS.h> #include <LibTLS/TLSv12.h> @@ -32,7 +33,7 @@ void TLSv12::alert(AlertLevel level, AlertDescription code) { auto the_alert = build_alert(level == AlertLevel::Critical, (u8)code); write_packet(the_alert); - flush(); + MUST(flush()); } void TLSv12::write_packet(ByteBuffer& packet) @@ -41,7 +42,7 @@ void TLSv12::write_packet(ByteBuffer& packet) if (m_context.connection_status > ConnectionStatus::Disconnected) { if (!m_has_scheduled_write_flush && !immediate) { dbgln_if(TLS_DEBUG, "Scheduling write of {}", m_context.tls_buffer.size()); - deferred_invoke([this] { write_into_socket(); }); + Core::deferred_invoke([this] { write_into_socket(); }); m_has_scheduled_write_flush = true; } else { // multiple packet are available, let's flush some out @@ -540,15 +541,17 @@ ssize_t TLSv12::handle_message(ReadonlyBytes buffer) if (code == (u8)AlertDescription::CloseNotify) { res += 2; alert(AlertLevel::Critical, AlertDescription::CloseNotify); - m_context.connection_finished = true; if (!m_context.cipher_spec_set) { // AWS CloudFront hits this. dbgln("Server sent a close notify and we haven't agreed on a cipher suite. Treating it as a handshake failure."); m_context.critical_error = (u8)AlertDescription::HandshakeFailure; try_disambiguate_error(); } + m_context.close_notify = true; } m_context.error_code = (Error)code; + check_connection_state(false); + notify_client_for_app_data(); // Give the user one more chance to observe the EOF } break; default: diff --git a/Userland/Libraries/LibTLS/Socket.cpp b/Userland/Libraries/LibTLS/Socket.cpp index 44434591df..f63900d790 100644 --- a/Userland/Libraries/LibTLS/Socket.cpp +++ b/Userland/Libraries/LibTLS/Socket.cpp @@ -6,6 +6,7 @@ #include <AK/Debug.h> #include <LibCore/DateTime.h> +#include <LibCore/EventLoop.h> #include <LibCore/Timer.h> #include <LibCrypto/PK/Code/EMSA_PSS.h> #include <LibTLS/TLSv12.h> @@ -17,24 +18,18 @@ constexpr static size_t MaximumApplicationDataChunkSize = 16 * KiB; namespace TLS { -Optional<ByteBuffer> TLSv12::read() +ErrorOr<size_t> TLSv12::read(Bytes bytes) { - if (m_context.application_buffer.size()) { - auto buf = move(m_context.application_buffer); - return { move(buf) }; + m_eof = false; + auto size_to_read = min(bytes.size(), m_context.application_buffer.size()); + if (size_to_read == 0) { + m_eof = true; + return 0; } - return {}; -} -ByteBuffer TLSv12::read(size_t max_size) -{ - if (m_context.application_buffer.size()) { - auto length = min(m_context.application_buffer.size(), max_size); - auto buf = m_context.application_buffer.slice(0, length); - m_context.application_buffer = m_context.application_buffer.slice(length, m_context.application_buffer.size() - length); - return buf; - } - return {}; + m_context.application_buffer.span().slice(0, size_to_read).copy_to(bytes); + m_context.application_buffer = m_context.application_buffer.slice(size_to_read, m_context.application_buffer.size() - size_to_read); + return size_to_read; } String TLSv12::read_line(size_t max_size) @@ -57,99 +52,112 @@ String TLSv12::read_line(size_t max_size) return line; } -bool TLSv12::write(ReadonlyBytes buffer) +ErrorOr<size_t> TLSv12::write(ReadonlyBytes bytes) { if (m_context.connection_status != ConnectionStatus::Established) { dbgln_if(TLS_DEBUG, "write request while not connected"); - return false; + return AK::Error::from_string_literal("TLS write request while not connected"); } - for (size_t offset = 0; offset < buffer.size(); offset += MaximumApplicationDataChunkSize) { - PacketBuilder builder { MessageType::ApplicationData, m_context.options.version, buffer.size() - offset }; - builder.append(buffer.slice(offset, min(buffer.size() - offset, MaximumApplicationDataChunkSize))); + for (size_t offset = 0; offset < bytes.size(); offset += MaximumApplicationDataChunkSize) { + PacketBuilder builder { MessageType::ApplicationData, m_context.options.version, bytes.size() - offset }; + builder.append(bytes.slice(offset, min(bytes.size() - offset, MaximumApplicationDataChunkSize))); auto packet = builder.build(); update_packet(packet); write_packet(packet); } - return true; + return bytes.size(); } -bool TLSv12::connect(const String& hostname, int port) +ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(const String& host, u16 port, Options options) { - set_sni(hostname); - return Core::Socket::connect(hostname, port); + Core::EventLoop loop; + OwnPtr<Core::Stream::Socket> tcp_socket = TRY(Core::Stream::TCPSocket::connect(host, port)); + TRY(tcp_socket->set_blocking(false)); + auto tls_socket = make<TLSv12>(move(tcp_socket), move(options)); + tls_socket->set_sni(host); + tls_socket->on_connected = [&] { + loop.quit(0); + }; + tls_socket->on_tls_error = [&](auto alert) { + loop.quit(256 - to_underlying(alert)); + }; + auto result = loop.exec(); + if (result == 0) + return tls_socket; + + tls_socket->try_disambiguate_error(); + // FIXME: Should return richer information here. + return AK::Error::from_string_literal(alert_name(static_cast<AlertDescription>(256 - result))); } -bool TLSv12::common_connect(const struct sockaddr* saddr, socklen_t length) +ErrorOr<NonnullOwnPtr<TLSv12>> TLSv12::connect(const String& host, Core::Stream::Socket& underlying_stream, Options options) { - if (m_context.critical_error) - return false; + StreamVariantType socket { &underlying_stream }; + auto tls_socket = make<TLSv12>(&underlying_stream, move(options)); + tls_socket->set_sni(host); + Core::EventLoop loop; + tls_socket->on_connected = [&] { + loop.quit(0); + }; + tls_socket->on_tls_error = [&](auto alert) { + loop.quit(256 - to_underlying(alert)); + }; + auto result = loop.exec(); + if (result == 0) + return tls_socket; - if (Core::Socket::is_connected()) { - if (is_established()) { - VERIFY_NOT_REACHED(); - } else { - Core::Socket::close(); // reuse? - } - } + tls_socket->try_disambiguate_error(); + // FIXME: Should return richer information here. + return AK::Error::from_string_literal(alert_name(static_cast<AlertDescription>(256 - result))); +} - Core::Socket::on_connected = [this] { - Core::Socket::on_ready_to_read = [this] { - read_from_socket(); +void TLSv12::setup_connection() +{ + Core::deferred_invoke([this] { + auto& stream = underlying_stream(); + stream.on_ready_to_read = [this] { + auto result = read_from_socket(); + if (result.is_error()) + dbgln("Read error: {}", result.error()); }; + m_handshake_timeout_timer = Core::Timer::create_single_shot( + m_max_wait_time_for_handshake_in_seconds * 1000, [&] { + dbgln("Handshake timeout :("); + auto timeout_diff = Core::DateTime::now().timestamp() - m_context.handshake_initiation_timestamp; + // If the timeout duration was actually within the max wait time (with a margin of error), + // we're not operating slow, so the server timed out. + // otherwise, it's our fault that the negotiation is taking too long, so extend the timer :P + if (timeout_diff < m_max_wait_time_for_handshake_in_seconds + 1) { + // The server did not respond fast enough, + // time the connection out. + alert(AlertLevel::Critical, AlertDescription::UserCanceled); + m_context.tls_buffer.clear(); + m_context.error_code = Error::TimedOut; + m_context.critical_error = (u8)Error::TimedOut; + check_connection_state(false); // Notify the client. + } else { + // Extend the timer, we are too slow. + m_handshake_timeout_timer->restart(m_max_wait_time_for_handshake_in_seconds * 1000); + } + }); auto packet = build_hello(); write_packet(packet); - - deferred_invoke([&] { - m_handshake_timeout_timer = Core::Timer::create_single_shot( - m_max_wait_time_for_handshake_in_seconds * 1000, [&] { - auto timeout_diff = Core::DateTime::now().timestamp() - m_context.handshake_initiation_timestamp; - // If the timeout duration was actually within the max wait time (with a margin of error), - // we're not operating slow, so the server timed out. - // otherwise, it's our fault that the negotiation is taking too long, so extend the timer :P - if (timeout_diff < m_max_wait_time_for_handshake_in_seconds + 1) { - // The server did not respond fast enough, - // time the connection out. - alert(AlertLevel::Critical, AlertDescription::UserCanceled); - m_context.connection_finished = true; - m_context.tls_buffer.clear(); - m_context.error_code = Error::TimedOut; - m_context.critical_error = (u8)Error::TimedOut; - check_connection_state(false); // Notify the client. - } else { - // Extend the timer, we are too slow. - m_handshake_timeout_timer->restart(m_max_wait_time_for_handshake_in_seconds * 1000); - } - }, - this); - write_into_socket(); - m_handshake_timeout_timer->start(); - m_context.handshake_initiation_timestamp = Core::DateTime::now().timestamp(); - }); - m_has_scheduled_write_flush = true; - - if (on_tls_connected) - on_tls_connected(); - }; - bool success = Core::Socket::common_connect(saddr, length); - if (!success) - return false; - - return true; + write_into_socket(); + m_handshake_timeout_timer->start(); + m_context.handshake_initiation_timestamp = Core::DateTime::now().timestamp(); + }); + m_has_scheduled_write_flush = true; } void TLSv12::notify_client_for_app_data() { if (m_context.application_buffer.size() > 0) { - if (!m_has_scheduled_app_data_flush) { - deferred_invoke([this] { notify_client_for_app_data(); }); - m_has_scheduled_app_data_flush = true; - } - if (on_tls_ready_to_read) - on_tls_ready_to_read(*this); + if (on_ready_to_read) + on_ready_to_read(); } else { if (m_context.connection_finished && !m_context.has_invoked_finish_or_error_callback) { m_context.has_invoked_finish_or_error_callback = true; @@ -160,7 +168,7 @@ void TLSv12::notify_client_for_app_data() m_has_scheduled_app_data_flush = false; } -void TLSv12::read_from_socket() +ErrorOr<void> TLSv12::read_from_socket() { // If there's anything before we consume stuff, let the client know // since we won't be consuming things if the connection is terminated. @@ -173,12 +181,28 @@ void TLSv12::read_from_socket() } }; - if (!check_connection_state(true)) { - set_idle(true); - return; - } + if (!check_connection_state(true)) + return {}; - consume(Core::Socket::read(4 * MiB)); + u8 buffer[16 * KiB]; + Bytes bytes { buffer, array_size(buffer) }; + size_t nread = 0; + auto& stream = underlying_stream(); + do { + auto result = stream.read(bytes); + if (result.is_error()) { + if (result.error().is_errno() && result.error().code() != EINTR) { + if (result.error().code() != EAGAIN) + dbgln("TLS Socket read failed, error: {}", result.error()); + break; + } + continue; + } + nread = result.release_value(); + consume(bytes.slice(0, nread)); + } while (nread > 0 && !m_context.critical_error); + + return {}; } void TLSv12::write_into_socket() @@ -188,14 +212,8 @@ void TLSv12::write_into_socket() m_has_scheduled_write_flush = false; if (!check_connection_state(false)) return; - flush(); - if (!is_established()) - return; - - if (!m_context.application_buffer.size()) // hey client, you still have stuff to read... - if (on_tls_ready_to_write) - on_tls_ready_to_write(*this); + MUST(flush()); } bool TLSv12::check_connection_state(bool read) @@ -203,16 +221,21 @@ bool TLSv12::check_connection_state(bool read) if (m_context.connection_finished) return false; - if (!Core::Socket::is_open() || !Core::Socket::is_connected()) { + if (m_context.close_notify) + m_context.connection_finished = true; + + auto& stream = underlying_stream(); + + if (!stream.is_open()) { // an abrupt closure (the server is a jerk) dbgln_if(TLS_DEBUG, "Socket not open, assuming abrupt closure"); m_context.connection_finished = true; m_context.connection_status = ConnectionStatus::Disconnected; - Core::Socket::close(); + close(); return false; } - if (read && Core::Socket::eof()) { + if (read && stream.is_eof()) { if (m_context.application_buffer.size() == 0 && m_context.connection_status != ConnectionStatus::Disconnected) { m_context.has_invoked_finish_or_error_callback = true; if (on_tls_finished) @@ -229,9 +252,10 @@ bool TLSv12::check_connection_state(bool read) on_tls_error((AlertDescription)m_context.critical_error); m_context.connection_finished = true; m_context.connection_status = ConnectionStatus::Disconnected; - Core::Socket::close(); + close(); return false; } + if (((read && m_context.application_buffer.size() == 0) || !read) && m_context.connection_finished) { if (m_context.application_buffer.size() == 0 && m_context.connection_status != ConnectionStatus::Disconnected) { m_context.has_invoked_finish_or_error_callback = true; @@ -250,30 +274,51 @@ bool TLSv12::check_connection_state(bool read) return true; } -bool TLSv12::flush() +ErrorOr<bool> TLSv12::flush() { - auto out_buffer = write_buffer().data(); - size_t out_buffer_index { 0 }; - size_t out_buffer_length = write_buffer().size(); + auto out_bytes = m_context.tls_buffer.bytes(); - if (out_buffer_length == 0) + if (out_bytes.is_empty()) return true; if constexpr (TLS_DEBUG) { dbgln("SENDING..."); - print_buffer(out_buffer, out_buffer_length); + print_buffer(out_bytes); } - if (Core::Socket::write(&out_buffer[out_buffer_index], out_buffer_length)) { - write_buffer().clear(); + + auto& stream = underlying_stream(); + Optional<AK::Error> error; + size_t written; + do { + auto result = stream.write(out_bytes); + if (result.is_error() && result.error().code() != EINTR && result.error().code() != EAGAIN) { + error = result.release_error(); + dbgln("TLS Socket write error: {}", *error); + break; + } + written = result.value(); + out_bytes = out_bytes.slice(written); + } while (!out_bytes.is_empty()); + + if (out_bytes.is_empty() && !error.has_value()) { + m_context.tls_buffer.clear(); return true; } + if (m_context.send_retries++ == 10) { // drop the records, we can't send - dbgln_if(TLS_DEBUG, "Dropping {} bytes worth of TLS records as max retries has been reached", write_buffer().size()); - write_buffer().clear(); + dbgln_if(TLS_DEBUG, "Dropping {} bytes worth of TLS records as max retries has been reached", m_context.tls_buffer.size()); + m_context.tls_buffer.clear(); m_context.send_retries = 0; } return false; } +void TLSv12::close() +{ + alert(AlertLevel::Critical, AlertDescription::CloseNotify); + // bye bye. + m_context.connection_status = ConnectionStatus::Disconnected; +} + } diff --git a/Userland/Libraries/LibTLS/TLSv12.cpp b/Userland/Libraries/LibTLS/TLSv12.cpp index e8c35a2a6c..0fb57f43ae 100644 --- a/Userland/Libraries/LibTLS/TLSv12.cpp +++ b/Userland/Libraries/LibTLS/TLSv12.cpp @@ -183,6 +183,7 @@ void TLSv12::set_root_certificates(Vector<Certificate> certificates) // FIXME: Figure out what we should do when our root certs are invalid. } m_context.root_certificates = move(certificates); + dbgln_if(TLS_DEBUG, "{}: Set {} root certificates", this, m_context.root_certificates.size()); } bool Context::verify_chain() const @@ -223,7 +224,7 @@ bool Context::verify_chain() const auto ref = chain.get(it.value); if (!ref.has_value()) { - dbgln("Certificate for {} is not signed by anyone we trust ({})", it.key, it.value); + dbgln("{}: Certificate for {} is not signed by anyone we trust ({})", this, it.key, it.value); return false; } @@ -301,50 +302,43 @@ void TLSv12::pseudorandom_function(Bytes output, ReadonlyBytes secret, const u8* } } -TLSv12::TLSv12(Core::Object* parent, Options options) - : Core::Socket(Core::Socket::Type::TCP, parent) +TLSv12::TLSv12(StreamVariantType stream, Options options) + : m_stream(move(stream)) { m_context.options = move(options); m_context.is_server = false; m_context.tls_buffer = {}; -#ifdef SOCK_NONBLOCK - int fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); -#else - int fd = socket(AF_INET, SOCK_STREAM, 0); - int option = 1; - ioctl(fd, FIONBIO, &option); -#endif - if (fd < 0) { - set_error(errno); - } else { - set_fd(fd); - set_mode(Core::OpenMode::ReadWrite); - set_error(0); - } + + set_root_certificates(m_context.options.root_certificates.has_value() + ? *m_context.options.root_certificates + : DefaultRootCACertificates::the().certificates()); + + setup_connection(); } -bool TLSv12::add_client_key(ReadonlyBytes certificate_pem_buffer, ReadonlyBytes rsa_key) // FIXME: This should not be bound to RSA +Vector<Certificate> TLSv12::parse_pem_certificate(ReadonlyBytes certificate_pem_buffer, ReadonlyBytes rsa_key) // FIXME: This should not be bound to RSA { if (certificate_pem_buffer.is_empty() || rsa_key.is_empty()) { - return true; + return {}; } + auto decoded_certificate = Crypto::decode_pem(certificate_pem_buffer); if (decoded_certificate.is_empty()) { dbgln("Certificate not PEM"); - return false; + return {}; } auto maybe_certificate = Certificate::parse_asn1(decoded_certificate); if (!maybe_certificate.has_value()) { dbgln("Invalid certificate"); - return false; + return {}; } Crypto::PK::RSA rsa(rsa_key); auto certificate = maybe_certificate.release_value(); certificate.private_key = rsa.private_key(); - return add_client_key(certificate); + return { move(certificate) }; } Singleton<DefaultRootCACertificates> DefaultRootCACertificates::s_the; @@ -364,5 +358,6 @@ DefaultRootCACertificates::DefaultRootCACertificates() cert.not_after = Crypto::ASN1::parse_generalized_time(config->read_entry(entity, "not_after", "")).value_or(next_year); m_ca_certificates.append(move(cert)); } + dbgln("Loaded {} CA Certificates", m_ca_certificates.size()); } } diff --git a/Userland/Libraries/LibTLS/TLSv12.h b/Userland/Libraries/LibTLS/TLSv12.h index 2d82cf00bc..23742e9080 100644 --- a/Userland/Libraries/LibTLS/TLSv12.h +++ b/Userland/Libraries/LibTLS/TLSv12.h @@ -10,8 +10,8 @@ #include <AK/IPv4Address.h> #include <AK/WeakPtr.h> #include <LibCore/Notifier.h> -#include <LibCore/Socket.h> -#include <LibCore/TCPSocket.h> +#include <LibCore/Stream.h> +#include <LibCore/Timer.h> #include <LibCrypto/Authentication/HMAC.h> #include <LibCrypto/BigInt/UnsignedBigInteger.h> #include <LibCrypto/Cipher/AES.h> @@ -215,7 +215,17 @@ struct Options { #define OPTION_WITH_DEFAULTS(typ, name, ...) \ static typ default_##name() { return typ { __VA_ARGS__ }; } \ - typ name = default_##name(); + typ name = default_##name(); \ + Options& set_##name(typ new_value)& \ + { \ + name = move(new_value); \ + return *this; \ + } \ + Options&& set_##name(typ new_value)&& \ + { \ + name = move(new_value); \ + return move(*this); \ + } OPTION_WITH_DEFAULTS(Version, version, Version::V12) OPTION_WITH_DEFAULTS(Vector<SignatureAndHashAlgorithm>, supported_signature_algorithms, @@ -227,6 +237,10 @@ struct Options { OPTION_WITH_DEFAULTS(bool, use_sni, true) OPTION_WITH_DEFAULTS(bool, use_compression, false) OPTION_WITH_DEFAULTS(bool, validate_certificates, true) + OPTION_WITH_DEFAULTS(Optional<Vector<Certificate>>, root_certificates, ) + OPTION_WITH_DEFAULTS(Function<void(AlertDescription)>, alert_handler, [](auto) {}) + OPTION_WITH_DEFAULTS(Function<void()>, finish_callback, [] {}) + OPTION_WITH_DEFAULTS(Function<Vector<Certificate>()>, certificate_provider, [] { return Vector<Certificate> {}; }) #undef OPTION_WITH_DEFAULTS }; @@ -290,6 +304,7 @@ struct Context { ClientVerificationStaus client_verified { Verified }; bool connection_finished { false }; + bool close_notify { false }; bool has_invoked_finish_or_error_callback { false }; // message flags @@ -311,12 +326,55 @@ struct Context { } server_diffie_hellman_params; }; -class TLSv12 : public Core::Socket { - C_OBJECT(TLSv12) +class TLSv12 final : public Core::Stream::Socket { +private: + Core::Stream::Socket& underlying_stream() + { + return *m_stream.visit([&](auto& stream) -> Core::Stream::Socket* { return stream; }); + } + Core::Stream::Socket const& underlying_stream() const + { + return *m_stream.visit([&](auto& stream) -> Core::Stream::Socket const* { return stream; }); + } + public: - ByteBuffer& write_buffer() { return m_context.tls_buffer; } + virtual bool is_readable() const override { return true; } + virtual bool is_writable() const override { return true; } + + /// Reads into a buffer, with the maximum size being the size of the buffer. + /// The amount of bytes read can be smaller than the size of the buffer. + /// Returns either the amount of bytes read, or an errno in the case of + /// failure. + virtual ErrorOr<size_t> read(Bytes) override; + + /// Tries to write the entire contents of the buffer. It is possible for + /// less than the full buffer to be written. Returns either the amount of + /// bytes written into the stream, or an errno in the case of failure. + virtual ErrorOr<size_t> write(ReadonlyBytes) override; + + virtual bool is_eof() const override { return m_context.connection_finished && m_context.application_buffer.is_empty(); } + + virtual bool is_open() const override { return is_established(); } + virtual void close() override; + + virtual ErrorOr<size_t> pending_bytes() const override { return m_context.application_buffer.size(); } + virtual ErrorOr<bool> can_read_without_blocking(int = 0) const override { return !m_context.application_buffer.is_empty(); } + virtual ErrorOr<void> set_blocking(bool block) override + { + VERIFY(!block); + return {}; + } + virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return underlying_stream().set_close_on_exec(enabled); } + + virtual void set_notifications_enabled(bool enabled) override { underlying_stream().set_notifications_enabled(enabled); } + + static ErrorOr<NonnullOwnPtr<TLSv12>> connect(String const& host, u16 port, Options = {}); + static ErrorOr<NonnullOwnPtr<TLSv12>> connect(String const& host, Core::Stream::Socket& underlying_stream, Options = {}); + + using StreamVariantType = Variant<OwnPtr<Core::Stream::Socket>, Core::Stream::Socket*>; + explicit TLSv12(StreamVariantType, Options); + bool is_established() const { return m_context.connection_status == ConnectionStatus::Established; } - virtual bool connect(const String&, int) override; void set_sni(StringView sni) { @@ -332,12 +390,7 @@ public: void set_root_certificates(Vector<Certificate>); - bool add_client_key(ReadonlyBytes certificate_pem_buffer, ReadonlyBytes key_pem_buffer); - bool add_client_key(Certificate certificate) - { - m_context.client_certificates.append(move(certificate)); - return true; - } + static Vector<Certificate> parse_pem_certificate(ReadonlyBytes certificate_pem_buffer, ReadonlyBytes key_pem_buffer); ByteBuffer finish_build(); @@ -363,35 +416,19 @@ public: return v == Version::V12; } - Optional<ByteBuffer> read(); - ByteBuffer read(size_t max_size); - - bool write(ReadonlyBytes); void alert(AlertLevel, AlertDescription); bool can_read_line() const { return m_context.application_buffer.size() && memchr(m_context.application_buffer.data(), '\n', m_context.application_buffer.size()); } bool can_read() const { return m_context.application_buffer.size() > 0; } String read_line(size_t max_size); - void set_on_tls_ready_to_write(Function<void(TLSv12&)> function) - { - on_tls_ready_to_write = move(function); - if (on_tls_ready_to_write) { - if (is_established()) - on_tls_ready_to_write(*this); - } - } - - Function<void(TLSv12&)> on_tls_ready_to_read; Function<void(AlertDescription)> on_tls_error; - Function<void()> on_tls_connected; Function<void()> on_tls_finished; Function<void(TLSv12&)> on_tls_certificate_request; + Function<void()> on_connected; private: - explicit TLSv12(Core::Object* parent, Options = {}); - - virtual bool common_connect(const struct sockaddr*, socklen_t) override; + void setup_connection(); void consume(ReadonlyBytes record); @@ -416,9 +453,9 @@ private: void build_rsa_pre_master_secret(PacketBuilder&); void build_dhe_rsa_pre_master_secret(PacketBuilder&); - bool flush(); + ErrorOr<bool> flush(); void write_into_socket(); - void read_from_socket(); + ErrorOr<void> read_from_socket(); bool check_connection_state(bool read); void notify_client_for_app_data(); @@ -512,6 +549,8 @@ private: void try_disambiguate_error() const; + bool m_eof { false }; + StreamVariantType m_stream; Context m_context; OwnPtr<Crypto::Authentication::HMAC<Crypto::Hash::Manager>> m_hmac_local; @@ -529,7 +568,6 @@ private: i32 m_max_wait_time_for_handshake_in_seconds { 10 }; RefPtr<Core::Timer> m_handshake_timeout_timer; - Function<void(TLSv12&)> on_tls_ready_to_write; }; } diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp index b3313ca9e6..1d8150a470 100644 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp +++ b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp @@ -24,36 +24,26 @@ void TLSv12WebSocketConnectionImpl::connect(ConnectionInfo const& connection) VERIFY(on_connected); VERIFY(on_connection_error); VERIFY(on_ready_to_read); - m_socket = TLS::TLSv12::construct(this); + m_socket = TLS::TLSv12::connect(connection.url().host(), connection.url().port_or_default()).release_value_but_fixme_should_propagate_errors(); - m_socket->set_root_certificates(DefaultRootCACertificates::the().certificates()); m_socket->on_tls_error = [this](TLS::AlertDescription) { on_connection_error(); }; - m_socket->on_tls_ready_to_read = [this](auto&) { + m_socket->on_ready_to_read = [this] { on_ready_to_read(); }; - m_socket->set_on_tls_ready_to_write([this](auto& tls) { - tls.set_on_tls_ready_to_write(nullptr); - on_connected(); - }); m_socket->on_tls_finished = [this] { on_connection_error(); }; m_socket->on_tls_certificate_request = [](auto&) { // FIXME : Once we handle TLS certificate requests, handle it here as well. }; - bool success = m_socket->connect(connection.url().host(), connection.url().port_or_default()); - if (!success) { - deferred_invoke([this] { - on_connection_error(); - }); - } + on_connected(); } bool TLSv12WebSocketConnectionImpl::send(ReadonlyBytes data) { - return m_socket->write(data); + return m_socket->write_or_error(data); } bool TLSv12WebSocketConnectionImpl::can_read_line() @@ -73,24 +63,24 @@ bool TLSv12WebSocketConnectionImpl::can_read() ByteBuffer TLSv12WebSocketConnectionImpl::read(int max_size) { - return m_socket->read(max_size); + auto buffer = ByteBuffer::create_uninitialized(max_size).release_value_but_fixme_should_propagate_errors(); + auto nread = m_socket->read(buffer).release_value_but_fixme_should_propagate_errors(); + return buffer.slice(0, nread); } bool TLSv12WebSocketConnectionImpl::eof() { - return m_socket->eof(); + return m_socket->is_eof(); } void TLSv12WebSocketConnectionImpl::discard_connection() { if (!m_socket) return; - m_socket->on_tls_connected = nullptr; m_socket->on_tls_error = nullptr; m_socket->on_tls_finished = nullptr; m_socket->on_tls_certificate_request = nullptr; m_socket->on_ready_to_read = nullptr; - remove_child(*m_socket); m_socket = nullptr; } diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h index 6cfe8b1592..db39b2bc63 100644 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h +++ b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h @@ -39,7 +39,7 @@ public: private: explicit TLSv12WebSocketConnectionImpl(Core::Object* parent = nullptr); - RefPtr<TLS::TLSv12> m_socket; + OwnPtr<TLS::TLSv12> m_socket; }; } diff --git a/Userland/Services/RequestServer/ClientConnection.cpp b/Userland/Services/RequestServer/ClientConnection.cpp index 676b90a5a1..4c931391e6 100644 --- a/Userland/Services/RequestServer/ClientConnection.cpp +++ b/Userland/Services/RequestServer/ClientConnection.cpp @@ -127,40 +127,15 @@ void ClientConnection::ensure_connection(URL const& url, ::RequestServer::CacheL struct { URL const& m_url; - void start(NonnullRefPtr<Core::Socket> socket) + void start(Core::Stream::Socket& socket) { - auto is_tls = is<TLS::TLSv12>(*socket); - auto* tls_instance = is_tls ? static_cast<TLS::TLSv12*>(socket.ptr()) : nullptr; - - auto is_connected = false; - if (is_tls && tls_instance->is_established()) - is_connected = true; - if (!is_tls && socket->is_connected()) - is_connected = true; - - VERIFY(!is_connected); - - bool did_connect; - if (is_tls) { - tls_instance->set_root_certificates(DefaultRootCACertificates::the().certificates()); - tls_instance->on_tls_connected = [socket = socket.ptr(), url = m_url, tls_instance] { - tls_instance->set_on_tls_ready_to_write([socket, url](auto&) { - ConnectionCache::request_did_finish(url, socket); - }); - }; - tls_instance->on_tls_error = [socket = socket.ptr(), url = m_url](auto) { - ConnectionCache::request_did_finish(url, socket); - }; - did_connect = tls_instance->connect(m_url.host(), m_url.port_or_default()); - } else { - socket->on_connected = [socket = socket.ptr(), url = m_url]() mutable { - ConnectionCache::request_did_finish(url, socket); - }; - did_connect = socket->connect(m_url.host(), m_url.port_or_default()); - } - - if (!did_connect) - ConnectionCache::request_did_finish(m_url, socket); + auto is_connected = socket.is_open(); + VERIFY(is_connected); + ConnectionCache::request_did_finish(m_url, &socket); + } + void fail(Core::NetworkJob::Error error) + { + dbgln("Pre-connect to {} failed: {}", m_url, Core::to_string(error)); } } job { url }; diff --git a/Userland/Services/RequestServer/ConnectionCache.cpp b/Userland/Services/RequestServer/ConnectionCache.cpp index 896e9d2f0d..6582ad2059 100644 --- a/Userland/Services/RequestServer/ConnectionCache.cpp +++ b/Userland/Services/RequestServer/ConnectionCache.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org> + * Copyright (c) 2021-2022, Ali Mohammad Pur <mpfard@serenityos.org> * * SPDX-License-Identifier: BSD-2-Clause */ @@ -10,10 +10,10 @@ namespace RequestServer::ConnectionCache { -HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::TCPSocket>>>> g_tcp_connection_cache {}; +HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket>>>> g_tcp_connection_cache {}; HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<TLS::TLSv12>>>> g_tls_connection_cache {}; -void request_did_finish(URL const& url, Core::Socket const* socket) +void request_did_finish(URL const& url, Core::Stream::Socket const* socket) { if (!socket) { dbgln("Request with a null socket finished for URL {}", url); @@ -37,34 +37,44 @@ void request_did_finish(URL const& url, Core::Socket const* socket) auto& connection = *connection_it; if (connection->request_queue.is_empty()) { - connection->has_started = false; - connection->current_url = {}; - connection->removal_timer->on_timeout = [ptr = connection.ptr(), &cache_entry = *it->value, key = it->key, &cache]() mutable { - Core::deferred_invoke([&, key = move(key), ptr] { - dbgln_if(REQUESTSERVER_DEBUG, "Removing no-longer-used connection {} (socket {})", ptr, ptr->socket); - auto did_remove = cache_entry.remove_first_matching([&](auto& entry) { return entry == ptr; }); - VERIFY(did_remove); - if (cache_entry.is_empty()) - cache.remove(key); - }); - }; - connection->removal_timer->start(); + Core::deferred_invoke([&connection, &cache_entry = *it->value, key = it->key, &cache] { + connection->socket->set_notifications_enabled(false); + connection->has_started = false; + connection->current_url = {}; + connection->job_data = {}; + connection->removal_timer->on_timeout = [ptr = connection.ptr(), &cache_entry, key = move(key), &cache]() mutable { + Core::deferred_invoke([&, key = move(key), ptr] { + dbgln_if(REQUESTSERVER_DEBUG, "Removing no-longer-used connection {} (socket {})", ptr, ptr->socket); + auto did_remove = cache_entry.remove_first_matching([&](auto& entry) { return entry == ptr; }); + VERIFY(did_remove); + if (cache_entry.is_empty()) + cache.remove(key); + }); + }; + connection->removal_timer->start(); + }); } else { - recreate_socket_if_needed(*connection, url); - dbgln_if(REQUESTSERVER_DEBUG, "Running next job in queue for connection {} @{}", &connection, connection->socket); - auto request = connection->request_queue.take_first(); - connection->timer.start(); - connection->current_url = url; - request(connection->socket); + if (auto result = recreate_socket_if_needed(*connection, url); result.is_error()) { + dbgln("ConnectionCache request finish handler, reconnection failed with {}", result.error()); + connection->job_data.fail(Core::NetworkJob::Error::ConnectionFailed); + return; + } + Core::deferred_invoke([&, url] { + dbgln_if(REQUESTSERVER_DEBUG, "Running next job in queue for connection {} @{}", &connection, connection->socket); + connection->timer.start(); + connection->current_url = url; + connection->job_data = connection->request_queue.take_first(); + connection->job_data.start(*connection->socket); + }); } }; - if (is<TLS::TLSv12>(socket)) + if (is<Core::Stream::BufferedSocket<TLS::TLSv12>>(socket)) fire_off_next_job(g_tls_connection_cache); - else if (is<Core::TCPSocket>(socket)) + else if (is<Core::Stream::BufferedSocket<Core::Stream::TCPSocket>>(socket)) fire_off_next_job(g_tcp_connection_cache); else - dbgln("Unknown socket {} finished for URL {}", *socket, url); + dbgln("Unknown socket {} finished for URL {}", socket, url); } void dump_jobs() @@ -74,7 +84,7 @@ void dump_jobs() dbgln(" - {}:{}", connection.key.hostname, connection.key.port); for (auto& entry : *connection.value) { dbgln(" - Connection {} (started={}) (socket={})", &entry, entry.has_started, entry.socket); - dbgln(" Currently loading {} ({} elapsed)", entry.current_url, entry.timer.elapsed()); + dbgln(" Currently loading {} ({} elapsed)", entry.current_url, entry.timer.is_valid() ? entry.timer.elapsed() : 0); dbgln(" Request Queue:"); for (auto& job : entry.request_queue) dbgln(" - {}", &job); @@ -85,7 +95,7 @@ void dump_jobs() dbgln(" - {}:{}", connection.key.hostname, connection.key.port); for (auto& entry : *connection.value) { dbgln(" - Connection {} (started={}) (socket={})", &entry, entry.has_started, entry.socket); - dbgln(" Currently loading {} ({} elapsed)", entry.current_url, entry.timer.elapsed()); + dbgln(" Currently loading {} ({} elapsed)", entry.current_url, entry.timer.is_valid() ? entry.timer.elapsed() : 0); dbgln(" Request Queue:"); for (auto& job : entry.request_queue) dbgln(" - {}", &job); diff --git a/Userland/Services/RequestServer/ConnectionCache.h b/Userland/Services/RequestServer/ConnectionCache.h index aa1079b051..716c764c59 100644 --- a/Userland/Services/RequestServer/ConnectionCache.h +++ b/Userland/Services/RequestServer/ConnectionCache.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org> + * Copyright (c) 2021-2022, Ali Mohammad Pur <mpfard@serenityos.org> * Copyright (c) 2022, the SerenityOS developers. * * SPDX-License-Identifier: BSD-2-Clause @@ -13,6 +13,8 @@ #include <AK/URL.h> #include <AK/Vector.h> #include <LibCore/ElapsedTimer.h> +#include <LibCore/EventLoop.h> +#include <LibCore/NetworkJob.h> #include <LibCore/TCPSocket.h> #include <LibCore/Timer.h> #include <LibTLS/TLSv12.h> @@ -30,15 +32,47 @@ namespace RequestServer::ConnectionCache { template<typename Socket> struct Connection { - using QueueType = Vector<Function<void(Core::Socket&)>>; + struct JobData { + Function<void(Core::Stream::Socket&)> start {}; + Function<void(Core::NetworkJob::Error)> fail {}; + Function<Vector<TLS::Certificate>()> provide_client_certificates {}; + + template<typename T> + static JobData create(T& job) + { + // Clang-format _really_ messes up formatting this, so just format it manually. + // clang-format off + return JobData { + .start = [&job](auto& socket) { + job.start(socket); + }, + .fail = [&job](auto error) { + job.fail(error); + }, + .provide_client_certificates = [&job] { + if constexpr (requires { job.on_certificate_requested; }) { + if (job.on_certificate_requested) + return job.on_certificate_requested(); + } else { + // "use" `job`, otherwise clang gets sad. + (void)job; + } + return Vector<TLS::Certificate> {}; + }, + }; + // clang-format on + } + }; + using QueueType = Vector<JobData>; using SocketType = Socket; - NonnullRefPtr<Socket> socket; + NonnullOwnPtr<Core::Stream::BufferedSocket<Socket>> socket; QueueType request_queue; NonnullRefPtr<Core::Timer> removal_timer; bool has_started { false }; URL current_url {}; Core::ElapsedTimer timer {}; + JobData job_data {}; }; struct ConnectionKey { @@ -60,44 +94,82 @@ struct AK::Traits<RequestServer::ConnectionCache::ConnectionKey> : public AK::Ge namespace RequestServer::ConnectionCache { -extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::TCPSocket>>>> g_tcp_connection_cache; +extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket>>>> g_tcp_connection_cache; extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<TLS::TLSv12>>>> g_tls_connection_cache; -void request_did_finish(URL const&, Core::Socket const*); +void request_did_finish(URL const&, Core::Stream::Socket const*); void dump_jobs(); constexpr static size_t MaxConcurrentConnectionsPerURL = 2; constexpr static size_t ConnectionKeepAliveTimeMilliseconds = 10'000; template<typename T> -void recreate_socket_if_needed(T& connection, URL const& url) +ErrorOr<void> recreate_socket_if_needed(T& connection, URL const& url) { - using SocketType = RemoveCVReference<decltype(*connection.socket)>; - bool is_connected; - if constexpr (IsSame<SocketType, TLS::TLSv12>) - is_connected = connection.socket->is_established(); - else - is_connected = connection.socket->is_connected(); - if (!is_connected) { + using SocketType = typename T::SocketType; + if (!connection.socket->is_open()) { // Create another socket for the connection. - connection.socket = SocketType::construct(nullptr); + auto set_socket = [&](auto socket) -> ErrorOr<void> { + connection.socket = TRY(Core::Stream::BufferedSocket<SocketType>::create(move(socket))); + return {}; + }; + + if constexpr (IsSame<TLS::TLSv12, SocketType>) { + TLS::Options options; + options.set_alert_handler([&connection](TLS::AlertDescription alert) { + Core::NetworkJob::Error reason; + if (alert == TLS::AlertDescription::HandshakeFailure) + reason = Core::NetworkJob::Error::ProtocolFailed; + else if (alert == TLS::AlertDescription::DecryptError) + reason = Core::NetworkJob::Error::ConnectionFailed; + else + reason = Core::NetworkJob::Error::TransmissionFailed; + + if (connection.job_data.fail) + connection.job_data.fail(reason); + }); + options.set_certificate_provider([&connection]() -> Vector<TLS::Certificate> { + if (connection.job_data.provide_client_certificates) + return connection.job_data.provide_client_certificates(); + return {}; + }); + TRY(set_socket(TRY(SocketType::connect(url.host(), url.port_or_default(), move(options))))); + } else { + TRY(set_socket(TRY(SocketType::connect(url.host(), url.port_or_default())))); + } dbgln_if(REQUESTSERVER_DEBUG, "Creating a new socket for {} -> {}", url, connection.socket); } + return {}; } decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) { using CacheEntryType = RemoveCVReference<decltype(*cache.begin()->value)>; - auto start_job = [&job](auto& socket) { - job.start(socket); - }; auto& sockets_for_url = *cache.ensure({ url.host(), url.port_or_default() }, [] { return make<CacheEntryType>(); }); + + using ReturnType = decltype(&sockets_for_url[0]); auto it = sockets_for_url.find_if([](auto& connection) { return connection->request_queue.is_empty(); }); auto did_add_new_connection = false; if (it.is_end() && sockets_for_url.size() < ConnectionCache::MaxConcurrentConnectionsPerURL) { using ConnectionType = RemoveCVReference<decltype(cache.begin()->value->at(0))>; + auto connection_result = ConnectionType::SocketType::connect(url.host(), url.port_or_default()); + if (connection_result.is_error()) { + dbgln("ConnectionCache: Connection to {} failed: {}", url, connection_result.error()); + Core::deferred_invoke([&job] { + job.fail(Core::NetworkJob::Error::ConnectionFailed); + }); + return ReturnType { nullptr }; + } + auto socket_result = Core::Stream::BufferedSocket<typename ConnectionType::SocketType>::create(connection_result.release_value()); + if (socket_result.is_error()) { + dbgln("ConnectionCache: Failed to make a buffered socket for {}: {}", url, socket_result.error()); + Core::deferred_invoke([&job] { + job.fail(Core::NetworkJob::Error::ConnectionFailed); + }); + return ReturnType { nullptr }; + } sockets_for_url.append(make<ConnectionType>( - ConnectionType::SocketType::construct(nullptr), + socket_result.release_value(), typename ConnectionType::QueueType {}, Core::Timer::create_single_shot(ConnectionKeepAliveTimeMilliseconds, nullptr))); did_add_new_connection = true; @@ -107,7 +179,7 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) if (did_add_new_connection) { index = sockets_for_url.size() - 1; } else { - // Find the least backed-up connection (based on how many entries are in their request queue. + // Find the least backed-up connection (based on how many entries are in their request queue). index = 0; auto min_queue_size = (size_t)-1; for (auto it = sockets_for_url.begin(); it != sockets_for_url.end(); ++it) { @@ -120,20 +192,35 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) } else { index = it.index(); } + if (sockets_for_url.is_empty()) { + Core::deferred_invoke([&job] { + job.fail(Core::NetworkJob::Error::ConnectionFailed); + }); + return ReturnType { nullptr }; + } + auto& connection = sockets_for_url[index]; if (!connection.has_started) { - recreate_socket_if_needed(connection, url); + if (auto result = recreate_socket_if_needed(connection, url); result.is_error()) { + dbgln("ConnectionCache: request failed to start, failed to make a socket: {}", result.error()); + Core::deferred_invoke([&job] { + job.fail(Core::NetworkJob::Error::ConnectionFailed); + }); + return ReturnType { nullptr }; + } dbgln_if(REQUESTSERVER_DEBUG, "Immediately start request for url {} in {} - {}", url, &connection, connection.socket); connection.has_started = true; connection.removal_timer->stop(); connection.timer.start(); connection.current_url = url; - start_job(*connection.socket); + connection.job_data = decltype(connection.job_data)::create(job); + connection.socket->set_notifications_enabled(true); + connection.job_data.start(*connection.socket); } else { dbgln_if(REQUESTSERVER_DEBUG, "Enqueue request for URL {} in {} - {}", url, &connection, connection.socket); - connection.request_queue.append(move(start_job)); + connection.request_queue.append(decltype(connection.job_data)::create(job)); } - return connection; + return &connection; } } diff --git a/Userland/Services/RequestServer/GeminiProtocol.cpp b/Userland/Services/RequestServer/GeminiProtocol.cpp index 45686caa75..9958a3b720 100644 --- a/Userland/Services/RequestServer/GeminiProtocol.cpp +++ b/Userland/Services/RequestServer/GeminiProtocol.cpp @@ -5,8 +5,8 @@ */ #include "ConnectionCache.h" -#include <LibGemini/GeminiJob.h> #include <LibGemini/GeminiRequest.h> +#include <LibGemini/Job.h> #include <RequestServer/GeminiProtocol.h> #include <RequestServer/GeminiRequest.h> @@ -30,10 +30,9 @@ OwnPtr<Request> GeminiProtocol::start_request(ClientConnection& client, const St if (pipe_result.is_error()) return {}; - auto output_stream = make<OutputFileStream>(pipe_result.value().write_fd); - output_stream->make_unbuffered(); - auto job = Gemini::GeminiJob::construct(request, *output_stream); - auto protocol_request = GeminiRequest::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream)); + auto output_stream = MUST(Core::Stream::File::adopt_fd(pipe_result.value().write_fd, Core::Stream::OpenMode::Write)); + auto job = Gemini::Job::construct(request, *output_stream); + auto protocol_request = GeminiRequest::create_with_job({}, client, *job, move(output_stream)); protocol_request->set_request_fd(pipe_result.value().read_fd); ConnectionCache::get_or_create_connection(ConnectionCache::g_tls_connection_cache, url, *job); diff --git a/Userland/Services/RequestServer/GeminiRequest.cpp b/Userland/Services/RequestServer/GeminiRequest.cpp index 9b087d783b..52cb3bbe2c 100644 --- a/Userland/Services/RequestServer/GeminiRequest.cpp +++ b/Userland/Services/RequestServer/GeminiRequest.cpp @@ -6,22 +6,22 @@ #include "ConnectionCache.h" #include <LibCore/EventLoop.h> -#include <LibGemini/GeminiJob.h> #include <LibGemini/GeminiResponse.h> +#include <LibGemini/Job.h> #include <RequestServer/GeminiRequest.h> namespace RequestServer { -GeminiRequest::GeminiRequest(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +GeminiRequest::GeminiRequest(ClientConnection& client, NonnullRefPtr<Gemini::Job> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) : Request(client, move(output_stream)) - , m_job(job) + , m_job(move(job)) { m_job->on_finish = [this](bool success) { Core::deferred_invoke([url = m_job->url(), socket = m_job->socket()] { ConnectionCache::request_did_finish(url, socket); }); if (auto* response = m_job->response()) { - set_downloaded_size(this->output_stream().size()); + set_downloaded_size(MUST(const_cast<Core::Stream::File&>(this->output_stream()).size())); if (!response->meta().is_empty()) { HashMap<String, String, CaseInsensitiveStringTraits> headers; headers.set("meta", response->meta()); @@ -42,16 +42,12 @@ GeminiRequest::GeminiRequest(ClientConnection& client, NonnullRefPtr<Gemini::Gem did_finish(success); }; m_job->on_progress = [this](Optional<u32> total, u32 current) { - did_progress(total, current); - }; - m_job->on_certificate_requested = [this](auto&) { - did_request_certificates(); + did_progress(move(total), current); }; } -void GeminiRequest::set_certificate(String certificate, String key) +void GeminiRequest::set_certificate(String, String) { - m_job->set_certificate(move(certificate), move(key)); } GeminiRequest::~GeminiRequest() @@ -61,7 +57,7 @@ GeminiRequest::~GeminiRequest() m_job->cancel(); } -NonnullOwnPtr<GeminiRequest> GeminiRequest::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +NonnullOwnPtr<GeminiRequest> GeminiRequest::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::Job> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) { return adopt_own(*new GeminiRequest(client, move(job), move(output_stream))); } diff --git a/Userland/Services/RequestServer/GeminiRequest.h b/Userland/Services/RequestServer/GeminiRequest.h index 0d9ffe4801..87c8fc6663 100644 --- a/Userland/Services/RequestServer/GeminiRequest.h +++ b/Userland/Services/RequestServer/GeminiRequest.h @@ -16,18 +16,18 @@ namespace RequestServer { class GeminiRequest final : public Request { public: virtual ~GeminiRequest() override; - static NonnullOwnPtr<GeminiRequest> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); + static NonnullOwnPtr<GeminiRequest> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::Job>, NonnullOwnPtr<Core::Stream::File>&&); - Gemini::GeminiJob const& job() const { return *m_job; } + Gemini::Job const& job() const { return *m_job; } virtual URL url() const override { return m_job->url(); } private: - explicit GeminiRequest(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); + explicit GeminiRequest(ClientConnection&, NonnullRefPtr<Gemini::Job>, NonnullOwnPtr<Core::Stream::File>&&); virtual void set_certificate(String certificate, String key) override; - NonnullRefPtr<Gemini::GeminiJob> m_job; + NonnullRefPtr<Gemini::Job> m_job; }; } diff --git a/Userland/Services/RequestServer/HttpCommon.h b/Userland/Services/RequestServer/HttpCommon.h index 301ccb77f3..797ddab245 100644 --- a/Userland/Services/RequestServer/HttpCommon.h +++ b/Userland/Services/RequestServer/HttpCommon.h @@ -36,7 +36,7 @@ void init(TSelf* self, TJob job) if (auto* response = self->job().response()) { self->set_status_code(response->code()); self->set_response_headers(response->headers()); - self->set_downloaded_size(self->output_stream().size()); + self->set_downloaded_size(response->downloaded_size()); } // if we didn't know the total size, pretend that the request finished successfully @@ -50,8 +50,12 @@ void init(TSelf* self, TJob job) self->did_progress(total, current); }; if constexpr (requires { job->on_certificate_requested; }) { - job->on_certificate_requested = [self](auto&) { + job->on_certificate_requested = [job, self] { self->did_request_certificates(); + Core::EventLoop::current().spin_until([&] { + return job->received_client_certificates(); + }); + return job->take_client_certificates(); }; } } @@ -79,8 +83,7 @@ OwnPtr<Request> start_request(TBadgedProtocol&& protocol, ClientConnection& clie return {}; request.set_body(allocated_body_result.release_value()); - auto output_stream = make<OutputFileStream>(pipe_result.value().write_fd); - output_stream->make_unbuffered(); + auto output_stream = MUST(Core::Stream::File::adopt_fd(pipe_result.value().write_fd, Core::Stream::OpenMode::Write)); auto job = TJob::construct(move(request), *output_stream); auto protocol_request = TRequest::create_with_job(forward<TBadgedProtocol>(protocol), client, (TJob&)*job, move(output_stream)); protocol_request->set_request_fd(pipe_result.value().read_fd); diff --git a/Userland/Services/RequestServer/HttpProtocol.h b/Userland/Services/RequestServer/HttpProtocol.h index 5641142e32..23f8680cab 100644 --- a/Userland/Services/RequestServer/HttpProtocol.h +++ b/Userland/Services/RequestServer/HttpProtocol.h @@ -11,7 +11,7 @@ #include <AK/OwnPtr.h> #include <AK/String.h> #include <AK/URL.h> -#include <LibHTTP/HttpJob.h> +#include <LibHTTP/Job.h> #include <RequestServer/ClientConnection.h> #include <RequestServer/HttpRequest.h> #include <RequestServer/Protocol.h> @@ -21,7 +21,7 @@ namespace RequestServer { class HttpProtocol final : public Protocol { public: - using JobType = HTTP::HttpJob; + using JobType = HTTP::Job; using RequestType = HttpRequest; HttpProtocol(); diff --git a/Userland/Services/RequestServer/HttpRequest.cpp b/Userland/Services/RequestServer/HttpRequest.cpp index a0f29f2e3c..2fc4f7fc73 100644 --- a/Userland/Services/RequestServer/HttpRequest.cpp +++ b/Userland/Services/RequestServer/HttpRequest.cpp @@ -4,14 +4,14 @@ * SPDX-License-Identifier: BSD-2-Clause */ -#include <LibHTTP/HttpJob.h> +#include <LibHTTP/Job.h> #include <RequestServer/HttpCommon.h> #include <RequestServer/HttpProtocol.h> #include <RequestServer/HttpRequest.h> namespace RequestServer { -HttpRequest::HttpRequest(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +HttpRequest::HttpRequest(ClientConnection& client, NonnullRefPtr<HTTP::Job> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) : Request(client, move(output_stream)) , m_job(job) { @@ -25,7 +25,7 @@ HttpRequest::~HttpRequest() m_job->cancel(); } -NonnullOwnPtr<HttpRequest> HttpRequest::create_with_job(Badge<HttpProtocol>&&, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +NonnullOwnPtr<HttpRequest> HttpRequest::create_with_job(Badge<HttpProtocol>&&, ClientConnection& client, NonnullRefPtr<HTTP::Job> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) { return adopt_own(*new HttpRequest(client, move(job), move(output_stream))); } diff --git a/Userland/Services/RequestServer/HttpRequest.h b/Userland/Services/RequestServer/HttpRequest.h index 3ef9bb206b..36f5fe1045 100644 --- a/Userland/Services/RequestServer/HttpRequest.h +++ b/Userland/Services/RequestServer/HttpRequest.h @@ -17,17 +17,17 @@ namespace RequestServer { class HttpRequest final : public Request { public: virtual ~HttpRequest() override; - static NonnullOwnPtr<HttpRequest> create_with_job(Badge<HttpProtocol>&&, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); + static NonnullOwnPtr<HttpRequest> create_with_job(Badge<HttpProtocol>&&, ClientConnection&, NonnullRefPtr<HTTP::Job>, NonnullOwnPtr<Core::Stream::File>&&); - HTTP::HttpJob& job() { return m_job; } - HTTP::HttpJob const& job() const { return m_job; } + HTTP::Job& job() { return m_job; } + HTTP::Job const& job() const { return m_job; } virtual URL url() const override { return m_job->url(); } private: - explicit HttpRequest(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); + explicit HttpRequest(ClientConnection&, NonnullRefPtr<HTTP::Job>, NonnullOwnPtr<Core::Stream::File>&&); - NonnullRefPtr<HTTP::HttpJob> m_job; + NonnullRefPtr<HTTP::Job> m_job; }; } diff --git a/Userland/Services/RequestServer/HttpsRequest.cpp b/Userland/Services/RequestServer/HttpsRequest.cpp index f3f537705c..6f0af9ba8b 100644 --- a/Userland/Services/RequestServer/HttpsRequest.cpp +++ b/Userland/Services/RequestServer/HttpsRequest.cpp @@ -11,7 +11,7 @@ namespace RequestServer { -HttpsRequest::HttpsRequest(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +HttpsRequest::HttpsRequest(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) : Request(client, move(output_stream)) , m_job(job) { @@ -30,7 +30,7 @@ HttpsRequest::~HttpsRequest() m_job->cancel(); } -NonnullOwnPtr<HttpsRequest> HttpsRequest::create_with_job(Badge<HttpsProtocol>&&, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) +NonnullOwnPtr<HttpsRequest> HttpsRequest::create_with_job(Badge<HttpsProtocol>&&, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<Core::Stream::File>&& output_stream) { return adopt_own(*new HttpsRequest(client, move(job), move(output_stream))); } diff --git a/Userland/Services/RequestServer/HttpsRequest.h b/Userland/Services/RequestServer/HttpsRequest.h index 51187368dd..5020be1e9b 100644 --- a/Userland/Services/RequestServer/HttpsRequest.h +++ b/Userland/Services/RequestServer/HttpsRequest.h @@ -16,7 +16,7 @@ namespace RequestServer { class HttpsRequest final : public Request { public: virtual ~HttpsRequest() override; - static NonnullOwnPtr<HttpsRequest> create_with_job(Badge<HttpsProtocol>&&, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); + static NonnullOwnPtr<HttpsRequest> create_with_job(Badge<HttpsProtocol>&&, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<Core::Stream::File>&&); HTTP::HttpsJob& job() { return m_job; } HTTP::HttpsJob const& job() const { return m_job; } @@ -24,7 +24,7 @@ public: virtual URL url() const override { return m_job->url(); } private: - explicit HttpsRequest(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); + explicit HttpsRequest(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<Core::Stream::File>&&); virtual void set_certificate(String certificate, String key) override; diff --git a/Userland/Services/RequestServer/Request.cpp b/Userland/Services/RequestServer/Request.cpp index c842018db2..012da64525 100644 --- a/Userland/Services/RequestServer/Request.cpp +++ b/Userland/Services/RequestServer/Request.cpp @@ -12,7 +12,7 @@ namespace RequestServer { // FIXME: What about rollover? static i32 s_next_id = 1; -Request::Request(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream) +Request::Request(ClientConnection& client, NonnullOwnPtr<Core::Stream::File>&& output_stream) : m_client(client) , m_id(s_next_id++) , m_output_stream(move(output_stream)) diff --git a/Userland/Services/RequestServer/Request.h b/Userland/Services/RequestServer/Request.h index 9e83e047ed..969039d167 100644 --- a/Userland/Services/RequestServer/Request.h +++ b/Userland/Services/RequestServer/Request.h @@ -41,10 +41,10 @@ public: void did_request_certificates(); void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&); void set_downloaded_size(size_t size) { m_downloaded_size = size; } - const OutputFileStream& output_stream() const { return *m_output_stream; } + const Core::Stream::File& output_stream() const { return *m_output_stream; } protected: - explicit Request(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&); + explicit Request(ClientConnection&, NonnullOwnPtr<Core::Stream::File>&&); private: ClientConnection& m_client; @@ -53,7 +53,7 @@ private: Optional<u32> m_status_code; Optional<u32> m_total_size {}; size_t m_downloaded_size { 0 }; - NonnullOwnPtr<OutputFileStream> m_output_stream; + NonnullOwnPtr<Core::Stream::File> m_output_stream; HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers; }; |