summaryrefslogtreecommitdiff
path: root/Kernel/Net
diff options
context:
space:
mode:
authorJean-Baptiste Boric <jblbeurope@gmail.com>2021-07-18 12:24:34 +0200
committerAndreas Kling <kling@serenityos.org>2021-08-07 11:48:00 +0200
commit9517100672bc01867e40b64753a429b9e7bff054 (patch)
treee1b477712a8857d1f2d0cab488cd496d9d525c93 /Kernel/Net
parent9216c72bfe47d58ccbee524f3006cfbe5a7d77e0 (diff)
downloadserenity-9517100672bc01867e40b64753a429b9e7bff054.zip
Kernel: Migrate UDP socket table locking to ProtectedValue
Diffstat (limited to 'Kernel/Net')
-rw-r--r--Kernel/Net/UDPSocket.cpp69
-rw-r--r--Kernel/Net/UDPSocket.h4
2 files changed, 37 insertions, 36 deletions
diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp
index 1c719bb516..ef5783137f 100644
--- a/Kernel/Net/UDPSocket.cpp
+++ b/Kernel/Net/UDPSocket.cpp
@@ -6,7 +6,6 @@
#include <AK/Singleton.h>
#include <Kernel/Devices/RandomDevice.h>
-#include <Kernel/Locking/Mutex.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Net/Routing.h>
#include <Kernel/Net/UDP.h>
@@ -18,30 +17,29 @@ namespace Kernel {
void UDPSocket::for_each(Function<void(const UDPSocket&)> callback)
{
- MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
- for (auto it : sockets_by_port().resource())
- callback(*it.value);
+ sockets_by_port().for_each_shared([&](const auto& socket) {
+ callback(*socket.value);
+ });
}
-static AK::Singleton<Lockable<HashMap<u16, UDPSocket*>>> s_map;
+static AK::Singleton<ProtectedValue<HashMap<u16, UDPSocket*>>> s_map;
-Lockable<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
+ProtectedValue<HashMap<u16, UDPSocket*>>& UDPSocket::sockets_by_port()
{
return *s_map;
}
SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
{
- RefPtr<UDPSocket> socket;
- {
- MutexLocker locker(sockets_by_port().lock(), Mutex::Mode::Shared);
- auto it = sockets_by_port().resource().find(port);
- if (it == sockets_by_port().resource().end())
+ return sockets_by_port().with_shared([&](const auto& table) -> SocketHandle<UDPSocket> {
+ RefPtr<UDPSocket> socket;
+ auto it = table.find(port);
+ if (it == table.end())
return {};
socket = (*it).value;
VERIFY(socket);
- }
- return { *socket };
+ return { *socket };
+ });
}
UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@@ -51,8 +49,9 @@ UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
UDPSocket::~UDPSocket()
{
- MutexLocker locker(sockets_by_port().lock());
- sockets_by_port().resource().remove(local_port());
+ sockets_by_port().with_exclusive([&](auto& table) {
+ table.remove(local_port());
+ });
}
KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
@@ -113,30 +112,32 @@ KResultOr<u16> UDPSocket::protocol_allocate_local_port()
constexpr u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
- MutexLocker locker(sockets_by_port().lock());
- for (u16 port = first_scan_port;;) {
- auto it = sockets_by_port().resource().find(port);
- if (it == sockets_by_port().resource().end()) {
- set_local_port(port);
- sockets_by_port().resource().set(port, this);
- return port;
+ return sockets_by_port().with_exclusive([&](auto& table) -> KResultOr<u16> {
+ for (u16 port = first_scan_port;;) {
+ auto it = table.find(port);
+ if (it == table.end()) {
+ set_local_port(port);
+ table.set(port, this);
+ return port;
+ }
+ ++port;
+ if (port > last_ephemeral_port)
+ port = first_ephemeral_port;
+ if (port == first_scan_port)
+ break;
}
- ++port;
- if (port > last_ephemeral_port)
- port = first_ephemeral_port;
- if (port == first_scan_port)
- break;
- }
- return EADDRINUSE;
+ return EADDRINUSE;
+ });
}
KResult UDPSocket::protocol_bind()
{
- MutexLocker locker(sockets_by_port().lock());
- if (sockets_by_port().resource().contains(local_port()))
- return EADDRINUSE;
- sockets_by_port().resource().set(local_port(), this);
- return KSuccess;
+ return sockets_by_port().with_exclusive([&](auto& table) -> KResult {
+ if (table.contains(local_port()))
+ return EADDRINUSE;
+ table.set(local_port(), this);
+ return KSuccess;
+ });
}
}
diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h
index b13b07ec91..d57ff39445 100644
--- a/Kernel/Net/UDPSocket.h
+++ b/Kernel/Net/UDPSocket.h
@@ -7,7 +7,7 @@
#pragma once
#include <Kernel/KResult.h>
-#include <Kernel/Locking/Lockable.h>
+#include <Kernel/Locking/ProtectedValue.h>
#include <Kernel/Net/IPv4Socket.h>
namespace Kernel {
@@ -23,7 +23,7 @@ public:
private:
explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
virtual StringView class_name() const override { return "UDPSocket"; }
- static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
+ static ProtectedValue<HashMap<u16, UDPSocket*>>& sockets_by_port();
virtual KResultOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;