summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiav A <liavalb@gmail.com>2023-04-11 03:50:15 +0300
committerLinus Groh <mail@linusgroh.de>2023-04-14 19:27:56 +0200
commit7c1f645e27038eb12e2e46ea0103cc04c96edbb8 (patch)
tree4624c34cec0b0477dd0cf4fc9181163e17373d19
parentbd7d4513bfcc250c472c84365649f522b074816f (diff)
downloadserenity-7c1f645e27038eb12e2e46ea0103cc04c96edbb8.zip
Kernel/Net: Iron out the locking mechanism across the subsystem
There is a big mix of LockRefPtrs all over the Networking subsystem, as well as lots of room for improvements with our locking patterns, which this commit will not pursue, but will give a good start for such work. To deal with this situation, we change the following things: - Creating instances of NetworkAdapter should always yield a non-locking NonnullRefPtr. Acquiring an instance from the NetworkingManagement should give a simple RefPtr,as giving LockRefPtr does not really protect from concurrency problems in such case. - Since NetworkingManagement works with normal RefPtrs we should protect all instances of RefPtr<NetworkAdapter> with SpinlockProtected to ensure references are gone unexpectedly. - Protect the so_error class member with a proper spinlock. This happens to be important because the clear_so_error() method lacked any proper locking measures. It also helps preventing a possible TOCTOU when we might do a more fine-grained locking in the Socket code, so this could be definitely a start for this. - Change unnecessary LockRefPtr<PacketWithTimestamp> in the structure of OutgoingPacket to a simple RefPtr<PacketWithTimestamp> as the whole list should be MutexProtected.
-rw-r--r--Kernel/Net/IPv4Socket.cpp3
-rw-r--r--Kernel/Net/NetworkAdapter.cpp6
-rw-r--r--Kernel/Net/NetworkAdapter.h4
-rw-r--r--Kernel/Net/NetworkTask.cpp4
-rw-r--r--Kernel/Net/NetworkingManagement.cpp10
-rw-r--r--Kernel/Net/NetworkingManagement.h12
-rw-r--r--Kernel/Net/Routing.cpp12
-rw-r--r--Kernel/Net/Routing.h14
-rw-r--r--Kernel/Net/Socket.cpp50
-rw-r--r--Kernel/Net/Socket.h30
-rw-r--r--Kernel/Net/TCPSocket.cpp24
-rw-r--r--Kernel/Net/TCPSocket.h4
-rw-r--r--Kernel/Net/UDPSocket.cpp3
13 files changed, 93 insertions, 83 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp
index 67ab93c33d..61bdb59441 100644
--- a/Kernel/Net/IPv4Socket.cpp
+++ b/Kernel/Net/IPv4Socket.cpp
@@ -212,7 +212,8 @@ ErrorOr<size_t> IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons
return set_so_error(EPIPE);
auto allow_using_gateway = ((flags & MSG_DONTROUTE) || m_routing_disabled) ? AllowUsingGateway::No : AllowUsingGateway::Yes;
- auto routing_decision = route_to(m_peer_address, m_local_address, bound_interface(), allow_using_gateway);
+ auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
+ auto routing_decision = route_to(m_peer_address, m_local_address, adapter, allow_using_gateway);
if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH);
diff --git a/Kernel/Net/NetworkAdapter.cpp b/Kernel/Net/NetworkAdapter.cpp
index b1139ba8ad..9ccfd88d84 100644
--- a/Kernel/Net/NetworkAdapter.cpp
+++ b/Kernel/Net/NetworkAdapter.cpp
@@ -111,9 +111,9 @@ size_t NetworkAdapter::dequeue_packet(u8* buffer, size_t buffer_size, Time& pack
return packet_size;
}
-LockRefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t size)
+RefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t size)
{
- auto packet = m_unused_packets.with([size](auto& unused_packets) -> LockRefPtr<PacketWithTimestamp> {
+ auto packet = m_unused_packets.with([size](auto& unused_packets) -> RefPtr<PacketWithTimestamp> {
if (unused_packets.is_empty())
return nullptr;
@@ -135,7 +135,7 @@ LockRefPtr<PacketWithTimestamp> NetworkAdapter::acquire_packet_buffer(size_t siz
auto buffer_or_error = KBuffer::try_create_with_size("NetworkAdapter: Packet buffer"sv, size, Memory::Region::Access::ReadWrite, AllocationStrategy::AllocateNow);
if (buffer_or_error.is_error())
return {};
- packet = adopt_lock_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() });
+ packet = adopt_ref_if_nonnull(new (nothrow) PacketWithTimestamp { buffer_or_error.release_value(), kgettimeofday() });
if (!packet)
return {};
packet->buffer->set_size(size);
diff --git a/Kernel/Net/NetworkAdapter.h b/Kernel/Net/NetworkAdapter.h
index bfc9314fd7..6fd6cabbd3 100644
--- a/Kernel/Net/NetworkAdapter.h
+++ b/Kernel/Net/NetworkAdapter.h
@@ -39,7 +39,7 @@ struct PacketWithTimestamp final : public AtomicRefCounted<PacketWithTimestamp>
NonnullOwnPtr<KBuffer> buffer;
Time timestamp;
- IntrusiveListNode<PacketWithTimestamp, LockRefPtr<PacketWithTimestamp>> packet_node;
+ IntrusiveListNode<PacketWithTimestamp, RefPtr<PacketWithTimestamp>> packet_node;
};
class NetworkingManagement;
@@ -91,7 +91,7 @@ public:
u32 packets_out() const { return m_packets_out; }
u32 bytes_out() const { return m_bytes_out; }
- LockRefPtr<PacketWithTimestamp> acquire_packet_buffer(size_t);
+ RefPtr<PacketWithTimestamp> acquire_packet_buffer(size_t);
void release_packet_buffer(PacketWithTimestamp&);
constexpr size_t layer3_payload_offset() const { return sizeof(EthernetFrameHeader); }
diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp
index b5de06dda8..a89860452c 100644
--- a/Kernel/Net/NetworkTask.cpp
+++ b/Kernel/Net/NetworkTask.cpp
@@ -31,7 +31,7 @@ static void handle_icmp(EthernetFrameHeader const&, IPv4Packet const&, Time cons
static void handle_udp(IPv4Packet const&, Time const& packet_timestamp);
static void handle_tcp(IPv4Packet const&, Time const& packet_timestamp);
static void send_delayed_tcp_ack(TCPSocket& socket);
-static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr<NetworkAdapter> adapter);
+static void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr<NetworkAdapter> adapter);
static void flush_delayed_tcp_acks();
static void retransmit_tcp_packets();
@@ -333,7 +333,7 @@ void flush_delayed_tcp_acks()
}
}
-void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, LockRefPtr<NetworkAdapter> adapter)
+void send_tcp_rst(IPv4Packet const& ipv4_packet, TCPPacket const& tcp_packet, RefPtr<NetworkAdapter> adapter)
{
auto routing_decision = route_to(ipv4_packet.source(), ipv4_packet.destination(), adapter);
if (routing_decision.is_zero())
diff --git a/Kernel/Net/NetworkingManagement.cpp b/Kernel/Net/NetworkingManagement.cpp
index 70958c28fc..111696cbbd 100644
--- a/Kernel/Net/NetworkingManagement.cpp
+++ b/Kernel/Net/NetworkingManagement.cpp
@@ -35,7 +35,7 @@ UNMAP_AFTER_INIT NetworkingManagement::NetworkingManagement()
{
}
-NonnullLockRefPtr<NetworkAdapter> NetworkingManagement::loopback_adapter() const
+NonnullRefPtr<NetworkAdapter> NetworkingManagement::loopback_adapter() const
{
return *m_loopback_adapter;
}
@@ -56,13 +56,13 @@ ErrorOr<void> NetworkingManagement::try_for_each(Function<ErrorOr<void>(NetworkA
});
}
-LockRefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address const& address) const
+RefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address const& address) const
{
if (address[0] == 0 && address[1] == 0 && address[2] == 0 && address[3] == 0)
return m_loopback_adapter;
if (address[0] == 127)
return m_loopback_adapter;
- return m_adapters.with([&](auto& adapters) -> LockRefPtr<NetworkAdapter> {
+ return m_adapters.with([&](auto& adapters) -> RefPtr<NetworkAdapter> {
for (auto& adapter : adapters) {
if (adapter->ipv4_address() == address || adapter->ipv4_broadcast() == address)
return adapter;
@@ -71,9 +71,9 @@ LockRefPtr<NetworkAdapter> NetworkingManagement::from_ipv4_address(IPv4Address c
});
}
-LockRefPtr<NetworkAdapter> NetworkingManagement::lookup_by_name(StringView name) const
+RefPtr<NetworkAdapter> NetworkingManagement::lookup_by_name(StringView name) const
{
- return m_adapters.with([&](auto& adapters) -> LockRefPtr<NetworkAdapter> {
+ return m_adapters.with([&](auto& adapters) -> RefPtr<NetworkAdapter> {
for (auto& adapter : adapters) {
if (adapter->name() == name)
return adapter;
diff --git a/Kernel/Net/NetworkingManagement.h b/Kernel/Net/NetworkingManagement.h
index 587773030a..ad18a7513e 100644
--- a/Kernel/Net/NetworkingManagement.h
+++ b/Kernel/Net/NetworkingManagement.h
@@ -8,9 +8,9 @@
#include <AK/Function.h>
#include <AK/NonnullOwnPtr.h>
+#include <AK/RefPtr.h>
#include <AK/Types.h>
#include <Kernel/Bus/PCI/Definitions.h>
-#include <Kernel/Library/NonnullLockRefPtr.h>
#include <Kernel/Locking/SpinlockProtected.h>
#include <Kernel/Memory/Region.h>
#include <Kernel/Net/NetworkAdapter.h>
@@ -33,16 +33,16 @@ public:
void for_each(Function<void(NetworkAdapter&)>);
ErrorOr<void> try_for_each(Function<ErrorOr<void>(NetworkAdapter&)>);
- LockRefPtr<NetworkAdapter> from_ipv4_address(IPv4Address const&) const;
- LockRefPtr<NetworkAdapter> lookup_by_name(StringView) const;
+ RefPtr<NetworkAdapter> from_ipv4_address(IPv4Address const&) const;
+ RefPtr<NetworkAdapter> lookup_by_name(StringView) const;
- NonnullLockRefPtr<NetworkAdapter> loopback_adapter() const;
+ NonnullRefPtr<NetworkAdapter> loopback_adapter() const;
private:
ErrorOr<NonnullRefPtr<NetworkAdapter>> determine_network_device(PCI::DeviceIdentifier const&) const;
- SpinlockProtected<Vector<NonnullLockRefPtr<NetworkAdapter>>, LockRank::None> m_adapters {};
- LockRefPtr<NetworkAdapter> m_loopback_adapter;
+ SpinlockProtected<Vector<NonnullRefPtr<NetworkAdapter>>, LockRank::None> m_adapters {};
+ RefPtr<NetworkAdapter> m_loopback_adapter;
};
}
diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp
index 81ca0b000c..9ceaa71bc5 100644
--- a/Kernel/Net/Routing.cpp
+++ b/Kernel/Net/Routing.cpp
@@ -135,11 +135,11 @@ SpinlockProtected<Route::RouteList, LockRank::None>& routing_table()
return *s_routing_table;
}
-ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr<NetworkAdapter> adapter, UpdateTable update)
+ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr<NetworkAdapter> adapter, UpdateTable update)
{
dbgln_if(ROUTING_DEBUG, "update_routing_table {} {} {} {} {} {}", destination, gateway, netmask, flags, adapter, update == UpdateTable::Set ? "Set" : "Delete");
- auto route_entry = adopt_lock_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() });
+ auto route_entry = adopt_ref_if_nonnull(new (nothrow) Route { destination, gateway, netmask, flags, adapter.release_nonnull() });
if (!route_entry)
return ENOMEM;
@@ -178,7 +178,7 @@ static MACAddress multicast_ethernet_address(IPv4Address const& address)
return MACAddress { 0x01, 0x00, 0x5e, (u8)(address[1] & 0x7f), address[2], address[3] };
}
-RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr<NetworkAdapter> const through, AllowUsingGateway allow_using_gateway)
+RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr<NetworkAdapter> const through, AllowUsingGateway allow_using_gateway)
{
auto matches = [&](auto& adapter) {
if (!through)
@@ -200,8 +200,8 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L
auto target_addr = target.to_u32();
auto source_addr = source.to_u32();
- LockRefPtr<NetworkAdapter> local_adapter = nullptr;
- LockRefPtr<Route> chosen_route = nullptr;
+ RefPtr<NetworkAdapter> local_adapter = nullptr;
+ RefPtr<Route> chosen_route = nullptr;
NetworkingManagement::the().for_each([source_addr, &target_addr, &local_adapter, &matches, &through](NetworkAdapter& adapter) {
auto adapter_addr = adapter.ipv4_address().to_u32();
@@ -263,7 +263,7 @@ RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, L
return { nullptr, {} };
}
- LockRefPtr<NetworkAdapter> adapter = nullptr;
+ RefPtr<NetworkAdapter> adapter = nullptr;
IPv4Address next_hop_ip;
if (local_adapter) {
diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h
index 8732184901..2ad6a08d90 100644
--- a/Kernel/Net/Routing.h
+++ b/Kernel/Net/Routing.h
@@ -7,7 +7,7 @@
#pragma once
#include <AK/IPv4Address.h>
-#include <Kernel/Library/NonnullLockRefPtr.h>
+#include <AK/RefPtr.h>
#include <Kernel/Locking/MutexProtected.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/Thread.h>
@@ -15,7 +15,7 @@
namespace Kernel {
struct Route final : public AtomicRefCounted<Route> {
- Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullLockRefPtr<NetworkAdapter> adapter)
+ Route(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, NonnullRefPtr<NetworkAdapter> adapter)
: destination(destination)
, gateway(gateway)
, netmask(netmask)
@@ -38,14 +38,14 @@ struct Route final : public AtomicRefCounted<Route> {
const IPv4Address gateway;
const IPv4Address netmask;
const u16 flags;
- NonnullLockRefPtr<NetworkAdapter> adapter;
+ NonnullRefPtr<NetworkAdapter> const adapter;
- IntrusiveListNode<Route, LockRefPtr<Route>> route_list_node {};
+ IntrusiveListNode<Route, RefPtr<Route>> route_list_node {};
using RouteList = IntrusiveList<&Route::route_list_node>;
};
struct RoutingDecision {
- LockRefPtr<NetworkAdapter> adapter;
+ RefPtr<NetworkAdapter> adapter;
MACAddress next_hop;
bool is_zero() const;
@@ -57,14 +57,14 @@ enum class UpdateTable {
};
void update_arp_table(IPv4Address const&, MACAddress const&, UpdateTable update);
-ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, LockRefPtr<NetworkAdapter> const adapter, UpdateTable update);
+ErrorOr<void> update_routing_table(IPv4Address const& destination, IPv4Address const& gateway, IPv4Address const& netmask, u16 flags, RefPtr<NetworkAdapter> const adapter, UpdateTable update);
enum class AllowUsingGateway {
Yes,
No,
};
-RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, LockRefPtr<NetworkAdapter> const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes);
+RoutingDecision route_to(IPv4Address const& target, IPv4Address const& source, RefPtr<NetworkAdapter> const through = nullptr, AllowUsingGateway = AllowUsingGateway::Yes);
SpinlockProtected<HashMap<IPv4Address, MACAddress>, LockRank::None>& arp_table();
SpinlockProtected<Route::RouteList, LockRank::None>& routing_table();
diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp
index 9ae1585b94..f5e8357a70 100644
--- a/Kernel/Net/Socket.cpp
+++ b/Kernel/Net/Socket.cpp
@@ -100,7 +100,9 @@ ErrorOr<void> Socket::setsockopt(int level, int option, Userspace<void const*> u
auto device = NetworkingManagement::the().lookup_by_name(ifname->view());
if (!device)
return ENODEV;
- m_bound_interface = move(device);
+ m_bound_interface.with([&device](auto& bound_device) {
+ bound_device = move(device);
+ });
return {};
}
case SO_DEBUG:
@@ -169,31 +171,35 @@ ErrorOr<void> Socket::getsockopt(OpenFileDescription&, int level, int option, Us
case SO_ERROR: {
if (size < sizeof(int))
return EINVAL;
- int errno = 0;
- if (auto const& error = so_error(); error.has_value())
- errno = error.value();
- TRY(copy_to_user(static_ptr_cast<int*>(value), &errno));
- size = sizeof(int);
- TRY(copy_to_user(value_size, &size));
- clear_so_error();
- return {};
+ return so_error().with([&size, value, value_size](auto& error) -> ErrorOr<void> {
+ int errno = 0;
+ if (error.has_value())
+ errno = error.value();
+ TRY(copy_to_user(static_ptr_cast<int*>(value), &errno));
+ size = sizeof(int);
+ TRY(copy_to_user(value_size, &size));
+ error = {};
+ return {};
+ });
}
case SO_BINDTODEVICE:
if (size < IFNAMSIZ)
return EINVAL;
- if (m_bound_interface) {
- auto name = m_bound_interface->name();
- auto length = name.length() + 1;
- auto characters = name.characters_without_null_termination();
- TRY(copy_to_user(static_ptr_cast<char*>(value), characters, length));
- size = length;
- return copy_to_user(value_size, &size);
- } else {
- size = 0;
- TRY(copy_to_user(value_size, &size));
- // FIXME: This return value looks suspicious.
- return EFAULT;
- }
+ return m_bound_interface.with([&](auto& bound_device) -> ErrorOr<void> {
+ if (bound_device) {
+ auto name = bound_device->name();
+ auto length = name.length() + 1;
+ auto characters = name.characters_without_null_termination();
+ TRY(copy_to_user(static_ptr_cast<char*>(value), characters, length));
+ size = length;
+ return copy_to_user(value_size, &size);
+ } else {
+ size = 0;
+ TRY(copy_to_user(value_size, &size));
+ // FIXME: This return value looks suspicious.
+ return EFAULT;
+ }
+ });
case SO_TIMESTAMP:
if (size < sizeof(int))
return EINVAL;
diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h
index 746274e0ff..ff071d2d83 100644
--- a/Kernel/Net/Socket.h
+++ b/Kernel/Net/Socket.h
@@ -7,9 +7,9 @@
#pragma once
#include <AK/Error.h>
+#include <AK/RefPtr.h>
#include <AK/Time.h>
#include <Kernel/FileSystem/File.h>
-#include <Kernel/Library/LockRefPtr.h>
#include <Kernel/Locking/Mutex.h>
#include <Kernel/Net/NetworkAdapter.h>
#include <Kernel/UnixTypes.h>
@@ -90,7 +90,7 @@ public:
ProcessID acceptor_pid() const { return m_acceptor.pid; }
UserID acceptor_uid() const { return m_acceptor.uid; }
GroupID acceptor_gid() const { return m_acceptor.gid; }
- LockRefPtr<NetworkAdapter> const bound_interface() const { return m_bound_interface; }
+ SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> const& bound_interface() const { return m_bound_interface; }
Mutex& mutex() { return m_mutex; }
@@ -123,31 +123,29 @@ protected:
Role m_role { Role::None };
- Optional<ErrnoCode> const& so_error() const
- {
- VERIFY(m_mutex.is_exclusively_locked_by_current_thread());
- return m_so_error;
- }
+ SpinlockProtected<Optional<ErrnoCode>, LockRank::None>& so_error() { return m_so_error; }
Error set_so_error(ErrnoCode error_code)
{
- MutexLocker locker(mutex());
- m_so_error = error_code;
-
+ m_so_error.with([&error_code](auto& so_error) {
+ so_error = error_code;
+ });
return Error::from_errno(error_code);
}
Error set_so_error(Error error)
{
- MutexLocker locker(mutex());
- m_so_error = static_cast<ErrnoCode>(error.code());
-
+ m_so_error.with([&error](auto& so_error) {
+ so_error = static_cast<ErrnoCode>(error.code());
+ });
return error;
}
void clear_so_error()
{
- m_so_error = {};
+ m_so_error.with([](auto& so_error) {
+ so_error = {};
+ });
}
void set_origin(Process const&);
@@ -173,13 +171,13 @@ private:
bool m_shut_down_for_reading { false };
bool m_shut_down_for_writing { false };
- LockRefPtr<NetworkAdapter> m_bound_interface { nullptr };
+ SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> m_bound_interface;
Time m_receive_timeout {};
Time m_send_timeout {};
int m_timestamp { 0 };
- Optional<ErrnoCode> m_so_error;
+ SpinlockProtected<Optional<ErrnoCode>, LockRank::None> m_so_error;
Vector<NonnullRefPtr<Socket>> m_pending;
};
diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp
index f238c47a7d..c42e9c6b65 100644
--- a/Kernel/Net/TCPSocket.cpp
+++ b/Kernel/Net/TCPSocket.cpp
@@ -202,7 +202,8 @@ ErrorOr<size_t> TCPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO
ErrorOr<size_t> TCPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length)
{
- RoutingDecision routing_decision = route_to(peer_address(), local_address(), bound_interface());
+ auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
+ RoutingDecision routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH);
size_t mss = routing_decision.adapter->mtu() - sizeof(IPv4Packet) - sizeof(TCPPacket);
@@ -220,7 +221,8 @@ ErrorOr<void> TCPSocket::send_ack(bool allow_duplicate)
ErrorOr<void> TCPSocket::send_tcp_packet(u16 flags, UserOrKernelBuffer const* payload, size_t payload_size, RoutingDecision* user_routing_decision)
{
- RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), bound_interface());
+ auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
+ RoutingDecision routing_decision = user_routing_decision ? *user_routing_decision : route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH);
@@ -409,13 +411,14 @@ NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(IPv4Address const& source, I
ErrorOr<void> TCPSocket::protocol_bind()
{
- if (has_specific_local_address() && !m_adapter) {
- m_adapter = NetworkingManagement::the().from_ipv4_address(local_address());
- if (!m_adapter)
- return set_so_error(EADDRNOTAVAIL);
- }
-
- return {};
+ return m_adapter.with([this](auto& adapter) -> ErrorOr<void> {
+ if (has_specific_local_address() && !adapter) {
+ adapter = NetworkingManagement::the().from_ipv4_address(local_address());
+ if (!adapter)
+ return set_so_error(EADDRNOTAVAIL);
+ }
+ return {};
+ });
}
ErrorOr<void> TCPSocket::protocol_listen(bool did_allocate_port)
@@ -598,7 +601,8 @@ void TCPSocket::retransmit_packets()
return;
}
- auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
+ auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
+ auto routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero())
return;
diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h
index ae73102832..a949a0e92b 100644
--- a/Kernel/Net/TCPSocket.h
+++ b/Kernel/Net/TCPSocket.h
@@ -189,7 +189,7 @@ private:
HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept;
Direction m_direction { Direction::Unspecified };
Error m_error { Error::None };
- LockRefPtr<NetworkAdapter> m_adapter;
+ SpinlockProtected<RefPtr<NetworkAdapter>, LockRank::None> m_adapter;
u32 m_sequence_number { 0 };
u32 m_ack_number { 0 };
State m_state { State::Closed };
@@ -200,7 +200,7 @@ private:
struct OutgoingPacket {
u32 ack_number { 0 };
- LockRefPtr<PacketWithTimestamp> buffer;
+ RefPtr<PacketWithTimestamp> buffer;
size_t ipv4_payload_offset;
LockWeakPtr<NetworkAdapter> adapter;
int tx_counter { 0 };
diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp
index 88fc3388e2..2372422cc2 100644
--- a/Kernel/Net/UDPSocket.cpp
+++ b/Kernel/Net/UDPSocket.cpp
@@ -84,7 +84,8 @@ ErrorOr<size_t> UDPSocket::protocol_receive(ReadonlyBytes raw_ipv4_packet, UserO
ErrorOr<size_t> UDPSocket::protocol_send(UserOrKernelBuffer const& data, size_t data_length)
{
- auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
+ auto adapter = bound_interface().with([](auto& bound_device) -> RefPtr<NetworkAdapter> { return bound_device; });
+ auto routing_decision = route_to(peer_address(), local_address(), adapter);
if (routing_decision.is_zero())
return set_so_error(EHOSTUNREACH);
auto ipv4_payload_offset = routing_decision.adapter->ipv4_payload_offset();