diff options
author | Ali Mohammad Pur <ali.mpfard@gmail.com> | 2022-04-06 04:14:18 +0430 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2022-04-09 12:21:43 +0200 |
commit | cd9d7401073e8d0307d4e656e5c926936759e38e (patch) | |
tree | b75a01ffa1ebba44c48f312a0c207c916af47c33 | |
parent | bd5403adf1e5427bf6932dc6ce84c3ab29ab2205 (diff) | |
download | serenity-cd9d7401073e8d0307d4e656e5c926936759e38e.zip |
LibCore+RequestServer: Add support for SOCKS5 proxies
-rw-r--r-- | Userland/Libraries/LibCore/CMakeLists.txt | 1 | ||||
-rw-r--r-- | Userland/Libraries/LibCore/Proxy.h | 31 | ||||
-rw-r--r-- | Userland/Libraries/LibCore/SOCKSProxyClient.cpp | 325 | ||||
-rw-r--r-- | Userland/Libraries/LibCore/SOCKSProxyClient.h | 64 | ||||
-rw-r--r-- | Userland/Services/RequestServer/ConnectionCache.cpp | 9 | ||||
-rw-r--r-- | Userland/Services/RequestServer/ConnectionCache.h | 55 |
6 files changed, 470 insertions, 15 deletions
diff --git a/Userland/Libraries/LibCore/CMakeLists.txt b/Userland/Libraries/LibCore/CMakeLists.txt index eb3332cab0..fcce16726a 100644 --- a/Userland/Libraries/LibCore/CMakeLists.txt +++ b/Userland/Libraries/LibCore/CMakeLists.txt @@ -26,6 +26,7 @@ set(SOURCES ProcessStatisticsReader.cpp Property.cpp SecretString.cpp + SOCKSProxyClient.cpp Stream.cpp StandardPaths.cpp System.cpp diff --git a/Userland/Libraries/LibCore/Proxy.h b/Userland/Libraries/LibCore/Proxy.h new file mode 100644 index 0000000000..d2fe8f7551 --- /dev/null +++ b/Userland/Libraries/LibCore/Proxy.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include <AK/Error.h> +#include <AK/Types.h> +#include <LibIPC/Forward.h> + +namespace Core { +// FIXME: Username/password support. +struct ProxyData { + enum Type { + Direct, + SOCKS5, + } type { Type::Direct }; + + u32 host_ipv4 { 0 }; + int port { 0 }; + + bool operator==(ProxyData const& other) const = default; +}; +} + +namespace IPC { +bool encode(Encoder&, Core::ProxyData const&); +ErrorOr<void> decode(Decoder&, Core::ProxyData&); +} diff --git a/Userland/Libraries/LibCore/SOCKSProxyClient.cpp b/Userland/Libraries/LibCore/SOCKSProxyClient.cpp new file mode 100644 index 0000000000..6368a07f9d --- /dev/null +++ b/Userland/Libraries/LibCore/SOCKSProxyClient.cpp @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <AK/MemoryStream.h> +#include <LibCore/SOCKSProxyClient.h> + +enum class Method : u8 { + NoAuth = 0x00, + GSSAPI = 0x01, + UsernamePassword = 0x02, + NoAcceptableMethods = 0xFF, +}; + +enum class AddressType : u8 { + IPV4 = 0x01, + DomainName = 0x03, + IPV6 = 0x04, +}; + +enum class Reply { + Succeeded = 0x00, + GeneralSocksServerFailure = 0x01, + ConnectionNotAllowedByRuleset = 0x02, + NetworkUnreachable = 0x03, + HostUnreachable = 0x04, + ConnectionRefused = 0x05, + TTLExpired = 0x06, + CommandNotSupported = 0x07, + AddressTypeNotSupported = 0x08, +}; + +struct [[gnu::packed]] Socks5VersionIdentifierAndMethodSelectionMessage { + u8 version_identifier; + u8 method_count; + // NOTE: We only send a single method, so we don't need to make this variable-length. + u8 methods[1]; +}; + +struct [[gnu::packed]] Socks5InitialResponse { + u8 version_identifier; + u8 method; +}; + +struct [[gnu::packed]] Socks5ConnectRequestHeader { + u8 version_identifier; + u8 command; + u8 reserved; +}; + +struct [[gnu::packed]] Socks5ConnectRequestTrailer { + u16 port; +}; + +struct [[gnu::packed]] Socks5ConnectResponseHeader { + u8 version_identifier; + u8 status; + u8 reserved; +}; + +struct [[gnu::packed]] Socks5ConnectResponseTrailer { + u8 bind_port; +}; + +struct [[gnu::packed]] Socks5UsernamePasswordResponse { + u8 version_identifier; + u8 status; +}; + +namespace { +StringView reply_response_name(Reply reply) +{ + switch (reply) { + case Reply::Succeeded: + return "Succeeded"; + case Reply::GeneralSocksServerFailure: + return "GeneralSocksServerFailure"; + case Reply::ConnectionNotAllowedByRuleset: + return "ConnectionNotAllowedByRuleset"; + case Reply::NetworkUnreachable: + return "NetworkUnreachable"; + case Reply::HostUnreachable: + return "HostUnreachable"; + case Reply::ConnectionRefused: + return "ConnectionRefused"; + case Reply::TTLExpired: + return "TTLExpired"; + case Reply::CommandNotSupported: + return "CommandNotSupported"; + case Reply::AddressTypeNotSupported: + return "AddressTypeNotSupported"; + } + VERIFY_NOT_REACHED(); +} + +ErrorOr<void> send_version_identifier_and_method_selection_message(Core::Stream::Socket& socket, Core::SOCKSProxyClient::Version version, Method method) +{ + Socks5VersionIdentifierAndMethodSelectionMessage message { + .version_identifier = to_underlying(version), + .method_count = 1, + .methods = { to_underlying(method) }, + }; + auto size = TRY(socket.write({ &message, sizeof(message) })); + if (size != sizeof(message)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send version identifier and method selection message"); + + Socks5InitialResponse response; + size = TRY(socket.read({ &response, sizeof(response) })); + if (size != sizeof(response)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive initial response"); + + if (response.version_identifier != to_underlying(version)) + return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + + if (response.method != to_underlying(method)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to negotiate a method"); + + return {}; +} + +ErrorOr<Reply> send_connect_request_message(Core::Stream::Socket& socket, Core::SOCKSProxyClient::Version version, Core::SOCKSProxyClient::HostOrIPV4 target, int port, Core::SOCKSProxyClient::Command command) +{ + DuplexMemoryStream stream; + + Socks5ConnectRequestHeader header { + .version_identifier = to_underlying(version), + .command = to_underlying(command), + .reserved = 0, + }; + Socks5ConnectRequestTrailer trailer { + .port = htons(port), + }; + + auto size = stream.write({ &header, sizeof(header) }); + if (size != sizeof(header)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send connect request header"); + + TRY(target.visit( + [&](String const& hostname) -> ErrorOr<void> { + u8 address_data[2]; + address_data[0] = to_underlying(AddressType::DomainName); + address_data[1] = hostname.length(); + auto size = stream.write({ address_data, sizeof(address_data) }); + if (size != array_size(address_data)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send connect request address data"); + stream.write({ hostname.characters(), hostname.length() }); + return {}; + }, + [&](u32 ipv4) -> ErrorOr<void> { + u8 address_data[5]; + address_data[0] = to_underlying(AddressType::IPV4); + u32 network_ordered_ipv4 = NetworkOrdered<u32>(ipv4); + memcpy(address_data + 1, &network_ordered_ipv4, sizeof(network_ordered_ipv4)); + auto size = stream.write({ address_data, sizeof(address_data) }); + if (size != array_size(address_data)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send connect request address data"); + return {}; + })); + + size = stream.write({ &trailer, sizeof(trailer) }); + if (size != sizeof(trailer)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send connect request trailer"); + + auto buffer = stream.copy_into_contiguous_buffer(); + size = TRY(socket.write({ buffer.data(), buffer.size() })); + if (size != buffer.size()) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send connect request"); + + Socks5ConnectResponseHeader response_header; + size = TRY(socket.read({ &response_header, sizeof(response_header) })); + if (size != sizeof(response_header)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response header"); + + if (response_header.version_identifier != to_underlying(version)) + return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + + u8 response_address_type; + size = TRY(socket.read({ &response_address_type, sizeof(response_address_type) })); + if (size != sizeof(response_address_type)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response address type"); + + switch (AddressType(response_address_type)) { + case AddressType::IPV4: { + u8 response_address_data[4]; + size = TRY(socket.read({ response_address_data, sizeof(response_address_data) })); + if (size != sizeof(response_address_data)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response address data"); + break; + } + case AddressType::DomainName: { + u8 response_address_length; + size = TRY(socket.read({ &response_address_length, sizeof(response_address_length) })); + if (size != sizeof(response_address_length)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response address length"); + ByteBuffer buffer; + buffer.resize(response_address_length); + size = TRY(socket.read(buffer)); + if (size != response_address_length) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response address data"); + break; + } + case AddressType::IPV6: + default: + return Error::from_string_literal("SOCKS negotiation failed: Invalid connect response address type"); + } + + u16 bound_port; + size = TRY(socket.read({ &bound_port, sizeof(bound_port) })); + if (size != sizeof(bound_port)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive connect response bound port"); + + return Reply(response_header.status); +} + +ErrorOr<u8> send_username_password_authentication_message(Core::Stream::Socket& socket, Core::SOCKSProxyClient::UsernamePasswordAuthenticationData const& auth_data) +{ + DuplexMemoryStream stream; + + u8 version = 0x01; + auto size = stream.write({ &version, sizeof(version) }); + if (size != sizeof(version)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + u8 username_length = auth_data.username.length(); + size = stream.write({ &username_length, sizeof(username_length) }); + if (size != sizeof(username_length)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + size = stream.write({ auth_data.username.characters(), auth_data.username.length() }); + if (size != auth_data.username.length()) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + u8 password_length = auth_data.password.length(); + size = stream.write({ &password_length, sizeof(password_length) }); + if (size != sizeof(password_length)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + size = stream.write({ auth_data.password.characters(), auth_data.password.length() }); + if (size != auth_data.password.length()) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + auto buffer = stream.copy_into_contiguous_buffer(); + size = TRY(socket.write(buffer)); + if (size != buffer.size()) + return Error::from_string_literal("SOCKS negotiation failed: Failed to send username/password authentication message"); + + Socks5UsernamePasswordResponse response; + size = TRY(socket.read({ &response, sizeof(response) })); + if (size != sizeof(response)) + return Error::from_string_literal("SOCKS negotiation failed: Failed to receive username/password authentication response"); + + if (response.version_identifier != version) + return Error::from_string_literal("SOCKS negotiation failed: Invalid version identifier"); + + return response.status; +} +} + +namespace Core { + +SOCKSProxyClient::~SOCKSProxyClient() +{ + close(); + m_socket.on_ready_to_read = nullptr; +} + +ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> SOCKSProxyClient::connect(Socket& underlying, Version version, HostOrIPV4 const& target, int target_port, Variant<UsernamePasswordAuthenticationData, Empty> const& auth_data, Command command) +{ + if (version != Version::V5) + return Error::from_string_literal("SOCKS version not supported"); + + return auth_data.visit( + [&](Empty) -> ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> { + TRY(send_version_identifier_and_method_selection_message(underlying, version, Method::NoAuth)); + auto reply = TRY(send_connect_request_message(underlying, version, target, target_port, command)); + if (reply != Reply::Succeeded) { + underlying.close(); + return Error::from_string_literal(reply_response_name(reply)); + } + + return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { + underlying, + nullptr, + }); + }, + [&](UsernamePasswordAuthenticationData const& auth_data) -> ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> { + TRY(send_version_identifier_and_method_selection_message(underlying, version, Method::UsernamePassword)); + auto auth_response = TRY(send_username_password_authentication_message(underlying, auth_data)); + if (auth_response != 0) { + underlying.close(); + return Error::from_string_literal("SOCKS authentication failed"); + } + + auto reply = TRY(send_connect_request_message(underlying, version, target, target_port, command)); + if (reply != Reply::Succeeded) { + underlying.close(); + return Error::from_string_literal(reply_response_name(reply)); + } + + return adopt_nonnull_own_or_enomem(new SOCKSProxyClient { + underlying, + nullptr, + }); + }); +} + +ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> SOCKSProxyClient::connect(HostOrIPV4 const& server, int server_port, Version version, HostOrIPV4 const& target, int target_port, Variant<UsernamePasswordAuthenticationData, Empty> const& auth_data, Command command) +{ + auto underlying = TRY(server.visit( + [&](u32 ipv4) { + return Core::Stream::TCPSocket::connect({ IPv4Address(ipv4), static_cast<u16>(server_port) }); + }, + [&](String const& hostname) { + return Core::Stream::TCPSocket::connect(hostname, static_cast<u16>(server_port)); + })); + + auto socket = TRY(connect(*underlying, version, target, target_port, auth_data, command)); + socket->m_own_underlying_socket = move(underlying); + dbgln("SOCKS proxy connected, have {} available bytes", TRY(socket->m_socket.pending_bytes())); + return socket; +} + +} diff --git a/Userland/Libraries/LibCore/SOCKSProxyClient.h b/Userland/Libraries/LibCore/SOCKSProxyClient.h new file mode 100644 index 0000000000..4b011be0a3 --- /dev/null +++ b/Userland/Libraries/LibCore/SOCKSProxyClient.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include <AK/OwnPtr.h> +#include <LibCore/Proxy.h> +#include <LibCore/Stream.h> + +namespace Core { +class SOCKSProxyClient final : public Stream::Socket { +public: + enum class Version : u8 { + V4 = 0x04, + V5 = 0x05, + }; + + struct UsernamePasswordAuthenticationData { + String username; + String password; + }; + + enum class Command : u8 { + Connect = 0x01, + Bind = 0x02, + UDPAssociate = 0x03, + }; + + using HostOrIPV4 = Variant<String, u32>; + + static ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> connect(Socket& underlying, Version, HostOrIPV4 const& target, int target_port, Variant<UsernamePasswordAuthenticationData, Empty> const& auth_data = {}, Command = Command::Connect); + static ErrorOr<NonnullOwnPtr<SOCKSProxyClient>> connect(HostOrIPV4 const& server, int server_port, Version, HostOrIPV4 const& target, int target_port, Variant<UsernamePasswordAuthenticationData, Empty> const& auth_data = {}, Command = Command::Connect); + + virtual ~SOCKSProxyClient() override; + + // ^Stream::Stream + virtual ErrorOr<size_t> read(Bytes bytes) override { return m_socket.read(bytes); } + virtual ErrorOr<size_t> write(ReadonlyBytes bytes) override { return m_socket.write(bytes); } + virtual bool is_eof() const override { return m_socket.is_eof(); } + virtual bool is_open() const override { return m_socket.is_open(); } + virtual void close() override { m_socket.close(); } + + // ^Stream::Socket + virtual ErrorOr<size_t> pending_bytes() const override { return m_socket.pending_bytes(); } + virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_socket.can_read_without_blocking(timeout); } + virtual ErrorOr<void> set_blocking(bool enabled) override { return m_socket.set_blocking(enabled); } + virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_socket.set_close_on_exec(enabled); } + virtual void set_notifications_enabled(bool enabled) override { m_socket.set_notifications_enabled(enabled); } + +private: + SOCKSProxyClient(Socket& socket, OwnPtr<Socket> own_socket) + : m_socket(socket) + , m_own_underlying_socket(move(own_socket)) + { + m_socket.on_ready_to_read = [this] { on_ready_to_read(); }; + } + + Socket& m_socket; + OwnPtr<Socket> m_own_underlying_socket; +}; +} diff --git a/Userland/Services/RequestServer/ConnectionCache.cpp b/Userland/Services/RequestServer/ConnectionCache.cpp index fc5aea4c43..7ae813322f 100644 --- a/Userland/Services/RequestServer/ConnectionCache.cpp +++ b/Userland/Services/RequestServer/ConnectionCache.cpp @@ -6,11 +6,12 @@ #include "ConnectionCache.h" #include <AK/Debug.h> +#include <AK/Find.h> #include <LibCore/EventLoop.h> namespace RequestServer::ConnectionCache { -HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket>>>> g_tcp_connection_cache {}; +HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket, Core::Stream::Socket>>>> g_tcp_connection_cache {}; HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<TLS::TLSv12>>>> g_tls_connection_cache {}; void request_did_finish(URL const& url, Core::Stream::Socket const* socket) @@ -22,9 +23,9 @@ void request_did_finish(URL const& url, Core::Stream::Socket const* socket) dbgln_if(REQUESTSERVER_DEBUG, "Request for {} finished", url); - ConnectionKey key { url.host(), url.port_or_default() }; + ConnectionKey partial_key { url.host(), url.port_or_default() }; auto fire_off_next_job = [&](auto& cache) { - auto it = cache.find(key); + auto it = find_if(cache.begin(), cache.end(), [&](auto& connection) { return connection.key.hostname == partial_key.hostname && connection.key.port == partial_key.port; }); if (it == cache.end()) { dbgln("Request for URL {} finished, but we don't own that!", url); return; @@ -72,7 +73,7 @@ void request_did_finish(URL const& url, Core::Stream::Socket const* socket) if (is<Core::Stream::BufferedSocket<TLS::TLSv12>>(socket)) fire_off_next_job(g_tls_connection_cache); - else if (is<Core::Stream::BufferedSocket<Core::Stream::TCPSocket>>(socket)) + else if (is<Core::Stream::BufferedSocket<Core::Stream::Socket>>(socket)) fire_off_next_job(g_tcp_connection_cache); else dbgln("Unknown socket {} finished for URL {}", socket, url); diff --git a/Userland/Services/RequestServer/ConnectionCache.h b/Userland/Services/RequestServer/ConnectionCache.h index 3784a91cec..a88bcc4358 100644 --- a/Userland/Services/RequestServer/ConnectionCache.h +++ b/Userland/Services/RequestServer/ConnectionCache.h @@ -15,6 +15,7 @@ #include <LibCore/ElapsedTimer.h> #include <LibCore/EventLoop.h> #include <LibCore/NetworkJob.h> +#include <LibCore/SOCKSProxyClient.h> #include <LibCore/Timer.h> #include <LibTLS/TLSv12.h> @@ -29,7 +30,31 @@ enum class CacheLevel { namespace RequestServer::ConnectionCache { -template<typename Socket> +struct Proxy { + Core::ProxyData data; + OwnPtr<Core::SOCKSProxyClient> proxy_client_storage {}; + + template<typename SocketType, typename StorageType, typename... Args> + ErrorOr<NonnullOwnPtr<StorageType>> tunnel(URL const& url, Args&&... args) + { + if (data.type == Core::ProxyData::Direct) { + return TRY(SocketType::connect(url.host(), url.port_or_default(), forward<Args>(args)...)); + } + if (data.type == Core::ProxyData::SOCKS5) { + if constexpr (requires { SocketType::connect(declval<String>(), *proxy_client_storage, forward<Args>(args)...); }) { + proxy_client_storage = TRY(Core::SOCKSProxyClient::connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, url.host(), url.port_or_default())); + return TRY(SocketType::connect(url.host(), *proxy_client_storage, forward<Args>(args)...)); + } else if constexpr (IsSame<SocketType, Core::Stream::TCPSocket>) { + return TRY(Core::SOCKSProxyClient::connect(data.host_ipv4, data.port, Core::SOCKSProxyClient::Version::V5, url.host(), url.port_or_default())); + } else { + return Error::from_string_literal("SOCKS5 not supported for this socket type"); + } + } + VERIFY_NOT_REACHED(); + } +}; + +template<typename Socket, typename SocketStorageType = Socket> struct Connection { struct JobData { Function<void(Core::Stream::Socket&)> start {}; @@ -64,19 +89,22 @@ struct Connection { }; using QueueType = Vector<JobData>; using SocketType = Socket; + using StorageType = SocketStorageType; - NonnullOwnPtr<Core::Stream::BufferedSocket<Socket>> socket; + NonnullOwnPtr<Core::Stream::BufferedSocket<SocketStorageType>> socket; QueueType request_queue; NonnullRefPtr<Core::Timer> removal_timer; bool has_started { false }; URL current_url {}; Core::ElapsedTimer timer {}; JobData job_data {}; + Proxy proxy {}; }; struct ConnectionKey { String hostname; u16 port { 0 }; + Core::ProxyData proxy_data {}; bool operator==(ConnectionKey const&) const = default; }; @@ -87,13 +115,13 @@ template<> struct AK::Traits<RequestServer::ConnectionCache::ConnectionKey> : public AK::GenericTraits<RequestServer::ConnectionCache::ConnectionKey> { static u32 hash(RequestServer::ConnectionCache::ConnectionKey const& key) { - return pair_int_hash(key.hostname.hash(), key.port); + return pair_int_hash(pair_int_hash(key.proxy_data.host_ipv4, key.proxy_data.port), pair_int_hash(key.hostname.hash(), key.port)); } }; namespace RequestServer::ConnectionCache { -extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket>>>> g_tcp_connection_cache; +extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<Core::Stream::TCPSocket, Core::Stream::Socket>>>> g_tcp_connection_cache; extern HashMap<ConnectionKey, NonnullOwnPtr<NonnullOwnPtrVector<Connection<TLS::TLSv12>>>> g_tls_connection_cache; void request_did_finish(URL const&, Core::Stream::Socket const*); @@ -106,10 +134,12 @@ template<typename T> ErrorOr<void> recreate_socket_if_needed(T& connection, URL const& url) { using SocketType = typename T::SocketType; + using SocketStorageType = typename T::StorageType; + if (!connection.socket->is_open() || connection.socket->is_eof()) { // Create another socket for the connection. auto set_socket = [&](auto socket) -> ErrorOr<void> { - connection.socket = TRY(Core::Stream::BufferedSocket<SocketType>::create(move(socket))); + connection.socket = TRY(Core::Stream::BufferedSocket<SocketStorageType>::create(move(socket))); return {}; }; @@ -132,19 +162,21 @@ ErrorOr<void> recreate_socket_if_needed(T& connection, URL const& url) return connection.job_data.provide_client_certificates(); return {}; }); - TRY(set_socket(TRY(SocketType::connect(url.host(), url.port_or_default(), move(options))))); + TRY(set_socket(TRY((connection.proxy.template tunnel<SocketType, SocketStorageType>(url, move(options)))))); } else { - TRY(set_socket(TRY(SocketType::connect(url.host(), url.port_or_default())))); + TRY(set_socket(TRY((connection.proxy.template tunnel<SocketType, SocketStorageType>(url))))); } dbgln_if(REQUESTSERVER_DEBUG, "Creating a new socket for {} -> {}", url, connection.socket); } return {}; } -decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) +decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job, Core::ProxyData proxy_data = {}) { using CacheEntryType = RemoveCVReference<decltype(*cache.begin()->value)>; - auto& sockets_for_url = *cache.ensure({ url.host(), url.port_or_default() }, [] { return make<CacheEntryType>(); }); + auto& sockets_for_url = *cache.ensure({ url.host(), url.port_or_default(), proxy_data }, [] { return make<CacheEntryType>(); }); + + Proxy proxy { proxy_data }; using ReturnType = decltype(&sockets_for_url[0]); auto it = sockets_for_url.find_if([](auto& connection) { return connection->request_queue.is_empty(); }); @@ -152,7 +184,7 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) auto failed_to_find_a_socket = it.is_end(); if (failed_to_find_a_socket && sockets_for_url.size() < ConnectionCache::MaxConcurrentConnectionsPerURL) { using ConnectionType = RemoveCVReference<decltype(cache.begin()->value->at(0))>; - auto connection_result = ConnectionType::SocketType::connect(url.host(), url.port_or_default()); + auto connection_result = proxy.tunnel<typename ConnectionType::SocketType, typename ConnectionType::StorageType>(url); if (connection_result.is_error()) { dbgln("ConnectionCache: Connection to {} failed: {}", url, connection_result.error()); Core::deferred_invoke([&job] { @@ -160,7 +192,7 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) }); return ReturnType { nullptr }; } - auto socket_result = Core::Stream::BufferedSocket<typename ConnectionType::SocketType>::create(connection_result.release_value()); + auto socket_result = Core::Stream::BufferedSocket<typename ConnectionType::StorageType>::create(connection_result.release_value()); if (socket_result.is_error()) { dbgln("ConnectionCache: Failed to make a buffered socket for {}: {}", url, socket_result.error()); Core::deferred_invoke([&job] { @@ -172,6 +204,7 @@ decltype(auto) get_or_create_connection(auto& cache, URL const& url, auto& job) socket_result.release_value(), typename ConnectionType::QueueType {}, Core::Timer::create_single_shot(ConnectionKeepAliveTimeMilliseconds, nullptr))); + sockets_for_url.last().proxy = move(proxy); did_add_new_connection = true; } size_t index; |