diff options
-rw-r--r-- | Kernel/Net/IPv4Socket.cpp | 6 | ||||
-rw-r--r-- | Kernel/Net/IPv4Socket.h | 4 | ||||
-rw-r--r-- | Kernel/Net/LocalSocket.cpp | 2 | ||||
-rw-r--r-- | Kernel/Net/LocalSocket.h | 2 | ||||
-rw-r--r-- | Kernel/Net/NetworkTask.cpp | 1 | ||||
-rw-r--r-- | Kernel/Net/Socket.h | 3 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.cpp | 13 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.h | 2 | ||||
-rw-r--r-- | Kernel/Net/UDPSocket.cpp | 2 | ||||
-rw-r--r-- | Kernel/Net/UDPSocket.h | 2 | ||||
-rw-r--r-- | Kernel/Process.cpp | 2 |
11 files changed, 22 insertions, 17 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 7c7b1fdb4f..04eb2953b9 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -66,7 +66,7 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) ASSERT_NOT_REACHED(); } -KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size) +KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock should_block) { ASSERT(!m_bound); if (address_size != sizeof(sockaddr_in)) @@ -78,7 +78,7 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size) m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr); m_destination_port = ntohs(ia.sin_port); - return protocol_connect(); + return protocol_connect(should_block); } void IPv4Socket::attach_fd(SocketRole) @@ -110,7 +110,7 @@ ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size) bool IPv4Socket::can_write(SocketRole) const { - return true; + return is_connected(); } int IPv4Socket::allocate_source_port_if_needed() diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index d090a7730e..a7aad4a500 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -21,7 +21,7 @@ public: static Lockable<HashTable<IPv4Socket*>>& all_sockets(); virtual KResult bind(const sockaddr*, socklen_t) override; - virtual KResult connect(const sockaddr*, socklen_t) override; + virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; virtual bool get_address(sockaddr*, socklen_t*) override; virtual void attach_fd(SocketRole) override; virtual void detach_fd(SocketRole) override; @@ -49,7 +49,7 @@ protected: virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } virtual int protocol_send(const void*, int) { return -ENOTIMPL; } - virtual KResult protocol_connect() { return KSuccess; } + virtual KResult protocol_connect(ShouldBlock) { return KSuccess; } virtual int protocol_allocate_source_port() { return 0; } virtual bool protocol_is_disconnected() const { return false; } diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index 906f9500fd..2754644ff5 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -65,7 +65,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size) return KSuccess; } -KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size) +KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock) { ASSERT(!m_bound); if (address_size != sizeof(sockaddr_un)) diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index dd161d36a0..9c5d7d6730 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -11,7 +11,7 @@ public: virtual ~LocalSocket() override; virtual KResult bind(const sockaddr*, socklen_t) override; - virtual KResult connect(const sockaddr*, socklen_t) override; + virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override; virtual bool get_address(sockaddr*, socklen_t*) override; virtual void attach_fd(SocketRole) override; virtual void detach_fd(SocketRole) override; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index 4aef31f6c3..4a238d47d2 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -342,6 +342,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK); socket->set_state(TCPSocket::State::Disconnecting); + socket->set_connected(false); return; } diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index 683ff53858..b0c946ab51 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -9,6 +9,7 @@ #include <Kernel/KResult.h> enum class SocketRole { None, Listener, Accepted, Connected, Connecting }; +enum class ShouldBlock { No = 0, Yes = 1 }; class Socket : public Retainable<Socket> { public: @@ -25,7 +26,7 @@ public: KResult listen(int backlog); virtual KResult bind(const sockaddr*, socklen_t) = 0; - virtual KResult connect(const sockaddr*, socklen_t) = 0; + virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock) = 0; virtual bool get_address(sockaddr*, socklen_t*) = 0; virtual bool is_local() const { return false; } virtual bool is_ipv4() const { return false; } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 6356b8566d..373ee08979 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -152,7 +152,7 @@ NetworkOrdered<word> TCPSocket::compute_tcp_checksum(const IPv4Address& source, return ~(checksum & 0xffff); } -KResult TCPSocket::protocol_connect() +KResult TCPSocket::protocol_connect(ShouldBlock should_block) { auto* adapter = adapter_for_route_to(destination_address()); if (!adapter) @@ -166,11 +166,14 @@ KResult TCPSocket::protocol_connect() send_tcp_packet(TCPFlags::SYN); m_state = State::Connecting; - current->set_blocked_socket(this); - current->block(Thread::BlockedConnect); + if (should_block == ShouldBlock::Yes) { + current->set_blocked_socket(this); + current->block(Thread::BlockedConnect); + ASSERT(is_connected()); + return KSuccess; + } - ASSERT(is_connected()); - return KSuccess; + return KResult(-EINPROGRESS); } int TCPSocket::protocol_allocate_source_port() diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 5a1d88b2f7..079bb5d64a 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -34,7 +34,7 @@ private: virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; virtual int protocol_send(const void*, int) override; - virtual KResult protocol_connect() override; + virtual KResult protocol_connect(ShouldBlock) override; virtual int protocol_allocate_source_port() override; virtual bool protocol_is_disconnected() const override; diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 6c4170d3e6..96c9a320ea 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -81,7 +81,7 @@ int UDPSocket::protocol_send(const void* data, int data_length) return data_length; } -KResult UDPSocket::protocol_connect() +KResult UDPSocket::protocol_connect(ShouldBlock) { return KSuccess; } diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index 904243db34..8675217c11 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -17,7 +17,7 @@ private: virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; virtual int protocol_send(const void*, int) override; - virtual KResult protocol_connect() override; + virtual KResult protocol_connect(ShouldBlock) override; virtual int protocol_allocate_source_port() override; }; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index dca2c6ad85..9638f6399b 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2038,7 +2038,7 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_ return -EISCONN; auto& socket = *descriptor->socket(); descriptor->set_socket_role(SocketRole::Connecting); - auto result = socket.connect(address, address_size); + auto result = socket.connect(address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No); if (result.is_error()) { descriptor->set_socket_role(SocketRole::None); return result; |