summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Kernel/Net/IPv4Socket.cpp6
-rw-r--r--Kernel/Net/IPv4Socket.h4
-rw-r--r--Kernel/Net/LocalSocket.cpp2
-rw-r--r--Kernel/Net/LocalSocket.h2
-rw-r--r--Kernel/Net/NetworkTask.cpp1
-rw-r--r--Kernel/Net/Socket.h3
-rw-r--r--Kernel/Net/TCPSocket.cpp13
-rw-r--r--Kernel/Net/TCPSocket.h2
-rw-r--r--Kernel/Net/UDPSocket.cpp2
-rw-r--r--Kernel/Net/UDPSocket.h2
-rw-r--r--Kernel/Process.cpp2
11 files changed, 22 insertions, 17 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp
index 7c7b1fdb4f..04eb2953b9 100644
--- a/Kernel/Net/IPv4Socket.cpp
+++ b/Kernel/Net/IPv4Socket.cpp
@@ -66,7 +66,7 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
ASSERT_NOT_REACHED();
}
-KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size)
+KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
{
ASSERT(!m_bound);
if (address_size != sizeof(sockaddr_in))
@@ -78,7 +78,7 @@ KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size)
m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
m_destination_port = ntohs(ia.sin_port);
- return protocol_connect();
+ return protocol_connect(should_block);
}
void IPv4Socket::attach_fd(SocketRole)
@@ -110,7 +110,7 @@ ssize_t IPv4Socket::write(SocketRole, const byte* data, ssize_t size)
bool IPv4Socket::can_write(SocketRole) const
{
- return true;
+ return is_connected();
}
int IPv4Socket::allocate_source_port_if_needed()
diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h
index d090a7730e..a7aad4a500 100644
--- a/Kernel/Net/IPv4Socket.h
+++ b/Kernel/Net/IPv4Socket.h
@@ -21,7 +21,7 @@ public:
static Lockable<HashTable<IPv4Socket*>>& all_sockets();
virtual KResult bind(const sockaddr*, socklen_t) override;
- virtual KResult connect(const sockaddr*, socklen_t) override;
+ virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual bool get_address(sockaddr*, socklen_t*) override;
virtual void attach_fd(SocketRole) override;
virtual void detach_fd(SocketRole) override;
@@ -49,7 +49,7 @@ protected:
virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; }
virtual int protocol_send(const void*, int) { return -ENOTIMPL; }
- virtual KResult protocol_connect() { return KSuccess; }
+ virtual KResult protocol_connect(ShouldBlock) { return KSuccess; }
virtual int protocol_allocate_source_port() { return 0; }
virtual bool protocol_is_disconnected() const { return false; }
diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp
index 906f9500fd..2754644ff5 100644
--- a/Kernel/Net/LocalSocket.cpp
+++ b/Kernel/Net/LocalSocket.cpp
@@ -65,7 +65,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
return KSuccess;
}
-KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size)
+KResult LocalSocket::connect(const sockaddr* address, socklen_t address_size, ShouldBlock)
{
ASSERT(!m_bound);
if (address_size != sizeof(sockaddr_un))
diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h
index dd161d36a0..9c5d7d6730 100644
--- a/Kernel/Net/LocalSocket.h
+++ b/Kernel/Net/LocalSocket.h
@@ -11,7 +11,7 @@ public:
virtual ~LocalSocket() override;
virtual KResult bind(const sockaddr*, socklen_t) override;
- virtual KResult connect(const sockaddr*, socklen_t) override;
+ virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual bool get_address(sockaddr*, socklen_t*) override;
virtual void attach_fd(SocketRole) override;
virtual void detach_fd(SocketRole) override;
diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp
index 4aef31f6c3..4a238d47d2 100644
--- a/Kernel/Net/NetworkTask.cpp
+++ b/Kernel/Net/NetworkTask.cpp
@@ -342,6 +342,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size)
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
socket->set_state(TCPSocket::State::Disconnecting);
+ socket->set_connected(false);
return;
}
diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h
index 683ff53858..b0c946ab51 100644
--- a/Kernel/Net/Socket.h
+++ b/Kernel/Net/Socket.h
@@ -9,6 +9,7 @@
#include <Kernel/KResult.h>
enum class SocketRole { None, Listener, Accepted, Connected, Connecting };
+enum class ShouldBlock { No = 0, Yes = 1 };
class Socket : public Retainable<Socket> {
public:
@@ -25,7 +26,7 @@ public:
KResult listen(int backlog);
virtual KResult bind(const sockaddr*, socklen_t) = 0;
- virtual KResult connect(const sockaddr*, socklen_t) = 0;
+ virtual KResult connect(const sockaddr*, socklen_t, ShouldBlock) = 0;
virtual bool get_address(sockaddr*, socklen_t*) = 0;
virtual bool is_local() const { return false; }
virtual bool is_ipv4() const { return false; }
diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp
index 6356b8566d..373ee08979 100644
--- a/Kernel/Net/TCPSocket.cpp
+++ b/Kernel/Net/TCPSocket.cpp
@@ -152,7 +152,7 @@ NetworkOrdered<word> TCPSocket::compute_tcp_checksum(const IPv4Address& source,
return ~(checksum & 0xffff);
}
-KResult TCPSocket::protocol_connect()
+KResult TCPSocket::protocol_connect(ShouldBlock should_block)
{
auto* adapter = adapter_for_route_to(destination_address());
if (!adapter)
@@ -166,11 +166,14 @@ KResult TCPSocket::protocol_connect()
send_tcp_packet(TCPFlags::SYN);
m_state = State::Connecting;
- current->set_blocked_socket(this);
- current->block(Thread::BlockedConnect);
+ if (should_block == ShouldBlock::Yes) {
+ current->set_blocked_socket(this);
+ current->block(Thread::BlockedConnect);
+ ASSERT(is_connected());
+ return KSuccess;
+ }
- ASSERT(is_connected());
- return KSuccess;
+ return KResult(-EINPROGRESS);
}
int TCPSocket::protocol_allocate_source_port()
diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h
index 5a1d88b2f7..079bb5d64a 100644
--- a/Kernel/Net/TCPSocket.h
+++ b/Kernel/Net/TCPSocket.h
@@ -34,7 +34,7 @@ private:
virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
virtual int protocol_send(const void*, int) override;
- virtual KResult protocol_connect() override;
+ virtual KResult protocol_connect(ShouldBlock) override;
virtual int protocol_allocate_source_port() override;
virtual bool protocol_is_disconnected() const override;
diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp
index 6c4170d3e6..96c9a320ea 100644
--- a/Kernel/Net/UDPSocket.cpp
+++ b/Kernel/Net/UDPSocket.cpp
@@ -81,7 +81,7 @@ int UDPSocket::protocol_send(const void* data, int data_length)
return data_length;
}
-KResult UDPSocket::protocol_connect()
+KResult UDPSocket::protocol_connect(ShouldBlock)
{
return KSuccess;
}
diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h
index 904243db34..8675217c11 100644
--- a/Kernel/Net/UDPSocket.h
+++ b/Kernel/Net/UDPSocket.h
@@ -17,7 +17,7 @@ private:
virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override;
virtual int protocol_send(const void*, int) override;
- virtual KResult protocol_connect() override;
+ virtual KResult protocol_connect(ShouldBlock) override;
virtual int protocol_allocate_source_port() override;
};
diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp
index dca2c6ad85..9638f6399b 100644
--- a/Kernel/Process.cpp
+++ b/Kernel/Process.cpp
@@ -2038,7 +2038,7 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_
return -EISCONN;
auto& socket = *descriptor->socket();
descriptor->set_socket_role(SocketRole::Connecting);
- auto result = socket.connect(address, address_size);
+ auto result = socket.connect(address, address_size, descriptor->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No);
if (result.is_error()) {
descriptor->set_socket_role(SocketRole::None);
return result;