summaryrefslogtreecommitdiff
path: root/Kernel/Net
diff options
context:
space:
mode:
Diffstat (limited to 'Kernel/Net')
-rw-r--r--Kernel/Net/IPv4Socket.cpp30
-rw-r--r--Kernel/Net/IPv4Socket.h2
-rw-r--r--Kernel/Net/LocalSocket.cpp35
-rw-r--r--Kernel/Net/LocalSocket.h8
-rw-r--r--Kernel/Net/NetworkTask.cpp8
-rw-r--r--Kernel/Net/Routing.cpp103
-rw-r--r--Kernel/Net/Routing.h2
-rw-r--r--Kernel/Net/Socket.cpp4
-rw-r--r--Kernel/Net/TCPSocket.cpp13
9 files changed, 169 insertions, 36 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp
index a6011acc71..182a2078f6 100644
--- a/Kernel/Net/IPv4Socket.cpp
+++ b/Kernel/Net/IPv4Socket.cpp
@@ -138,6 +138,7 @@ KResult IPv4Socket::listen(size_t backlog)
set_backlog(backlog);
m_role = Role::Listener;
+ evaluate_block_conditions();
#ifdef IPV4_SOCKET_DEBUG
dbg() << "IPv4Socket{" << this << "} listening with backlog=" << backlog;
@@ -262,10 +263,11 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description
return KResult(-EAGAIN);
locker.unlock();
- auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description);
+ auto unblocked_flags = Thread::FileDescriptionBlocker::BlockFlags::None;
+ auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblocked_flags);
locker.lock();
- if (!m_can_read) {
+ if (!((u32)unblocked_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Read)) {
if (res.was_interrupted())
return KResult(-EINTR);
@@ -279,7 +281,7 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description
if (nreceived > 0)
Thread::current()->did_ipv4_socket_read((size_t)nreceived);
- m_can_read = !m_receive_buffer.is_empty();
+ set_can_read(!m_receive_buffer.is_empty());
return nreceived;
}
@@ -299,7 +301,7 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti
if (!m_receive_queue.is_empty()) {
packet = m_receive_queue.take_first();
- m_can_read = !m_receive_queue.is_empty();
+ set_can_read(!m_receive_queue.is_empty());
#ifdef IPV4_SOCKET_DEBUG
dbg() << "IPv4Socket(" << this << "): recvfrom without blocking " << packet.data.value().size() << " bytes, packets in queue: " << m_receive_queue.size();
#endif
@@ -312,10 +314,11 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti
}
locker.unlock();
- auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description);
+ auto unblocked_flags = Thread::FileDescriptionBlocker::BlockFlags::None;
+ auto res = Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblocked_flags);
locker.lock();
- if (!m_can_read) {
+ if (!((u32)unblocked_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Read)) {
if (res.was_interrupted())
return KResult(-EINTR);
@@ -325,7 +328,7 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti
ASSERT(m_can_read);
ASSERT(!m_receive_queue.is_empty());
packet = m_receive_queue.take_first();
- m_can_read = !m_receive_queue.is_empty();
+ set_can_read(!m_receive_queue.is_empty());
#ifdef IPV4_SOCKET_DEBUG
dbg() << "IPv4Socket(" << this << "): recvfrom with blocking " << packet.data.value().size() << " bytes, packets in queue: " << m_receive_queue.size();
#endif
@@ -411,14 +414,14 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
ssize_t nwritten = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value());
if (nwritten < 0)
return false;
- m_can_read = !m_receive_buffer.is_empty();
+ set_can_read(!m_receive_buffer.is_empty());
} else {
if (m_receive_queue.size() > 2000) {
dbg() << "IPv4Socket(" << this << "): did_receive refusing packet since queue is full.";
return false;
}
m_receive_queue.append({ source_address, source_port, packet_timestamp, move(packet) });
- m_can_read = true;
+ set_can_read(true);
}
m_bytes_received += packet_size;
#ifdef IPV4_SOCKET_DEBUG
@@ -625,7 +628,14 @@ KResult IPv4Socket::close()
void IPv4Socket::shut_down_for_reading()
{
Socket::shut_down_for_reading();
- m_can_read = true;
+ set_can_read(true);
+}
+
+void IPv4Socket::set_can_read(bool value)
+{
+ m_can_read = value;
+ if (value)
+ evaluate_block_conditions();
}
}
diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h
index 2bee154361..67c339a67a 100644
--- a/Kernel/Net/IPv4Socket.h
+++ b/Kernel/Net/IPv4Socket.h
@@ -113,6 +113,8 @@ private:
KResultOr<size_t> receive_byte_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
KResultOr<size_t> receive_packet_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>, timeval&);
+ void set_can_read(bool);
+
IPv4Address m_local_address;
IPv4Address m_peer_address;
diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp
index 0b4d4fc658..345cb7db05 100644
--- a/Kernel/Net/LocalSocket.cpp
+++ b/Kernel/Net/LocalSocket.cpp
@@ -68,6 +68,13 @@ LocalSocket::LocalSocket(int type)
m_prebind_gid = current_process->gid();
m_prebind_mode = 0666;
+ m_for_client.set_unblock_callback([this]() {
+ evaluate_block_conditions();
+ });
+ m_for_server.set_unblock_callback([this]() {
+ evaluate_block_conditions();
+ });
+
#ifdef DEBUG_LOCAL_SOCKET
dbg() << "LocalSocket{" << this << "} created with type=" << type;
#endif
@@ -170,22 +177,23 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka
memcpy(m_address.sun_path, safe_address, sizeof(m_address.sun_path));
ASSERT(m_connect_side_fd == &description);
- m_connect_side_role = Role::Connecting;
+ set_connect_side_role(Role::Connecting);
auto peer = m_file->inode()->socket();
auto result = peer->queue_connection_from(*this);
if (result.is_error()) {
- m_connect_side_role = Role::None;
+ set_connect_side_role(Role::None);
return result;
}
if (is_connected()) {
- m_connect_side_role = Role::Connected;
+ set_connect_side_role(Role::Connected);
return KSuccess;
}
- if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description).was_interrupted()) {
- m_connect_side_role = Role::None;
+ auto unblock_flags = Thread::FileDescriptionBlocker::BlockFlags::None;
+ if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description, unblock_flags).was_interrupted()) {
+ set_connect_side_role(Role::None);
return KResult(-EINTR);
}
@@ -193,11 +201,11 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka
dbg() << "LocalSocket{" << this << "} connect(" << safe_address << ") status is " << to_string(setup_state());
#endif
- if (!is_connected()) {
- m_connect_side_role = Role::None;
+ if (!((u32)unblock_flags & (u32)Thread::FileDescriptionBlocker::BlockFlags::Connect)) {
+ set_connect_side_role(Role::None);
return KResult(-ECONNREFUSED);
}
- m_connect_side_role = Role::Connected;
+ set_connect_side_role(Role::Connected);
return KSuccess;
}
@@ -207,7 +215,9 @@ KResult LocalSocket::listen(size_t backlog)
if (type() != SOCK_STREAM)
return KResult(-EOPNOTSUPP);
set_backlog(backlog);
- m_connect_side_role = m_role = Role::Listener;
+ auto previous_role = m_role;
+ m_role = Role::Listener;
+ set_connect_side_role(Role::Listener, previous_role != m_role);
#ifdef DEBUG_LOCAL_SOCKET
dbg() << "LocalSocket{" << this << "} listening with backlog=" << backlog;
#endif
@@ -224,6 +234,8 @@ void LocalSocket::attach(FileDescription& description)
ASSERT(m_connect_side_fd != &description);
m_accept_side_fd_open = true;
}
+
+ evaluate_block_conditions();
}
void LocalSocket::detach(FileDescription& description)
@@ -234,6 +246,8 @@ void LocalSocket::detach(FileDescription& description)
ASSERT(m_accept_side_fd_open);
m_accept_side_fd_open = false;
}
+
+ evaluate_block_conditions();
}
bool LocalSocket::can_read(const FileDescription& description, size_t) const
@@ -308,7 +322,8 @@ KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, UserOrKern
return KResult(-EAGAIN);
}
} else if (!can_read(description, 0)) {
- if (Thread::current()->block<Thread::ReadBlocker>(nullptr, description).was_interrupted())
+ auto unblock_flags = Thread::FileDescriptionBlocker::BlockFlags::None;
+ if (Thread::current()->block<Thread::ReadBlocker>(nullptr, description, unblock_flags).was_interrupted())
return KResult(-EINTR);
}
if (!has_attached_peer(description) && buffer_for_me.is_empty())
diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h
index 30693283a3..3b9395d1a1 100644
--- a/Kernel/Net/LocalSocket.h
+++ b/Kernel/Net/LocalSocket.h
@@ -77,6 +77,14 @@ private:
NonnullRefPtrVector<FileDescription>& sendfd_queue_for(const FileDescription&);
NonnullRefPtrVector<FileDescription>& recvfd_queue_for(const FileDescription&);
+ void set_connect_side_role(Role connect_side_role, bool force_evaluate_block_conditions = false)
+ {
+ auto previous = m_connect_side_role;
+ m_connect_side_role = connect_side_role;
+ if (previous != m_connect_side_role || force_evaluate_block_conditions)
+ evaluate_block_conditions();
+ }
+
// An open socket file on the filesystem.
RefPtr<FileDescription> m_file;
diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp
index cd7616f098..8a2e9c2c20 100644
--- a/Kernel/Net/NetworkTask.cpp
+++ b/Kernel/Net/NetworkTask.cpp
@@ -187,13 +187,7 @@ void handle_arp(const EthernetFrameHeader& eth, size_t frame_size)
// Someone has this IPv4 address. I guess we can try to remember that.
// FIXME: Protect against ARP spamming.
// FIXME: Support static ARP table entries.
- LOCKER(arp_table().lock());
- arp_table().resource().set(packet.sender_protocol_address(), packet.sender_hardware_address());
-
- klog() << "ARP table (" << arp_table().resource().size() << " entries):";
- for (auto& it : arp_table().resource()) {
- klog() << it.value.to_string().characters() << " :: " << it.key.to_string().characters();
- }
+ update_arp_table(packet.sender_protocol_address(), packet.sender_hardware_address());
}
if (packet.operation() == ARPOperation::Request) {
diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp
index 281f06ceb4..c974cb5537 100644
--- a/Kernel/Net/Routing.cpp
+++ b/Kernel/Net/Routing.cpp
@@ -36,11 +36,105 @@ namespace Kernel {
static AK::Singleton<Lockable<HashMap<IPv4Address, MACAddress>>> s_arp_table;
+class ARPTableBlocker : public Thread::Blocker {
+public:
+ ARPTableBlocker(IPv4Address ip_addr, Optional<MACAddress>& addr);
+
+ virtual const char* state_string() const override { return "Routing (ARP)"; }
+ virtual Type blocker_type() const override { return Type::Routing; }
+ virtual bool should_block() override { return m_should_block; }
+
+ virtual void not_blocking(bool) override;
+
+ bool unblock(bool from_add_blocker, const IPv4Address& ip_addr, const MACAddress& addr)
+ {
+ if (m_ip_addr != ip_addr)
+ return false;
+
+ {
+ ScopedSpinLock lock(m_lock);
+ if (m_did_unblock)
+ return false;
+ m_did_unblock = true;
+ m_addr = addr;
+ }
+
+ if (!from_add_blocker)
+ unblock_from_blocker();
+ return true;
+ }
+
+ const IPv4Address& ip_addr() const { return m_ip_addr; }
+
+private:
+ const IPv4Address m_ip_addr;
+ Optional<MACAddress>& m_addr;
+ bool m_did_unblock { false };
+ bool m_should_block { true };
+};
+
+class ARPTableBlockCondition : public Thread::BlockCondition {
+public:
+ void unblock(const IPv4Address& ip_addr, const MACAddress& addr)
+ {
+ unblock_all([&](auto& b, void*) {
+ ASSERT(b.blocker_type() == Thread::Blocker::Type::Routing);
+ auto& blocker = static_cast<ARPTableBlocker&>(b);
+ return blocker.unblock(false, ip_addr, addr);
+ });
+ }
+
+protected:
+ virtual bool should_add_blocker(Thread::Blocker& b, void*) override
+ {
+ ASSERT(b.blocker_type() == Thread::Blocker::Type::Routing);
+ auto& blocker = static_cast<ARPTableBlocker&>(b);
+ auto val = s_arp_table->resource().get(blocker.ip_addr());
+ if (!val.has_value())
+ return true;
+ return blocker.unblock(true, blocker.ip_addr(), val.value());
+ }
+};
+
+static AK::Singleton<ARPTableBlockCondition> s_arp_table_block_condition;
+
+ARPTableBlocker::ARPTableBlocker(IPv4Address ip_addr, Optional<MACAddress>& addr)
+ : m_ip_addr(ip_addr)
+ , m_addr(addr)
+{
+ if (!set_block_condition(*s_arp_table_block_condition))
+ m_should_block = false;
+}
+
+void ARPTableBlocker::not_blocking(bool timeout_in_past)
+{
+ ASSERT(timeout_in_past || !m_should_block);
+ auto addr = s_arp_table->resource().get(ip_addr());
+
+ ScopedSpinLock lock(m_lock);
+ if (!m_did_unblock) {
+ m_did_unblock = true;
+ m_addr = move(addr);
+ }
+}
+
Lockable<HashMap<IPv4Address, MACAddress>>& arp_table()
{
return *s_arp_table;
}
+void update_arp_table(const IPv4Address& ip_addr, const MACAddress& addr)
+{
+ LOCKER(arp_table().lock());
+ arp_table().resource().set(ip_addr, addr);
+ s_arp_table_block_condition->unblock(ip_addr, addr);
+
+ klog() << "ARP table (" << arp_table().resource().size() << " entries):";
+ for (auto& it : arp_table().resource()) {
+ klog() << it.value.to_string().characters() << " :: " << it.key.to_string().characters();
+ }
+}
+
bool RoutingDecision::is_zero() const
{
return adapter.is_null() || next_hop.is_zero();
@@ -135,13 +229,8 @@ RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, c
request.set_sender_protocol_address(adapter->ipv4_address());
adapter->send({ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, request);
- (void)Thread::current()->block_until("Routing (ARP)", [next_hop_ip] {
- return arp_table().resource().get(next_hop_ip).has_value();
- });
-
- {
- LOCKER(arp_table().lock());
- auto addr = arp_table().resource().get(next_hop_ip);
+ Optional<MACAddress> addr;
+ if (!Thread::current()->block<ARPTableBlocker>(nullptr, next_hop_ip, addr).was_interrupted()) {
if (addr.has_value()) {
#ifdef ROUTING_DEBUG
klog() << "Routing: Got ARP response using adapter " << adapter->name().characters() << " for " << next_hop_ip.to_string().characters() << " (" << addr.value().to_string().characters() << ")";
diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h
index 50c7d72ed3..d58cbb7215 100644
--- a/Kernel/Net/Routing.h
+++ b/Kernel/Net/Routing.h
@@ -27,6 +27,7 @@
#pragma once
#include <Kernel/Net/NetworkAdapter.h>
+#include <Kernel/Thread.h>
namespace Kernel {
@@ -37,6 +38,7 @@ struct RoutingDecision {
bool is_zero() const;
};
+void update_arp_table(const IPv4Address&, const MACAddress&);
RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, const RefPtr<NetworkAdapter> through = nullptr);
Lockable<HashMap<IPv4Address, MACAddress>>& arp_table();
diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp
index 96bfb8a42c..7b6cd51a5c 100644
--- a/Kernel/Net/Socket.cpp
+++ b/Kernel/Net/Socket.cpp
@@ -70,6 +70,7 @@ void Socket::set_setup_state(SetupState new_setup_state)
#endif
m_setup_state = new_setup_state;
+ evaluate_block_conditions();
}
RefPtr<Socket> Socket::accept()
@@ -86,6 +87,8 @@ RefPtr<Socket> Socket::accept()
client->m_acceptor = { process.pid().value(), process.uid(), process.gid() };
client->m_connected = true;
client->m_role = Role::Accepted;
+ if (!m_pending.is_empty())
+ evaluate_block_conditions();
return client;
}
@@ -98,6 +101,7 @@ KResult Socket::queue_connection_from(NonnullRefPtr<Socket> peer)
if (m_pending.size() >= m_backlog)
return KResult(-ECONNREFUSED);
m_pending.append(peer);
+ evaluate_block_conditions();
return KSuccess;
}
diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp
index 8f66216dd4..91b5252207 100644
--- a/Kernel/Net/TCPSocket.cpp
+++ b/Kernel/Net/TCPSocket.cpp
@@ -52,6 +52,9 @@ void TCPSocket::set_state(State new_state)
dbg() << "TCPSocket{" << this << "} state moving from " << to_string(m_state) << " to " << to_string(new_state);
#endif
+ auto was_disconnected = protocol_is_disconnected();
+ auto previous_role = m_role;
+
m_state = new_state;
if (new_state == State::Established && m_direction == Direction::Outgoing)
@@ -61,6 +64,9 @@ void TCPSocket::set_state(State new_state)
LOCKER(closing_sockets().lock());
closing_sockets().resource().remove(tuple());
}
+
+ if (previous_role != m_role || was_disconnected != protocol_is_disconnected())
+ evaluate_block_conditions();
}
static AK::Singleton<Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>> s_socket_closing;
@@ -389,13 +395,16 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
m_role = Role::Connecting;
m_direction = Direction::Outgoing;
+ evaluate_block_conditions();
+
if (should_block == ShouldBlock::Yes) {
locker.unlock();
- if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description).was_interrupted())
+ auto unblock_flags = Thread::FileBlocker::BlockFlags::None;
+ if (Thread::current()->block<Thread::ConnectBlocker>(nullptr, description, unblock_flags).was_interrupted())
return KResult(-EINTR);
locker.lock();
ASSERT(setup_state() == SetupState::Completed);
- if (has_error()) {
+ if (has_error()) { // TODO: check unblock_flags
m_role = Role::None;
return KResult(-ECONNREFUSED);
}