diff options
-rwxr-xr-x | Kernel/build-root-filesystem.sh | 1 | ||||
-rwxr-xr-x | Kernel/makeall.sh | 2 | ||||
-rw-r--r-- | Libraries/LibProtocol/Client.cpp | 45 | ||||
-rw-r--r-- | Libraries/LibProtocol/Client.h | 29 | ||||
-rw-r--r-- | Libraries/LibProtocol/Makefile | 20 | ||||
-rw-r--r-- | Makefile.common | 1 | ||||
-rw-r--r-- | Servers/ProtocolServer/Download.cpp | 55 | ||||
-rw-r--r-- | Servers/ProtocolServer/Download.h | 35 | ||||
-rw-r--r-- | Servers/ProtocolServer/HttpDownload.cpp | 20 | ||||
-rw-r--r-- | Servers/ProtocolServer/HttpDownload.h | 18 | ||||
-rw-r--r-- | Servers/ProtocolServer/HttpProtocol.cpp | 24 | ||||
-rw-r--r-- | Servers/ProtocolServer/HttpProtocol.h | 11 | ||||
-rw-r--r-- | Servers/ProtocolServer/Makefile | 35 | ||||
-rw-r--r-- | Servers/ProtocolServer/PSClientConnection.cpp | 63 | ||||
-rw-r--r-- | Servers/ProtocolServer/PSClientConnection.h | 26 | ||||
-rw-r--r-- | Servers/ProtocolServer/Protocol.cpp | 23 | ||||
-rw-r--r-- | Servers/ProtocolServer/Protocol.h | 23 | ||||
-rw-r--r-- | Servers/ProtocolServer/ProtocolClient.ipc | 6 | ||||
-rw-r--r-- | Servers/ProtocolServer/ProtocolServer.ipc | 12 | ||||
-rw-r--r-- | Servers/ProtocolServer/main.cpp | 25 | ||||
-rw-r--r-- | Servers/SystemServer/main.cpp | 1 |
21 files changed, 475 insertions, 0 deletions
diff --git a/Kernel/build-root-filesystem.sh b/Kernel/build-root-filesystem.sh index a017123e20..afdd529d7b 100755 --- a/Kernel/build-root-filesystem.sh +++ b/Kernel/build-root-filesystem.sh @@ -107,6 +107,7 @@ cp ../Servers/WindowServer/WindowServer mnt/bin/WindowServer cp ../Servers/AudioServer/AudioServer mnt/bin/AudioServer cp ../Servers/TTYServer/TTYServer mnt/bin/TTYServer cp ../Servers/TelnetServer/TelnetServer mnt/bin/TelnetServer +cp ../Servers/ProtocolServer/ProtocolServer mnt/bin/ProtocolServer cp ../Shell/Shell mnt/bin/Shell echo "done" diff --git a/Kernel/makeall.sh b/Kernel/makeall.sh index 6d09b8623d..dd8f021f7e 100755 --- a/Kernel/makeall.sh +++ b/Kernel/makeall.sh @@ -31,6 +31,7 @@ build_targets="$build_targets ../Libraries/LibPthread" # Build IPC servers before their client code to ensure the IPC definitions are available. build_targets="$build_targets ../Servers/AudioServer" build_targets="$build_targets ../Servers/LookupServer" +build_targets="$build_targets ../Servers/ProtocolServer" build_targets="$build_targets ../AK" @@ -42,6 +43,7 @@ build_targets="$build_targets ../Libraries/LibM" build_targets="$build_targets ../Libraries/LibPCIDB" build_targets="$build_targets ../Libraries/LibVT" build_targets="$build_targets ../Libraries/LibMarkdown" +build_targets="$build_targets ../Libraries/LibProtocol" build_targets="$build_targets ../Applications/About" build_targets="$build_targets ../Applications/Calculator" diff --git a/Libraries/LibProtocol/Client.cpp b/Libraries/LibProtocol/Client.cpp new file mode 100644 index 0000000000..d37317d9dd --- /dev/null +++ b/Libraries/LibProtocol/Client.cpp @@ -0,0 +1,45 @@ +#include <LibProtocol/Client.h> +#include <SharedBuffer.h> + +namespace LibProtocol { + +Client::Client() + : ConnectionNG(*this, "/tmp/psportal") +{ +} + +void Client::handshake() +{ + auto response = send_sync<ProtocolServer::Greet>(getpid()); + set_server_pid(response->server_pid()); + set_my_client_id(response->client_id()); +} + +bool Client::is_supported_protocol(const String& protocol) +{ + return send_sync<ProtocolServer::IsSupportedProtocol>(protocol)->supported(); +} + +i32 Client::start_download(const String& url) +{ + return send_sync<ProtocolServer::StartDownload>(url)->download_id(); +} + +bool Client::stop_download(i32 download_id) +{ + return send_sync<ProtocolServer::StopDownload>(download_id)->success(); +} + +void Client::handle(const ProtocolClient::DownloadFinished& message) +{ + if (on_download_finish) + on_download_finish(message.download_id(), message.success()); +} + +void Client::handle(const ProtocolClient::DownloadProgress& message) +{ + if (on_download_progress) + on_download_progress(message.download_id(), message.total_size(), message.downloaded_size()); +} + +} diff --git a/Libraries/LibProtocol/Client.h b/Libraries/LibProtocol/Client.h new file mode 100644 index 0000000000..86b4867f60 --- /dev/null +++ b/Libraries/LibProtocol/Client.h @@ -0,0 +1,29 @@ +#pragma once + +#include <LibCore/CoreIPCClient.h> +#include <ProtocolServer/ProtocolClientEndpoint.h> +#include <ProtocolServer/ProtocolServerEndpoint.h> + +namespace LibProtocol { + +class Client : public IPC::Client::ConnectionNG<ProtocolClientEndpoint, ProtocolServerEndpoint> + , public ProtocolClientEndpoint { + C_OBJECT(Client) +public: + Client(); + + virtual void handshake() override; + + bool is_supported_protocol(const String&); + i32 start_download(const String& url); + bool stop_download(i32 download_id); + + Function<void(i32 download_id, bool success)> on_download_finish; + Function<void(i32 download_id, u64 total_size, u64 downloaded_size)> on_download_progress; + +private: + virtual void handle(const ProtocolClient::DownloadProgress&) override; + virtual void handle(const ProtocolClient::DownloadFinished&) override; +}; + +} diff --git a/Libraries/LibProtocol/Makefile b/Libraries/LibProtocol/Makefile new file mode 100644 index 0000000000..16c751bbde --- /dev/null +++ b/Libraries/LibProtocol/Makefile @@ -0,0 +1,20 @@ +include ../../Makefile.common + +OBJS = \ + Client.o + +LIBRARY = libprotocol.a +DEFINES += -DUSERLAND + +all: $(LIBRARY) + +$(LIBRARY): $(OBJS) + @echo "LIB $@"; $(AR) rcs $@ $(OBJS) $(LIBS) + +.cpp.o: + @echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $< + +-include $(OBJS:%.o=%.d) + +clean: + @echo "CLEAN"; rm -f $(LIBRARY) $(OBJS) *.d diff --git a/Makefile.common b/Makefile.common index 8f6997cace..c00bab00c0 100644 --- a/Makefile.common +++ b/Makefile.common @@ -28,6 +28,7 @@ LDFLAGS = \ -L$(SERENITY_BASE_DIR)/Libraries/LibMarkdown \ -L$(SERENITY_BASE_DIR)/Libraries/LibThread \ -L$(SERENITY_BASE_DIR)/Libraries/LibVT \ + -L$(SERENITY_BASE_DIR)/Libraries/LibProtocol \ -L$(SERENITY_BASE_DIR)/Libraries/LibAudio CLANG_FLAGS = -Wconsumed -m32 -ffreestanding -march=i686 diff --git a/Servers/ProtocolServer/Download.cpp b/Servers/ProtocolServer/Download.cpp new file mode 100644 index 0000000000..d6fe75bef4 --- /dev/null +++ b/Servers/ProtocolServer/Download.cpp @@ -0,0 +1,55 @@ +#include <ProtocolServer/Download.h> +#include <ProtocolServer/PSClientConnection.h> + +// FIXME: What about rollover? +static i32 s_next_id = 1; + +static HashMap<i32, RefPtr<Download>>& all_downloads() +{ + static HashMap<i32, RefPtr<Download>> map; + return map; +} + +Download* Download::find_by_id(i32 id) +{ + return all_downloads().get(id).value_or(nullptr); +} + +Download::Download(PSClientConnection& client) + : m_id(s_next_id++) + , m_client(client.make_weak_ptr()) +{ + all_downloads().set(m_id, this); +} + +Download::~Download() +{ +} + +void Download::stop() +{ + all_downloads().remove(m_id); +} + +void Download::did_finish(bool success) +{ + if (!m_client) { + dbg() << "Download::did_finish() after the client already disconnected."; + return; + } + m_client->did_finish_download({}, *this, success); + all_downloads().remove(m_id); +} + +void Download::did_progress(size_t total_size, size_t downloaded_size) +{ + if (!m_client) { + // FIXME: We should also abort the download in this situation, I guess! + dbg() << "Download::did_progress() after the client already disconnected."; + return; + } + m_total_size = total_size; + m_downloaded_size = downloaded_size; + m_client->did_progress_download({}, *this); +} + diff --git a/Servers/ProtocolServer/Download.h b/Servers/ProtocolServer/Download.h new file mode 100644 index 0000000000..d12a6b7e52 --- /dev/null +++ b/Servers/ProtocolServer/Download.h @@ -0,0 +1,35 @@ +#pragma once + +#include <AK/RefCounted.h> +#include <AK/URL.h> +#include <AK/WeakPtr.h> + +class PSClientConnection; + +class Download : public RefCounted<Download> { +public: + virtual ~Download(); + + static Download* find_by_id(i32); + + i32 id() const { return m_id; } + URL url() const { return m_url; } + + size_t total_size() const { return m_total_size; } + size_t downloaded_size() const { return m_downloaded_size; } + + void stop(); + +protected: + explicit Download(PSClientConnection&); + + void did_finish(bool success); + void did_progress(size_t total_size, size_t downloaded_size); + +private: + i32 m_id; + URL m_url; + size_t m_total_size { 0 }; + size_t m_downloaded_size { 0 }; + WeakPtr<PSClientConnection> m_client; +}; diff --git a/Servers/ProtocolServer/HttpDownload.cpp b/Servers/ProtocolServer/HttpDownload.cpp new file mode 100644 index 0000000000..76dba63491 --- /dev/null +++ b/Servers/ProtocolServer/HttpDownload.cpp @@ -0,0 +1,20 @@ +#include <LibCore/CHttpJob.h> +#include <ProtocolServer/HttpDownload.h> + +HttpDownload::HttpDownload(PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job) + : Download(client) + , m_job(job) +{ + m_job->on_finish = [this](bool success) { + did_finish(success); + }; +} + +HttpDownload::~HttpDownload() +{ +} + +NonnullRefPtr<HttpDownload> HttpDownload::create_with_job(Badge<HttpProtocol>, PSClientConnection& client, NonnullRefPtr<CHttpJob>&& job) +{ + return adopt(*new HttpDownload(client, move(job))); +} diff --git a/Servers/ProtocolServer/HttpDownload.h b/Servers/ProtocolServer/HttpDownload.h new file mode 100644 index 0000000000..0de9cb1191 --- /dev/null +++ b/Servers/ProtocolServer/HttpDownload.h @@ -0,0 +1,18 @@ +#pragma once + +#include <AK/Badge.h> +#include <ProtocolServer/Download.h> + +class CHttpJob; +class HttpProtocol; + +class HttpDownload final : public Download { +public: + virtual ~HttpDownload() override; + static NonnullRefPtr<HttpDownload> create_with_job(Badge<HttpProtocol>, PSClientConnection&, NonnullRefPtr<CHttpJob>&&); + +private: + explicit HttpDownload(PSClientConnection&, NonnullRefPtr<CHttpJob>&&); + + NonnullRefPtr<CHttpJob> m_job; +}; diff --git a/Servers/ProtocolServer/HttpProtocol.cpp b/Servers/ProtocolServer/HttpProtocol.cpp new file mode 100644 index 0000000000..b1bc8d6298 --- /dev/null +++ b/Servers/ProtocolServer/HttpProtocol.cpp @@ -0,0 +1,24 @@ +#include <LibCore/CHttpJob.h> +#include <LibCore/CHttpRequest.h> +#include <ProtocolServer/HttpDownload.h> +#include <ProtocolServer/HttpProtocol.h> + +HttpProtocol::HttpProtocol() + : Protocol("http") +{ +} + +HttpProtocol::~HttpProtocol() +{ +} + +RefPtr<Download> HttpProtocol::start_download(PSClientConnection& client, const URL& url) +{ + CHttpRequest request; + request.set_method(CHttpRequest::Method::GET); + request.set_url(url); + auto job = request.schedule(); + if (!job) + return nullptr; + return HttpDownload::create_with_job({}, client, (CHttpJob&)*job); +} diff --git a/Servers/ProtocolServer/HttpProtocol.h b/Servers/ProtocolServer/HttpProtocol.h new file mode 100644 index 0000000000..67a5d9ecf5 --- /dev/null +++ b/Servers/ProtocolServer/HttpProtocol.h @@ -0,0 +1,11 @@ +#pragma once + +#include <ProtocolServer/Protocol.h> + +class HttpProtocol final : public Protocol { +public: + HttpProtocol(); + virtual ~HttpProtocol() override; + + virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) override; +}; diff --git a/Servers/ProtocolServer/Makefile b/Servers/ProtocolServer/Makefile new file mode 100644 index 0000000000..3e84b1f12f --- /dev/null +++ b/Servers/ProtocolServer/Makefile @@ -0,0 +1,35 @@ +include ../../Makefile.common + +OBJS = \ + PSClientConnection.o \ + Protocol.o \ + Download.o \ + HttpProtocol.o \ + HttpDownload.o \ + main.o + +APP = ProtocolServer + +DEFINES += -DUSERLAND + +all: $(APP) + +*.cpp: ProtocolServerEndpoint.h ProtocolClientEndpoint.h + +ProtocolServerEndpoint.h: ProtocolServer.ipc + @echo "IPC $<"; $(IPCCOMPILER) $< > $@ + +ProtocolClientEndpoint.h: ProtocolClient.ipc + @echo "IPC $<"; $(IPCCOMPILER) $< > $@ + +$(APP): $(OBJS) + $(LD) -o $(APP) $(LDFLAGS) $(OBJS) -lc -lcore -lipc -ldraw + +.cpp.o: + @echo "CXX $<"; $(CXX) $(CXXFLAGS) -o $@ -c $< + +-include $(OBJS:%.o=%.d) + +clean: + @echo "CLEAN"; rm -f $(APP) $(OBJS) *.d ProtocolClientEndpoint.h ProtocolServerEndpoint.h + diff --git a/Servers/ProtocolServer/PSClientConnection.cpp b/Servers/ProtocolServer/PSClientConnection.cpp new file mode 100644 index 0000000000..fb93268efe --- /dev/null +++ b/Servers/ProtocolServer/PSClientConnection.cpp @@ -0,0 +1,63 @@ +#include <ProtocolServer/Download.h> +#include <ProtocolServer/PSClientConnection.h> +#include <ProtocolServer/Protocol.h> +#include <ProtocolServer/ProtocolClientEndpoint.h> + +static HashMap<int, RefPtr<PSClientConnection>> s_connections; + +PSClientConnection::PSClientConnection(CLocalSocket& socket, int client_id) + : ConnectionNG(*this, socket, client_id) +{ + s_connections.set(client_id, *this); +} + +PSClientConnection::~PSClientConnection() +{ +} + +void PSClientConnection::die() +{ + s_connections.remove(client_id()); +} + +OwnPtr<ProtocolServer::IsSupportedProtocolResponse> PSClientConnection::handle(const ProtocolServer::IsSupportedProtocol& message) +{ + bool supported = Protocol::find_by_name(message.protocol().to_lowercase()); + return make<ProtocolServer::IsSupportedProtocolResponse>(supported); +} + +OwnPtr<ProtocolServer::StartDownloadResponse> PSClientConnection::handle(const ProtocolServer::StartDownload& message) +{ + URL url(message.url()); + ASSERT(url.is_valid()); + auto* protocol = Protocol::find_by_name(url.protocol()); + ASSERT(protocol); + auto download = protocol->start_download(*this, url); + return make<ProtocolServer::StartDownloadResponse>(download->id()); +} + +OwnPtr<ProtocolServer::StopDownloadResponse> PSClientConnection::handle(const ProtocolServer::StopDownload& message) +{ + auto* download = Download::find_by_id(message.download_id()); + bool success = false; + if (download) { + download->stop(); + } + return make<ProtocolServer::StopDownloadResponse>(success); +} + +void PSClientConnection::did_finish_download(Badge<Download>, Download& download, bool success) +{ + post_message(ProtocolClient::DownloadFinished(download.id(), success)); +} + +void PSClientConnection::did_progress_download(Badge<Download>, Download& download) +{ + post_message(ProtocolClient::DownloadProgress(download.id(), download.total_size(), download.downloaded_size())); +} + +OwnPtr<ProtocolServer::GreetResponse> PSClientConnection::handle(const ProtocolServer::Greet& message) +{ + set_client_pid(message.client_pid()); + return make<ProtocolServer::GreetResponse>(getpid(), client_id()); +} diff --git a/Servers/ProtocolServer/PSClientConnection.h b/Servers/ProtocolServer/PSClientConnection.h new file mode 100644 index 0000000000..190e9bd062 --- /dev/null +++ b/Servers/ProtocolServer/PSClientConnection.h @@ -0,0 +1,26 @@ +#pragma once + +#include <AK/Badge.h> +#include <LibCore/CoreIPCServer.h> +#include <ProtocolServer/ProtocolServerEndpoint.h> + +class Download; + +class PSClientConnection final : public IPC::Server::ConnectionNG<ProtocolServerEndpoint> + , public ProtocolServerEndpoint { + C_OBJECT(PSClientConnection) +public: + explicit PSClientConnection(CLocalSocket&, int client_id); + ~PSClientConnection() override; + + virtual void die() override; + + void did_finish_download(Badge<Download>, Download&, bool success); + void did_progress_download(Badge<Download>, Download&); + +private: + virtual OwnPtr<ProtocolServer::GreetResponse> handle(const ProtocolServer::Greet&) override; + virtual OwnPtr<ProtocolServer::IsSupportedProtocolResponse> handle(const ProtocolServer::IsSupportedProtocol&) override; + virtual OwnPtr<ProtocolServer::StartDownloadResponse> handle(const ProtocolServer::StartDownload&) override; + virtual OwnPtr<ProtocolServer::StopDownloadResponse> handle(const ProtocolServer::StopDownload&) override; +}; diff --git a/Servers/ProtocolServer/Protocol.cpp b/Servers/ProtocolServer/Protocol.cpp new file mode 100644 index 0000000000..7d29458eb5 --- /dev/null +++ b/Servers/ProtocolServer/Protocol.cpp @@ -0,0 +1,23 @@ +#include <AK/HashMap.h> +#include <ProtocolServer/Protocol.h> + +static HashMap<String, Protocol*>& all_protocols() +{ + static HashMap<String, Protocol*> map; + return map; +} + +Protocol* Protocol::find_by_name(const String& name) +{ + return all_protocols().get(name).value_or(nullptr); +} + +Protocol::Protocol(const String& name) +{ + all_protocols().set(name, this); +} + +Protocol::~Protocol() +{ + ASSERT_NOT_REACHED(); +} diff --git a/Servers/ProtocolServer/Protocol.h b/Servers/ProtocolServer/Protocol.h new file mode 100644 index 0000000000..d828352bbc --- /dev/null +++ b/Servers/ProtocolServer/Protocol.h @@ -0,0 +1,23 @@ +#pragma once + +#include <AK/RefPtr.h> +#include <AK/URL.h> + +class Download; +class PSClientConnection; + +class Protocol { +public: + virtual ~Protocol(); + + const String& name() const { return m_name; } + virtual RefPtr<Download> start_download(PSClientConnection&, const URL&) = 0; + + static Protocol* find_by_name(const String&); + +protected: + explicit Protocol(const String& name); + +private: + String m_name; +}; diff --git a/Servers/ProtocolServer/ProtocolClient.ipc b/Servers/ProtocolServer/ProtocolClient.ipc new file mode 100644 index 0000000000..df88714f50 --- /dev/null +++ b/Servers/ProtocolServer/ProtocolClient.ipc @@ -0,0 +1,6 @@ +endpoint ProtocolClient = 13 +{ + // Download notifications + DownloadProgress(i32 download_id, u32 total_size, u32 downloaded_size) =| + DownloadFinished(i32 download_id, bool success) =| +} diff --git a/Servers/ProtocolServer/ProtocolServer.ipc b/Servers/ProtocolServer/ProtocolServer.ipc new file mode 100644 index 0000000000..90af2d33ec --- /dev/null +++ b/Servers/ProtocolServer/ProtocolServer.ipc @@ -0,0 +1,12 @@ +endpoint ProtocolServer = 9 +{ + // Basic protocol + Greet(i32 client_pid) => (i32 server_pid, i32 client_id) + + // Test if a specific protocol is supported, e.g "http" + IsSupportedProtocol(String protocol) => (bool supported) + + // Download API + StartDownload(String url) => (i32 download_id) + StopDownload(i32 download_id) => (bool success) +} diff --git a/Servers/ProtocolServer/main.cpp b/Servers/ProtocolServer/main.cpp new file mode 100644 index 0000000000..568509384b --- /dev/null +++ b/Servers/ProtocolServer/main.cpp @@ -0,0 +1,25 @@ +#include <LibCore/CEventLoop.h> +#include <LibCore/CLocalServer.h> +#include <LibCore/CoreIPCServer.h> +#include <ProtocolServer/HttpProtocol.h> +#include <ProtocolServer/PSClientConnection.h> + +int main(int, char**) +{ + CEventLoop event_loop; + (void)*new HttpProtocol; + auto server = CLocalServer::construct(); + unlink("/tmp/psportal"); + server->listen("/tmp/psportal"); + server->on_ready_to_accept = [&] { + auto client_socket = server->accept(); + if (!client_socket) { + dbg() << "ProtocolServer: accept failed."; + return; + } + static int s_next_client_id = 0; + int client_id = ++s_next_client_id; + IPC::Server::new_connection_ng_for_client<PSClientConnection>(*client_socket, client_id); + }; + return event_loop.exec(); +} diff --git a/Servers/SystemServer/main.cpp b/Servers/SystemServer/main.cpp index 6f4f742496..e84fa4ed99 100644 --- a/Servers/SystemServer/main.cpp +++ b/Servers/SystemServer/main.cpp @@ -106,6 +106,7 @@ int main(int, char**) signal(SIGCHLD, sigchld_handler); + start_process("/bin/ProtocolServer", {}, lowest_prio); start_process("/bin/LookupServer", {}, lowest_prio); start_process("/bin/WindowServer", {}, highest_prio); start_process("/bin/AudioServer", {}, highest_prio); |