diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/sys/socket/addr.rs | 103 | ||||
-rw-r--r-- | src/sys/socket/mod.rs | 13 |
2 files changed, 108 insertions, 8 deletions
diff --git a/src/sys/socket/addr.rs b/src/sys/socket/addr.rs index 48309749..12cdc7a4 100644 --- a/src/sys/socket/addr.rs +++ b/src/sys/socket/addr.rs @@ -763,6 +763,22 @@ impl SockaddrLike for UnixAddr { { mem::size_of::<libc::sockaddr_un>() as libc::socklen_t } + + unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + cfg_if! { + if #[cfg(any(target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "redox", + ))] { + self.sun_len = new_length as u8; + } else { + self.sun.sun_len = new_length as u8; + } + }; + Ok(()) + } } impl AsRef<libc::sockaddr_un> for UnixAddr { @@ -912,8 +928,30 @@ pub trait SockaddrLike: private::SockaddrLikePriv { { mem::size_of::<Self>() as libc::socklen_t } + + /// Set the length of this socket address + /// + /// This method may only be called on socket addresses whose lenghts are dynamic, and it + /// returns an error if called on a type whose length is static. + /// + /// # Safety + /// + /// `new_length` must be a valid length for this type of address. Specifically, reads of that + /// length from `self` must be valid. + unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic>; } +/// The error returned by [`SockaddrLike::set_length`] on an address whose length is statically +/// fixed. +#[derive(Copy, Clone, Debug)] +pub struct SocketAddressLengthNotDynamic; +impl fmt::Display for SocketAddressLengthNotDynamic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Attempted to set length on socket whose length is statically fixed") + } +} +impl std::error::Error for SocketAddressLengthNotDynamic {} + impl private::SockaddrLikePriv for () { fn as_mut_ptr(&mut self) -> *mut libc::sockaddr { ptr::null_mut() @@ -946,6 +984,10 @@ impl SockaddrLike for () { fn len(&self) -> libc::socklen_t { 0 } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } /// An IPv4 socket address @@ -1015,6 +1057,10 @@ impl SockaddrLike for SockaddrIn { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } #[cfg(feature = "net")] @@ -1134,6 +1180,10 @@ impl SockaddrLike for SockaddrIn6 { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } #[cfg(feature = "net")] @@ -1361,6 +1411,27 @@ impl SockaddrLike for SockaddrStorage { None => mem::size_of_val(self) as libc::socklen_t, } } + + unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + match self.as_unix_addr_mut() { + Some(addr) => { + cfg_if! { + if #[cfg(any(target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "redox", + ))] { + addr.sun_len = new_length as u8; + } else { + addr.sun.sun_len = new_length as u8; + } + } + Ok(()) + }, + None => Err(SocketAddressLengthNotDynamic), + } + } } macro_rules! accessors { @@ -1679,7 +1750,7 @@ impl PartialEq for SockaddrStorage { } } -mod private { +pub(super) mod private { pub trait SockaddrLikePriv { /// Returns a mutable raw pointer to the inner structure. /// @@ -1754,6 +1825,10 @@ pub mod netlink { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_nl> for NetlinkAddr { @@ -1803,6 +1878,10 @@ pub mod alg { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_alg> for AlgAddr { @@ -1902,7 +1981,7 @@ pub mod sys_control { use std::{fmt, mem, ptr}; use std::os::unix::io::RawFd; use crate::{Errno, Result}; - use super::{private, SockaddrLike}; + use super::{private, SockaddrLike, SocketAddressLengthNotDynamic}; // FIXME: Move type into `libc` #[repr(C)] @@ -1943,6 +2022,10 @@ pub mod sys_control { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_ctl> for SysControlAddr { @@ -2007,7 +2090,7 @@ pub mod sys_control { mod datalink { feature! { #![feature = "net"] - use super::{fmt, mem, private, ptr, SockaddrLike}; + use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic}; /// Hardware Address #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] @@ -2085,6 +2168,10 @@ mod datalink { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_ll> for LinkAddr { @@ -2110,7 +2197,7 @@ mod datalink { mod datalink { feature! { #![feature = "net"] - use super::{fmt, mem, private, ptr, SockaddrLike}; + use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic}; /// Hardware Address #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] @@ -2209,6 +2296,10 @@ mod datalink { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_dl> for LinkAddr { @@ -2257,6 +2348,10 @@ pub mod vsock { } Some(Self(ptr::read_unaligned(addr as *const _))) } + + unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> { + Err(SocketAddressLengthNotDynamic) + } } impl AsRef<libc::sockaddr_vm> for VsockAddr { diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 1e3438ea..c77bc961 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -1609,7 +1609,7 @@ impl<S> MultiHeaders<S> { { // we will be storing pointers to addresses inside mhdr - convert it into boxed // slice so it can'be changed later by pushing anything into self.addresses - let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice(); + let mut addresses = vec![std::mem::MaybeUninit::<S>::uninit(); num_slices].into_boxed_slice(); let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity()); @@ -1914,7 +1914,7 @@ unsafe fn read_mhdr<'a, 'i, S>( mhdr: msghdr, r: isize, msg_controllen: usize, - address: S, + mut address: S, ) -> RecvMsg<'a, 'i, S> where S: SockaddrLike { @@ -1930,6 +1930,11 @@ unsafe fn read_mhdr<'a, 'i, S>( }.as_ref() }; + // Ignore errors if this socket address has statically-known length + // + // This is to ensure that unix socket addresses have their length set appropriately. + let _ = address.set_length(mhdr.msg_namelen as usize); + RecvMsg { bytes: r as usize, cmsghdr, @@ -1965,7 +1970,7 @@ unsafe fn pack_mhdr_to_receive<S>( // initialize it. let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed(); let p = mhdr.as_mut_ptr(); - (*p).msg_name = (*address).as_mut_ptr() as *mut c_void; + (*p).msg_name = address as *mut c_void; (*p).msg_namelen = S::size(); (*p).msg_iov = iov_buffer as *mut iovec; (*p).msg_iovlen = iov_buffer_len as _; @@ -2048,7 +2053,7 @@ pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'i where S: SockaddrLike + 'a, 'inner: 'outer { - let mut address = mem::MaybeUninit::uninit(); + let mut address = mem::MaybeUninit::zeroed(); let (msg_control, msg_controllen) = cmsg_buffer.as_mut() .map(|v| (v.as_mut_ptr(), v.capacity())) |