summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Elliott <pelliott@ualberta.ca>2023-02-13 16:01:13 -0700
committerLinus Groh <mail@linusgroh.de>2023-02-19 00:37:37 +0100
commitf20902deb3c88a16b1cd9fe8cd9d9db16494ddc9 (patch)
tree7350f929dd6ab13a146e51983541c9816e7c1d28
parentae5d7f542c256f3f1b4c5003d666e8ef850516e9 (diff)
downloadserenity-f20902deb3c88a16b1cd9fe8cd9d9db16494ddc9.zip
Kernel: Support sending filedescriptors with sendmsg(2) and SCM_RIGHTS
This is necessary to support the wayland protocol. I also moved the CMSG_* macros to the kernel API since they are used in both kernel and userspace. this does not break ntpquery/SCM_TIMESTAMP.
-rw-r--r--Kernel/API/POSIX/sys/socket.h26
-rw-r--r--Kernel/Net/LocalSocket.cpp20
-rw-r--r--Kernel/Net/LocalSocket.h1
-rw-r--r--Kernel/Syscalls/socket.cpp61
-rw-r--r--Userland/Libraries/LibC/sys/socket.h26
5 files changed, 97 insertions, 37 deletions
diff --git a/Kernel/API/POSIX/sys/socket.h b/Kernel/API/POSIX/sys/socket.h
index 750a353a41..d8fc765b88 100644
--- a/Kernel/API/POSIX/sys/socket.h
+++ b/Kernel/API/POSIX/sys/socket.h
@@ -81,6 +81,32 @@ struct msghdr {
int msg_flags;
};
+// These three are non-POSIX, but common:
+#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1))
+#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x))
+#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x))
+
+static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg)
+{
+ if (msg->msg_controllen < sizeof(struct cmsghdr))
+ return (struct cmsghdr*)0;
+ return (struct cmsghdr*)msg->msg_control;
+}
+
+static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg)
+{
+ struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len));
+ unsigned offset = (char*)next - (char*)msg->msg_control;
+ if (msg->msg_controllen < offset + sizeof(struct cmsghdr))
+ return (struct cmsghdr*)0;
+ return next;
+}
+
+static inline void* CMSG_DATA(struct cmsghdr* cmsg)
+{
+ return (void*)(cmsg + 1);
+}
+
struct sockaddr {
sa_family_t sa_family;
char sa_data[14];
diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp
index 4b55e8bea2..33db0b60b0 100644
--- a/Kernel/Net/LocalSocket.cpp
+++ b/Kernel/Net/LocalSocket.cpp
@@ -520,6 +520,26 @@ ErrorOr<NonnullLockRefPtr<OpenFileDescription>> LocalSocket::recvfd(OpenFileDesc
return queue.take_first();
}
+ErrorOr<NonnullLockRefPtrVector<OpenFileDescription>> LocalSocket::recvfds(OpenFileDescription const& socket_description, int n)
+{
+ MutexLocker locker(mutex());
+ NonnullLockRefPtrVector<OpenFileDescription> fds;
+
+ auto role = this->role(socket_description);
+ if (role != Role::Connected && role != Role::Accepted)
+ return set_so_error(EINVAL);
+ auto& queue = recvfd_queue_for(socket_description);
+
+ for (int i = 0; i < n; ++i) {
+ if (queue.is_empty())
+ break;
+
+ fds.append(queue.take_first());
+ }
+
+ return fds;
+}
+
ErrorOr<void> LocalSocket::try_set_path(StringView path)
{
m_path = TRY(KString::try_create(path));
diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h
index 5416bd5373..9698a49384 100644
--- a/Kernel/Net/LocalSocket.h
+++ b/Kernel/Net/LocalSocket.h
@@ -28,6 +28,7 @@ public:
ErrorOr<void> sendfd(OpenFileDescription const& socket_description, NonnullLockRefPtr<OpenFileDescription> passing_description);
ErrorOr<NonnullLockRefPtr<OpenFileDescription>> recvfd(OpenFileDescription const& socket_description);
+ ErrorOr<NonnullLockRefPtrVector<OpenFileDescription>> recvfds(OpenFileDescription const& socket_description, int n);
static void for_each(Function<void(LocalSocket const&)>);
static ErrorOr<void> try_for_each(Function<ErrorOr<void>(LocalSocket const&)>);
diff --git a/Kernel/Syscalls/socket.cpp b/Kernel/Syscalls/socket.cpp
index e02e78bb10..f29598a35d 100644
--- a/Kernel/Syscalls/socket.cpp
+++ b/Kernel/Syscalls/socket.cpp
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: BSD-2-Clause
*/
+#include <AK/ByteBuffer.h>
#include <Kernel/FileSystem/OpenFileDescription.h>
#include <Kernel/Net/LocalSocket.h>
#include <Kernel/Process.h>
@@ -199,6 +200,24 @@ ErrorOr<FlatPtr> Process::sys$sendmsg(int sockfd, Userspace<const struct msghdr*
Thread::current()->send_signal(SIGPIPE, &Process::current());
return EPIPE;
}
+
+ if (msg.msg_controllen > 0) {
+ // Handle command messages.
+ auto cmsg_buffer = TRY(ByteBuffer::create_uninitialized(msg.msg_controllen));
+ TRY(copy_from_user(cmsg_buffer.data(), msg.msg_control, msg.msg_controllen));
+ msg.msg_control = cmsg_buffer.data();
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (socket.is_local() && cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+ auto& local_socket = static_cast<LocalSocket&>(socket);
+ int* fds = (int*)CMSG_DATA(cmsg);
+ size_t nfds = (cmsg->cmsg_len - CMSG_ALIGN(sizeof(struct cmsghdr))) / sizeof(int);
+ for (size_t i = 0; i < nfds; ++i) {
+ TRY(local_socket.sendfd(*description, TRY(open_file_description(fds[i]))));
+ }
+ }
+ }
+ }
+
auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len));
while (true) {
@@ -267,21 +286,41 @@ ErrorOr<FlatPtr> Process::sys$recvmsg(int sockfd, Userspace<struct msghdr*> user
msg_flags |= MSG_TRUNC;
}
- if (socket.wants_timestamp()) {
- struct {
- cmsghdr cmsg;
- timeval timestamp;
- } cmsg_timestamp;
- socklen_t control_length = sizeof(cmsg_timestamp);
- if (msg.msg_controllen < control_length) {
+ socklen_t current_cmsg_len = 0;
+ auto try_add_cmsg = [&](int level, int type, void const* data, socklen_t len) -> ErrorOr<bool> {
+ if (current_cmsg_len + len > msg.msg_controllen) {
msg_flags |= MSG_CTRUNC;
- } else {
- cmsg_timestamp = { { control_length, SOL_SOCKET, SCM_TIMESTAMP }, timestamp.to_timeval() };
- TRY(copy_to_user(msg.msg_control, &cmsg_timestamp, control_length));
+ return false;
}
- TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, &control_length));
+
+ cmsghdr cmsg = { (socklen_t)CMSG_LEN(len), level, type };
+ cmsghdr* target = (cmsghdr*)(((char*)msg.msg_control) + current_cmsg_len);
+ TRY(copy_to_user(target, &cmsg));
+ TRY(copy_to_user(CMSG_DATA(target), data, len));
+ current_cmsg_len += CMSG_ALIGN(cmsg.cmsg_len);
+ return true;
+ };
+
+ if (socket.wants_timestamp()) {
+ timeval time = timestamp.to_timeval();
+ TRY(try_add_cmsg(SOL_SOCKET, SCM_TIMESTAMP, &time, sizeof(time)));
+ }
+
+ int space_for_fds = (msg.msg_controllen - current_cmsg_len - sizeof(struct cmsghdr)) / sizeof(int);
+ if (space_for_fds > 0 && socket.is_local()) {
+ auto& local_socket = static_cast<LocalSocket&>(socket);
+ auto descriptions = TRY(local_socket.recvfds(description, space_for_fds));
+ Vector<int> fdnums;
+ for (auto& description : descriptions) {
+ auto fd_allocation = TRY(m_fds.with_exclusive([](auto& fds) { return fds.allocate(); }));
+ m_fds.with_exclusive([&](auto& fds) { fds[fd_allocation.fd].set(description, 0); });
+ fdnums.append(fd_allocation.fd);
+ }
+ TRY(try_add_cmsg(SOL_SOCKET, SCM_RIGHTS, fdnums.data(), fdnums.size() * sizeof(int)));
}
+ TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, &current_cmsg_len));
+
TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_flags, &msg_flags));
return result.value();
}
diff --git a/Userland/Libraries/LibC/sys/socket.h b/Userland/Libraries/LibC/sys/socket.h
index 158ec12303..f505992c02 100644
--- a/Userland/Libraries/LibC/sys/socket.h
+++ b/Userland/Libraries/LibC/sys/socket.h
@@ -33,30 +33,4 @@ int socketpair(int domain, int type, int protocol, int sv[2]);
int sendfd(int sockfd, int fd);
int recvfd(int sockfd, int options);
-// These three are non-POSIX, but common:
-#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1))
-#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x))
-#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x))
-
-static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg)
-{
- if (msg->msg_controllen < sizeof(struct cmsghdr))
- return 0;
- return (struct cmsghdr*)msg->msg_control;
-}
-
-static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg)
-{
- struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len));
- unsigned offset = (char*)next - (char*)msg->msg_control;
- if (msg->msg_controllen < offset + sizeof(struct cmsghdr))
- return NULL;
- return next;
-}
-
-static inline void* CMSG_DATA(struct cmsghdr* cmsg)
-{
- return (void*)(cmsg + 1);
-}
-
__END_DECLS