diff options
author | Jean-Baptiste Boric <jblbeurope@gmail.com> | 2021-07-18 12:24:34 +0200 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2021-08-07 11:48:00 +0200 |
commit | 9517100672bc01867e40b64753a429b9e7bff054 (patch) | |
tree | e1b477712a8857d1f2d0cab488cd496d9d525c93 /Kernel/Net | |
parent | 9216c72bfe47d58ccbee524f3006cfbe5a7d77e0 (diff) | |
download | serenity-9517100672bc01867e40b64753a429b9e7bff054.zip |
Kernel: Migrate UDP socket table locking to ProtectedValue
Diffstat (limited to 'Kernel/Net')
-rw-r--r-- | Kernel/Net/UDPSocket.cpp | 69 | ||||
-rw-r--r-- | Kernel/Net/UDPSocket.h | 4 |
2 files changed, 37 insertions, 36 deletions
diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 1c719bb516..ef5783137f 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -6,7 +6,6 @@ #include <AK/Singleton.h> #include <Kernel/Devices/RandomDevice.h> -#include <Kernel/Locking/Mutex.h> #include <Kernel/Net/NetworkAdapter.h> #include <Kernel/Net/Routing.h> #include <Kernel/Net/UDP.h> @@ -18,30 +17,29 @@ namespace Kernel { void UDPSocket::for_each(Function<void(const UDPSocket&)> callback) { - MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); - for (auto it : sockets_by_port().resource()) - callback(*it.value); + sockets_by_port().for_each_shared([&](const auto& socket) { + callback(*socket.value); + }); } -static AK::Singleton<Lockable<HashMap<u16, UDPSocket*>>> s_map; +static AK::Singleton<ProtectedValue<HashMap<u16, UDPSocket*>>> s_map; -Lockable<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port() +ProtectedValue<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port() { return *s_map; } SocketHandle<UDPSocket> UDPSocket::from_port(u16 port) { - RefPtr<UDPSocket> socket; - { - MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared); - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) + return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle<UDPSocket> { + RefPtr<UDPSocket> socket; + auto it = table.find(port); + if (it == table.end()) return {}; socket = (*it).value; VERIFY(socket); - } - return { *socket }; + return { *socket }; + }); } UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer) @@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer) UDPSocket::~UDPSocket() { - MutexLocker locker(sockets_by_port().lock()); - sockets_by_port().resource().remove(local_port()); + sockets_by_port().with_exclusive([&](auto& table) { + table.remove(local_port()); + }); } KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer) @@ -113,30 +112,32 @@ KResultOr<u16> UDPSocket::protocol_allocate_local_port() constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port; u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size; - MutexLocker locker(sockets_by_port().lock()); - for (u16 port = first_scan_port;;) { - auto it = sockets_by_port().resource().find(port); - if (it == sockets_by_port().resource().end()) { - set_local_port(port); - sockets_by_port().resource().set(port, this); - return port; + return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr<u16> { + for (u16 port = first_scan_port;;) { + auto it = table.find(port); + if (it == table.end()) { + set_local_port(port); + table.set(port, this); + return port; + } + ++port; + if (port > last_ephemeral_port) + port = first_ephemeral_port; + if (port == first_scan_port) + break; } - ++port; - if (port > last_ephemeral_port) - port = first_ephemeral_port; - if (port == first_scan_port) - break; - } - return EADDRINUSE; + return EADDRINUSE; + }); } KResult UDPSocket::protocol_bind() { - MutexLocker locker(sockets_by_port().lock()); - if (sockets_by_port().resource().contains(local_port())) - return EADDRINUSE; - sockets_by_port().resource().set(local_port(), this); - return KSuccess; + return sockets_by_port().with_exclusive([&](auto& table) -> KResult { + if (table.contains(local_port())) + return EADDRINUSE; + table.set(local_port(), this); + return KSuccess; + }); } } diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index b13b07ec91..d57ff39445 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -7,7 +7,7 @@ #pragma once #include <Kernel/KResult.h> -#include <Kernel/Locking/Lockable.h> +#include <Kernel/Locking/ProtectedValue.h> #include <Kernel/Net/IPv4Socket.h> namespace Kernel { @@ -23,7 +23,7 @@ public: private: explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer); virtual StringView class_name() const override { return "UDPSocket"; } - static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port(); + static ProtectedValue<HashMap<u16, UDPSocket*>>& sockets_by_port(); virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override; virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override; |