diff options
author | Ali Mohammad Pur <ali.mpfard@gmail.com> | 2022-02-04 14:13:30 +0330 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2022-02-06 13:10:10 +0100 |
commit | 3f614a8fca216a09c3b441c81b4075ab6fb0630d (patch) | |
tree | 81d0a642f741faad09654467a37469e046d2d04b /Userland/Libraries/LibWebSocket | |
parent | d66c513131f19f492ca3eb5b475950076925b239 (diff) | |
download | serenity-3f614a8fca216a09c3b441c81b4075ab6fb0630d.zip |
LibWebSocket: Switch to using Core::Stream
As LibTLS now supports the Core::Stream APIs, we can get rid of the
split paths for TCP/TLS and significantly simplify the code as well.
Provided to you free of charge by the Core::Stream-ification team :^)
Diffstat (limited to 'Userland/Libraries/LibWebSocket')
11 files changed, 142 insertions, 348 deletions
diff --git a/Userland/Libraries/LibWebSocket/CMakeLists.txt b/Userland/Libraries/LibWebSocket/CMakeLists.txt index b238fa6a69..c90654a68e 100644 --- a/Userland/Libraries/LibWebSocket/CMakeLists.txt +++ b/Userland/Libraries/LibWebSocket/CMakeLists.txt @@ -1,8 +1,6 @@ set(SOURCES ConnectionInfo.cpp - Impl/AbstractWebSocketImpl.cpp - Impl/TCPWebSocketConnectionImpl.cpp - Impl/TLSv12WebSocketConnectionImpl.cpp + Impl/WebSocketImpl.cpp WebSocket.cpp ) diff --git a/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.cpp deleted file mode 100644 index 263eefb3e2..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.cpp +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#include <LibWebSocket/Impl/AbstractWebSocketImpl.h> - -namespace WebSocket { - -AbstractWebSocketImpl::AbstractWebSocketImpl(Core::Object* parent) - : Object(parent) -{ -} - -AbstractWebSocketImpl::~AbstractWebSocketImpl() -{ -} - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.h b/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.h deleted file mode 100644 index 2bb8a3f88d..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/AbstractWebSocketImpl.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#pragma once - -#include <AK/ByteBuffer.h> -#include <AK/Span.h> -#include <AK/String.h> -#include <LibCore/Object.h> -#include <LibWebSocket/ConnectionInfo.h> - -namespace WebSocket { - -class AbstractWebSocketImpl : public Core::Object { - C_OBJECT_ABSTRACT(AbstractWebSocketImpl); - -public: - virtual ~AbstractWebSocketImpl() override; - explicit AbstractWebSocketImpl(Core::Object* parent = nullptr); - - virtual void connect(ConnectionInfo const&) = 0; - - virtual bool can_read_line() = 0; - virtual String read_line(size_t size) = 0; - - virtual bool can_read() = 0; - virtual ByteBuffer read(int max_size) = 0; - - virtual bool send(ReadonlyBytes) = 0; - - virtual bool eof() = 0; - - virtual void discard_connection() = 0; - - Function<void()> on_connected; - Function<void()> on_connection_error; - Function<void()> on_ready_to_read; -}; - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.cpp deleted file mode 100644 index 709f28fc7f..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#include <LibWebSocket/Impl/TCPWebSocketConnectionImpl.h> - -namespace WebSocket { - -TCPWebSocketConnectionImpl::TCPWebSocketConnectionImpl(Core::Object* parent) - : AbstractWebSocketImpl(parent) -{ -} - -TCPWebSocketConnectionImpl::~TCPWebSocketConnectionImpl() -{ - discard_connection(); -} - -void TCPWebSocketConnectionImpl::connect(ConnectionInfo const& connection) -{ - VERIFY(!m_socket); - VERIFY(on_connected); - VERIFY(on_connection_error); - VERIFY(on_ready_to_read); - m_socket = Core::TCPSocket::construct(this); - - m_notifier = Core::Notifier::construct(m_socket->fd(), Core::Notifier::Read); - m_notifier->on_ready_to_read = [this] { - on_ready_to_read(); - }; - - m_socket->on_connected = [this] { - on_connected(); - }; - bool success = m_socket->connect(connection.url().host(), connection.url().port_or_default()); - if (!success) { - deferred_invoke([this] { - on_connection_error(); - }); - } -} - -bool TCPWebSocketConnectionImpl::send(ReadonlyBytes data) -{ - return m_socket->write(data); -} - -bool TCPWebSocketConnectionImpl::can_read_line() -{ - return m_socket->can_read_line(); -} - -String TCPWebSocketConnectionImpl::read_line(size_t size) -{ - return m_socket->read_line(size); -} - -bool TCPWebSocketConnectionImpl::can_read() -{ - return m_socket->can_read(); -} - -ByteBuffer TCPWebSocketConnectionImpl::read(int max_size) -{ - return m_socket->read(max_size); -} - -bool TCPWebSocketConnectionImpl::eof() -{ - return m_socket->eof(); -} - -void TCPWebSocketConnectionImpl::discard_connection() -{ - if (!m_socket) - return; - m_socket->on_ready_to_read = nullptr; - remove_child(*m_socket); - m_socket = nullptr; -} - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.h b/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.h deleted file mode 100644 index 1ece3f8bef..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/TCPWebSocketConnectionImpl.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#pragma once - -#include <AK/ByteBuffer.h> -#include <AK/Span.h> -#include <AK/String.h> -#include <LibCore/Notifier.h> -#include <LibCore/Object.h> -#include <LibCore/TCPSocket.h> -#include <LibWebSocket/ConnectionInfo.h> -#include <LibWebSocket/Impl/AbstractWebSocketImpl.h> - -namespace WebSocket { - -class TCPWebSocketConnectionImpl final : public AbstractWebSocketImpl { - C_OBJECT(TCPWebSocketConnectionImpl); - -public: - virtual ~TCPWebSocketConnectionImpl() override; - - virtual void connect(ConnectionInfo const& connection) override; - - virtual bool can_read_line() override; - virtual String read_line(size_t size) override; - - virtual bool can_read() override; - virtual ByteBuffer read(int max_size) override; - - virtual bool send(ReadonlyBytes data) override; - - virtual bool eof() override; - - virtual void discard_connection() override; - -private: - explicit TCPWebSocketConnectionImpl(Core::Object* parent = nullptr); - - RefPtr<Core::Notifier> m_notifier; - RefPtr<Core::TCPSocket> m_socket; -}; - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp deleted file mode 100644 index 1d8150a470..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#include <LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h> - -namespace WebSocket { - -TLSv12WebSocketConnectionImpl::TLSv12WebSocketConnectionImpl(Core::Object* parent) - : AbstractWebSocketImpl(parent) -{ -} - -TLSv12WebSocketConnectionImpl::~TLSv12WebSocketConnectionImpl() -{ - discard_connection(); -} - -void TLSv12WebSocketConnectionImpl::connect(ConnectionInfo const& connection) -{ - VERIFY(!m_socket); - VERIFY(on_connected); - VERIFY(on_connection_error); - VERIFY(on_ready_to_read); - m_socket = TLS::TLSv12::connect(connection.url().host(), connection.url().port_or_default()).release_value_but_fixme_should_propagate_errors(); - - m_socket->on_tls_error = [this](TLS::AlertDescription) { - on_connection_error(); - }; - m_socket->on_ready_to_read = [this] { - on_ready_to_read(); - }; - m_socket->on_tls_finished = [this] { - on_connection_error(); - }; - m_socket->on_tls_certificate_request = [](auto&) { - // FIXME : Once we handle TLS certificate requests, handle it here as well. - }; - on_connected(); -} - -bool TLSv12WebSocketConnectionImpl::send(ReadonlyBytes data) -{ - return m_socket->write_or_error(data); -} - -bool TLSv12WebSocketConnectionImpl::can_read_line() -{ - return m_socket->can_read_line(); -} - -String TLSv12WebSocketConnectionImpl::read_line(size_t size) -{ - return m_socket->read_line(size); -} - -bool TLSv12WebSocketConnectionImpl::can_read() -{ - return m_socket->can_read(); -} - -ByteBuffer TLSv12WebSocketConnectionImpl::read(int max_size) -{ - auto buffer = ByteBuffer::create_uninitialized(max_size).release_value_but_fixme_should_propagate_errors(); - auto nread = m_socket->read(buffer).release_value_but_fixme_should_propagate_errors(); - return buffer.slice(0, nread); -} - -bool TLSv12WebSocketConnectionImpl::eof() -{ - return m_socket->is_eof(); -} - -void TLSv12WebSocketConnectionImpl::discard_connection() -{ - if (!m_socket) - return; - m_socket->on_tls_error = nullptr; - m_socket->on_tls_finished = nullptr; - m_socket->on_tls_certificate_request = nullptr; - m_socket->on_ready_to_read = nullptr; - m_socket = nullptr; -} - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h b/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h deleted file mode 100644 index db39b2bc63..0000000000 --- a/Userland/Libraries/LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> - * - * SPDX-License-Identifier: BSD-2-Clause - */ - -#pragma once - -#include <AK/ByteBuffer.h> -#include <AK/Span.h> -#include <AK/String.h> -#include <LibCore/Object.h> -#include <LibTLS/TLSv12.h> -#include <LibWebSocket/ConnectionInfo.h> -#include <LibWebSocket/Impl/AbstractWebSocketImpl.h> - -namespace WebSocket { - -class TLSv12WebSocketConnectionImpl final : public AbstractWebSocketImpl { - C_OBJECT(TLSv12WebSocketConnectionImpl); - -public: - virtual ~TLSv12WebSocketConnectionImpl() override; - - void connect(ConnectionInfo const& connection) override; - - virtual bool can_read_line() override; - virtual String read_line(size_t size) override; - - virtual bool can_read() override; - virtual ByteBuffer read(int max_size) override; - - virtual bool send(ReadonlyBytes data) override; - - virtual bool eof() override; - - virtual void discard_connection() override; - -private: - explicit TLSv12WebSocketConnectionImpl(Core::Object* parent = nullptr); - - OwnPtr<TLS::TLSv12> m_socket; -}; - -} diff --git a/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.cpp b/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.cpp new file mode 100644 index 0000000000..762177aeaa --- /dev/null +++ b/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> + * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include <LibWebSocket/Impl/WebSocketImpl.h> + +namespace WebSocket { + +WebSocketImpl::WebSocketImpl(Core::Object* parent) + : Object(parent) +{ +} + +WebSocketImpl::~WebSocketImpl() +{ +} + +void WebSocketImpl::connect(ConnectionInfo const& connection_info) +{ + VERIFY(!m_socket); + VERIFY(on_connected); + VERIFY(on_connection_error); + VERIFY(on_ready_to_read); + auto socket_result = [&]() -> ErrorOr<NonnullOwnPtr<Core::Stream::BufferedSocketBase>> { + if (connection_info.is_secure()) { + TLS::Options options; + options.set_alert_handler([this](auto) { + on_connection_error(); + }); + return TRY(Core::Stream::BufferedSocket<TLS::TLSv12>::create( + TRY(TLS::TLSv12::connect(connection_info.url().host(), connection_info.url().port_or_default(), move(options))))); + } + + return TRY(Core::Stream::BufferedTCPSocket::create( + TRY(Core::Stream::TCPSocket::connect(connection_info.url().host(), connection_info.url().port_or_default())))); + }(); + + if (socket_result.is_error()) { + deferred_invoke([this] { + on_connection_error(); + }); + return; + } + + m_socket = socket_result.release_value(); + + m_socket->on_ready_to_read = [this] { + on_ready_to_read(); + }; + + deferred_invoke([this] { + on_connected(); + }); +} + +ErrorOr<ByteBuffer> WebSocketImpl::read(int max_size) +{ + auto buffer = TRY(ByteBuffer::create_uninitialized(max_size)); + auto nread = TRY(m_socket->read(buffer)); + return buffer.slice(0, nread); +} + +ErrorOr<String> WebSocketImpl::read_line(size_t size) +{ + auto buffer = TRY(ByteBuffer::create_uninitialized(size)); + auto nread = TRY(m_socket->read_line(buffer)); + return String::copy(buffer.span().slice(0, nread)); +} + +} diff --git a/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.h b/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.h new file mode 100644 index 0000000000..78f349e247 --- /dev/null +++ b/Userland/Libraries/LibWebSocket/Impl/WebSocketImpl.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021, Dex♪ <dexes.ttp@gmail.com> + * Copyright (c) 2022, Ali Mohammad Pur <mpfard@serenityos.org> + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include <AK/ByteBuffer.h> +#include <AK/Span.h> +#include <AK/String.h> +#include <LibCore/Object.h> +#include <LibWebSocket/ConnectionInfo.h> + +namespace WebSocket { + +class WebSocketImpl : public Core::Object { + C_OBJECT(WebSocketImpl); + +public: + virtual ~WebSocketImpl() override; + explicit WebSocketImpl(Core::Object* parent = nullptr); + + void connect(ConnectionInfo const&); + + bool can_read_line() { return MUST(m_socket->can_read_line()); } + ErrorOr<String> read_line(size_t size); + + bool can_read() { return MUST(m_socket->can_read_without_blocking()); } + ErrorOr<ByteBuffer> read(int max_size); + + bool send(ReadonlyBytes bytes) { return m_socket->write_or_error(bytes); } + + bool eof() { return m_socket->is_eof(); } + + void discard_connection() + { + m_socket.clear(); + } + + Function<void()> on_connected; + Function<void()> on_connection_error; + Function<void()> on_ready_to_read; + +private: + OwnPtr<Core::Stream::BufferedSocketBase> m_socket; +}; + +} diff --git a/Userland/Libraries/LibWebSocket/WebSocket.cpp b/Userland/Libraries/LibWebSocket/WebSocket.cpp index 6fd1873660..a38ea4f1f4 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.cpp +++ b/Userland/Libraries/LibWebSocket/WebSocket.cpp @@ -7,8 +7,6 @@ #include <AK/Base64.h> #include <AK/Random.h> #include <LibCrypto/Hash/HashManager.h> -#include <LibWebSocket/Impl/TCPWebSocketConnectionImpl.h> -#include <LibWebSocket/Impl/TLSv12WebSocketConnectionImpl.h> #include <LibWebSocket/WebSocket.h> #include <unistd.h> @@ -35,10 +33,7 @@ void WebSocket::start() { VERIFY(m_state == WebSocket::InternalState::NotStarted); VERIFY(!m_impl); - if (m_connection.is_secure()) - m_impl = TLSv12WebSocketConnectionImpl::construct(); - else - m_impl = TCPWebSocketConnectionImpl::construct(); + m_impl = WebSocketImpl::construct(); m_impl->on_connection_error = [this] { dbgln("WebSocket: Connection error (underlying socket)"); @@ -117,7 +112,8 @@ void WebSocket::drain_read() case InternalState::EstablishingProtocolConnection: case InternalState::SendingClientHandshake: { auto initializing_bytes = m_impl->read(1024); - dbgln("drain_read() was called on a websocket that isn't opened yet. Read {} bytes from the socket.", initializing_bytes.size()); + if (!initializing_bytes.is_error()) + dbgln("drain_read() was called on a websocket that isn't opened yet. Read {} bytes from the socket.", initializing_bytes.value().size()); } break; case InternalState::WaitingForServerHandshake: { read_server_handshake(); @@ -129,7 +125,8 @@ void WebSocket::drain_read() case InternalState::Closed: case InternalState::Errored: { auto closed_bytes = m_impl->read(1024); - dbgln("drain_read() was called on a closed websocket. Read {} bytes from the socket.", closed_bytes.size()); + if (!closed_bytes.is_error()) + dbgln("drain_read() was called on a closed websocket. Read {} bytes from the socket.", closed_bytes.value().size()); } break; default: VERIFY_NOT_REACHED(); @@ -209,7 +206,7 @@ void WebSocket::read_server_handshake() return; if (!m_has_read_server_handshake_first_line) { - auto header = m_impl->read_line(PAGE_SIZE); + auto header = m_impl->read_line(PAGE_SIZE).release_value_but_fixme_should_propagate_errors(); auto parts = header.split(' '); if (parts.size() < 2) { dbgln("WebSocket: Server HTTP Handshake contained HTTP header was malformed"); @@ -235,7 +232,7 @@ void WebSocket::read_server_handshake() // Read the rest of the reply until we find an empty line while (m_impl->can_read_line()) { - auto line = m_impl->read_line(PAGE_SIZE); + auto line = m_impl->read_line(PAGE_SIZE).release_value_but_fixme_should_propagate_errors(); if (line.is_whitespace()) { // We're done with the HTTP headers. // Fail the connection if we're missing any of the following: @@ -364,14 +361,15 @@ void WebSocket::read_frame() VERIFY(m_impl); VERIFY(m_state == WebSocket::InternalState::Open || m_state == WebSocket::InternalState::Closing); - auto head_bytes = m_impl->read(2); - if (head_bytes.size() == 0) { + auto head_bytes_result = m_impl->read(2); + if (head_bytes_result.is_error() || head_bytes_result.value().is_empty()) { // The connection got closed. m_state = WebSocket::InternalState::Closed; notify_close(m_last_close_code, m_last_close_message, true); discard_connection(); return; } + auto head_bytes = head_bytes_result.release_value(); VERIFY(head_bytes.size() == 2); bool is_final_frame = head_bytes[0] & 0x80; @@ -388,7 +386,7 @@ void WebSocket::read_frame() auto payload_length_bits = head_bytes[1] & 0x7f; if (payload_length_bits == 127) { // A code of 127 means that the next 8 bytes contains the payload length - auto actual_bytes = m_impl->read(8); + auto actual_bytes = MUST(m_impl->read(8)); VERIFY(actual_bytes.size() == 8); u64 full_payload_length = (u64)((u64)(actual_bytes[0] & 0xff) << 56) | (u64)((u64)(actual_bytes[1] & 0xff) << 48) @@ -402,7 +400,7 @@ void WebSocket::read_frame() payload_length = (size_t)full_payload_length; } else if (payload_length_bits == 126) { // A code of 126 means that the next 2 bytes contains the payload length - auto actual_bytes = m_impl->read(2); + auto actual_bytes = MUST(m_impl->read(2)); VERIFY(actual_bytes.size() == 2); payload_length = (size_t)((size_t)(actual_bytes[0] & 0xff) << 8) | (size_t)((size_t)(actual_bytes[1] & 0xff) << 0); @@ -418,7 +416,7 @@ void WebSocket::read_frame() // But because it doesn't cost much, we can support receiving masked frames anyways. u8 masking_key[4]; if (is_masked) { - auto masking_key_data = m_impl->read(4); + auto masking_key_data = MUST(m_impl->read(4)); VERIFY(masking_key_data.size() == 4); masking_key[0] = masking_key_data[0]; masking_key[1] = masking_key_data[1]; @@ -429,13 +427,14 @@ void WebSocket::read_frame() auto payload = ByteBuffer::create_uninitialized(payload_length).release_value_but_fixme_should_propagate_errors(); // FIXME: Handle possible OOM situation. u64 read_length = 0; while (read_length < payload_length) { - auto payload_part = m_impl->read(payload_length - read_length); - if (payload_part.size() == 0) { + auto payload_part_result = m_impl->read(payload_length - read_length); + if (payload_part_result.is_error() || payload_part_result.value().is_empty()) { // We got disconnected, somehow. dbgln("Websocket: Server disconnected while sending payload ({} bytes read out of {})", read_length, payload_length); fatal_error(WebSocket::Error::ServerClosedSocket); return; } + auto payload_part = payload_part_result.release_value(); // We read at most "actual_length - read" bytes, so this is safe to do. payload.overwrite(read_length, payload_part.data(), payload_part.size()); read_length -= payload_part.size(); diff --git a/Userland/Libraries/LibWebSocket/WebSocket.h b/Userland/Libraries/LibWebSocket/WebSocket.h index 9fd18893c2..75b778f04a 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.h +++ b/Userland/Libraries/LibWebSocket/WebSocket.h @@ -9,7 +9,7 @@ #include <AK/Span.h> #include <LibCore/Object.h> #include <LibWebSocket/ConnectionInfo.h> -#include <LibWebSocket/Impl/AbstractWebSocketImpl.h> +#include <LibWebSocket/Impl/WebSocketImpl.h> #include <LibWebSocket/Message.h> namespace WebSocket { @@ -104,7 +104,7 @@ private: String m_last_close_message; ConnectionInfo m_connection; - RefPtr<AbstractWebSocketImpl> m_impl; + RefPtr<WebSocketImpl> m_impl; }; } |