summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorConrad Pankoff <deoxxa@fknsrs.biz>2020-05-16 05:36:40 +1000
committerAndreas Kling <kling@serenityos.org>2020-05-17 12:41:38 +0200
commitf2621f37a4ce024a49e41755ea12d826e991fc1a (patch)
tree8d4648908e63e69e42a58c7dcc127c1b5bb9f4a0
parent184ee8ac77457e7e5bae07a3cb48f872e2063871 (diff)
downloadserenity-f2621f37a4ce024a49e41755ea12d826e991fc1a.zip
ProtocolServer: Attach downloads and their lifecycles to clients
Previously a download lived independently of the client connection it came from. This was the source of several undesirable behaviours, including the potential for clients to influence downloads they didn't start, and downloads living longer than their associated client connections. Now we attach downloads to client connections, which means they're cleaned up automatically when the client goes away, and there's significantly less risk of clients interfering with each other.
-rw-r--r--Services/ProtocolServer/Download.cpp32
-rw-r--r--Services/ProtocolServer/Download.h9
-rw-r--r--Services/ProtocolServer/GeminiDownload.cpp9
-rw-r--r--Services/ProtocolServer/GeminiDownload.h4
-rw-r--r--Services/ProtocolServer/GeminiProtocol.cpp2
-rw-r--r--Services/ProtocolServer/GeminiProtocol.h2
-rw-r--r--Services/ProtocolServer/HttpDownload.cpp9
-rw-r--r--Services/ProtocolServer/HttpDownload.h4
-rw-r--r--Services/ProtocolServer/HttpProtocol.cpp2
-rw-r--r--Services/ProtocolServer/HttpProtocol.h2
-rw-r--r--Services/ProtocolServer/HttpsDownload.cpp9
-rw-r--r--Services/ProtocolServer/HttpsDownload.h4
-rw-r--r--Services/ProtocolServer/HttpsProtocol.cpp2
-rw-r--r--Services/ProtocolServer/HttpsProtocol.h2
-rw-r--r--Services/ProtocolServer/PSClientConnection.cpp10
-rw-r--r--Services/ProtocolServer/PSClientConnection.h1
-rw-r--r--Services/ProtocolServer/Protocol.h2
17 files changed, 48 insertions, 57 deletions
diff --git a/Services/ProtocolServer/Download.cpp b/Services/ProtocolServer/Download.cpp
index e205d38a55..6d510374d3 100644
--- a/Services/ProtocolServer/Download.cpp
+++ b/Services/ProtocolServer/Download.cpp
@@ -31,22 +31,10 @@
// FIXME: What about rollover?
static i32 s_next_id = 1;
-static HashMap<i32, RefPtr<Download>>& all_downloads()
-{
- static HashMap<i32, RefPtr<Download>> map;
- return map;
-}
-
-Download* Download::find_by_id(i32 id)
-{
- return const_cast<Download*>(all_downloads().get(id).value_or(nullptr));
-}
-
Download::Download(PSClientConnection& client)
- : m_id(s_next_id++)
- , m_client(client.make_weak_ptr())
+ : m_client(client)
+ , m_id(s_next_id++)
{
- all_downloads().set(m_id, this);
}
Download::~Download()
@@ -55,7 +43,7 @@ Download::~Download()
void Download::stop()
{
- all_downloads().remove(m_id);
+ m_client.did_finish_download({}, *this, false);
}
void Download::set_payload(const ByteBuffer& payload)
@@ -71,22 +59,12 @@ void Download::set_response_headers(const HashMap<String, String, CaseInsensitiv
void Download::did_finish(bool success)
{
- if (!m_client) {
- dbg() << "Download::did_finish() after the client already disconnected.";
- return;
- }
- m_client->did_finish_download({}, *this, success);
- all_downloads().remove(m_id);
+ m_client.did_finish_download({}, *this, success);
}
void Download::did_progress(Optional<u32> total_size, u32 downloaded_size)
{
- if (!m_client) {
- // FIXME: We should also abort the download in this situation, I guess!
- dbg() << "Download::did_progress() after the client already disconnected.";
- return;
- }
m_total_size = total_size;
m_downloaded_size = downloaded_size;
- m_client->did_progress_download({}, *this);
+ m_client.did_progress_download({}, *this);
}
diff --git a/Services/ProtocolServer/Download.h b/Services/ProtocolServer/Download.h
index 0796765482..08b0605e03 100644
--- a/Services/ProtocolServer/Download.h
+++ b/Services/ProtocolServer/Download.h
@@ -31,16 +31,13 @@
#include <AK/Optional.h>
#include <AK/RefCounted.h>
#include <AK/URL.h>
-#include <AK/WeakPtr.h>
class PSClientConnection;
-class Download : public RefCounted<Download> {
+class Download {
public:
virtual ~Download();
- static Download* find_by_id(i32);
-
i32 id() const { return m_id; }
URL url() const { return m_url; }
@@ -60,11 +57,11 @@ protected:
void set_response_headers(const HashMap<String, String, CaseInsensitiveStringTraits>&);
private:
- i32 m_id;
+ PSClientConnection& m_client;
+ i32 m_id { 0 };
URL m_url;
Optional<u32> m_total_size {};
size_t m_downloaded_size { 0 };
ByteBuffer m_payload;
HashMap<String, String, CaseInsensitiveStringTraits> m_response_headers;
- WeakPtr<PSClientConnection> m_client;
};
diff --git a/Services/ProtocolServer/GeminiDownload.cpp b/Services/ProtocolServer/GeminiDownload.cpp
index a08945830c..54114bcd77 100644
--- a/Services/ProtocolServer/GeminiDownload.cpp
+++ b/Services/ProtocolServer/GeminiDownload.cpp
@@ -28,7 +28,7 @@
#include <LibGemini/GeminiJob.h>
#include <ProtocolServer/GeminiDownload.h>
-GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
+GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
: Download(client)
, m_job(job)
{
@@ -55,9 +55,12 @@ GeminiDownload::GeminiDownload(PSClientConnection& client, NonnullRefPtr<Gemini:
GeminiDownload::~GeminiDownload()
{
+ m_job->on_finish = nullptr;
+ m_job->on_progress = nullptr;
+ m_job->shutdown();
}
-NonnullRefPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob>&& job)
+NonnullOwnPtr<GeminiDownload> GeminiDownload::create_with_job(Badge<GeminiProtocol>, PSClientConnection& client, NonnullRefPtr<Gemini::GeminiJob> job)
{
- return adopt(*new GeminiDownload(client, move(job)));
+ return adopt_own(*new GeminiDownload(client, move(job)));
}
diff --git a/Services/ProtocolServer/GeminiDownload.h b/Services/ProtocolServer/GeminiDownload.h
index a5ef599051..cc7ed4f242 100644
--- a/Services/ProtocolServer/GeminiDownload.h
+++ b/Services/ProtocolServer/GeminiDownload.h
@@ -36,10 +36,10 @@ class GeminiProtocol;
class GeminiDownload final : public Download {
public:
virtual ~GeminiDownload() override;
- static NonnullRefPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
+ static NonnullOwnPtr<GeminiDownload> create_with_job(Badge<GeminiProtocol>, PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
private:
- explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>&&);
+ explicit GeminiDownload(PSClientConnection&, NonnullRefPtr<Gemini::GeminiJob>);
NonnullRefPtr<Gemini::GeminiJob> m_job;
};
diff --git a/Services/ProtocolServer/GeminiProtocol.cpp b/Services/ProtocolServer/GeminiProtocol.cpp
index 8eda72eb5b..0bc413a463 100644
--- a/Services/ProtocolServer/GeminiProtocol.cpp
+++ b/Services/ProtocolServer/GeminiProtocol.cpp
@@ -38,7 +38,7 @@ GeminiProtocol::~GeminiProtocol()
{
}
-RefPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
+OwnPtr<Download> GeminiProtocol::start_download(PSClientConnection& client, const URL& url)
{
Gemini::GeminiRequest request;
request.set_url(url);
diff --git a/Services/ProtocolServer/GeminiProtocol.h b/Services/ProtocolServer/GeminiProtocol.h
index d15e64e327..9d3066e9b7 100644
--- a/Services/ProtocolServer/GeminiProtocol.h
+++ b/Services/ProtocolServer/GeminiProtocol.h
@@ -33,5 +33,5 @@ public:
GeminiProtocol();
virtual ~GeminiProtocol() override;
- virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
+ virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
diff --git a/Services/ProtocolServer/HttpDownload.cpp b/Services/ProtocolServer/HttpDownload.cpp
index 5ffa30362b..5bfadefb6d 100644
--- a/Services/ProtocolServer/HttpDownload.cpp
+++ b/Services/ProtocolServer/HttpDownload.cpp
@@ -28,7 +28,7 @@
#include <LibHTTP/HttpResponse.h>
#include <ProtocolServer/HttpDownload.h>
-HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
+HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
: Download(client)
, m_job(job)
{
@@ -52,9 +52,12 @@ HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpJ
HttpDownload::~HttpDownload()
{
+ m_job->on_finish = nullptr;
+ m_job->on_progress = nullptr;
+ m_job->shutdown();
}
-NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob>&& job)
+NonnullOwnPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpJob> job)
{
- return adopt(*new HttpDownload(client, move(job)));
+ return adopt_own(*new HttpDownload(client, move(job)));
}
diff --git a/Services/ProtocolServer/HttpDownload.h b/Services/ProtocolServer/HttpDownload.h
index 364fe6aef2..49da391ed9 100644
--- a/Services/ProtocolServer/HttpDownload.h
+++ b/Services/ProtocolServer/HttpDownload.h
@@ -36,10 +36,10 @@ class HttpProtocol;
class HttpDownload final : public Download {
public:
virtual ~HttpDownload() override;
- static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
+ static NonnullOwnPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
private:
- explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>&&);
+ explicit HttpDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpJob>);
NonnullRefPtr<HTTP::HttpJob> m_job;
};
diff --git a/Services/ProtocolServer/HttpProtocol.cpp b/Services/ProtocolServer/HttpProtocol.cpp
index 1513a32e66..5e2f4803a7 100644
--- a/Services/ProtocolServer/HttpProtocol.cpp
+++ b/Services/ProtocolServer/HttpProtocol.cpp
@@ -38,7 +38,7 @@ HttpProtocol::~HttpProtocol()
{
}
-RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
+OwnPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url)
{
HTTP::HttpRequest request;
request.set_method(HTTP::HttpRequest::Method::GET);
diff --git a/Services/ProtocolServer/HttpProtocol.h b/Services/ProtocolServer/HttpProtocol.h
index ea6c15d2a4..0c6c391931 100644
--- a/Services/ProtocolServer/HttpProtocol.h
+++ b/Services/ProtocolServer/HttpProtocol.h
@@ -33,5 +33,5 @@ public:
HttpProtocol();
virtual ~HttpProtocol() override;
- virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
+ virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
diff --git a/Services/ProtocolServer/HttpsDownload.cpp b/Services/ProtocolServer/HttpsDownload.cpp
index 8e068c2b04..6a629b1bbc 100644
--- a/Services/ProtocolServer/HttpsDownload.cpp
+++ b/Services/ProtocolServer/HttpsDownload.cpp
@@ -28,7 +28,7 @@
#include <LibHTTP/HttpsJob.h>
#include <ProtocolServer/HttpsDownload.h>
-HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
+HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
: Download(client)
, m_job(job)
{
@@ -52,9 +52,12 @@ HttpsDownload::HttpsDownload(PSClientConnection& client, NonnullRefPtr<HTTP::Htt
HttpsDownload::~HttpsDownload()
{
+ m_job->on_finish = nullptr;
+ m_job->on_progress = nullptr;
+ m_job->shutdown();
}
-NonnullRefPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob>&& job)
+NonnullOwnPtr<HttpsDownload> HttpsDownload::create_with_job(Badge<HttpsProtocol>, PSClientConnection& client, NonnullRefPtr<HTTP::HttpsJob> job)
{
- return adopt(*new HttpsDownload(client, move(job)));
+ return adopt_own(*new HttpsDownload(client, move(job)));
}
diff --git a/Services/ProtocolServer/HttpsDownload.h b/Services/ProtocolServer/HttpsDownload.h
index a6c75bc4d7..8a5b6aceb2 100644
--- a/Services/ProtocolServer/HttpsDownload.h
+++ b/Services/ProtocolServer/HttpsDownload.h
@@ -36,10 +36,10 @@ class HttpsProtocol;
class HttpsDownload final : public Download {
public:
virtual ~HttpsDownload() override;
- static NonnullRefPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
+ static NonnullOwnPtr<HttpsDownload> create_with_job(Badge<HttpsProtocol>, PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
private:
- explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>&&);
+ explicit HttpsDownload(PSClientConnection&, NonnullRefPtr<HTTP::HttpsJob>);
NonnullRefPtr<HTTP::HttpsJob> m_job;
};
diff --git a/Services/ProtocolServer/HttpsProtocol.cpp b/Services/ProtocolServer/HttpsProtocol.cpp
index 07020affe5..8796741c08 100644
--- a/Services/ProtocolServer/HttpsProtocol.cpp
+++ b/Services/ProtocolServer/HttpsProtocol.cpp
@@ -38,7 +38,7 @@ HttpsProtocol::~HttpsProtocol()
{
}
-RefPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
+OwnPtr<Download> HttpsProtocol::start_download(PSClientConnection& client, const URL& url)
{
HTTP::HttpRequest request;
request.set_method(HTTP::HttpRequest::Method::GET);
diff --git a/Services/ProtocolServer/HttpsProtocol.h b/Services/ProtocolServer/HttpsProtocol.h
index e446178751..ca6f8dafd2 100644
--- a/Services/ProtocolServer/HttpsProtocol.h
+++ b/Services/ProtocolServer/HttpsProtocol.h
@@ -33,5 +33,5 @@ public:
HttpsProtocol();
virtual ~HttpsProtocol() override;
- virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override;
+ virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) override;
};
diff --git a/Services/ProtocolServer/PSClientConnection.cpp b/Services/ProtocolServer/PSClientConnection.cpp
index 1d07fdb8b2..bb0c5c2940 100644
--- a/Services/ProtocolServer/PSClientConnection.cpp
+++ b/Services/ProtocolServer/PSClientConnection.cpp
@@ -63,12 +63,16 @@ OwnPtr<Messages::ProtocolServer::StartDownloadResponse> PSClientConnection::hand
if (!protocol)
return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
auto download = protocol->start_download(*this, url);
- return make<Messages::ProtocolServer::StartDownloadResponse>(download->id());
+ if (!download)
+ return make<Messages::ProtocolServer::StartDownloadResponse>(-1);
+ auto id = download->id();
+ m_downloads.set(id, move(download));
+ return make<Messages::ProtocolServer::StartDownloadResponse>(id);
}
OwnPtr<Messages::ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const Messages::ProtocolServer::StopDownload& message)
{
- auto* download = Download::find_by_id(message.download_id());
+ auto* download = const_cast<Download*>(m_downloads.get(message.download_id()).value_or(nullptr));
bool success = false;
if (download) {
download->stop();
@@ -93,6 +97,8 @@ void PSClientConnection::did_finish_download(Badge<Download>, Download& download
for (auto& it : download.response_headers())
response_headers.add(it.key, it.value);
post_message(Messages::ProtocolClient::DownloadFinished(download.id(), success, download.total_size().value(), buffer ? buffer->shbuf_id() : -1, response_headers));
+
+ m_downloads.remove(download.id());
}
void PSClientConnection::did_progress_download(Badge<Download>, Download& download)
diff --git a/Services/ProtocolServer/PSClientConnection.h b/Services/ProtocolServer/PSClientConnection.h
index 56a7c63c0c..8bb0fd1aa5 100644
--- a/Services/ProtocolServer/PSClientConnection.h
+++ b/Services/ProtocolServer/PSClientConnection.h
@@ -51,5 +51,6 @@ private:
virtual OwnPtr<Messages::ProtocolServer::StopDownloadResponse> handle(const Messages::ProtocolServer::StopDownload&) override;
virtual OwnPtr<Messages::ProtocolServer::DisownSharedBufferResponse> handle(const Messages::ProtocolServer::DisownSharedBuffer&) override;
+ HashMap<i32, OwnPtr<Download>> m_downloads;
HashMap<i32, RefPtr<AK::SharedBuffer>> m_shared_buffers;
};
diff --git a/Services/ProtocolServer/Protocol.h b/Services/ProtocolServer/Protocol.h
index 19bd38f517..bc72751bd5 100644
--- a/Services/ProtocolServer/Protocol.h
+++ b/Services/ProtocolServer/Protocol.h
@@ -37,7 +37,7 @@ public:
virtual ~Protocol();
const String& name() const { return m_name; }
- virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
+ virtual OwnPtr<Download> start_download(PSClientConnection&, const URL&) = 0;
static Protocol* find_by_name(const String&);