diff options
author | AnotherTest <ali.mpfard@gmail.com> | 2020-12-26 17:14:12 +0330 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2020-12-30 13:31:55 +0100 |
commit | 4a2da10e38c2b413c9e4db411e47d1b90d98d8ee (patch) | |
tree | 893aef9f1f760622e739183050a9172b717025f0 /Services | |
parent | 36d642ee7585801e1abe8a421b30f1a779be3bab (diff) | |
download | serenity-4a2da10e38c2b413c9e4db411e47d1b90d98d8ee.zip |
ProtocolServer: Stream the downloaded data if possible
This patchset makes ProtocolServer stream the downloads to its client
(LibProtocol), and as such changes the download API; a possible
download lifecycle could be as such:
notation = client->server:'>', server->client:'<', pipe activity:'*'
```
> StartDownload(GET, url, headers, {})
< Response(0, fd 8)
* {data, 1024b}
< HeadersBecameAvailable(0, response_headers, 200)
< DownloadProgress(0, 4K, 1024)
* {data, 1024b}
* {data, 1024b}
< DownloadProgress(0, 4K, 2048)
* {data, 1024b}
< DownloadProgress(0, 4K, 1024)
< DownloadFinished(0, true, 4K)
```
Since managing the received file descriptor is a pain, LibProtocol
implements `Download::stream_into(OutputStream)`, which can be used to
stream the download into any given output stream (be it a file, or
memory, or writing stuff with a delay, etc.).
Also, as some of the users of this API require all the downloaded data
upfront, LibProtocol also implements `set_should_buffer_all_input()`,
which causes the download instance to buffer all the data until the
download is complete, and to call the `on_buffered_download_finish`
hook.
Diffstat (limited to 'Services')
20 files changed, 117 insertions, 81 deletions
diff --git a/Services/ProtocolServer/ClientConnection.cpp b/Services/ProtocolServer/ClientConnection.cpp index 93f02ae2e6..eee7277869 100644 --- a/Services/ProtocolServer/ClientConnection.cpp +++ b/Services/ProtocolServer/ClientConnection.cpp @@ -62,16 +62,17 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> ClientConnection::handle { URL url(message.url()); if (!url.is_valid()) - return make<Messages::ProtocolServer::StartDownloadResponse>(-1); + return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); auto* protocol = Protocol::find_by_name(url.protocol()); if (!protocol) - return make<Messages::ProtocolServer::StartDownloadResponse>(-1); - auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body().to_byte_buffer()); + return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); + auto download = protocol->start_download(*this, message.method(), url, message.request_headers().entries(), message.request_body()); if (!download) - return make<Messages::ProtocolServer::StartDownloadResponse>(-1); + return make<Messages::ProtocolServer::StartDownloadResponse>(-1, -1); auto id = download->id(); + auto fd = download->download_fd(); m_downloads.set(id, move(download)); - return make<Messages::ProtocolServer::StartDownloadResponse>(id); + return make<Messages::ProtocolServer::StartDownloadResponse>(id, fd); } OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle(const Messages::ProtocolServer::StopDownload& message) @@ -86,22 +87,20 @@ OwnPtr<Messages::ProtocolServer::StopDownloadResponse> ClientConnection::handle( return make<Messages::ProtocolServer::StopDownloadResponse>(success); } -void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) +void ClientConnection::did_receive_headers(Badge<Download>, Download& download) { - RefPtr<SharedBuffer> buffer; - if (success && download.payload().size() > 0 && !download.payload().is_null()) { - buffer = SharedBuffer::create_with_size(download.payload().size()); - memcpy(buffer->data<void>(), download.payload().data(), download.payload().size()); - buffer->seal(); - buffer->share_with(client_pid()); - m_shared_buffers.set(buffer->shbuf_id(), buffer); - } - ASSERT(download.total_size().has_value()); - IPC::Dictionary response_headers; for (auto& it : download.response_headers()) response_headers.add(it.key, it.value); - post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.status_code(), download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers)); + + post_message(Messages::ProtocolClient::HeadersBecameAvailable(download.id(), move(response_headers), download.status_code())); +} + +void ClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) +{ + ASSERT(download.total_size().has_value()); + + post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value())); m_downloads.remove(download.id()); } @@ -121,12 +120,6 @@ OwnPtr<Messages::ProtocolServer::GreetResponse> ClientConnection::handle(const M return make<Messages::ProtocolServer::GreetResponse>(client_id()); } -OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> ClientConnection::handle(const Messages::ProtocolServer::DisownSharedBuffer& message) -{ - m_shared_buffers.remove(message.shbuf_id()); - return make<Messages::ProtocolServer::DisownSharedBufferResponse>(); -} - OwnPtr<Messages::ProtocolServer::SetCertificateResponse> ClientConnection::handle(const Messages::ProtocolServer::SetCertificate& message) { auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr)); diff --git a/Services/ProtocolServer/ClientConnection.h b/Services/ProtocolServer/ClientConnection.h index 4439fdbd87..778f3eff81 100644 --- a/Services/ProtocolServer/ClientConnection.h +++ b/Services/ProtocolServer/ClientConnection.h @@ -45,6 +45,7 @@ public: virtual void die() override; + void did_receive_headers(Badge<Download>, Download&); void did_finish_download(Badge<Download>, Download&, bool success); void did_progress_download(Badge<Download>, Download&); void did_request_certificates(Badge<Download>, Download&); @@ -54,11 +55,9 @@ private: virtual OwnPtr<Messages::ProtocolServer::IsSupportedProtocolResponse> handle(const Messages::ProtocolServer::IsSupportedProtocol&) override; virtual OwnPtr<Messages::ProtocolServer::StartDownloadResponse> handle(const Messages::ProtocolServer::StartDownload&) override; virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override; - virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override; - virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&); + virtual OwnPtr<Messages::ProtocolServer::SetCertificateResponse> handle(const Messages::ProtocolServer::SetCertificate&) override; HashMap<i32, OwnPtr<Download>> m_downloads; - HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers; }; } diff --git a/Services/ProtocolServer/Download.cpp b/Services/ProtocolServer/Download.cpp index d0d9aa2ab8..11a7d28933 100644 --- a/Services/ProtocolServer/Download.cpp +++ b/Services/ProtocolServer/Download.cpp @@ -33,9 +33,10 @@ namespace ProtocolServer { // FIXME: What about rollover? static i32 s_next_id = 1; -Download::Download(ClientConnection& client) +Download::Download(ClientConnection& client, NonnullOwnPtr<OutputFileStream>&& output_stream) : m_client(client) , m_id(s_next_id++) + , m_output_stream(move(output_stream)) { } @@ -48,15 +49,10 @@ void Download::stop() m_client.did_finish_download({}, *this, false); } -void Download::set_payload(const ByteBuffer& payload) -{ - m_payload = payload; - m_total_size = payload.size(); -} - void Download::set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers) { m_response_headers = response_headers; + m_client.did_receive_headers({}, *this); } void Download::set_certificate(String, String) diff --git a/Services/ProtocolServer/Download.h b/Services/ProtocolServer/Download.h index f0d0342006..35c60269f4 100644 --- a/Services/ProtocolServer/Download.h +++ b/Services/ProtocolServer/Download.h @@ -26,8 +26,9 @@ #pragma once -#include <AK/ByteBuffer.h> +#include <AK/FileStream.h> #include <AK/HashMap.h> +#include <AK/NonnullOwnPtr.h> #include <AK/Optional.h> #include <AK/RefCounted.h> #include <AK/URL.h> @@ -45,30 +46,35 @@ public: Optional<u32> status_code() const { return m_status_code; } Optional<u32> total_size() const { return m_total_size; } size_t downloaded_size() const { return m_downloaded_size; } - const ByteBuffer& payload() const { return m_payload; } const HashMap<String, String, CaseInsensitiveStringTraits>& response_headers() const { return m_response_headers; } void stop(); virtual void set_certificate(String, String); + // FIXME: Want Badge<Protocol>, but can't make one from HttpProtocol, etc. + void set_download_fd(int fd) { m_download_fd = fd; } + int download_fd() const { return m_download_fd; } + protected: - explicit Download(ClientConnection&); + explicit Download(ClientConnection&, NonnullOwnPtr<OutputFileStream>&&); void did_finish(bool success); void did_progress(Optional<u32> total_size, u32 downloaded_size); void set_status_code(u32 status_code) { m_status_code = status_code; } void did_request_certificates(); - void set_payload(const ByteBuffer&); void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&); + void set_downloaded_size(size_t size) { m_downloaded_size = size; } + const OutputFileStream& output_stream() const { return *m_output_stream; } private: ClientConnection& m_client; i32 m_id { 0 }; + int m_download_fd { -1 }; // Passed to client. URL m_url; Optional<u32> m_status_code; Optional<u32> m_total_size {}; size_t m_downloaded_size { 0 }; - ByteBuffer m_payload; + NonnullOwnPtr<OutputFileStream> m_output_stream; HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers; }; diff --git a/Services/ProtocolServer/GeminiDownload.cpp b/Services/ProtocolServer/GeminiDownload.cpp index a504aaca7f..0bba75519d 100644 --- a/Services/ProtocolServer/GeminiDownload.cpp +++ b/Services/ProtocolServer/GeminiDownload.cpp @@ -30,13 +30,13 @@ namespace ProtocolServer { -GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) - : Download(client) +GeminiDownload::GeminiDownload(ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { - set_payload(response->payload()); + set_downloaded_size(this->output_stream().size()); if (!response->meta().is_empty()) { HashMap<String, String, CaseInsensitiveStringTraits> headers; headers.set("meta", response->meta()); @@ -76,9 +76,9 @@ GeminiDownload::~GeminiDownload() m_job->shutdown(); } -NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job) +NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, ClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) { - return adopt_own(*new GeminiDownload(client, move(job))); + return adopt_own(*new GeminiDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/GeminiDownload.h b/Services/ProtocolServer/GeminiDownload.h index c429bac7b1..fcd81c121f 100644 --- a/Services/ProtocolServer/GeminiDownload.h +++ b/Services/ProtocolServer/GeminiDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class GeminiDownload final : public Download { public: virtual ~GeminiDownload() override; - static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); + static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); private: - explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>); + explicit GeminiDownload(ClientConnection&, NonnullRefPtr<Gemini::GeminiJob>, NonnullOwnPtr<OutputFileStream>&&); virtual void set_certificate(String certificate, String key) override; diff --git a/Services/ProtocolServer/GeminiProtocol.cpp b/Services/ProtocolServer/GeminiProtocol.cpp index f1167ecd61..ff4380de7f 100644 --- a/Services/ProtocolServer/GeminiProtocol.cpp +++ b/Services/ProtocolServer/GeminiProtocol.cpp @@ -40,12 +40,22 @@ GeminiProtocol::~GeminiProtocol() { } -OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, const ByteBuffer&) +OwnPtr<Download> GeminiProtocol::start_download(ClientConnection& client, const String&, const URL& url, const HashMap<String, String>&, ReadonlyBytes) { Gemini::GeminiRequest request; request.set_url(url); - auto job = Gemini::GeminiJob::construct(request); - auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + auto output_stream = make<OutputFileStream>(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = Gemini::GeminiJob::construct(request, *output_stream); + auto download = GeminiDownload::create_with_job({}, client, (Gemini::GeminiJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/GeminiProtocol.h b/Services/ProtocolServer/GeminiProtocol.h index f9ed21cca3..23a4d7c717 100644 --- a/Services/ProtocolServer/GeminiProtocol.h +++ b/Services/ProtocolServer/GeminiProtocol.h @@ -35,7 +35,7 @@ public: GeminiProtocol(); virtual ~GeminiProtocol() override; - virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, const ByteBuffer& request_body) override; + virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>&, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/HttpDownload.cpp b/Services/ProtocolServer/HttpDownload.cpp index bfa22351d3..8ba945d335 100644 --- a/Services/ProtocolServer/HttpDownload.cpp +++ b/Services/ProtocolServer/HttpDownload.cpp @@ -30,15 +30,21 @@ namespace ProtocolServer { -HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) - : Download(client) +HttpDownload::HttpDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { + m_job->on_headers_received = [this](auto& headers, auto response_code) { + if (response_code.has_value()) + set_status_code(response_code.value()); + set_response_headers(headers); + }; + m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { set_status_code(response->code()); - set_payload(response->payload()); set_response_headers(response->headers()); + set_downloaded_size(this->output_stream().size()); } // if we didn't know the total size, pretend that the download finished successfully @@ -60,9 +66,9 @@ HttpDownload::~HttpDownload() m_job->shutdown(); } -NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job) +NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) { - return adopt_own(*new HttpDownload(client, move(job))); + return adopt_own(*new HttpDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/HttpDownload.h b/Services/ProtocolServer/HttpDownload.h index d0d745ef0c..50095bd0e5 100644 --- a/Services/ProtocolServer/HttpDownload.h +++ b/Services/ProtocolServer/HttpDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class HttpDownload final : public Download { public: virtual ~HttpDownload() override; - static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); + static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); private: - explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>); + explicit HttpDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpJob>, NonnullOwnPtr<OutputFileStream>&&); NonnullRefPtr<HTTP::HttpJob> m_job; }; diff --git a/Services/ProtocolServer/HttpProtocol.cpp b/Services/ProtocolServer/HttpProtocol.cpp index b0e74e766a..e8c7a3203c 100644 --- a/Services/ProtocolServer/HttpProtocol.cpp +++ b/Services/ProtocolServer/HttpProtocol.cpp @@ -28,6 +28,7 @@ #include <LibHTTP/HttpRequest.h> #include <ProtocolServer/HttpDownload.h> #include <ProtocolServer/HttpProtocol.h> +#include <fcntl.h> namespace ProtocolServer { @@ -40,7 +41,7 @@ HttpProtocol::~HttpProtocol() { } -OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) +OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body) { HTTP::HttpRequest request; if (method.equals_ignoring_case("post")) @@ -49,9 +50,20 @@ OwnPtr<Download> HttpProtocol::start_download(ClientConnection& client, const St request.set_method(HTTP::HttpRequest::Method::GET); request.set_url(url); request.set_headers(headers); - request.set_body(request_body); - auto job = HTTP::HttpJob::construct(request); - auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job); + request.set_body(body); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + + auto output_stream = make<OutputFileStream>(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = HTTP::HttpJob::construct(request, *output_stream); + auto download = HttpDownload::create_with_job({}, client, (HTTP::HttpJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/HttpProtocol.h b/Services/ProtocolServer/HttpProtocol.h index aa9601b8ce..8c4a564f37 100644 --- a/Services/ProtocolServer/HttpProtocol.h +++ b/Services/ProtocolServer/HttpProtocol.h @@ -35,7 +35,7 @@ public: HttpProtocol(); virtual ~HttpProtocol() override; - virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; + virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/HttpsDownload.cpp b/Services/ProtocolServer/HttpsDownload.cpp index fe381d216e..991dd730d8 100644 --- a/Services/ProtocolServer/HttpsDownload.cpp +++ b/Services/ProtocolServer/HttpsDownload.cpp @@ -30,15 +30,21 @@ namespace ProtocolServer { -HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) - : Download(client) +HttpsDownload::HttpsDownload(ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) + : Download(client, move(output_stream)) , m_job(job) { + m_job->on_headers_received = [this](auto& headers, auto response_code) { + if (response_code.has_value()) + set_status_code(response_code.value()); + set_response_headers(headers); + }; + m_job->on_finish = [this](bool success) { if (auto* response = m_job->response()) { set_status_code(response->code()); - set_payload(response->payload()); set_response_headers(response->headers()); + set_downloaded_size(this->output_stream().size()); } // if we didn't know the total size, pretend that the download finished successfully @@ -68,9 +74,9 @@ HttpsDownload::~HttpsDownload() m_job->shutdown(); } -NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job) +NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, ClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job, NonnullOwnPtr<OutputFileStream>&& output_stream) { - return adopt_own(*new HttpsDownload(client, move(job))); + return adopt_own(*new HttpsDownload(client, move(job), move(output_stream))); } } diff --git a/Services/ProtocolServer/HttpsDownload.h b/Services/ProtocolServer/HttpsDownload.h index 48f255b2fa..254172b3f7 100644 --- a/Services/ProtocolServer/HttpsDownload.h +++ b/Services/ProtocolServer/HttpsDownload.h @@ -36,10 +36,10 @@ namespace ProtocolServer { class HttpsDownload final : public Download { public: virtual ~HttpsDownload() override; - static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); + static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); private: - explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>); + explicit HttpsDownload(ClientConnection&, NonnullRefPtr<HTTP::HttpsJob>, NonnullOwnPtr<OutputFileStream>&&); virtual void set_certificate(String certificate, String key) override; diff --git a/Services/ProtocolServer/HttpsProtocol.cpp b/Services/ProtocolServer/HttpsProtocol.cpp index 3de9ca8e2b..e34ff32422 100644 --- a/Services/ProtocolServer/HttpsProtocol.cpp +++ b/Services/ProtocolServer/HttpsProtocol.cpp @@ -40,7 +40,7 @@ HttpsProtocol::~HttpsProtocol() { } -OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, const ByteBuffer& request_body) +OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const String& method, const URL& url, const HashMap<String, String>& headers, ReadonlyBytes body) { HTTP::HttpRequest request; if (method.equals_ignoring_case("post")) @@ -49,9 +49,19 @@ OwnPtr<Download> HttpsProtocol::start_download(ClientConnection& client, const S request.set_method(HTTP::HttpRequest::Method::GET); request.set_url(url); request.set_headers(headers); - request.set_body(request_body); - auto job = HTTP::HttpsJob::construct(request); - auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job); + request.set_body(body); + + int fd_pair[2] { 0 }; + if (pipe(fd_pair) != 0) { + auto saved_errno = errno; + dbgln("Protocol: pipe() failed: {}", strerror(saved_errno)); + return nullptr; + } + auto output_stream = make<OutputFileStream>(fd_pair[1]); + output_stream->make_unbuffered(); + auto job = HTTP::HttpsJob::construct(request, *output_stream); + auto download = HttpsDownload::create_with_job({}, client, (HTTP::HttpsJob&)*job, move(output_stream)); + download->set_download_fd(fd_pair[0]); job->start(); return download; } diff --git a/Services/ProtocolServer/HttpsProtocol.h b/Services/ProtocolServer/HttpsProtocol.h index 9cb0ce190b..40e59aa271 100644 --- a/Services/ProtocolServer/HttpsProtocol.h +++ b/Services/ProtocolServer/HttpsProtocol.h @@ -35,7 +35,7 @@ public: HttpsProtocol(); virtual ~HttpsProtocol() override; - virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) override; + virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) override; }; } diff --git a/Services/ProtocolServer/Protocol.h b/Services/ProtocolServer/Protocol.h index 035b56cb6e..609362f548 100644 --- a/Services/ProtocolServer/Protocol.h +++ b/Services/ProtocolServer/Protocol.h @@ -37,7 +37,7 @@ public: virtual ~Protocol(); const String& name() const { return m_name; } - virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, const ByteBuffer& request_body) = 0; + virtual OwnPtr<Download> start_download(ClientConnection&, const String& method, const URL&, const HashMap<String, String>& headers, ReadonlyBytes body) = 0; static Protocol* find_by_name(const String&); diff --git a/Services/ProtocolServer/ProtocolClient.ipc b/Services/ProtocolServer/ProtocolClient.ipc index ef00d760ce..88f4cfc96d 100644 --- a/Services/ProtocolServer/ProtocolClient.ipc +++ b/Services/ProtocolServer/ProtocolClient.ipc @@ -2,7 +2,8 @@ endpoint ProtocolClient = 13 { // Download notifications DownloadProgress(i32 download_id, Optional<u32> total_size, u32 downloaded_size) =| - DownloadFinished(i32 download_id, bool success, Optional<u32> status_code, u32 total_size, i32 shbuf_id, IPC::Dictionary response_headers) =| + DownloadFinished(i32 download_id, bool success, u32 total_size) =| + HeadersBecameAvailable(i32 download_id, IPC::Dictionary response_headers, Optional<u32> status_code) =| // Certificate requests CertificateRequested(i32 download_id) => () diff --git a/Services/ProtocolServer/ProtocolServer.ipc b/Services/ProtocolServer/ProtocolServer.ipc index 4cf1204520..0707afb733 100644 --- a/Services/ProtocolServer/ProtocolServer.ipc +++ b/Services/ProtocolServer/ProtocolServer.ipc @@ -3,14 +3,11 @@ endpoint ProtocolServer = 9 // Basic protocol Greet() => (i32 client_id) - // FIXME: It would be nice if the kernel provided a way to avoid this - DisownSharedBuffer(i32 shbuf_id) => () - // Test if a specific protocol is supported, e.g "http" IsSupportedProtocol(String protocol) => (bool supported) // Download API - StartDownload(String method, URL url, IPC::Dictionary request_headers, String request_body) => (i32 download_id) + StartDownload(String method, URL url, IPC::Dictionary request_headers, ByteBuffer request_body) => (i32 download_id, IPC::File response_fd) StopDownload(i32 download_id) => (bool success) SetCertificate(i32 download_id, String certificate, String key) => (bool success) } diff --git a/Services/ProtocolServer/main.cpp b/Services/ProtocolServer/main.cpp index 765dbe5056..62fc908dd1 100644 --- a/Services/ProtocolServer/main.cpp +++ b/Services/ProtocolServer/main.cpp @@ -35,7 +35,7 @@ int main(int, char**) { - if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr", nullptr) < 0) { + if (pledge("stdio inet shared_buffer accept unix rpath cpath fattr sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } @@ -45,7 +45,7 @@ int main(int, char**) Core::EventLoop event_loop; // FIXME: Establish a connection to LookupServer and then drop "unix"? - if (pledge("stdio inet shared_buffer accept unix", nullptr) < 0) { + if (pledge("stdio inet shared_buffer accept unix sendfd recvfd", nullptr) < 0) { perror("pledge"); return 1; } |