summaryrefslogtreecommitdiff
path: root/Kernel/Net
diff options
context:
space:
mode:
authorTom <tomut@yahoo.com>2020-09-11 21:11:07 -0600
committerAndreas Kling <kling@serenityos.org>2020-09-13 21:19:15 +0200
commitc8d9f1b9c920e0314bb9ca67f183c8df743e845a (patch)
tree773e9fb30d26602598ba309291b8e6d2e80475e6 /Kernel/Net
parent7d1b8417bdf5d2818d1c9d310786cdf59650b104 (diff)
downloadserenity-c8d9f1b9c920e0314bb9ca67f183c8df743e845a.zip
Kernel: Make copy_to/from_user safe and remove unnecessary checks
Since the CPU already does almost all necessary validation steps for us, we don't really need to attempt to do this. Doing it ourselves doesn't really work very reliably, because we'd have to account for other processors modifying virtual memory, and we'd have to account for e.g. pages not being able to be allocated due to insufficient resources. So change the copy_to/from_user (and associated helper functions) to use the new safe_memcpy, which will return whether it succeeded or not. The only manual validation step needed (which the CPU can't perform for us) is making sure the pointers provided by user mode aren't pointing to kernel mappings. To make it easier to read/write from/to either kernel or user mode data add the UserOrKernelBuffer helper class, which will internally either use copy_from/to_user or directly memcpy, or pass the data through directly using a temporary buffer on the stack. Last but not least we need to keep syscall params trivial as we need to copy them from/to user mode using copy_from/to_user.
Diffstat (limited to 'Kernel/Net')
-rw-r--r--Kernel/Net/IPv4Socket.cpp122
-rw-r--r--Kernel/Net/IPv4Socket.h14
-rw-r--r--Kernel/Net/LocalSocket.cpp42
-rw-r--r--Kernel/Net/LocalSocket.h6
-rw-r--r--Kernel/Net/NetworkAdapter.cpp29
-rw-r--r--Kernel/Net/NetworkAdapter.h5
-rw-r--r--Kernel/Net/NetworkTask.cpp39
-rw-r--r--Kernel/Net/Socket.cpp41
-rw-r--r--Kernel/Net/Socket.h10
-rw-r--r--Kernel/Net/TCPSocket.cpp50
-rw-r--r--Kernel/Net/TCPSocket.h6
-rw-r--r--Kernel/Net/UDPSocket.cpp13
-rw-r--r--Kernel/Net/UDPSocket.h4
13 files changed, 225 insertions, 156 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp
index 7ddead3e15..c64c036c4d 100644
--- a/Kernel/Net/IPv4Socket.cpp
+++ b/Kernel/Net/IPv4Socket.cpp
@@ -105,7 +105,8 @@ KResult IPv4Socket::bind(Userspace<const sockaddr*> user_address, socklen_t addr
return KResult(-EINVAL);
sockaddr_in address;
- copy_from_user(&address, user_address, sizeof(sockaddr_in));
+ if (!copy_from_user(&address, user_address, sizeof(sockaddr_in)))
+ return KResult(-EFAULT);
if (address.sin_family != AF_INET)
return KResult(-EINVAL);
@@ -144,18 +145,25 @@ KResult IPv4Socket::listen(size_t backlog)
return protocol_listen();
}
-KResult IPv4Socket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock should_block)
+KResult IPv4Socket::connect(FileDescription& description, Userspace<const sockaddr*> address, socklen_t address_size, ShouldBlock should_block)
{
if (address_size != sizeof(sockaddr_in))
return KResult(-EINVAL);
- if (address->sa_family != AF_INET)
+ u16 sa_family_copy;
+ auto* user_address = reinterpret_cast<const sockaddr*>(address.unsafe_userspace_ptr());
+ if (!copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16)))
+ return KResult(-EFAULT);
+ if (sa_family_copy != AF_INET)
return KResult(-EINVAL);
if (m_role == Role::Connected)
return KResult(-EISCONN);
- auto& ia = *(const sockaddr_in*)address;
- m_peer_address = IPv4Address((const u8*)&ia.sin_addr.s_addr);
- m_peer_port = ntohs(ia.sin_port);
+ sockaddr_in safe_address;
+ if (!copy_from_user(&safe_address, (const sockaddr_in*)user_address, sizeof(sockaddr_in)))
+ return KResult(-EFAULT);
+
+ m_peer_address = IPv4Address((const u8*)&safe_address.sin_addr.s_addr);
+ m_peer_port = ntohs(safe_address.sin_port);
return protocol_connect(description, should_block);
}
@@ -193,7 +201,7 @@ int IPv4Socket::allocate_local_port_if_needed()
return port;
}
-KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const void* data, size_t data_length, int flags, Userspace<const sockaddr*> addr, socklen_t addr_length)
+KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer& data, size_t data_length, int flags, Userspace<const sockaddr*> addr, socklen_t addr_length)
{
(void)flags;
if (addr && addr_length != sizeof(sockaddr_in))
@@ -201,7 +209,7 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const void* data, size_t
if (addr) {
sockaddr_in ia;
- if (!Process::current()->validate_read_and_copy_typed(&ia, Userspace<const sockaddr_in*>(addr.ptr())))
+ if (!copy_from_user(&ia, Userspace<const sockaddr_in*>(addr.ptr())))
return KResult(-EFAULT);
if (ia.sin_family != AF_INET) {
@@ -229,7 +237,9 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const void* data, size_t
#endif
if (type() == SOCK_RAW) {
- routing_decision.adapter->send_ipv4(routing_decision.next_hop, m_peer_address, (IPv4Protocol)protocol(), { (const u8*)data, data_length }, m_ttl);
+ int err = routing_decision.adapter->send_ipv4(routing_decision.next_hop, m_peer_address, (IPv4Protocol)protocol(), data, data_length, m_ttl);
+ if (err < 0)
+ return KResult(err);
return data_length;
}
@@ -239,7 +249,7 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const void* data, size_t
return nsent_or_error;
}
-KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description, void* buffer, size_t buffer_length, int, Userspace<sockaddr*>, Userspace<socklen_t*>)
+KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int, Userspace<sockaddr*>, Userspace<socklen_t*>)
{
Locker locker(lock());
if (m_receive_buffer.is_empty()) {
@@ -262,7 +272,7 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description
}
ASSERT(!m_receive_buffer.is_empty());
- int nreceived = m_receive_buffer.read((u8*)buffer, buffer_length);
+ int nreceived = m_receive_buffer.read(buffer, buffer_length);
if (nreceived > 0)
Thread::current()->did_ipv4_socket_read((size_t)nreceived);
@@ -270,7 +280,7 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description
return nreceived;
}
-KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& description, void* buffer, size_t buffer_length, int flags, Userspace<sockaddr*> addr, Userspace<socklen_t*> addr_length)
+KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> addr, Userspace<socklen_t*> addr_length)
{
Locker locker(lock());
ReceivedPacket packet;
@@ -330,27 +340,30 @@ KResultOr<size_t> IPv4Socket::receive_packet_buffered(FileDescription& descripti
out_addr.sin_port = htons(packet.peer_port);
out_addr.sin_family = AF_INET;
Userspace<sockaddr_in*> dest_addr = addr.ptr();
- copy_to_user(dest_addr, &out_addr);
+ if (!copy_to_user(dest_addr, &out_addr))
+ return KResult(-EFAULT);
socklen_t out_length = sizeof(sockaddr_in);
ASSERT(addr_length);
- copy_to_user(addr_length, &out_length);
+ if (!copy_to_user(addr_length, &out_length))
+ return KResult(-EFAULT);
}
if (type() == SOCK_RAW) {
size_t bytes_written = min((size_t) ipv4_packet.payload_size(), buffer_length);
- memcpy(buffer, ipv4_packet.payload(), bytes_written);
+ if (!buffer.write(ipv4_packet.payload(), bytes_written))
+ return KResult(-EFAULT);
return bytes_written;
}
return protocol_receive(packet.data.value(), buffer, buffer_length, flags);
}
-KResultOr<size_t> IPv4Socket::recvfrom(FileDescription& description, void* buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length)
+KResultOr<size_t> IPv4Socket::recvfrom(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*> user_addr, Userspace<socklen_t*> user_addr_length)
{
if (user_addr_length) {
socklen_t addr_length;
- if (!Process::current()->validate_read_and_copy_typed(&addr_length, user_addr_length))
+ if (!copy_from_user(&addr_length, user_addr_length.unsafe_userspace_ptr()))
return KResult(-EFAULT);
if (addr_length < sizeof(sockaddr_in))
return KResult(-EINVAL);
@@ -387,10 +400,13 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
ASSERT(m_can_read);
return false;
}
- auto nreceived_or_error = protocol_receive(packet, m_scratch_buffer.value().data(), m_scratch_buffer.value().size(), 0);
+ auto scratch_buffer = UserOrKernelBuffer::for_kernel_buffer(m_scratch_buffer.value().data());
+ auto nreceived_or_error = protocol_receive(packet, scratch_buffer, m_scratch_buffer.value().size(), 0);
if (nreceived_or_error.is_error())
return false;
- m_receive_buffer.write(m_scratch_buffer.value().data(), nreceived_or_error.value());
+ ssize_t nwritten = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value());
+ if (nwritten < 0)
+ return false;
m_can_read = !m_receive_buffer.is_empty();
} else {
if (m_receive_queue.size() > 2000) {
@@ -452,7 +468,7 @@ KResult IPv4Socket::setsockopt(int level, int option, Userspace<const void*> use
if (user_value_size < sizeof(int))
return KResult(-EINVAL);
int value;
- if (!Process::current()->validate_read_and_copy_typed(&value, static_ptr_cast<const int*>(user_value)))
+ if (!copy_from_user(&value, static_ptr_cast<const int*>(user_value)))
return KResult(-EFAULT);
if (value < 0 || value > 255)
return KResult(-EINVAL);
@@ -470,16 +486,18 @@ KResult IPv4Socket::getsockopt(FileDescription& description, int level, int opti
return Socket::getsockopt(description, level, option, value, value_size);
socklen_t size;
- if (!Process::current()->validate_read_and_copy_typed(&size, value_size))
+ if (!copy_from_user(&size, value_size.unsafe_userspace_ptr()))
return KResult(-EFAULT);
switch (option) {
case IP_TTL:
if (size < sizeof(int))
return KResult(-EINVAL);
- copy_to_user(static_ptr_cast<int*>(value), (int*)&m_ttl);
+ if (!copy_to_user(static_ptr_cast<int*>(value), (int*)&m_ttl))
+ return KResult(-EFAULT);
size = sizeof(int);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
default:
return KResult(-ENOPROTOOPT);
@@ -493,15 +511,15 @@ int IPv4Socket::ioctl(FileDescription&, unsigned request, FlatPtr arg)
SmapDisabler disabler;
auto ioctl_route = [request, arg]() {
- auto* route = (rtentry*)arg;
- if (!Process::current()->validate_read_typed(route))
+ rtentry route;
+ if (!copy_from_user(&route, (rtentry*)arg))
return -EFAULT;
- char namebuf[IFNAMSIZ + 1];
- memcpy(namebuf, route->rt_dev, IFNAMSIZ);
- namebuf[sizeof(namebuf) - 1] = '\0';
+ auto copied_ifname = copy_string_from_user(route.rt_dev, IFNAMSIZ);
+ if (copied_ifname.is_null())
+ return -EFAULT;
- auto adapter = NetworkAdapter::lookup_by_name(namebuf);
+ auto adapter = NetworkAdapter::lookup_by_name(copied_ifname);
if (!adapter)
return -ENODEV;
@@ -509,11 +527,11 @@ int IPv4Socket::ioctl(FileDescription&, unsigned request, FlatPtr arg)
case SIOCADDRT:
if (!Process::current()->is_superuser())
return -EPERM;
- if (route->rt_gateway.sa_family != AF_INET)
+ if (route.rt_gateway.sa_family != AF_INET)
return -EAFNOSUPPORT;
- if ((route->rt_flags & (RTF_UP | RTF_GATEWAY)) != (RTF_UP | RTF_GATEWAY))
+ if ((route.rt_flags & (RTF_UP | RTF_GATEWAY)) != (RTF_UP | RTF_GATEWAY))
return -EINVAL; // FIXME: Find the correct value to return
- adapter->set_ipv4_gateway(IPv4Address(((sockaddr_in&)route->rt_gateway).sin_addr.s_addr));
+ adapter->set_ipv4_gateway(IPv4Address(((sockaddr_in&)route.rt_gateway).sin_addr.s_addr));
return 0;
case SIOCDELRT:
@@ -525,12 +543,13 @@ int IPv4Socket::ioctl(FileDescription&, unsigned request, FlatPtr arg)
};
auto ioctl_interface = [request, arg]() {
- auto* ifr = (ifreq*)arg;
- if (!Process::current()->validate_read_typed(ifr))
+ ifreq* user_ifr = (ifreq*)arg;
+ ifreq ifr;
+ if (!copy_from_user(&ifr, user_ifr))
return -EFAULT;
char namebuf[IFNAMSIZ + 1];
- memcpy(namebuf, ifr->ifr_name, IFNAMSIZ);
+ memcpy(namebuf, ifr.ifr_name, IFNAMSIZ);
namebuf[sizeof(namebuf) - 1] = '\0';
auto adapter = NetworkAdapter::lookup_by_name(namebuf);
@@ -541,36 +560,39 @@ int IPv4Socket::ioctl(FileDescription&, unsigned request, FlatPtr arg)
case SIOCSIFADDR:
if (!Process::current()->is_superuser())
return -EPERM;
- if (ifr->ifr_addr.sa_family != AF_INET)
+ if (ifr.ifr_addr.sa_family != AF_INET)
return -EAFNOSUPPORT;
- adapter->set_ipv4_address(IPv4Address(((sockaddr_in&)ifr->ifr_addr).sin_addr.s_addr));
+ adapter->set_ipv4_address(IPv4Address(((sockaddr_in&)ifr.ifr_addr).sin_addr.s_addr));
return 0;
case SIOCSIFNETMASK:
if (!Process::current()->is_superuser())
return -EPERM;
- if (ifr->ifr_addr.sa_family != AF_INET)
+ if (ifr.ifr_addr.sa_family != AF_INET)
return -EAFNOSUPPORT;
- adapter->set_ipv4_netmask(IPv4Address(((sockaddr_in&)ifr->ifr_netmask).sin_addr.s_addr));
+ adapter->set_ipv4_netmask(IPv4Address(((sockaddr_in&)ifr.ifr_netmask).sin_addr.s_addr));
return 0;
- case SIOCGIFADDR:
- if (!Process::current()->validate_write_typed(ifr))
+ case SIOCGIFADDR: {
+ u16 sa_family = AF_INET;
+ if (!copy_to_user(&user_ifr->ifr_addr.sa_family, &sa_family))
+ return -EFAULT;
+ auto ip4_addr = adapter->ipv4_address().to_u32();
+ if (!copy_to_user(&((sockaddr_in&)user_ifr->ifr_addr).sin_addr.s_addr, &ip4_addr, sizeof(ip4_addr)))
return -EFAULT;
- ifr->ifr_addr.sa_family = AF_INET;
- ((sockaddr_in&)ifr->ifr_addr).sin_addr.s_addr = adapter->ipv4_address().to_u32();
return 0;
+ }
- case SIOCGIFHWADDR:
- if (!Process::current()->validate_write_typed(ifr))
+ case SIOCGIFHWADDR: {
+ u16 sa_family = AF_INET;
+ if (!copy_to_user(&user_ifr->ifr_hwaddr.sa_family, &sa_family))
+ return -EFAULT;
+ auto mac_address = adapter->mac_address();
+ if (!copy_to_user(ifr.ifr_hwaddr.sa_data, &mac_address, sizeof(MACAddress)))
return -EFAULT;
- ifr->ifr_hwaddr.sa_family = AF_INET;
- {
- auto mac_address = adapter->mac_address();
- memcpy(ifr->ifr_hwaddr.sa_data, &mac_address, sizeof(MACAddress));
- }
return 0;
}
+ }
return -EINVAL;
};
diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h
index cb513247ac..675cfbfcd1 100644
--- a/Kernel/Net/IPv4Socket.h
+++ b/Kernel/Net/IPv4Socket.h
@@ -50,7 +50,7 @@ public:
virtual KResult close() override;
virtual KResult bind(Userspace<const sockaddr*>, socklen_t) override;
- virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
+ virtual KResult connect(FileDescription&, Userspace<const sockaddr*>, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual KResult listen(size_t) override;
virtual void get_local_address(sockaddr*, socklen_t*) override;
virtual void get_peer_address(sockaddr*, socklen_t*) override;
@@ -58,8 +58,8 @@ public:
virtual void detach(FileDescription&) override;
virtual bool can_read(const FileDescription&, size_t) const override;
virtual bool can_write(const FileDescription&, size_t) const override;
- virtual KResultOr<size_t> sendto(FileDescription&, const void*, size_t, int, Userspace<const sockaddr*>, socklen_t) override;
- virtual KResultOr<size_t> recvfrom(FileDescription&, void*, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) override;
+ virtual KResultOr<size_t> sendto(FileDescription&, const UserOrKernelBuffer&, size_t, int, Userspace<const sockaddr*>, socklen_t) override;
+ virtual KResultOr<size_t> recvfrom(FileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) override;
virtual KResult setsockopt(int level, int option, Userspace<const void*>, socklen_t) override;
virtual KResult getsockopt(FileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
@@ -96,8 +96,8 @@ protected:
virtual KResult protocol_bind() { return KSuccess; }
virtual KResult protocol_listen() { return KSuccess; }
- virtual KResultOr<size_t> protocol_receive(const KBuffer&, void*, size_t, int) { return -ENOTIMPL; }
- virtual KResultOr<size_t> protocol_send(const void*, size_t) { return -ENOTIMPL; }
+ virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer&, size_t, int) { return -ENOTIMPL; }
+ virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) { return -ENOTIMPL; }
virtual KResult protocol_connect(FileDescription&, ShouldBlock) { return KSuccess; }
virtual int protocol_allocate_local_port() { return 0; }
virtual bool protocol_is_disconnected() const { return false; }
@@ -110,8 +110,8 @@ protected:
private:
virtual bool is_ipv4() const override { return true; }
- KResultOr<size_t> receive_byte_buffered(FileDescription&, void* buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
- KResultOr<size_t> receive_packet_buffered(FileDescription&, void* buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
+ KResultOr<size_t> receive_byte_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
+ KResultOr<size_t> receive_packet_buffered(FileDescription&, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>);
IPv4Address m_local_address;
IPv4Address m_peer_address;
diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp
index a401af1e9f..ee08ca0c26 100644
--- a/Kernel/Net/LocalSocket.cpp
+++ b/Kernel/Net/LocalSocket.cpp
@@ -98,7 +98,8 @@ KResult LocalSocket::bind(Userspace<const sockaddr*> user_address, socklen_t add
return KResult(-EINVAL);
sockaddr_un address;
- copy_from_user(&address, user_address, sizeof(sockaddr_un));
+ if (!copy_from_user(&address, user_address, sizeof(sockaddr_un)))
+ return KResult(-EFAULT);
if (address.sun_family != AF_LOCAL)
return KResult(-EINVAL);
@@ -131,19 +132,25 @@ KResult LocalSocket::bind(Userspace<const sockaddr*> user_address, socklen_t add
return KSuccess;
}
-KResult LocalSocket::connect(FileDescription& description, const sockaddr* address, socklen_t address_size, ShouldBlock)
+KResult LocalSocket::connect(FileDescription& description, Userspace<const sockaddr*> address, socklen_t address_size, ShouldBlock)
{
ASSERT(!m_bound);
if (address_size != sizeof(sockaddr_un))
return KResult(-EINVAL);
- if (address->sa_family != AF_LOCAL)
+ u16 sa_family_copy;
+ auto* user_address = reinterpret_cast<const sockaddr*>(address.unsafe_userspace_ptr());
+ if (!copy_from_user(&sa_family_copy, &user_address->sa_family, sizeof(u16)))
+ return KResult(-EFAULT);
+ if (sa_family_copy != AF_LOCAL)
return KResult(-EINVAL);
if (is_connected())
return KResult(-EISCONN);
- const sockaddr_un& local_address = *reinterpret_cast<const sockaddr_un*>(address);
+ const auto& local_address = *reinterpret_cast<const sockaddr_un*>(user_address);
char safe_address[sizeof(local_address.sun_path) + 1] = { 0 };
- memcpy(safe_address, local_address.sun_path, sizeof(local_address.sun_path));
+ if (!copy_from_user(&safe_address[0], &local_address.sun_path[0], sizeof(safe_address) - 1))
+ return KResult(-EFAULT);
+ safe_address[sizeof(safe_address) - 1] = '\0';
#ifdef DEBUG_LOCAL_SOCKET
dbg() << "LocalSocket{" << this << "} connect(" << safe_address << ")";
@@ -159,7 +166,8 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre
if (!m_file->inode()->socket())
return KResult(-ECONNREFUSED);
- m_address = local_address;
+ m_address.sun_family = sa_family_copy;
+ memcpy(m_address.sun_path, safe_address, sizeof(m_address.sun_path));
ASSERT(m_connect_side_fd == &description);
m_connect_side_role = Role::Connecting;
@@ -260,11 +268,11 @@ bool LocalSocket::can_write(const FileDescription& description, size_t) const
return false;
}
-KResultOr<size_t> LocalSocket::sendto(FileDescription& description, const void* data, size_t data_size, int, Userspace<const sockaddr*>, socklen_t)
+KResultOr<size_t> LocalSocket::sendto(FileDescription& description, const UserOrKernelBuffer& data, size_t data_size, int, Userspace<const sockaddr*>, socklen_t)
{
if (!has_attached_peer(description))
return KResult(-EPIPE);
- ssize_t nwritten = send_buffer_for(description).write((const u8*)data, data_size);
+ ssize_t nwritten = send_buffer_for(description).write(data, data_size);
if (nwritten > 0)
Thread::current()->did_unix_socket_write(nwritten);
return nwritten;
@@ -290,7 +298,7 @@ DoubleBuffer& LocalSocket::send_buffer_for(FileDescription& description)
ASSERT_NOT_REACHED();
}
-KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, void* buffer, size_t buffer_size, int, Userspace<sockaddr*>, Userspace<socklen_t*>)
+KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_size, int, Userspace<sockaddr*>, Userspace<socklen_t*>)
{
auto& buffer_for_me = receive_buffer_for(description);
if (!description.is_blocking()) {
@@ -306,7 +314,7 @@ KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, void* buff
if (!has_attached_peer(description) && buffer_for_me.is_empty())
return 0;
ASSERT(!buffer_for_me.is_empty());
- int nread = buffer_for_me.read((u8*)buffer, buffer_size);
+ int nread = buffer_for_me.read(buffer, buffer_size);
if (nread > 0)
Thread::current()->did_unix_socket_read(nread);
return nread;
@@ -350,7 +358,7 @@ KResult LocalSocket::getsockopt(FileDescription& description, int level, int opt
return Socket::getsockopt(description, level, option, value, value_size);
socklen_t size;
- if (!Process::current()->validate_read_and_copy_typed(&size, value_size))
+ if (!copy_from_user(&size, value_size.unsafe_userspace_ptr()))
return KResult(-EFAULT);
switch (option) {
@@ -359,14 +367,18 @@ KResult LocalSocket::getsockopt(FileDescription& description, int level, int opt
return KResult(-EINVAL);
switch (role(description)) {
case Role::Accepted:
- copy_to_user(static_ptr_cast<ucred*>(value), &m_origin);
+ if (!copy_to_user(static_ptr_cast<ucred*>(value), &m_origin))
+ return KResult(-EFAULT);
size = sizeof(ucred);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
case Role::Connected:
- copy_to_user(static_ptr_cast<ucred*>(value), &m_acceptor);
+ if (!copy_to_user(static_ptr_cast<ucred*>(value), &m_acceptor))
+ return KResult(-EFAULT);
size = sizeof(ucred);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
case Role::Connecting:
return KResult(-ENOTCONN);
diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h
index cafdf599a4..d39e78d20b 100644
--- a/Kernel/Net/LocalSocket.h
+++ b/Kernel/Net/LocalSocket.h
@@ -52,7 +52,7 @@ public:
// ^Socket
virtual KResult bind(Userspace<const sockaddr*>, socklen_t) override;
- virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
+ virtual KResult connect(FileDescription&, Userspace<const sockaddr*>, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
virtual KResult listen(size_t) override;
virtual void get_local_address(sockaddr*, socklen_t*) override;
virtual void get_peer_address(sockaddr*, socklen_t*) override;
@@ -60,8 +60,8 @@ public:
virtual void detach(FileDescription&) override;
virtual bool can_read(const FileDescription&, size_t) const override;
virtual bool can_write(const FileDescription&, size_t) const override;
- virtual KResultOr<size_t> sendto(FileDescription&, const void*, size_t, int, Userspace<const sockaddr*>, socklen_t) override;
- virtual KResultOr<size_t> recvfrom(FileDescription&, void*, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) override;
+ virtual KResultOr<size_t> sendto(FileDescription&, const UserOrKernelBuffer&, size_t, int, Userspace<const sockaddr*>, socklen_t) override;
+ virtual KResultOr<size_t> recvfrom(FileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) override;
virtual KResult getsockopt(FileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>) override;
virtual KResult chown(FileDescription&, uid_t, gid_t) override;
virtual KResult chmod(FileDescription&, mode_t) override;
diff --git a/Kernel/Net/NetworkAdapter.cpp b/Kernel/Net/NetworkAdapter.cpp
index 4cfdd1cd43..4a81fa77af 100644
--- a/Kernel/Net/NetworkAdapter.cpp
+++ b/Kernel/Net/NetworkAdapter.cpp
@@ -100,15 +100,13 @@ void NetworkAdapter::send(const MACAddress& destination, const ARPPacket& packet
send_raw({ (const u8*)eth, size_in_bytes });
}
-void NetworkAdapter::send_ipv4(const MACAddress& destination_mac, const IPv4Address& destination_ipv4, IPv4Protocol protocol, ReadonlyBytes payload, u8 ttl)
+int NetworkAdapter::send_ipv4(const MACAddress& destination_mac, const IPv4Address& destination_ipv4, IPv4Protocol protocol, const UserOrKernelBuffer& payload, size_t payload_size, u8 ttl)
{
- size_t ipv4_packet_size = sizeof(IPv4Packet) + payload.size();
- if (ipv4_packet_size > mtu()) {
- send_ipv4_fragmented(destination_mac, destination_ipv4, protocol, payload, ttl);
- return;
- }
+ size_t ipv4_packet_size = sizeof(IPv4Packet) + payload_size;
+ if (ipv4_packet_size > mtu())
+ return send_ipv4_fragmented(destination_mac, destination_ipv4, protocol, payload, payload_size, ttl);
- size_t ethernet_frame_size = sizeof(EthernetFrameHeader) + sizeof(IPv4Packet) + payload.size();
+ size_t ethernet_frame_size = sizeof(EthernetFrameHeader) + sizeof(IPv4Packet) + payload_size;
auto buffer = ByteBuffer::create_zeroed(ethernet_frame_size);
auto& eth = *(EthernetFrameHeader*)buffer.data();
eth.set_source(mac_address());
@@ -120,22 +118,25 @@ void NetworkAdapter::send_ipv4(const MACAddress& destination_mac, const IPv4Addr
ipv4.set_source(ipv4_address());
ipv4.set_destination(destination_ipv4);
ipv4.set_protocol((u8)protocol);
- ipv4.set_length(sizeof(IPv4Packet) + payload.size());
+ ipv4.set_length(sizeof(IPv4Packet) + payload_size);
ipv4.set_ident(1);
ipv4.set_ttl(ttl);
ipv4.set_checksum(ipv4.compute_checksum());
m_packets_out++;
m_bytes_out += ethernet_frame_size;
- memcpy(ipv4.payload(), payload.data(), payload.size());
+
+ if (!payload.read(ipv4.payload(), payload_size))
+ return -EFAULT;
send_raw({ (const u8*)&eth, ethernet_frame_size });
+ return 0;
}
-void NetworkAdapter::send_ipv4_fragmented(const MACAddress& destination_mac, const IPv4Address& destination_ipv4, IPv4Protocol protocol, ReadonlyBytes payload, u8 ttl)
+int NetworkAdapter::send_ipv4_fragmented(const MACAddress& destination_mac, const IPv4Address& destination_ipv4, IPv4Protocol protocol, const UserOrKernelBuffer& payload, size_t payload_size, u8 ttl)
{
// packets must be split on the 64-bit boundary
auto packet_boundary_size = (mtu() - sizeof(IPv4Packet) - sizeof(EthernetFrameHeader)) & 0xfffffff8;
- auto fragment_block_count = (payload.size() + packet_boundary_size) / packet_boundary_size;
- auto last_block_size = payload.size() - packet_boundary_size * (fragment_block_count - 1);
+ auto fragment_block_count = (payload_size + packet_boundary_size) / packet_boundary_size;
+ auto last_block_size = payload_size - packet_boundary_size * (fragment_block_count - 1);
auto number_of_blocks_in_fragment = packet_boundary_size / 8;
auto identification = get_good_random<u16>();
@@ -163,9 +164,11 @@ void NetworkAdapter::send_ipv4_fragmented(const MACAddress& destination_mac, con
ipv4.set_checksum(ipv4.compute_checksum());
m_packets_out++;
m_bytes_out += ethernet_frame_size;
- memcpy(ipv4.payload(), payload.data() + packet_index * packet_boundary_size, packet_payload_size);
+ if (!payload.read(ipv4.payload(), packet_index * packet_boundary_size, packet_payload_size))
+ return -EFAULT;
send_raw({ (const u8*)&eth, ethernet_frame_size });
}
+ return 0;
}
void NetworkAdapter::did_receive(ReadonlyBytes payload)
diff --git a/Kernel/Net/NetworkAdapter.h b/Kernel/Net/NetworkAdapter.h
index 7fc178bf7d..619b25f281 100644
--- a/Kernel/Net/NetworkAdapter.h
+++ b/Kernel/Net/NetworkAdapter.h
@@ -37,6 +37,7 @@
#include <Kernel/Net/ARP.h>
#include <Kernel/Net/ICMP.h>
#include <Kernel/Net/IPv4.h>
+#include <Kernel/UserOrKernelBuffer.h>
namespace Kernel {
@@ -63,8 +64,8 @@ public:
void set_ipv4_gateway(const IPv4Address&);
void send(const MACAddress&, const ARPPacket&);
- void send_ipv4(const MACAddress&, const IPv4Address&, IPv4Protocol, ReadonlyBytes payload, u8 ttl);
- void send_ipv4_fragmented(const MACAddress&, const IPv4Address&, IPv4Protocol, ReadonlyBytes payload, u8 ttl);
+ int send_ipv4(const MACAddress&, const IPv4Address&, IPv4Protocol, const UserOrKernelBuffer& payload, size_t payload_size, u8 ttl);
+ int send_ipv4_fragmented(const MACAddress&, const IPv4Address&, IPv4Protocol, const UserOrKernelBuffer& payload, size_t payload_size, u8 ttl);
size_t dequeue_packet(u8* buffer, size_t buffer_size);
diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp
index f198bc0607..0098d686fd 100644
--- a/Kernel/Net/NetworkTask.cpp
+++ b/Kernel/Net/NetworkTask.cpp
@@ -285,7 +285,8 @@ void handle_icmp(const EthernetFrameHeader& eth, const IPv4Packet& ipv4_packet)
memcpy(response.payload(), request.payload(), icmp_payload_size);
response.header.set_checksum(internet_checksum(&response, icmp_packet_size));
// FIXME: What is the right TTL value here? Is 64 ok? Should we use the same TTL as the echo request?
- adapter->send_ipv4(eth.source(), ipv4_packet.source(), IPv4Protocol::ICMP, buffer, 64);
+ auto response_buffer = UserOrKernelBuffer::for_kernel_buffer((u8*)&response);
+ adapter->send_ipv4(eth.source(), ipv4_packet.source(), IPv4Protocol::ICMP, response_buffer, buffer.size(), 64);
}
}
@@ -379,7 +380,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
case TCPSocket::State::TimeWait:
klog() << "handle_tcp: unexpected flags in TimeWait state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
case TCPSocket::State::Listen:
@@ -400,46 +401,46 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
#endif
client->set_sequence_number(1000);
client->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
- client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
+ (void)client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
client->set_state(TCPSocket::State::SynReceived);
return;
}
default:
klog() << "handle_tcp: unexpected flags in Listen state";
- // socket->send_tcp_packet(TCPFlags::RST);
+ // (void)socket->send_tcp_packet(TCPFlags::RST);
return;
}
case TCPSocket::State::SynSent:
switch (tcp_packet.flags()) {
case TCPFlags::SYN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::SynReceived);
return;
case TCPFlags::ACK | TCPFlags::SYN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::Established);
socket->set_setup_state(Socket::SetupState::Completed);
socket->set_connected(true);
return;
case TCPFlags::ACK | TCPFlags::FIN:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::Closed);
socket->set_error(TCPSocket::Error::FINDuringConnect);
socket->set_setup_state(Socket::SetupState::Completed);
return;
case TCPFlags::ACK | TCPFlags::RST:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::Closed);
socket->set_error(TCPSocket::Error::RSTDuringConnect);
socket->set_setup_state(Socket::SetupState::Completed);
return;
default:
klog() << "handle_tcp: unexpected flags in SynSent state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
socket->set_error(TCPSocket::Error::UnexpectedFlagsDuringConnect);
socket->set_setup_state(Socket::SetupState::Completed);
@@ -454,7 +455,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
case TCPSocket::Direction::Incoming:
if (!socket->has_originator()) {
klog() << "handle_tcp: connection doesn't have an originating socket; maybe it went away?";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -470,7 +471,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: got ACK in SynReceived state but direction is invalid (" << TCPSocket::to_string(socket->direction()) << ")";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -478,7 +479,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: unexpected flags in SynReceived state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -486,7 +487,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
switch (tcp_packet.flags()) {
default:
klog() << "handle_tcp: unexpected flags in CloseWait state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -498,7 +499,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: unexpected flags in LastAck state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -514,7 +515,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: unexpected flags in FinWait1 state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -529,7 +530,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: unexpected flags in FinWait2 state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -541,7 +542,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
return;
default:
klog() << "handle_tcp: unexpected flags in Closing state";
- socket->send_tcp_packet(TCPFlags::RST);
+ (void)socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed);
return;
}
@@ -551,7 +552,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::CloseWait);
socket->set_connected(false);
return;
@@ -565,7 +566,7 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
if (payload_size) {
if (socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size())))
- socket->send_tcp_packet(TCPFlags::ACK);
+ (void)socket->send_tcp_packet(TCPFlags::ACK);
}
}
}
diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp
index 061068e3ec..1caa3c5328 100644
--- a/Kernel/Net/Socket.cpp
+++ b/Kernel/Net/Socket.cpp
@@ -108,18 +108,20 @@ KResult Socket::setsockopt(int level, int option, Userspace<const void*> user_va
case SO_SNDTIMEO:
if (user_value_size != sizeof(timeval))
return KResult(-EINVAL);
- copy_from_user(&m_send_timeout, static_ptr_cast<const timeval*>(user_value));
+ if (!copy_from_user(&m_send_timeout, static_ptr_cast<const timeval*>(user_value)))
+ return KResult(-EFAULT);
return KSuccess;
case SO_RCVTIMEO:
if (user_value_size != sizeof(timeval))
return KResult(-EINVAL);
- copy_from_user(&m_receive_timeout, static_ptr_cast<const timeval*>(user_value));
+ if (!copy_from_user(&m_receive_timeout, static_ptr_cast<const timeval*>(user_value)))
+ return KResult(-EFAULT);
return KSuccess;
case SO_BINDTODEVICE: {
if (user_value_size != IFNAMSIZ)
return KResult(-EINVAL);
auto user_string = static_ptr_cast<const char*>(user_value);
- auto ifname = Process::current()->validate_and_copy_string_from_user(user_string, user_value_size);
+ auto ifname = copy_string_from_user(user_string, user_value_size);
if (ifname.is_null())
return KResult(-EFAULT);
auto device = NetworkAdapter::lookup_by_name(ifname);
@@ -140,7 +142,7 @@ KResult Socket::setsockopt(int level, int option, Userspace<const void*> user_va
KResult Socket::getsockopt(FileDescription&, int level, int option, Userspace<void*> value, Userspace<socklen_t*> value_size)
{
socklen_t size;
- if (!Process::current()->validate_read_and_copy_typed(&size, value_size))
+ if (!copy_from_user(&size, value_size.unsafe_userspace_ptr()))
return KResult(-EFAULT);
ASSERT(level == SOL_SOCKET);
@@ -148,25 +150,31 @@ KResult Socket::getsockopt(FileDescription&, int level, int option, Userspace<vo
case SO_SNDTIMEO:
if (size < sizeof(timeval))
return KResult(-EINVAL);
- copy_to_user(static_ptr_cast<timeval*>(value), &m_send_timeout);
+ if (!copy_to_user(static_ptr_cast<timeval*>(value), &m_send_timeout))
+ return KResult(-EFAULT);
size = sizeof(timeval);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
case SO_RCVTIMEO:
if (size < sizeof(timeval))
return KResult(-EINVAL);
- copy_to_user(static_ptr_cast<timeval*>(value), &m_receive_timeout);
+ if (!copy_to_user(static_ptr_cast<timeval*>(value), &m_receive_timeout))
+ return KResult(-EFAULT);
size = sizeof(timeval);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
case SO_ERROR: {
if (size < sizeof(int))
return KResult(-EINVAL);
dbg() << "getsockopt(SO_ERROR): FIXME!";
int errno = 0;
- copy_to_user(static_ptr_cast<int*>(value), &errno);
+ if (!copy_to_user(static_ptr_cast<int*>(value), &errno))
+ return KResult(-EFAULT);
size = sizeof(int);
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
}
case SO_BINDTODEVICE:
@@ -175,13 +183,16 @@ KResult Socket::getsockopt(FileDescription&, int level, int option, Userspace<vo
if (m_bound_interface) {
const auto& name = m_bound_interface->name();
auto length = name.length() + 1;
- copy_to_user(static_ptr_cast<char*>(value), name.characters(), length);
+ if (!copy_to_user(static_ptr_cast<char*>(value), name.characters(), length))
+ return KResult(-EFAULT);
size = length;
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KSuccess;
} else {
size = 0;
- copy_to_user(value_size, &size);
+ if (!copy_to_user(value_size, &size))
+ return KResult(-EFAULT);
return KResult(-EFAULT);
}
@@ -191,14 +202,14 @@ KResult Socket::getsockopt(FileDescription&, int level, int option, Userspace<vo
}
}
-KResultOr<size_t> Socket::read(FileDescription& description, size_t, u8* buffer, size_t size)
+KResultOr<size_t> Socket::read(FileDescription& description, size_t, UserOrKernelBuffer& buffer, size_t size)
{
if (is_shut_down_for_reading())
return 0;
return recvfrom(description, buffer, size, 0, {}, 0);
}
-KResultOr<size_t> Socket::write(FileDescription& description, size_t, const u8* data, size_t size)
+KResultOr<size_t> Socket::write(FileDescription& description, size_t, const UserOrKernelBuffer& data, size_t size)
{
if (is_shut_down_for_writing())
return -EPIPE;
diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h
index b72fa5f837..c768793292 100644
--- a/Kernel/Net/Socket.h
+++ b/Kernel/Net/Socket.h
@@ -99,7 +99,7 @@ public:
KResult shutdown(int how);
virtual KResult bind(Userspace<const sockaddr*>, socklen_t) = 0;
- virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0;
+ virtual KResult connect(FileDescription&, Userspace<const sockaddr*>, socklen_t, ShouldBlock) = 0;
virtual KResult listen(size_t) = 0;
virtual void get_local_address(sockaddr*, socklen_t*) = 0;
virtual void get_peer_address(sockaddr*, socklen_t*) = 0;
@@ -107,8 +107,8 @@ public:
virtual bool is_ipv4() const { return false; }
virtual void attach(FileDescription&) = 0;
virtual void detach(FileDescription&) = 0;
- virtual KResultOr<size_t> sendto(FileDescription&, const void*, size_t, int flags, Userspace<const sockaddr*>, socklen_t) = 0;
- virtual KResultOr<size_t> recvfrom(FileDescription&, void*, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) = 0;
+ virtual KResultOr<size_t> sendto(FileDescription&, const UserOrKernelBuffer&, size_t, int flags, Userspace<const sockaddr*>, socklen_t) = 0;
+ virtual KResultOr<size_t> recvfrom(FileDescription&, UserOrKernelBuffer&, size_t, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>) = 0;
virtual KResult setsockopt(int level, int option, Userspace<const void*>, socklen_t);
virtual KResult getsockopt(FileDescription&, int level, int option, Userspace<void*>, Userspace<socklen_t*>);
@@ -124,8 +124,8 @@ public:
Lock& lock() { return m_lock; }
// ^File
- virtual KResultOr<size_t> read(FileDescription&, size_t, u8*, size_t) override final;
- virtual KResultOr<size_t> write(FileDescription&, size_t, const u8*, size_t) override final;
+ virtual KResultOr<size_t> read(FileDescription&, size_t, UserOrKernelBuffer&, size_t) override final;
+ virtual KResultOr<size_t> write(FileDescription&, size_t, const UserOrKernelBuffer&, size_t) override final;
virtual KResult stat(::stat&) const override;
virtual String absolute_path(const FileDescription&) const override = 0;
diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp
index b7b615c951..e14ecf2770 100644
--- a/Kernel/Net/TCPSocket.cpp
+++ b/Kernel/Net/TCPSocket.cpp
@@ -161,7 +161,7 @@ NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
return adopt(*new TCPSocket(protocol));
}
-KResultOr<size_t> TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags)
+KResultOr<size_t> TCPSocket::protocol_receive(const KBuffer& packet_buffer, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
{
(void)flags;
auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.data());
@@ -171,17 +171,20 @@ KResultOr<size_t> TCPSocket::protocol_receive(const KBuffer& packet_buffer, void
klog() << "payload_size " << payload_size << ", will it fit in " << buffer_size << "?";
#endif
ASSERT(buffer_size >= payload_size);
- memcpy(buffer, tcp_packet.payload(), payload_size);
+ if (!buffer.write(tcp_packet.payload(), payload_size))
+ return KResult(-EFAULT);
return payload_size;
}
-KResultOr<size_t> TCPSocket::protocol_send(const void* data, size_t data_length)
+KResultOr<size_t> TCPSocket::protocol_send(const UserOrKernelBuffer& data, size_t data_length)
{
- send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length);
+ int err = send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, &data, data_length);
+ if (err < 0)
+ return KResult(err);
return data_length;
}
-void TCPSocket::send_tcp_packet(u16 flags, const void* payload, size_t payload_size)
+int TCPSocket::send_tcp_packet(u16 flags, const UserOrKernelBuffer* payload, size_t payload_size)
{
auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
auto& tcp_packet = *(TCPPacket*)(buffer.data());
@@ -196,31 +199,37 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, size_t payload_s
if (flags & TCPFlags::ACK)
tcp_packet.set_ack_number(m_ack_number);
+ if (payload && !payload->read(tcp_packet.payload(), payload_size))
+ return -EFAULT;
+
if (flags & TCPFlags::SYN) {
++m_sequence_number;
} else {
m_sequence_number += payload_size;
}
- memcpy(tcp_packet.payload(), payload, payload_size);
tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size));
if (tcp_packet.has_syn() || payload_size > 0) {
LOCKER(m_not_acked_lock);
m_not_acked.append({ m_sequence_number, move(buffer) });
send_outgoing_packets();
- return;
+ return 0;
}
auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
ASSERT(!routing_decision.is_zero());
- routing_decision.adapter->send_ipv4(
+ auto packet_buffer = UserOrKernelBuffer::for_kernel_buffer(buffer.data());
+ int err = routing_decision.adapter->send_ipv4(
routing_decision.next_hop, peer_address(), IPv4Protocol::TCP,
- buffer, ttl());
+ packet_buffer, buffer.size(), ttl());
+ if (err < 0)
+ return err;
m_packets_out++;
m_bytes_out += buffer.size();
+ return 0;
}
void TCPSocket::send_outgoing_packets()
@@ -243,12 +252,17 @@ void TCPSocket::send_outgoing_packets()
auto& tcp_packet = *(TCPPacket*)(packet.buffer.data());
klog() << "sending tcp packet from " << local_address().to_string().characters() << ":" << local_port() << " to " << peer_address().to_string().characters() << ":" << peer_port() << " with (" << (tcp_packet.has_syn() ? "SYN " : "") << (tcp_packet.has_ack() ? "ACK " : "") << (tcp_packet.has_fin() ? "FIN " : "") << (tcp_packet.has_rst() ? "RST " : "") << ") seq_no=" << tcp_packet.sequence_number() << ", ack_no=" << tcp_packet.ack_number() << ", tx_counter=" << packet.tx_counter;
#endif
- routing_decision.adapter->send_ipv4(
+ auto packet_buffer = UserOrKernelBuffer::for_kernel_buffer(packet.buffer.data());
+ int err = routing_decision.adapter->send_ipv4(
routing_decision.next_hop, peer_address(), IPv4Protocol::TCP,
- packet.buffer, ttl());
-
- m_packets_out++;
- m_bytes_out += packet.buffer.size();
+ packet_buffer, packet.buffer.size(), ttl());
+ if (err < 0) {
+ auto& tcp_packet = *(TCPPacket*)(packet.buffer.data());
+ klog() << "Error (" << err << ") sending tcp packet from " << local_address().to_string().characters() << ":" << local_port() << " to " << peer_address().to_string().characters() << ":" << peer_port() << " with (" << (tcp_packet.has_syn() ? "SYN " : "") << (tcp_packet.has_ack() ? "ACK " : "") << (tcp_packet.has_fin() ? "FIN " : "") << (tcp_packet.has_rst() ? "RST " : "") << ") seq_no=" << tcp_packet.sequence_number() << ", ack_no=" << tcp_packet.ack_number() << ", tx_counter=" << packet.tx_counter;
+ } else {
+ m_packets_out++;
+ m_bytes_out += packet.buffer.size();
+ }
}
}
@@ -366,7 +380,9 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
m_ack_number = 0;
set_setup_state(SetupState::InProgress);
- send_tcp_packet(TCPFlags::SYN);
+ int err = send_tcp_packet(TCPFlags::SYN);
+ if (err < 0)
+ return KResult(err);
m_state = State::SynSent;
m_role = Role::Connecting;
m_direction = Direction::Outgoing;
@@ -433,7 +449,7 @@ void TCPSocket::shut_down_for_writing()
#ifdef TCP_SOCKET_DEBUG
dbg() << " Sending FIN/ACK from Established and moving into FinWait1";
#endif
- send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
+ (void)send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
set_state(State::FinWait1);
} else {
dbg() << " Shutting down TCPSocket for writing but not moving to FinWait1 since state is " << to_string(state());
@@ -447,7 +463,7 @@ KResult TCPSocket::close()
#ifdef TCP_SOCKET_DEBUG
dbg() << " Sending FIN from CloseWait and moving into LastAck";
#endif
- send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
+ (void)send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
set_state(State::LastAck);
}
diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h
index bff5ac14c2..0a8e25119d 100644
--- a/Kernel/Net/TCPSocket.h
+++ b/Kernel/Net/TCPSocket.h
@@ -148,7 +148,7 @@ public:
u32 packets_out() const { return m_packets_out; }
u32 bytes_out() const { return m_bytes_out; }
- void send_tcp_packet(u16 flags, const void* = nullptr, size_t = 0);
+ [[nodiscard]] int send_tcp_packet(u16 flags, const UserOrKernelBuffer* = nullptr, size_t = 0);
void send_outgoing_packets();
void receive_tcp_packet(const TCPPacket&, u16 size);
@@ -177,8 +177,8 @@ private:
virtual void shut_down_for_writing() override;
- virtual KResultOr<size_t> protocol_receive(const KBuffer&, void* buffer, size_t buffer_size, int flags) override;
- virtual KResultOr<size_t> protocol_send(const void*, size_t) override;
+ virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
+ virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
virtual int protocol_allocate_local_port() override;
virtual bool protocol_is_disconnected() const override;
diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp
index c770a4e184..22db799af7 100644
--- a/Kernel/Net/UDPSocket.cpp
+++ b/Kernel/Net/UDPSocket.cpp
@@ -79,18 +79,19 @@ NonnullRefPtr<UDPSocket> UDPSocket::create(int protocol)
return adopt(*new UDPSocket(protocol));
}
-KResultOr<size_t> UDPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags)
+KResultOr<size_t> UDPSocket::protocol_receive(const KBuffer& packet_buffer, UserOrKernelBuffer& buffer, size_t buffer_size, int flags)
{
(void)flags;
auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.data());
auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
ASSERT(buffer_size >= (udp_packet.length() - sizeof(UDPPacket)));
- memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket));
+ if (!buffer.write(udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket)))
+ return KResult(-EFAULT);
return udp_packet.length() - sizeof(UDPPacket);
}
-KResultOr<size_t> UDPSocket::protocol_send(const void* data, size_t data_length)
+KResultOr<size_t> UDPSocket::protocol_send(const UserOrKernelBuffer& data, size_t data_length)
{
auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
if (routing_decision.is_zero())
@@ -100,9 +101,11 @@ KResultOr<size_t> UDPSocket::protocol_send(const void* data, size_t data_length)
udp_packet.set_source_port(local_port());
udp_packet.set_destination_port(peer_port());
udp_packet.set_length(sizeof(UDPPacket) + data_length);
- memcpy(udp_packet.payload(), data, data_length);
+ if (!data.read(udp_packet.payload(), data_length))
+ return KResult(-EFAULT);
klog() << "sending as udp packet from " << routing_decision.adapter->ipv4_address().to_string().characters() << ":" << local_port() << " to " << peer_address().to_string().characters() << ":" << peer_port() << "!";
- routing_decision.adapter->send_ipv4(routing_decision.next_hop, peer_address(), IPv4Protocol::UDP, buffer, ttl());
+ auto udp_packet_buffer = UserOrKernelBuffer::for_kernel_buffer((u8*)&udp_packet);
+ routing_decision.adapter->send_ipv4(routing_decision.next_hop, peer_address(), IPv4Protocol::UDP, udp_packet_buffer, buffer.size(), ttl());
return data_length;
}
diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h
index ab711a9dc8..d640be316e 100644
--- a/Kernel/Net/UDPSocket.h
+++ b/Kernel/Net/UDPSocket.h
@@ -43,8 +43,8 @@ private:
virtual const char* class_name() const override { return "UDPSocket"; }
static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
- virtual KResultOr<size_t> protocol_receive(const KBuffer&, void* buffer, size_t buffer_size, int flags) override;
- virtual KResultOr<size_t> protocol_send(const void*, size_t) override;
+ virtual KResultOr<size_t> protocol_receive(const KBuffer&, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
+ virtual KResultOr<size_t> protocol_send(const UserOrKernelBuffer&, size_t) override;
virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
virtual int protocol_allocate_local_port() override;
virtual KResult protocol_bind() override;