diff options
author | Brian Gianforcaro <bgianf@serenityos.org> | 2021-05-13 01:01:38 -0700 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2021-05-13 16:21:53 +0200 |
commit | 46ce7adf7b9e0101e1f8e4cb63757cf522d1ec62 (patch) | |
tree | 0ac4cc95877c99dbb1b8bffe050cb9a5dbdb5965 | |
parent | 9375f3dc094fda351a7bb5bc6f78438b73beffdb (diff) | |
download | serenity-46ce7adf7b9e0101e1f8e4cb63757cf522d1ec62.zip |
Kernel: Make TCPSocket::create API OOM safe
Note that the changes to IPv4Socket::create are unfortunately needed as
the return type of TCPSocket::create and IPv4Socket::create don't match.
- KResultOr<NonnullRefPtr<TcpSocket>>>
vs
- KResultOr<NonnullRefPtr<Socket>>>
To handle this we are forced to manually decompose the KResultOr<T> and
return the value() and error() separately.
-rw-r--r-- | Kernel/Net/IPv4Socket.cpp | 8 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.cpp | 12 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.h | 3 |
3 files changed, 17 insertions, 6 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 610274c295..728b4b1c37 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -36,8 +36,12 @@ Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets() KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol) { - if (type == SOCK_STREAM) - return TCPSocket::create(protocol); + if (type == SOCK_STREAM) { + auto tcp_socket = TCPSocket::create(protocol); + if (tcp_socket.is_error()) + return tcp_socket.error(); + return tcp_socket.release_value(); + } if (type == SOCK_DGRAM) return UDPSocket::create(protocol); if (type == SOCK_RAW) diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 035e76d480..353a59fbc1 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -98,8 +98,11 @@ RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address, if (sockets_by_tuple().resource().contains(tuple)) return {}; - auto client = TCPSocket::create(protocol()); + auto result = TCPSocket::create(protocol()); + if (result.is_error()) + return {}; + auto client = result.release_value(); client->set_setup_state(SetupState::InProgress); client->set_local_address(new_local_address); client->set_local_port(new_local_port); @@ -142,9 +145,12 @@ TCPSocket::~TCPSocket() dbgln_if(TCP_SOCKET_DEBUG, "~TCPSocket in state {}", to_string(state())); } -NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol) +KResultOr<NonnullRefPtr<TCPSocket>> TCPSocket::create(int protocol) { - return adopt_ref(*new TCPSocket(protocol)); + auto socket = adopt_ref_if_nonnull(new TCPSocket(protocol)); + if (socket) + return socket.release_nonnull(); + return ENOMEM; } KResultOr<size_t> TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, [[maybe_unused]] int flags) diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 28621a98f0..ac3489ef96 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -10,6 +10,7 @@ #include <AK/HashMap.h> #include <AK/SinglyLinkedList.h> #include <AK/WeakPtr.h> +#include <Kernel/KResult.h> #include <Kernel/Net/IPv4Socket.h> namespace Kernel { @@ -17,7 +18,7 @@ namespace Kernel { class TCPSocket final : public IPv4Socket { public: static void for_each(Function<void(const TCPSocket&)>); - static NonnullRefPtr<TCPSocket> create(int protocol); + static KResultOr<NonnullRefPtr<TCPSocket>> create(int protocol); virtual ~TCPSocket() override; enum class Direction { |