diff options
author | Alan Somers <asomers@gmail.com> | 2019-01-14 23:06:21 -0700 |
---|---|---|
committer | Alan Somers <asomers@gmail.com> | 2019-02-14 10:15:49 -0700 |
commit | ed1c90d6bc9871d9881e8fcb05370752b6a5d25b (patch) | |
tree | 730f4933304e1af72ef47787896da94a70f6aa73 /src/sys/socket/mod.rs | |
parent | 842f5acf3e003e5600746c9f4ee0a3c02fbf3dd0 (diff) | |
download | nix-ed1c90d6bc9871d9881e8fcb05370752b6a5d25b.zip |
Replace hand-rolled cmsg logic with libc's cmsg(3) functions.
Our hand-rolled logic had subtle alignment bugs that caused
test_scm_rights to fail on OpenBSD (and probably could cause problems on
other platforms too). Using cmsg(3) is much cleaner, shorter, and more
portable. No user-visible changes.
Diffstat (limited to 'src/sys/socket/mod.rs')
-rw-r--r-- | src/sys/socket/mod.rs | 488 |
1 files changed, 227 insertions, 261 deletions
diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 64d2fc12..9527d0b3 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -3,7 +3,7 @@ //! [Further reading](http://man7.org/linux/man-pages/man7/socket.7.html) use {Error, Result}; use errno::Errno; -use libc::{self, c_void, c_int, socklen_t, size_t}; +use libc::{self, c_void, c_int, iovec, socklen_t, size_t}; use std::{fmt, mem, ptr, slice}; use std::os::unix::io::RawFd; use sys::time::TimeVal; @@ -32,6 +32,11 @@ pub use self::addr::{ pub use ::sys::socket::addr::netlink::NetlinkAddr; pub use libc::{ + CMSG_FIRSTHDR, + CMSG_NXTHDR, + CMSG_DATA, + CMSG_SPACE, + CMSG_LEN, cmsghdr, msghdr, sa_family_t, @@ -303,39 +308,6 @@ impl fmt::Debug for Ipv6MembershipRequest { } } -/// Copy the in-memory representation of `src` into the byte slice `dst`. -/// -/// Returns the remainder of `dst`. -/// -/// Panics when `dst` is too small for `src` (more precisely, panics if -/// `mem::size_of_val(src) >= dst.len()`). -/// -/// Unsafe because it transmutes `src` to raw bytes, which is only safe for some -/// types `T`. Refer to the [Rustonomicon] for details. -/// -/// [Rustonomicon]: https://doc.rust-lang.org/nomicon/transmutes.html -unsafe fn copy_bytes<'a, T: ?Sized>(src: &T, dst: &'a mut [u8]) -> &'a mut [u8] { - let srclen = mem::size_of_val(src); - ptr::copy_nonoverlapping( - src as *const T as *const u8, - dst[..srclen].as_mut_ptr(), - srclen - ); - - &mut dst[srclen..] -} - -/// Fills `dst` with `len` zero bytes and returns the remainder of the slice. -/// -/// Panics when `len >= dst.len()`. -fn pad_bytes(len: usize, dst: &mut [u8]) -> &mut [u8] { - for pad in &mut dst[..len] { - *pad = 0; - } - - &mut dst[len..] -} - cfg_if! { // Darwin and DragonFly BSD always align struct cmsghdr to 32-bit only. if #[cfg(any(target_os = "dragonfly", target_os = "ios", target_os = "macos"))] { @@ -375,13 +347,12 @@ impl<T> CmsgSpace<T> { } } -#[derive(Debug)] +#[allow(missing_debug_implementations)] // msghdr isn't Debug pub struct RecvMsg<'a> { - // The number of bytes received. - pub bytes: usize, - cmsg_buffer: &'a [u8], + cmsghdr: Option<&'a cmsghdr>, pub address: Option<SockAddr>, pub flags: MsgFlags, + mhdr: msghdr, } impl<'a> RecvMsg<'a> { @@ -389,45 +360,37 @@ impl<'a> RecvMsg<'a> { /// msghdr. pub fn cmsgs(&self) -> CmsgIterator { CmsgIterator { - buf: self.cmsg_buffer, + cmsghdr: self.cmsghdr, + mhdr: &self.mhdr } } } -#[derive(Debug)] +#[allow(missing_debug_implementations)] // msghdr isn't Debug pub struct CmsgIterator<'a> { /// Control message buffer to decode from. Must adhere to cmsg alignment. - buf: &'a [u8], + cmsghdr: Option<&'a cmsghdr>, + mhdr: &'a msghdr } impl<'a> Iterator for CmsgIterator<'a> { type Item = ControlMessage<'a>; - // The implementation loosely follows CMSG_FIRSTHDR / CMSG_NXTHDR, - // although we handle the invariants in slightly different places to - // get a better iterator interface. fn next(&mut self) -> Option<ControlMessage<'a>> { - if self.buf.len() == 0 { - // The iterator assumes that `self.buf` always contains exactly the - // bytes we need, so we're at the end when the buffer is empty. - return None; - } - - // Safe if: `self.buf` is `cmsghdr`-aligned. - let cmsg: &'a cmsghdr = unsafe { - &*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr) - }; - - let cmsg_len = cmsg.cmsg_len as usize; - - // Advance our internal pointer. - let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len]; - self.buf = &self.buf[cmsg_align(cmsg_len)..]; - - // Safe if: `cmsg_data` contains the expected (amount of) content. This - // is verified by the kernel. - unsafe { - Some(ControlMessage::decode_from(cmsg, cmsg_data)) + match self.cmsghdr { + None => None, // No more messages + Some(hdr) => { + // Get the data. + // Safe if cmsghdr points to valid data returned by recvmsg(2) + let cm = unsafe { Some(ControlMessage::decode_from(hdr))}; + // Advance the internal pointer. Safe if mhdr and cmsghdr point + // to valid data returned by recvmsg(2) + self.cmsghdr = unsafe { + let p = CMSG_NXTHDR(self.mhdr as *const _, hdr as *const _); + p.as_ref() + }; + cm + } } } } @@ -562,33 +525,96 @@ pub enum ControlMessage<'a> { #[allow(missing_debug_implementations)] pub struct UnknownCmsg<'a>(&'a cmsghdr, &'a [u8]); -// Round `len` up to meet the platform's required alignment for -// `cmsghdr`s and trailing `cmsghdr` data. This should match the -// behaviour of CMSG_ALIGN from the Linux headers and do the correct -// thing on other platforms that don't usually provide CMSG_ALIGN. -#[inline] -fn cmsg_align(len: usize) -> usize { - let align_bytes = mem::size_of::<align_of_cmsg_data>() - 1; - (len + align_bytes) & !align_bytes -} - impl<'a> ControlMessage<'a> { /// The value of CMSG_SPACE on this message. + /// Safe because CMSG_SPACE is always safe fn space(&self) -> usize { - cmsg_align(self.len()) + unsafe{CMSG_SPACE(self.len() as libc::c_uint) as usize} } /// The value of CMSG_LEN on this message. + /// Safe because CMSG_LEN is always safe + #[cfg(any(target_os = "android", + all(target_os = "linux", not(target_env = "musl"))))] + fn cmsg_len(&self) -> usize { + unsafe{CMSG_LEN(self.len() as libc::c_uint) as usize} + } + + #[cfg(not(any(target_os = "android", + all(target_os = "linux", not(target_env = "musl")))))] + fn cmsg_len(&self) -> libc::c_uint { + unsafe{CMSG_LEN(self.len() as libc::c_uint)} + } + + /// Return a reference to the payload data as a byte pointer + fn data(&self) -> *const u8 { + match self { + &ControlMessage::ScmRights(fds) => { + fds as *const [RawFd] as *const u8 + }, + #[cfg(any(target_os = "android", target_os = "linux"))] + &ControlMessage::ScmCredentials(creds) => { + creds as *const libc::ucred as *const u8 + } + &ControlMessage::ScmTimestamp(t) => { + t as *const TimeVal as *const u8 + }, + #[cfg(any( + target_os = "android", + target_os = "ios", + target_os = "linux", + target_os = "macos" + ))] + &ControlMessage::Ipv4PacketInfo(pktinfo) => { + pktinfo as *const libc::in_pktinfo as *const u8 + }, + #[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "ios", + target_os = "linux", + target_os = "macos" + ))] + &ControlMessage::Ipv6PacketInfo(pktinfo) => { + pktinfo as *const libc::in6_pktinfo as *const u8 + }, + #[cfg(any( + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + ))] + &ControlMessage::Ipv4RecvIf(dl) => { + dl as *const libc::sockaddr_dl as *const u8 + }, + #[cfg(any( + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + ))] + &ControlMessage::Ipv4RecvDstAddr(in_addr) => { + in_addr as *const libc::in_addr as *const u8 + }, + &ControlMessage::Unknown(UnknownCmsg(_, bytes)) => { + bytes as *const _ as *const u8 + } + } + } + + /// The size of the payload, excluding its cmsghdr fn len(&self) -> usize { - cmsg_align(mem::size_of::<cmsghdr>()) + match *self { - ControlMessage::ScmRights(fds) => { + match self { + &ControlMessage::ScmRights(fds) => { mem::size_of_val(fds) }, #[cfg(any(target_os = "android", target_os = "linux"))] - ControlMessage::ScmCredentials(creds) => { + &ControlMessage::ScmCredentials(creds) => { mem::size_of_val(creds) } - ControlMessage::ScmTimestamp(t) => { + &ControlMessage::ScmTimestamp(t) => { mem::size_of_val(t) }, #[cfg(any( @@ -597,7 +623,7 @@ impl<'a> ControlMessage<'a> { target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv4PacketInfo(pktinfo) => { + &ControlMessage::Ipv4PacketInfo(pktinfo) => { mem::size_of_val(pktinfo) }, #[cfg(any( @@ -607,7 +633,7 @@ impl<'a> ControlMessage<'a> { target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv6PacketInfo(pktinfo) => { + &ControlMessage::Ipv6PacketInfo(pktinfo) => { mem::size_of_val(pktinfo) }, #[cfg(any( @@ -617,7 +643,7 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvIf(dl) => { + &ControlMessage::Ipv4RecvIf(dl) => { mem::size_of_val(dl) }, #[cfg(any( @@ -627,10 +653,10 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvDstAddr(inaddr) => { + &ControlMessage::Ipv4RecvDstAddr(inaddr) => { mem::size_of_val(inaddr) }, - ControlMessage::Unknown(UnknownCmsg(_, bytes)) => { + &ControlMessage::Unknown(UnknownCmsg(_, bytes)) => { mem::size_of_val(bytes) } } @@ -638,18 +664,18 @@ impl<'a> ControlMessage<'a> { /// Returns the value to put into the `cmsg_level` field of the header. fn cmsg_level(&self) -> libc::c_int { - match *self { - ControlMessage::ScmRights(_) => libc::SOL_SOCKET, + match self { + &ControlMessage::ScmRights(_) => libc::SOL_SOCKET, #[cfg(any(target_os = "android", target_os = "linux"))] - ControlMessage::ScmCredentials(_) => libc::SOL_SOCKET, - ControlMessage::ScmTimestamp(_) => libc::SOL_SOCKET, + &ControlMessage::ScmCredentials(_) => libc::SOL_SOCKET, + &ControlMessage::ScmTimestamp(_) => libc::SOL_SOCKET, #[cfg(any( target_os = "android", target_os = "ios", target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv4PacketInfo(_) => libc::IPPROTO_IP, + &ControlMessage::Ipv4PacketInfo(_) => libc::IPPROTO_IP, #[cfg(any( target_os = "android", target_os = "freebsd", @@ -657,7 +683,7 @@ impl<'a> ControlMessage<'a> { target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv6PacketInfo(_) => libc::IPPROTO_IPV6, + &ControlMessage::Ipv6PacketInfo(_) => libc::IPPROTO_IPV6, #[cfg(any( target_os = "freebsd", target_os = "ios", @@ -665,7 +691,7 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvIf(_) => libc::IPPROTO_IP, + &ControlMessage::Ipv4RecvIf(_) => libc::IPPROTO_IP, #[cfg(any( target_os = "freebsd", target_os = "ios", @@ -673,25 +699,25 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvDstAddr(_) => libc::IPPROTO_IP, - ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_level, + &ControlMessage::Ipv4RecvDstAddr(_) => libc::IPPROTO_IP, + &ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_level, } } /// Returns the value to put into the `cmsg_type` field of the header. fn cmsg_type(&self) -> libc::c_int { - match *self { - ControlMessage::ScmRights(_) => libc::SCM_RIGHTS, + match self { + &ControlMessage::ScmRights(_) => libc::SCM_RIGHTS, #[cfg(any(target_os = "android", target_os = "linux"))] - ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS, - ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP, + &ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS, + &ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP, #[cfg(any( target_os = "android", target_os = "ios", target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv4PacketInfo(_) => libc::IP_PKTINFO, + &ControlMessage::Ipv4PacketInfo(_) => libc::IP_PKTINFO, #[cfg(any( target_os = "android", target_os = "freebsd", @@ -699,7 +725,7 @@ impl<'a> ControlMessage<'a> { target_os = "linux", target_os = "macos" ))] - ControlMessage::Ipv6PacketInfo(_) => libc::IPV6_PKTINFO, + &ControlMessage::Ipv6PacketInfo(_) => libc::IPV6_PKTINFO, #[cfg(any( target_os = "freebsd", target_os = "ios", @@ -707,7 +733,7 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvIf(_) => libc::IP_RECVIF, + &ControlMessage::Ipv4RecvIf(_) => libc::IP_RECVIF, #[cfg(any( target_os = "freebsd", target_os = "ios", @@ -715,93 +741,23 @@ impl<'a> ControlMessage<'a> { target_os = "netbsd", target_os = "openbsd", ))] - ControlMessage::Ipv4RecvDstAddr(_) => libc::IP_RECVDSTADDR, - ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type, + &ControlMessage::Ipv4RecvDstAddr(_) => libc::IP_RECVDSTADDR, + &ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type, } } - // Unsafe: start and end of buffer must be cmsg_align'd. Updates - // the provided slice; panics if the buffer is too small. - unsafe fn encode_into(&self, buf: &mut [u8]) { - let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self { - let &UnknownCmsg(orig_cmsg, bytes) = cmsg; - - let buf = copy_bytes(orig_cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - - mem::size_of_val(&orig_cmsg); - let buf = pad_bytes(padlen, buf); - - copy_bytes(bytes, buf) - } else { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: self.cmsg_level(), - cmsg_type: self.cmsg_type(), - ..mem::zeroed() // zero out platform-dependent padding fields - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - match *self { - ControlMessage::ScmRights(fds) => { - copy_bytes(fds, buf) - }, - #[cfg(any(target_os = "android", target_os = "linux"))] - ControlMessage::ScmCredentials(creds) => { - copy_bytes(creds, buf) - }, - ControlMessage::ScmTimestamp(t) => { - copy_bytes(t, buf) - }, - #[cfg(any( - target_os = "android", - target_os = "ios", - target_os = "linux", - target_os = "macos" - ))] - ControlMessage::Ipv4PacketInfo(pktinfo) => { - copy_bytes(pktinfo, buf) - }, - #[cfg(any( - target_os = "android", - target_os = "freebsd", - target_os = "ios", - target_os = "linux", - target_os = "macos" - ))] - ControlMessage::Ipv6PacketInfo(pktinfo) => { - copy_bytes(pktinfo, buf) - } - #[cfg(any( - target_os = "freebsd", - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd", - ))] - ControlMessage::Ipv4RecvIf(dl) => { - copy_bytes(dl, buf) - }, - #[cfg(any( - target_os = "freebsd", - target_os = "ios", - target_os = "macos", - target_os = "netbsd", - target_os = "openbsd", - ))] - ControlMessage::Ipv4RecvDstAddr(inaddr) => { - copy_bytes(inaddr, buf) - }, - ControlMessage::Unknown(_) => unreachable!(), - } - }; - - let padlen = self.space() - self.len(); - pad_bytes(padlen, final_buf); + // Unsafe: cmsg must point to a valid cmsghdr with enough space to + // encode self. + unsafe fn encode_into(&self, cmsg: *mut cmsghdr) { + (*cmsg).cmsg_level = self.cmsg_level(); + (*cmsg).cmsg_type = self.cmsg_type(); + (*cmsg).cmsg_len = self.cmsg_len(); + let data = self.data(); + ptr::copy_nonoverlapping( + data, + CMSG_DATA(cmsg), + self.len() + ); } /// Decodes a `ControlMessage` from raw bytes. @@ -810,22 +766,23 @@ impl<'a> ControlMessage<'a> { /// specified in the header. Normally, the kernel ensures that this is the /// case. "Correct" in this case includes correct length, alignment and /// actual content. - unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> { + unsafe fn decode_from(header: &'a cmsghdr) -> ControlMessage<'a> + { + let p = CMSG_DATA(header); + let len = header as *const _ as usize + header.cmsg_len as usize + - p as usize; match (header.cmsg_level, header.cmsg_type) { (libc::SOL_SOCKET, libc::SCM_RIGHTS) => { ControlMessage::ScmRights( - slice::from_raw_parts(data.as_ptr() as *const _, - data.len() / mem::size_of::<RawFd>())) + slice::from_raw_parts(p as *const RawFd, + len / mem::size_of::<RawFd>())) }, #[cfg(any(target_os = "android", target_os = "linux"))] (libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => { - ControlMessage::ScmCredentials( - &*(data.as_ptr() as *const _) - ) + ControlMessage::ScmCredentials(&*(p as *const _)) } (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => { - ControlMessage::ScmTimestamp( - &*(data.as_ptr() as *const _)) + ControlMessage::ScmTimestamp(&*(p as *const _)) }, #[cfg(any( target_os = "android", @@ -835,8 +792,7 @@ impl<'a> ControlMessage<'a> { target_os = "macos" ))] (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { - ControlMessage::Ipv6PacketInfo( - &*(data.as_ptr() as *const _)) + ControlMessage::Ipv6PacketInfo(&*(p as *const _)) } #[cfg(any( target_os = "android", @@ -845,8 +801,7 @@ impl<'a> ControlMessage<'a> { target_os = "macos" ))] (libc::IPPROTO_IP, libc::IP_PKTINFO) => { - ControlMessage::Ipv4PacketInfo( - &*(data.as_ptr() as *const _)) + ControlMessage::Ipv4PacketInfo(&*(p as *const _)) } #[cfg(any( target_os = "freebsd", @@ -856,8 +811,7 @@ impl<'a> ControlMessage<'a> { target_os = "openbsd", ))] (libc::IPPROTO_IP, libc::IP_RECVIF) => { - ControlMessage::Ipv4RecvIf( - &*(data.as_ptr() as *const _)) + ControlMessage::Ipv4RecvIf(&*(p as *const _)) } #[cfg(any( target_os = "freebsd", @@ -867,11 +821,11 @@ impl<'a> ControlMessage<'a> { target_os = "openbsd", ))] (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { - ControlMessage::Ipv4RecvDstAddr( - &*(data.as_ptr() as *const _)) + ControlMessage::Ipv4RecvDstAddr(&*(p as *const _)) } (_, _) => { + let data = slice::from_raw_parts(p, len); ControlMessage::Unknown(UnknownCmsg(header, data)) } } @@ -884,51 +838,60 @@ impl<'a> ControlMessage<'a> { /// as with sendto. /// /// Allocates if cmsgs is nonempty. -pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: MsgFlags, addr: Option<&'a SockAddr>) -> Result<usize> { - let mut capacity = 0; - for cmsg in cmsgs { - capacity += cmsg.space(); - } - // Note that the resulting vector claims to have length == capacity, - // so it's presently uninitialized. - let mut cmsg_buffer = unsafe { - let mut vec = Vec::<u8>::with_capacity(capacity); - vec.set_len(capacity); - vec - }; - { - let mut ofs = 0; - for cmsg in cmsgs { - let ptr = &mut cmsg_buffer[ofs..]; - unsafe { - cmsg.encode_into(ptr); - } - ofs += cmsg.space(); - } - } +pub fn sendmsg(fd: RawFd, iov: &[IoVec<&[u8]>], cmsgs: &[ControlMessage], + flags: MsgFlags, addr: Option<&SockAddr>) -> Result<usize> +{ + let capacity = cmsgs.iter().map(|c| c.space()).sum(); + + // First size the buffer needed to hold the cmsgs. It must be zeroed, + // because subsequent code will not clear the padding bytes. + let cmsg_buffer = vec![0u8; capacity]; + // Next encode the sending address, if provided let (name, namelen) = match addr { - Some(addr) => { let (x, y) = unsafe { addr.as_ffi_pair() }; (x as *const _, y) } + Some(addr) => { + let (x, y) = unsafe { addr.as_ffi_pair() }; + (x as *const _, y) + }, None => (ptr::null(), 0), }; + // The message header must be initialized before the individual cmsgs. let cmsg_ptr = if capacity > 0 { - cmsg_buffer.as_ptr() as *const c_void + cmsg_buffer.as_ptr() as *mut c_void } else { - ptr::null() + ptr::null_mut() }; - let mhdr = unsafe { - let mut mhdr: msghdr = mem::uninitialized(); - mhdr.msg_name = name as *mut _; - mhdr.msg_namelen = namelen; - mhdr.msg_iov = iov.as_ptr() as *mut _; - mhdr.msg_iovlen = iov.len() as _; - mhdr.msg_control = cmsg_ptr as *mut _; - mhdr.msg_controllen = capacity as _; - mhdr.msg_flags = 0; + let mhdr = { + // Musl's msghdr has private fields, so this is the only way to + // initialize it. + let mut mhdr: msghdr = unsafe{mem::uninitialized()}; + mhdr.msg_name = name as *mut _; + mhdr.msg_namelen = namelen; + // transmute iov into a mutable pointer. sendmsg doesn't really mutate + // the buffer, but the standard says that it takes a mutable pointer + mhdr.msg_iov = iov.as_ptr() as *mut _; + mhdr.msg_iovlen = iov.len() as _; + mhdr.msg_control = cmsg_ptr; + mhdr.msg_controllen = capacity as _; + mhdr.msg_flags = 0; mhdr }; + + // Encode each cmsg. This must happen after initializing the header because + // CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields. + // CMSG_FIRSTHDR is always safe + let mut pmhdr: *mut cmsghdr = unsafe{CMSG_FIRSTHDR(&mhdr as *const msghdr)}; + for cmsg in cmsgs { + assert_ne!(pmhdr, ptr::null_mut()); + // Safe because we know that pmhdr is valid, and we initialized it with + // sufficient space + unsafe { cmsg.encode_into(pmhdr) }; + // Safe because mhdr is valid + pmhdr = unsafe{CMSG_NXTHDR(&mhdr as *const msghdr, pmhdr)}; + } + let ret = unsafe { libc::sendmsg(fd, &mhdr, flags.bits()) }; Errno::result(ret).map(|r| r as usize) @@ -937,45 +900,48 @@ pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<' /// Receive message in scatter-gather vectors from a socket, and /// optionally receive ancillary data into the provided buffer. /// If no ancillary data is desired, use () as the type parameter. -pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<&'a mut CmsgSpace<T>>, flags: MsgFlags) -> Result<RecvMsg<'a>> { +pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], + cmsg_buffer: Option<&'a mut CmsgSpace<T>>, + flags: MsgFlags) -> Result<RecvMsg<'a>> +{ let mut address: sockaddr_storage = unsafe { mem::uninitialized() }; let (msg_control, msg_controllen) = match cmsg_buffer { Some(cmsg_buffer) => (cmsg_buffer as *mut _, mem::size_of_val(cmsg_buffer)), None => (ptr::null_mut(), 0), }; - let mut mhdr = unsafe { - let mut mhdr: msghdr = mem::uninitialized(); - mhdr.msg_name = &mut address as *mut _ as *mut _; - mhdr.msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t; - mhdr.msg_iov = iov.as_ptr() as *mut _; - mhdr.msg_iovlen = iov.len() as _; - mhdr.msg_control = msg_control as *mut _; - mhdr.msg_controllen = msg_controllen as _; - mhdr.msg_flags = 0; + let mut mhdr = { + // Musl's msghdr has private fields, so this is the only way to + // initialize it. + let mut mhdr: msghdr = unsafe{mem::uninitialized()}; + mhdr.msg_name = &mut address as *mut sockaddr_storage as *mut c_void; + mhdr.msg_namelen = mem::size_of::<sockaddr_storage>() as socklen_t; + mhdr.msg_iov = iov.as_ptr() as *mut iovec; + mhdr.msg_iovlen = iov.len() as _; + mhdr.msg_control = msg_control as *mut c_void; + mhdr.msg_controllen = msg_controllen as _; + mhdr.msg_flags = 0; mhdr }; - let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; - - let cmsg_buffer = if msg_controllen > 0 { - // got control message(s) - debug_assert!(!mhdr.msg_control.is_null()); - unsafe { - // Safe: The pointer is not null and the length is correct as part of `recvmsg`s - // contract. - slice::from_raw_parts(mhdr.msg_control as *const u8, - mhdr.msg_controllen as usize) - } - } else { - // No control message, create an empty buffer to avoid creating a slice from a null pointer - &[] + + let _ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; + + let cmsghdr = unsafe { + if mhdr.msg_controllen > 0 { + // got control message(s) + debug_assert!(!mhdr.msg_control.is_null()); + debug_assert!(msg_controllen >= mhdr.msg_controllen as usize); + CMSG_FIRSTHDR(&mhdr as *const msghdr) + } else { + ptr::null() + }.as_ref() }; Ok(unsafe { RecvMsg { - bytes: Errno::result(ret)? as usize, - cmsg_buffer, + cmsghdr, address: sockaddr_storage_to_addr(&address, mhdr.msg_namelen as usize).ok(), flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), + mhdr, } }) } |