diff options
Diffstat (limited to 'Kernel/Net')
-rw-r--r-- | Kernel/Net/IPv4Socket.cpp | 30 | ||||
-rw-r--r-- | Kernel/Net/IPv4Socket.h | 2 | ||||
-rw-r--r-- | Kernel/Net/LocalSocket.cpp | 35 | ||||
-rw-r--r-- | Kernel/Net/LocalSocket.h | 8 | ||||
-rw-r--r-- | Kernel/Net/NetworkTask.cpp | 8 | ||||
-rw-r--r-- | Kernel/Net/Routing.cpp | 103 | ||||
-rw-r--r-- | Kernel/Net/Routing.h | 2 | ||||
-rw-r--r-- | Kernel/Net/Socket.cpp | 4 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.cpp | 13 |
9 files changed, 169 insertions, 36 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index a6011acc71..182a2078f6 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -138,6 +138,7 @@ KResult IPv4Socket::listen(size_t backlog) set_backlog(backlog); m_role = Role::Listener; + evaluate_block_conditions(); #ifdef IPV4_SOCKET_DEBUG dbg() << "IPv4Socket{" << this << "} listening with backlog=" << backlog; @@ -262,10 +263,11 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description return KResult(-EAGAIN); locker.unlock(); - auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description); + auto unblocked_flags = Thread::FileDescriptionBlocker::BlockFlags::None; + auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblocked_flags); locker.lock(); - if (!m_can_read) { + if (!((u32)unblocked_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Read)) { if (res.was_interrupted()) return KResult(-EINTR); @@ -279,7 +281,7 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description if (nreceived > 0) Thread::current()->did_ipv4_socket_read((size_t)nreceived); - m_can_read = !m_receive_buffer.is_empty(); + set_can_read(!m_receive_buffer.is_empty()); return nreceived; } @@ -299,7 +301,7 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti if (!m_receive_queue.is_empty()) { packet = m_receive_queue.take_first(); - m_can_read = !m_receive_queue.is_empty(); + set_can_read(!m_receive_queue.is_empty()); #ifdef IPV4_SOCKET_DEBUG dbg() << "IPv4Socket(" << this << "): recvfrom without blocking " << packet.data.value().size() << " bytes, packets in queue: " << m_receive_queue.size(); #endif @@ -312,10 +314,11 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti } locker.unlock(); - auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description); + auto unblocked_flags = Thread::FileDescriptionBlocker::BlockFlags::None; + auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblocked_flags); locker.lock(); - if (!m_can_read) { + if (!((u32)unblocked_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Read)) { if (res.was_interrupted()) return KResult(-EINTR); @@ -325,7 +328,7 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti ASSERT(m_can_read); ASSERT(!m_receive_queue.is_empty()); packet = m_receive_queue.take_first(); - m_can_read = !m_receive_queue.is_empty(); + set_can_read(!m_receive_queue.is_empty()); #ifdef IPV4_SOCKET_DEBUG dbg() << "IPv4Socket(" << this << "): recvfrom with blocking " << packet.data.value().size() << " bytes, packets in queue: " << m_receive_queue.size(); #endif @@ -411,14 +414,14 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port, ssize_t nwritten = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value()); if (nwritten < 0) return false; - m_can_read = !m_receive_buffer.is_empty(); + set_can_read(!m_receive_buffer.is_empty()); } else { if (m_receive_queue.size() > 2000) { dbg() << "IPv4Socket(" << this << "): did_receive refusing packet since queue is full."; return false; } m_receive_queue.append({ source_address, source_port, packet_timestamp, move(packet) }); - m_can_read = true; + set_can_read(true); } m_bytes_received += packet_size; #ifdef IPV4_SOCKET_DEBUG @@ -625,7 +628,14 @@ KResult IPv4Socket::close() void IPv4Socket::shut_down_for_reading() { Socket::shut_down_for_reading(); - m_can_read = true; + set_can_read(true); +} + +void IPv4Socket::set_can_read(bool value) +{ + m_can_read = value; + if (value) + evaluate_block_conditions(); } } diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index 2bee154361..67c339a67a 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -113,6 +113,8 @@ private: KResultOr<size_t> receive_byte_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>); KResultOr<size_t> receive_packet_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, timeval&); + void set_can_read(bool); + IPv4Address m_local_address; IPv4Address m_peer_address; diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index 0b4d4fc658..345cb7db05 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -68,6 +68,13 @@ LocalSocket::LocalSocket(int type) m_prebind_gid = current_process->gid(); m_prebind_mode = 0666; + m_for_client.set_unblock_callback([this]() { + evaluate_block_conditions(); + }); + m_for_server.set_unblock_callback([this]() { + evaluate_block_conditions(); + }); + #ifdef DEBUG_LOCAL_SOCKET dbg() << "LocalSocket{" << this << "} created with type=" << type; #endif @@ -170,22 +177,23 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka memcpy(m_address.sun_path, safe_address, sizeof(m_address.sun_path)); ASSERT(m_connect_side_fd == &description); - m_connect_side_role = Role::Connecting; + set_connect_side_role(Role::Connecting); auto peer = m_file->inode()->socket(); auto result = peer->queue_connection_from(*this); if (result.is_error()) { - m_connect_side_role = Role::None; + set_connect_side_role(Role::None); return result; } if (is_connected()) { - m_connect_side_role = Role::Connected; + set_connect_side_role(Role::Connected); return KSuccess; } - if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description).was_interrupted()) { - m_connect_side_role = Role::None; + auto unblock_flags = Thread::FileDescriptionBlocker::BlockFlags::None; + if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description, unblock_flags).was_interrupted()) { + set_connect_side_role(Role::None); return KResult(-EINTR); } @@ -193,11 +201,11 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka dbg() << "LocalSocket{" << this << "} connect(" << safe_address << ") status is " << to_string(setup_state()); #endif - if (!is_connected()) { - m_connect_side_role = Role::None; + if (!((u32)unblock_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Connect)) { + set_connect_side_role(Role::None); return KResult(-ECONNREFUSED); } - m_connect_side_role = Role::Connected; + set_connect_side_role(Role::Connected); return KSuccess; } @@ -207,7 +215,9 @@ KResult LocalSocket::listen(size_t backlog) if (type() != SOCK_STREAM) return KResult(-EOPNOTSUPP); set_backlog(backlog); - m_connect_side_role = m_role = Role::Listener; + auto previous_role = m_role; + m_role = Role::Listener; + set_connect_side_role(Role::Listener, previous_role != m_role); #ifdef DEBUG_LOCAL_SOCKET dbg() << "LocalSocket{" << this << "} listening with backlog=" << backlog; #endif @@ -224,6 +234,8 @@ void LocalSocket::attach(FileDescription& description) ASSERT(m_connect_side_fd != &description); m_accept_side_fd_open = true; } + + evaluate_block_conditions(); } void LocalSocket::detach(FileDescription& description) @@ -234,6 +246,8 @@ void LocalSocket::detach(FileDescription& description) ASSERT(m_accept_side_fd_open); m_accept_side_fd_open = false; } + + evaluate_block_conditions(); } bool LocalSocket::can_read(const FileDescription& description, size_t) const @@ -308,7 +322,8 @@ KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, UserOrKern return KResult(-EAGAIN); } } else if (!can_read(description, 0)) { - if (Thread::current()->block<Thread::ReadBlocker>(nullptr, description).was_interrupted()) + auto unblock_flags = Thread::FileDescriptionBlocker::BlockFlags::None; + if (Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblock_flags).was_interrupted()) return KResult(-EINTR); } if (!has_attached_peer(description) && buffer_for_me.is_empty()) diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 30693283a3..3b9395d1a1 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -77,6 +77,14 @@ private: NonnullRefPtrVector<FileDescription>& sendfd_queue_for(const FileDescription&); NonnullRefPtrVector<FileDescription>& recvfd_queue_for(const FileDescription&); + void set_connect_side_role(Role connect_side_role, bool force_evaluate_block_conditions = false) + { + auto previous = m_connect_side_role; + m_connect_side_role = connect_side_role; + if (previous != m_connect_side_role || force_evaluate_block_conditions) + evaluate_block_conditions(); + } + // An open socket file on the filesystem. RefPtr<FileDescription> m_file; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index cd7616f098..8a2e9c2c20 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -187,13 +187,7 @@ void handle_arp(const EthernetFrameHeader& eth, size_t frame_size) // Someone has this IPv4 address. I guess we can try to remember that. // FIXME: Protect against ARP spamming. // FIXME: Support static ARP table entries. - LOCKER(arp_table().lock()); - arp_table().resource().set(packet.sender_protocol_address(), packet.sender_hardware_address()); - - klog() << "ARP table (" << arp_table().resource().size() << " entries):"; - for (auto& it : arp_table().resource()) { - klog() << it.value.to_string().characters() << " :: " << it.key.to_string().characters(); - } + update_arp_table(packet.sender_protocol_address(), packet.sender_hardware_address()); } if (packet.operation() == ARPOperation::Request) { diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp index 281f06ceb4..c974cb5537 100644 --- a/Kernel/Net/Routing.cpp +++ b/Kernel/Net/Routing.cpp @@ -36,11 +36,105 @@ namespace Kernel { static AK::Singleton<Lockable<HashMap<IPv4Address, MACAddress>>> s_arp_table; +class ARPTableBlocker : public Thread::Blocker { +public: + ARPTableBlocker(IPv4Address ip_addr, Optional<MACAddress>& addr); + + virtual const char* state_string() const override { return "Routing (ARP)"; } + virtual Type blocker_type() const override { return Type::Routing; } + virtual bool should_block() override { return m_should_block; } + + virtual void not_blocking(bool) override; + + bool unblock(bool from_add_blocker, const IPv4Address& ip_addr, const MACAddress& addr) + { + if (m_ip_addr != ip_addr) + return false; + + { + ScopedSpinLock lock(m_lock); + if (m_did_unblock) + return false; + m_did_unblock = true; + m_addr = addr; + } + + if (!from_add_blocker) + unblock_from_blocker(); + return true; + } + + const IPv4Address& ip_addr() const { return m_ip_addr; } + +private: + const IPv4Address m_ip_addr; + Optional<MACAddress>& m_addr; + bool m_did_unblock { false }; + bool m_should_block { true }; +}; + +class ARPTableBlockCondition : public Thread::BlockCondition { +public: + void unblock(const IPv4Address& ip_addr, const MACAddress& addr) + { + unblock_all([&](auto& b, void*) { + ASSERT(b.blocker_type() == Thread::Blocker::Type::Routing); + auto& blocker = static_cast<ARPTableBlocker&>(b); + return blocker.unblock(false, ip_addr, addr); + }); + } + +protected: + virtual bool should_add_blocker(Thread::Blocker& b, void*) override + { + ASSERT(b.blocker_type() == Thread::Blocker::Type::Routing); + auto& blocker = static_cast<ARPTableBlocker&>(b); + auto val = s_arp_table->resource().get(blocker.ip_addr()); + if (!val.has_value()) + return true; + return blocker.unblock(true, blocker.ip_addr(), val.value()); + } +}; + +static AK::Singleton<ARPTableBlockCondition> s_arp_table_block_condition; + +ARPTableBlocker::ARPTableBlocker(IPv4Address ip_addr, Optional<MACAddress>& addr) + : m_ip_addr(ip_addr) + , m_addr(addr) +{ + if (!set_block_condition(*s_arp_table_block_condition)) + m_should_block = false; +} + +void ARPTableBlocker::not_blocking(bool timeout_in_past) +{ + ASSERT(timeout_in_past || !m_should_block); + auto addr = s_arp_table->resource().get(ip_addr()); + + ScopedSpinLock lock(m_lock); + if (!m_did_unblock) { + m_did_unblock = true; + m_addr = move(addr); + } +} + Lockable<HashMap<IPv4Address, MACAddress>>& arp_table() { return *s_arp_table; } +void update_arp_table(const IPv4Address& ip_addr, const MACAddress& addr) +{ + LOCKER(arp_table().lock()); + arp_table().resource().set(ip_addr, addr); + s_arp_table_block_condition->unblock(ip_addr, addr); + + klog() << "ARP table (" << arp_table().resource().size() << " entries):"; + for (auto& it : arp_table().resource()) { + klog() << it.value.to_string().characters() << " :: " << it.key.to_string().characters(); + } +} + bool RoutingDecision::is_zero() const { return adapter.is_null() || next_hop.is_zero(); @@ -135,13 +229,8 @@ RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, c request.set_sender_protocol_address(adapter->ipv4_address()); adapter->send({ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, request); - (void)Thread::current()->block_until("Routing (ARP)", [next_hop_ip] { - return arp_table().resource().get(next_hop_ip).has_value(); - }); - - { - LOCKER(arp_table().lock()); - auto addr = arp_table().resource().get(next_hop_ip); + Optional<MACAddress> addr; + if (!Thread::current()->block<ARPTableBlocker>(nullptr, next_hop_ip, addr).was_interrupted()) { if (addr.has_value()) { #ifdef ROUTING_DEBUG klog() << "Routing: Got ARP response using adapter " << adapter->name().characters() << " for " << next_hop_ip.to_string().characters() << " (" << addr.value().to_string().characters() << ")"; diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h index 50c7d72ed3..d58cbb7215 100644 --- a/Kernel/Net/Routing.h +++ b/Kernel/Net/Routing.h @@ -27,6 +27,7 @@ #pragma once #include <Kernel/Net/NetworkAdapter.h> +#include <Kernel/Thread.h> namespace Kernel { @@ -37,6 +38,7 @@ struct RoutingDecision { bool is_zero() const; }; +void update_arp_table(const IPv4Address&, const MACAddress&); RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, const RefPtr<NetworkAdapter> through = nullptr); Lockable<HashMap<IPv4Address, MACAddress>>& arp_table(); diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index 96bfb8a42c..7b6cd51a5c 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -70,6 +70,7 @@ void Socket::set_setup_state(SetupState new_setup_state) #endif m_setup_state = new_setup_state; + evaluate_block_conditions(); } RefPtr<Socket> Socket::accept() @@ -86,6 +87,8 @@ RefPtr<Socket> Socket::accept() client->m_acceptor = { process.pid().value(), process.uid(), process.gid() }; client->m_connected = true; client->m_role = Role::Accepted; + if (!m_pending.is_empty()) + evaluate_block_conditions(); return client; } @@ -98,6 +101,7 @@ KResult Socket::queue_connection_from(NonnullRefPtr<Socket> peer) if (m_pending.size() >= m_backlog) return KResult(-ECONNREFUSED); m_pending.append(peer); + evaluate_block_conditions(); return KSuccess; } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 8f66216dd4..91b5252207 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -52,6 +52,9 @@ void TCPSocket::set_state(State new_state) dbg() << "TCPSocket{" << this << "} state moving from " << to_string(m_state) << " to " << to_string(new_state); #endif + auto was_disconnected = protocol_is_disconnected(); + auto previous_role = m_role; + m_state = new_state; if (new_state == State::Established && m_direction == Direction::Outgoing) @@ -61,6 +64,9 @@ void TCPSocket::set_state(State new_state) LOCKER(closing_sockets().lock()); closing_sockets().resource().remove(tuple()); } + + if (previous_role != m_role || was_disconnected != protocol_is_disconnected()) + evaluate_block_conditions(); } static AK::Singleton<Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>> s_socket_closing; @@ -389,13 +395,16 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh m_role = Role::Connecting; m_direction = Direction::Outgoing; + evaluate_block_conditions(); + if (should_block == ShouldBlock::Yes) { locker.unlock(); - if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description).was_interrupted()) + auto unblock_flags = Thread::FileBlocker::BlockFlags::None; + if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description, unblock_flags).was_interrupted()) return KResult(-EINTR); locker.lock(); ASSERT(setup_state() == SetupState::Completed); - if (has_error()) { + if (has_error()) { // TODO: check unblock_flags m_role = Role::None; return KResult(-ECONNREFUSED); } |