diff options
-rw-r--r-- | Kernel/IPv4Socket.cpp | 74 | ||||
-rw-r--r-- | Kernel/IPv4Socket.h | 14 | ||||
-rw-r--r-- | Kernel/LocalSocket.cpp | 5 | ||||
-rw-r--r-- | Kernel/LocalSocket.h | 1 | ||||
-rw-r--r-- | Kernel/NetworkTask.cpp | 11 | ||||
-rw-r--r-- | Kernel/Process.cpp | 27 | ||||
-rw-r--r-- | Kernel/Process.h | 5 | ||||
-rw-r--r-- | Kernel/Scheduler.cpp | 10 | ||||
-rw-r--r-- | Kernel/Socket.h | 1 | ||||
-rw-r--r-- | Kernel/Syscall.cpp | 2 | ||||
-rw-r--r-- | Kernel/Syscall.h | 10 | ||||
-rw-r--r-- | LibC/sys/socket.cpp | 6 | ||||
-rw-r--r-- | LibC/sys/socket.h | 1 | ||||
-rw-r--r-- | Userland/ping.cpp | 13 |
14 files changed, 178 insertions, 2 deletions
diff --git a/Kernel/IPv4Socket.cpp b/Kernel/IPv4Socket.cpp index 64a06609d7..239cf3edd3 100644 --- a/Kernel/IPv4Socket.cpp +++ b/Kernel/IPv4Socket.cpp @@ -2,8 +2,19 @@ #include <Kernel/UnixTypes.h> #include <Kernel/Process.h> #include <Kernel/NetworkAdapter.h> +#include <Kernel/IPv4.h> +#include <Kernel/ICMP.h> +#include <Kernel/ARP.h> #include <LibC/errno_numbers.h> +Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets() +{ + static Lockable<HashTable<IPv4Socket*>>* s_table; + if (!s_table) + s_table = new Lockable<HashTable<IPv4Socket*>>; + return *s_table; +} + Retained<IPv4Socket> IPv4Socket::create(int type, int protocol) { return adopt(*new IPv4Socket(type, protocol)); @@ -11,12 +22,17 @@ Retained<IPv4Socket> IPv4Socket::create(int type, int protocol) IPv4Socket::IPv4Socket(int type, int protocol) : Socket(AF_INET, type, protocol) + , m_lock("IPv4Socket") { kprintf("%s(%u) IPv4Socket{%p} created with type=%u\n", current->name().characters(), current->pid(), this, type); + LOCKER(all_sockets().lock()); + all_sockets().resource().set(this); } IPv4Socket::~IPv4Socket() { + LOCKER(all_sockets().lock()); + all_sockets().resource().remove(this); } bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size) @@ -63,7 +79,7 @@ void IPv4Socket::detach_fd(SocketRole) bool IPv4Socket::can_read(SocketRole) const { - ASSERT_NOT_REACHED(); + return m_can_read; } ssize_t IPv4Socket::read(SocketRole role, byte* buffer, ssize_t size) @@ -84,7 +100,6 @@ bool IPv4Socket::can_write(SocketRole role) const ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length) { (void)flags; - ASSERT(data_length); if (addr_length != sizeof(sockaddr_in)) return -EINVAL; // FIXME: Find the adapter some better way! @@ -109,3 +124,58 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons adapter->send_ipv4(mac_address, peer_address, (IPv4Protocol)protocol(), ByteBuffer::copy((const byte*)data, data_length)); return data_length; } + +ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, const sockaddr* addr, socklen_t addr_length) +{ + (void)flags; + if (addr_length != sizeof(sockaddr_in)) + return -EINVAL; + // FIXME: Find the adapter some better way! + auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); + if (!adapter) { + // FIXME: Figure out which error code to return. + ASSERT_NOT_REACHED(); + } + + if (addr->sa_family != AF_INET) { + kprintf("recvfrom: Bad address family: %u is not AF_INET!\n", addr->sa_family); + return -EAFNOSUPPORT; + } + + auto peer_address = IPv4Address((const byte*)&((const sockaddr_in*)addr)->sin_addr.s_addr); + kprintf("recvfrom: peer_address=%s\n", peer_address.to_string().characters()); + + ByteBuffer packet_buffer; + + { + LOCKER(m_lock); + if (!m_receive_queue.is_empty()) { + packet_buffer = m_receive_queue.take_first(); + m_can_read = m_receive_queue.is_empty(); + } + } + if (packet_buffer.is_null()) { + current->set_blocked_socket(this); + block(Process::BlockedReceive); + Scheduler::yield(); + + LOCKER(m_lock); + ASSERT(m_can_read); + ASSERT(!m_receive_queue.is_empty()); + packet_buffer = m_receive_queue.take_first(); + m_can_read = m_receive_queue.is_empty(); + } + ASSERT(!packet_buffer.is_null()); + auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer()); + ASSERT(buffer_length >= ipv4_packet.payload_size()); + memcpy(buffer, ipv4_packet.payload(), ipv4_packet.payload_size()); + return ipv4_packet.payload_size(); +} + +void IPv4Socket::did_receive(ByteBuffer&& packet) +{ + LOCKER(m_lock); + kprintf("IPv4Socket(%p): did_receive %d bytes\n", packet.size()); + m_receive_queue.append(move(packet)); + m_can_read = true; +} diff --git a/Kernel/IPv4Socket.h b/Kernel/IPv4Socket.h index 98d4bda5d8..416798e606 100644 --- a/Kernel/IPv4Socket.h +++ b/Kernel/IPv4Socket.h @@ -3,12 +3,16 @@ #include <Kernel/Socket.h> #include <Kernel/DoubleBuffer.h> #include <Kernel/IPv4.h> +#include <AK/Lock.h> +#include <AK/SinglyLinkedList.h> class IPv4Socket final : public Socket { public: static Retained<IPv4Socket> create(int type, int protocol); virtual ~IPv4Socket() override; + static Lockable<HashTable<IPv4Socket*>>& all_sockets(); + virtual KResult bind(const sockaddr*, socklen_t) override; virtual KResult connect(const sockaddr*, socklen_t) override; virtual bool get_address(sockaddr*, socklen_t*) override; @@ -19,6 +23,11 @@ public: virtual ssize_t write(SocketRole, const byte*, ssize_t) override; virtual bool can_write(SocketRole) const override; virtual ssize_t sendto(const void*, size_t, int, const sockaddr*, socklen_t) override; + virtual ssize_t recvfrom(void*, size_t, int flags, const sockaddr*, socklen_t) override; + + void did_receive(ByteBuffer&&); + + Lock& lock() { return m_lock; } private: IPv4Socket(int type, int protocol); @@ -30,5 +39,10 @@ private: DoubleBuffer m_for_client; DoubleBuffer m_for_server; + + SinglyLinkedList<ByteBuffer> m_receive_queue; + + Lock m_lock; + bool m_can_read { false }; }; diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index 5cd7f9b58d..ab21aabfa7 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -168,3 +168,8 @@ ssize_t LocalSocket::sendto(const void*, size_t, int, const sockaddr*, socklen_t { ASSERT_NOT_REACHED(); } + +ssize_t LocalSocket::recvfrom(void*, size_t, int flags, const sockaddr*, socklen_t) +{ + ASSERT_NOT_REACHED(); +} diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index b4972c8035..e817a1c197 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -20,6 +20,7 @@ public: virtual ssize_t write(SocketRole, const byte*, ssize_t) override; virtual bool can_write(SocketRole) const override; virtual ssize_t sendto(const void*, size_t, int, const sockaddr*, socklen_t) override; + virtual ssize_t recvfrom(void*, size_t, int flags, const sockaddr*, socklen_t) override; private: explicit LocalSocket(int type); diff --git a/Kernel/NetworkTask.cpp b/Kernel/NetworkTask.cpp index b8984d0a94..b579bd8dfe 100644 --- a/Kernel/NetworkTask.cpp +++ b/Kernel/NetworkTask.cpp @@ -3,6 +3,7 @@ #include <Kernel/ARP.h> #include <Kernel/ICMP.h> #include <Kernel/IPv4.h> +#include <Kernel/IPv4Socket.h> #include <Kernel/Process.h> #include <Kernel/EtherType.h> #include <AK/Lock.h> @@ -165,6 +166,16 @@ void handle_icmp(const EthernetFrameHeader& eth, int frame_size) ); #endif + { + LOCKER(IPv4Socket::all_sockets().lock()); + for (RetainPtr<IPv4Socket> socket : IPv4Socket::all_sockets().resource()) { + LOCKER(socket->lock()); + if (socket->protocol() != (unsigned)IPv4Protocol::ICMP) + continue; + socket->did_receive(ByteBuffer::copy((const byte*)&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())); + } + } + auto* adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination()); if (!adapter) return; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index 13600f47f7..7b50ca2750 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2541,6 +2541,32 @@ ssize_t Process::sys$sendto(const Syscall::SC_sendto_params* params) return socket.sendto(data, data_length, flags, addr, addr_length); } +ssize_t Process::sys$recvfrom(const Syscall::SC_recvfrom_params* params) +{ + if (!validate_read_typed(params)) + return -EFAULT; + + int sockfd = params->sockfd; + void* buffer = params->buffer; + size_t buffer_length = params->buffer_length; + int flags = params->flags; + auto* addr = (const sockaddr*)params->addr; + auto addr_length = (socklen_t)params->addr_length; + + if (!validate_write(buffer, buffer_length)) + return -EFAULT; + if (!validate_read(addr, addr_length)) + return -EFAULT; + auto* descriptor = file_descriptor(sockfd); + if (!descriptor) + return -EBADF; + if (!descriptor->is_socket()) + return -ENOTSOCK; + auto& socket = *descriptor->socket(); + kprintf("recvfrom %p (%u), flags=%u, addr: %p (%u)\n", buffer, buffer_length, flags, addr, addr_length); + return socket.recvfrom(buffer, buffer_length, flags, addr, addr_length); +} + struct SharedBuffer { SharedBuffer(pid_t pid1, pid_t pid2, int size) : m_pid1(pid1) @@ -2780,6 +2806,7 @@ const char* to_string(Process::State state) case Process::BlockedSelect: return "Select"; case Process::BlockedLurking: return "Lurking"; case Process::BlockedConnect: return "Connect"; + case Process::BlockedReceive: return "Receive"; case Process::BeingInspected: return "Inspect"; } kprintf("to_string(Process::State): Invalid state: %u\n", state); diff --git a/Kernel/Process.h b/Kernel/Process.h index af58a4339a..ae39d25398 100644 --- a/Kernel/Process.h +++ b/Kernel/Process.h @@ -75,6 +75,7 @@ public: BlockedSignal, BlockedSelect, BlockedConnect, + BlockedReceive, }; enum Priority { @@ -231,6 +232,7 @@ public: int sys$accept(int sockfd, sockaddr*, socklen_t*); int sys$connect(int sockfd, const sockaddr*, socklen_t); ssize_t sys$sendto(const Syscall::SC_sendto_params*); + ssize_t sys$recvfrom(const Syscall::SC_recvfrom_params*); int sys$restore_signal_mask(dword mask); int sys$create_shared_buffer(pid_t peer_pid, int, void** buffer); @@ -307,6 +309,8 @@ public: Region* allocate_region(LinearAddress, size_t, String&& name, bool is_readable = true, bool is_writable = true, bool commit = true); bool deallocate_region(Region& region); + void set_blocked_socket(Socket* socket) { m_blocked_socket = socket; } + private: friend class MemoryManager; friend class Scheduler; @@ -364,6 +368,7 @@ private: dword m_times_scheduled { 0 }; pid_t m_waitee_pid { -1 }; int m_blocked_fd { -1 }; + Socket* m_blocked_socket { nullptr }; Vector<int> m_select_read_fds; Vector<int> m_select_write_fds; Vector<int> m_select_exceptional_fds; diff --git a/Kernel/Scheduler.cpp b/Kernel/Scheduler.cpp index c2c49ecaa9..b7b0b89d5c 100644 --- a/Kernel/Scheduler.cpp +++ b/Kernel/Scheduler.cpp @@ -98,6 +98,16 @@ bool Scheduler::pick_next() return true; } + if (process.state() == Process::BlockedReceive) { + ASSERT(process.m_blocked_socket); + // FIXME: Block until the amount of data wanted is available. + if (process.m_blocked_socket->can_read(SocketRole::None)) { + process.unblock(); + process.m_blocked_socket = nullptr; + } + return true; + } + if (process.state() == Process::BlockedSelect) { if (process.m_select_has_timeout) { auto now_sec = RTC::now(); diff --git a/Kernel/Socket.h b/Kernel/Socket.h index 04575f616d..e83d0aeb9f 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -36,6 +36,7 @@ public: virtual ssize_t write(SocketRole, const byte*, ssize_t) = 0; virtual bool can_write(SocketRole) const = 0; virtual ssize_t sendto(const void*, size_t, int flags, const sockaddr*, socklen_t) = 0; + virtual ssize_t recvfrom(void*, size_t, int flags, const sockaddr*, socklen_t) = 0; pid_t origin_pid() const { return m_origin_pid; } diff --git a/Kernel/Syscall.cpp b/Kernel/Syscall.cpp index 3e38b66df8..1b2059a3dc 100644 --- a/Kernel/Syscall.cpp +++ b/Kernel/Syscall.cpp @@ -229,6 +229,8 @@ static dword handle(RegisterDump& regs, dword function, dword arg1, dword arg2, return current->sys$get_shared_buffer_size((int)arg1); case Syscall::SC_sendto: return current->sys$sendto((const SC_sendto_params*)arg1); + case Syscall::SC_recvfrom: + return current->sys$recvfrom((const SC_recvfrom_params*)arg1); default: kprintf("<%u> int0x82: Unknown function %u requested {%x, %x, %x}\n", current->pid(), function, arg1, arg2, arg3); break; diff --git a/Kernel/Syscall.h b/Kernel/Syscall.h index 6cb3c9ac59..b3586034a4 100644 --- a/Kernel/Syscall.h +++ b/Kernel/Syscall.h @@ -89,6 +89,7 @@ __ENUMERATE_SYSCALL(get_shared_buffer_size) \ __ENUMERATE_SYSCALL(seal_shared_buffer) \ __ENUMERATE_SYSCALL(sendto) \ + __ENUMERATE_SYSCALL(recvfrom) \ namespace Syscall { @@ -138,6 +139,15 @@ struct SC_sendto_params { size_t addr_length; // socklen_t }; +struct SC_recvfrom_params { + int sockfd; + void* buffer; + size_t buffer_length; + int flags; + const void* addr; // const sockaddr* + size_t addr_length; // socklen_t +}; + void initialize(); int sync(); diff --git a/LibC/sys/socket.cpp b/LibC/sys/socket.cpp index 1a85ffff96..10b9fb67b6 100644 --- a/LibC/sys/socket.cpp +++ b/LibC/sys/socket.cpp @@ -41,5 +41,11 @@ ssize_t sendto(int sockfd, const void* data, size_t data_length, int flags, cons __RETURN_WITH_ERRNO(rc, rc, -1); } +ssize_t recvfrom(int sockfd, void* buffer, size_t buffer_length, int flags, const struct sockaddr* addr, socklen_t addr_length) +{ + Syscall::SC_recvfrom_params params { sockfd, buffer, buffer_length, flags, addr, addr_length }; + int rc = syscall(SC_recvfrom, ¶ms); + __RETURN_WITH_ERRNO(rc, rc, -1); } +} diff --git a/LibC/sys/socket.h b/LibC/sys/socket.h index 59407872e1..4721205a19 100644 --- a/LibC/sys/socket.h +++ b/LibC/sys/socket.h @@ -51,6 +51,7 @@ int listen(int sockfd, int backlog); int accept(int sockfd, sockaddr*, socklen_t*); int connect(int sockfd, const sockaddr*, socklen_t); ssize_t sendto(int sockfd, const void*, size_t, int flags, const struct sockaddr*, socklen_t); +ssize_t recvfrom(int sockfd, void*, size_t, int flags, const struct sockaddr*, socklen_t); __END_DECLS diff --git a/Userland/ping.cpp b/Userland/ping.cpp index 321cdff48c..b35258e1a4 100644 --- a/Userland/ping.cpp +++ b/Userland/ping.cpp @@ -3,6 +3,7 @@ #include <netinet/ip_icmp.h> #include <stdio.h> #include <string.h> +#include <unistd.h> #include <Kernel/NetworkOrdered.h> NetworkOrdered<word> internet_checksum(const void* ptr, size_t count) @@ -41,9 +42,13 @@ int main(int argc, char** argv) }; PingPacket ping_packet; + PingPacket pong_packet; memset(&ping_packet, 0, sizeof(PingPacket)); ping_packet.header.type = 8; // Echo request + ping_packet.header.code = 0; + ping_packet.header.un.echo.id = htons(getpid()); + ping_packet.header.un.echo.sequence = htons(1); strcpy(ping_packet.msg, "Hello there!\n"); ping_packet.header.checksum = htons(internet_checksum(&ping_packet, sizeof(PingPacket))); @@ -54,5 +59,13 @@ int main(int argc, char** argv) return 1; } + rc = recvfrom(fd, &pong_packet, sizeof(PingPacket), 0, (const struct sockaddr*)&peer_address, sizeof(sockaddr_in)); + if (rc < 0) { + perror("recvfrom"); + return 1; + } + + printf("received %p (%d)\n", &pong_packet, rc); + return 0; } |