diff options
Diffstat (limited to 'Kernel')
-rw-r--r-- | Kernel/Net/IPv4Socket.cpp | 2 | ||||
-rw-r--r-- | Kernel/Net/NetworkAdapter.cpp | 4 | ||||
-rw-r--r-- | Kernel/Net/NetworkAdapter.h | 6 | ||||
-rw-r--r-- | Kernel/Net/NetworkTask.cpp | 10 | ||||
-rw-r--r-- | Kernel/Net/Routing.cpp | 4 | ||||
-rw-r--r-- | Kernel/Net/Routing.h | 2 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.cpp | 11 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.h | 3 | ||||
-rw-r--r-- | Kernel/Net/UDPSocket.cpp | 2 |
9 files changed, 28 insertions, 16 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index a97005bb11..f33e90224c 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -169,7 +169,7 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt m_peer_port = ntohs(ia.sin_port); } - auto* adapter = adapter_for_route_to(m_peer_address); + auto adapter = adapter_for_route_to(m_peer_address); if (!adapter) return -EHOSTUNREACH; diff --git a/Kernel/Net/NetworkAdapter.cpp b/Kernel/Net/NetworkAdapter.cpp index f77804a914..14ed365f1a 100644 --- a/Kernel/Net/NetworkAdapter.cpp +++ b/Kernel/Net/NetworkAdapter.cpp @@ -22,12 +22,12 @@ void NetworkAdapter::for_each(Function<void(NetworkAdapter&)> callback) callback(*it); } -NetworkAdapter* NetworkAdapter::from_ipv4_address(const IPv4Address& address) +WeakPtr<NetworkAdapter> NetworkAdapter::from_ipv4_address(const IPv4Address& address) { LOCKER(all_adapters().lock()); for (auto* adapter : all_adapters().resource()) { if (adapter->ipv4_address() == address) - return adapter; + return adapter->make_weak_ptr(); } return nullptr; } diff --git a/Kernel/Net/NetworkAdapter.h b/Kernel/Net/NetworkAdapter.h index c1d99b3f9e..4860e8b647 100644 --- a/Kernel/Net/NetworkAdapter.h +++ b/Kernel/Net/NetworkAdapter.h @@ -4,6 +4,8 @@ #include <AK/Function.h> #include <AK/SinglyLinkedList.h> #include <AK/Types.h> +#include <AK/Weakable.h> +#include <AK/WeakPtr.h> #include <Kernel/KBuffer.h> #include <Kernel/Net/ARP.h> #include <Kernel/Net/ICMP.h> @@ -12,10 +14,10 @@ class NetworkAdapter; -class NetworkAdapter { +class NetworkAdapter : public Weakable<NetworkAdapter> { public: static void for_each(Function<void(NetworkAdapter&)>); - static NetworkAdapter* from_ipv4_address(const IPv4Address&); + static WeakPtr<NetworkAdapter> from_ipv4_address(const IPv4Address&); virtual ~NetworkAdapter(); virtual const char* class_name() const = 0; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index c57562b1aa..38bd3b93d4 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -38,7 +38,7 @@ void NetworkTask_main() { LoopbackAdapter::the(); - auto* adapter = E1000NetworkAdapter::the(); + auto adapter = E1000NetworkAdapter::the(); if (!adapter) dbgprintf("E1000 network card not found!\n"); @@ -150,7 +150,7 @@ void handle_arp(const EthernetFrameHeader& eth, int frame_size) if (packet.operation() == ARPOperation::Request) { // Who has this IP address? - if (auto* adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) { + if (auto adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) { // We do! kprintf("handle_arp: Responding to ARP request for my IPv4 address (%s)\n", adapter->ipv4_address().to_string().characters()); @@ -231,7 +231,7 @@ void handle_icmp(const EthernetFrameHeader& eth, int frame_size) } } - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) return; @@ -260,7 +260,7 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size) (void)frame_size; auto& ipv4_packet = *static_cast<const IPv4Packet*>(eth.payload()); - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) { kprintf("handle_udp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters()); return; @@ -292,7 +292,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) (void)frame_size; auto& ipv4_packet = *static_cast<const IPv4Packet*>(eth.payload()); - auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); + auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) { kprintf("handle_tcp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters()); return; diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp index b2c3d8a782..c9bf54f67e 100644 --- a/Kernel/Net/Routing.cpp +++ b/Kernel/Net/Routing.cpp @@ -1,10 +1,10 @@ #include <Kernel/Net/LoopbackAdapter.h> #include <Kernel/Net/Routing.h> -NetworkAdapter* adapter_for_route_to(const IPv4Address& ipv4_address) +WeakPtr<NetworkAdapter> adapter_for_route_to(const IPv4Address& ipv4_address) { // FIXME: Have an actual routing table. if (ipv4_address == IPv4Address(127, 0, 0, 1)) - return &LoopbackAdapter::the(); + return LoopbackAdapter::the().make_weak_ptr(); return NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); } diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h index 48143520eb..0feed2cf64 100644 --- a/Kernel/Net/Routing.h +++ b/Kernel/Net/Routing.h @@ -2,4 +2,4 @@ #include <Kernel/Net/NetworkAdapter.h> -NetworkAdapter* adapter_for_route_to(const IPv4Address&); +WeakPtr<NetworkAdapter> adapter_for_route_to(const IPv4Address&); diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index b6c1b726af..df998389f4 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -80,7 +80,16 @@ int TCPSocket::protocol_send(const void* data, int data_length) void TCPSocket::send_tcp_packet(u16 flags, const void* payload, int payload_size) { - ASSERT(m_adapter); + if (!m_adapter) { + if (has_specific_local_address()) { + m_adapter = NetworkAdapter::from_ipv4_address(local_address()); + } else { + m_adapter = adapter_for_route_to(peer_address()); + if (m_adapter) + set_local_address(m_adapter->ipv4_address()); + } + } + ASSERT(!!m_adapter); auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size); auto& tcp_packet = *(TCPPacket*)(buffer.pointer()); diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index c6450e850b..48a47807b3 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -1,6 +1,7 @@ #pragma once #include <AK/Function.h> +#include <AK/WeakPtr.h> #include <Kernel/Net/IPv4Socket.h> class TCPSocket final : public IPv4Socket { @@ -86,7 +87,7 @@ private: virtual KResult protocol_bind() override; virtual KResult protocol_listen() override; - NetworkAdapter* m_adapter { nullptr }; + WeakPtr<NetworkAdapter> m_adapter; u32 m_sequence_number { 0 }; u32 m_ack_number { 0 }; State m_state { State::Closed }; diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index c3acf1261f..7ceebbb1d9 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -56,7 +56,7 @@ int UDPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size int UDPSocket::protocol_send(const void* data, int data_length) { - auto* adapter = adapter_for_route_to(peer_address()); + auto adapter = adapter_for_route_to(peer_address()); if (!adapter) return -EHOSTUNREACH; auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length); |