diff options
author | AnotherTest <ali.mpfard@gmail.com> | 2020-04-05 01:16:45 +0430 |
---|---|---|
committer | Andreas Kling <kling@serenityos.org> | 2020-04-05 09:50:48 +0200 |
commit | 77191d82dc07319458442ff9905139e5789c9168 (patch) | |
tree | 45743d9a15bfae09be1eec6cffb1cdb38738270c | |
parent | 7d0bf9b5a9b635eb54f34843d17ea38d7c18092e (diff) | |
download | serenity-77191d82dc07319458442ff9905139e5789c9168.zip |
Kernel: Add the SO_BINDTODEVICE socket option
This patch adds a way for a socket to ask to be routed through a
specific interface.
Currently, this option only applies to sending, however, it should also
apply to receiving...somehow :^)
-rw-r--r-- | Kernel/Net/IPv4Socket.cpp | 2 | ||||
-rw-r--r-- | Kernel/Net/Routing.cpp | 22 | ||||
-rw-r--r-- | Kernel/Net/Routing.h | 2 | ||||
-rw-r--r-- | Kernel/Net/Socket.cpp | 24 | ||||
-rw-r--r-- | Kernel/Net/Socket.h | 10 | ||||
-rw-r--r-- | Kernel/Net/TCPSocket.cpp | 4 | ||||
-rw-r--r-- | Kernel/Net/UDPSocket.cpp | 2 | ||||
-rw-r--r-- | Kernel/UnixTypes.h | 1 | ||||
-rw-r--r-- | Libraries/LibC/sys/socket.h | 1 | ||||
-rw-r--r-- | Userland/test-bindtodevice.cpp | 131 |
10 files changed, 186 insertions, 13 deletions
diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 99b522b1ed..4d4f7c4ac2 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -210,7 +210,7 @@ ssize_t IPv4Socket::sendto(FileDescription&, const void* data, size_t data_lengt m_peer_port = ntohs(ia.sin_port); } - auto routing_decision = route_to(m_peer_address, m_local_address); + auto routing_decision = route_to(m_peer_address, m_local_address, bound_interface()); if (routing_decision.is_zero()) return -EHOSTUNREACH; diff --git a/Kernel/Net/Routing.cpp b/Kernel/Net/Routing.cpp index 84b5fee6ba..a147957898 100644 --- a/Kernel/Net/Routing.cpp +++ b/Kernel/Net/Routing.cpp @@ -46,10 +46,22 @@ bool RoutingDecision::is_zero() const return adapter.is_null() || next_hop.is_zero(); } -RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source) +RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, const RefPtr<NetworkAdapter> through) { + auto matches = [&](auto& adapter) { + if (!through) + return true; + + return through == adapter; + }; + auto if_matches = [&](auto& adapter, const auto& mac) -> RoutingDecision { + if (!matches(adapter)) + return { nullptr, {} }; + return { adapter, mac }; + }; + if (target[0] == 127) - return { LoopbackAdapter::the(), LoopbackAdapter::the().mac_address() }; + return if_matches(LoopbackAdapter::the(), LoopbackAdapter::the().mac_address()); auto target_addr = target.to_u32(); auto source_addr = source.to_u32(); @@ -57,17 +69,17 @@ RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source) RefPtr<NetworkAdapter> local_adapter = nullptr; RefPtr<NetworkAdapter> gateway_adapter = nullptr; - NetworkAdapter::for_each([source_addr, &target_addr, &local_adapter, &gateway_adapter](auto& adapter) { + NetworkAdapter::for_each([source_addr, &target_addr, &local_adapter, &gateway_adapter, &matches](auto& adapter) { auto adapter_addr = adapter.ipv4_address().to_u32(); auto adapter_mask = adapter.ipv4_netmask().to_u32(); if (source_addr != 0 && source_addr != adapter_addr) return; - if ((target_addr & adapter_mask) == (adapter_addr & adapter_mask)) + if ((target_addr & adapter_mask) == (adapter_addr & adapter_mask) && matches(adapter)) local_adapter = adapter; - if (adapter.ipv4_gateway().to_u32() != 0) + if (adapter.ipv4_gateway().to_u32() != 0 && matches(adapter)) gateway_adapter = adapter; }); diff --git a/Kernel/Net/Routing.h b/Kernel/Net/Routing.h index 40d61bcf6b..50c7d72ed3 100644 --- a/Kernel/Net/Routing.h +++ b/Kernel/Net/Routing.h @@ -37,7 +37,7 @@ struct RoutingDecision { bool is_zero() const; }; -RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source); +RoutingDecision route_to(const IPv4Address& target, const IPv4Address& source, const RefPtr<NetworkAdapter> through = nullptr); Lockable<HashMap<IPv4Address, MACAddress>>& arp_table(); diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index ed1092902a..0b063ada6e 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -25,6 +25,7 @@ */ #include <AK/StringBuilder.h> +#include <AK/StringView.h> #include <Kernel/FileSystem/FileDescription.h> #include <Kernel/Net/IPv4Socket.h> #include <Kernel/Net/LocalSocket.h> @@ -114,6 +115,16 @@ KResult Socket::setsockopt(int level, int option, const void* value, socklen_t v return KResult(-EINVAL); m_receive_timeout = *(const timeval*)value; return KSuccess; + case SO_BINDTODEVICE: { + if (value_size != IFNAMSIZ) + return KResult(-EINVAL); + StringView ifname { (const char*)value }; + auto device = NetworkAdapter::lookup_by_name(ifname); + if (!device) + return KResult(-ENODEV); + m_bound_interface = device; + return KSuccess; + } default: dbg() << "setsockopt(" << option << ") at SOL_SOCKET not implemented."; return KResult(-ENOPROTOOPT); @@ -143,6 +154,19 @@ KResult Socket::getsockopt(FileDescription&, int level, int option, void* value, *(int*)value = 0; *value_size = sizeof(int); return KSuccess; + case SO_BINDTODEVICE: + if (*value_size < IFNAMSIZ) + return KResult(-EINVAL); + if (m_bound_interface) { + const auto& name = m_bound_interface->name(); + auto length = name.length() + 1; + memcpy(value, name.characters(), length); + *value_size = length; + return KSuccess; + } else { + *value_size = 0; + return KResult(-EFAULT); + } default: dbg() << "getsockopt(" << option << ") at SOL_SOCKET not implemented."; return KResult(-ENOPROTOOPT); diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index 5e4feecdba..510b50eb37 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -33,6 +33,7 @@ #include <Kernel/FileSystem/File.h> #include <Kernel/KResult.h> #include <Kernel/Lock.h> +#include <Kernel/Net/NetworkAdapter.h> #include <Kernel/UnixTypes.h> namespace Kernel { @@ -57,9 +58,9 @@ public: bool is_shut_down_for_reading() const { return m_shut_down_for_reading; } enum class SetupState { - Unstarted, // we haven't tried to set the socket up yet + Unstarted, // we haven't tried to set the socket up yet InProgress, // we're in the process of setting things up - for TCP maybe we've sent a SYN packet - Completed, // the setup process is complete, but not necessarily successful + Completed, // the setup process is complete, but not necessarily successful }; enum class Role : u8 { @@ -118,6 +119,7 @@ public: pid_t acceptor_pid() const { return m_acceptor.pid; } uid_t acceptor_uid() const { return m_acceptor.uid; } gid_t acceptor_gid() const { return m_acceptor.gid; } + const RefPtr<NetworkAdapter> bound_interface() const { return m_bound_interface; } Lock& lock() { return m_lock; } @@ -165,13 +167,15 @@ private: bool m_shut_down_for_reading { false }; bool m_shut_down_for_writing { false }; + RefPtr<NetworkAdapter> m_bound_interface { nullptr }; + timeval m_receive_timeout { 0, 0 }; timeval m_send_timeout { 0, 0 }; NonnullRefPtrVector<Socket> m_pending; }; -template<typename SocketType> +template <typename SocketType> class SocketHandle { public: SocketHandle() {} diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index 9e8fed62cc..d29fac55db 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -212,7 +212,7 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, size_t payload_s return; } - auto routing_decision = route_to(peer_address(), local_address()); + auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); ASSERT(!routing_decision.is_zero()); routing_decision.adapter->send_ipv4( @@ -225,7 +225,7 @@ void TCPSocket::send_tcp_packet(u16 flags, const void* payload, size_t payload_s void TCPSocket::send_outgoing_packets() { - auto routing_decision = route_to(peer_address(), local_address()); + auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); ASSERT(!routing_decision.is_zero()); auto now = kgettimeofday(); diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 8794b9b7a6..2afd24bd89 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -92,7 +92,7 @@ int UDPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size int UDPSocket::protocol_send(const void* data, size_t data_length) { - auto routing_decision = route_to(peer_address(), local_address()); + auto routing_decision = route_to(peer_address(), local_address(), bound_interface()); if (routing_decision.is_zero()) return -EHOSTUNREACH; auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length); diff --git a/Kernel/UnixTypes.h b/Kernel/UnixTypes.h index 9706cbd9f2..d7eab8cf5f 100644 --- a/Kernel/UnixTypes.h +++ b/Kernel/UnixTypes.h @@ -401,6 +401,7 @@ struct pollfd { #define SO_ERROR 4 #define SO_PEERCRED 5 #define SO_REUSEADDR 6 +#define SO_BINDTODEVICE 7 #define IPPROTO_IP 0 #define IPPROTO_ICMP 1 diff --git a/Libraries/LibC/sys/socket.h b/Libraries/LibC/sys/socket.h index 2207b8c787..528c409f7e 100644 --- a/Libraries/LibC/sys/socket.h +++ b/Libraries/LibC/sys/socket.h @@ -80,6 +80,7 @@ struct ucred { #define SO_ERROR 4 #define SO_PEERCRED 5 #define SO_REUSEADDR 6 +#define SO_BINDTODEVICE 7 int socket(int domain, int type, int protocol); int bind(int sockfd, const struct sockaddr* addr, socklen_t); diff --git a/Userland/test-bindtodevice.cpp b/Userland/test-bindtodevice.cpp new file mode 100644 index 0000000000..c17373c235 --- /dev/null +++ b/Userland/test-bindtodevice.cpp @@ -0,0 +1,131 @@ +#include <AK/Function.h> +#include <AK/IPv4Address.h> +#include <net/if.h> +#include <netinet/in.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> + +void test_invalid(int); +void test_no_route(int); +void test_valid(int); +void test_send(int); + +void test(AK::Function<void(int)> test_fn) +{ + + int fd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (fd < 0) { + perror("socket"); + return; + } + + test_fn(fd); + + // be a responsible boi + close(fd); +} + +auto main() -> int +{ + test(test_invalid); + test(test_valid); + test(test_no_route); + test(test_send); +} + +void test_invalid(int fd) +{ + // bind to an interface that does not exist + char buf[IFNAMSIZ]; + socklen_t buflen = IFNAMSIZ; + memcpy(buf, "foodev", 7); + + if (setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, buf, buflen) < 0) { + perror("setsockopt(SO_BINDTODEVICE) :: invalid (Should fail with ENODEV)"); + puts("PASS invalid"); + } else { + puts("FAIL invalid"); + } +} + +void test_valid(int fd) +{ + // bind to an interface that exists + char buf[IFNAMSIZ]; + socklen_t buflen = IFNAMSIZ; + memcpy(buf, "loop0", 6); + + if (setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, buf, buflen) < 0) { + perror("setsockopt(SO_BINDTODEVICE) :: valid"); + puts("FAIL valid"); + } else { + puts("PASS valid"); + } +} + +void test_no_route(int fd) +{ + // bind to an interface that cannot deliver + char buf[IFNAMSIZ]; + socklen_t buflen = IFNAMSIZ; + memcpy(buf, "loop0", 6); + + if (setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, buf, buflen) < 0) { + perror("setsockopt(SO_BINDTODEVICE) :: no_route"); + puts("FAIL no_route"); + return; + } + sockaddr_in sin; + memset(&sin, 0, sizeof(sin)); + + sin.sin_addr.s_addr = IPv4Address { 10, 0, 2, 15 }.to_u32(); + sin.sin_port = 8080; + sin.sin_family = AF_INET; + + if (bind(fd, (sockaddr*)&sin, sizeof(sin)) < 0) { + perror("bind() :: no_route"); + puts("FAIL no_route"); + return; + } + + if (sendto(fd, "TEST", 4, 0, (sockaddr*)&sin, sizeof(sin)) < 0) { + perror("sendto() :: no_route (Should fail with EHOSTUNREACH)"); + puts("PASS no_route"); + } else + puts("FAIL no_route"); +} + +void test_send(int fd) +{ + // bind to an interface that cannot deliver + char buf[IFNAMSIZ]; + socklen_t buflen = IFNAMSIZ; + memcpy(buf, "e1k0", 5); + + if (setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, buf, buflen) < 0) { + perror("setsockopt(SO_BINDTODEVICE) :: send"); + puts("FAIL send"); + return; + } + sockaddr_in sin; + memset(&sin, 0, sizeof(sin)); + + sin.sin_addr.s_addr = IPv4Address { 10, 0, 2, 15 }.to_u32(); + sin.sin_port = 8080; + sin.sin_family = AF_INET; + + if (bind(fd, (sockaddr*)&sin, sizeof(sin)) < 0) { + perror("bind() :: send"); + puts("FAIL send"); + return; + } + + if (sendto(fd, "TEST", 4, 0, (sockaddr*)&sin, sizeof(sin)) < 0) { + perror("sendto() :: send"); + puts("FAIL send"); + return; + } + puts("PASS send"); +} |