From 5d6dc265b313aa45868cedc972a4b18114972da9 Mon Sep 17 00:00:00 2001 From: alecmocatta Date: Mon, 19 Mar 2018 20:07:31 +0000 Subject: Clean up cmsg code and fix passing multiple cmsgs --- src/sys/socket/mod.rs | 113 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 34 deletions(-) (limited to 'src/sys') diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index b46fa8b0..0706618a 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -287,18 +287,37 @@ impl fmt::Debug for Ipv6MembershipRequest { } } -/// Copy the in-memory representation of src into the byte slice dst, -/// updating the slice to point to the remainder of dst only. Unsafe -/// because it exposes all bytes in src, which may be UB if some of them -/// are uninitialized (including padding). -unsafe fn copy_bytes<'a, 'b, T: ?Sized>(src: &T, dst: &'a mut &'b mut [u8]) { +/// Copy the in-memory representation of `src` into the byte slice `dst`. +/// +/// Returns the remainder of `dst`. +/// +/// Panics when `dst` is too small for `src` (more precisely, panics if +/// `mem::size_of_val(src) >= dst.len()`). +/// +/// Unsafe because it transmutes `src` to raw bytes, which is only safe for some +/// types `T`. Refer to the [Rustonomicon] for details. +/// +/// [Rustonomicon]: https://doc.rust-lang.org/nomicon/transmutes.html +unsafe fn copy_bytes<'a, T: ?Sized>(src: &T, dst: &'a mut [u8]) -> &'a mut [u8] { let srclen = mem::size_of_val(src); - let mut tmpdst = &mut [][..]; - mem::swap(&mut tmpdst, dst); - let (target, mut remainder) = tmpdst.split_at_mut(srclen); - // Safe because the mutable borrow of dst guarantees that src does not alias it. - ptr::copy_nonoverlapping(src as *const T as *const u8, target.as_mut_ptr(), srclen); - mem::swap(dst, &mut remainder); + ptr::copy_nonoverlapping( + src as *const T as *const u8, + dst[..srclen].as_mut_ptr(), + srclen + ); + + &mut dst[srclen..] +} + +/// Fills `dst` with `len` zero bytes and returns the remainder of the slice. +/// +/// Panics when `len >= dst.len()`. +fn pad_bytes(len: usize, dst: &mut [u8]) -> &mut [u8] { + for pad in &mut dst[..len] { + *pad = 0; + } + + &mut dst[len..] } cfg_if! { @@ -434,6 +453,11 @@ pub enum ControlMessage<'a> { /// /// See the description in the "Ancillary messages" section of the /// [unix(7) man page](http://man7.org/linux/man-pages/man7/unix.7.html). + /// + /// Using multiple `ScmRights` messages for a single `sendmsg` call isn't recommended since it + /// causes platform-dependent behaviour: It might swallow all but the first `ScmRights` message + /// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights` + /// message. ScmRights(&'a [RawFd]), /// A message of type `SCM_TIMESTAMP`, containing the time the /// packet was received by the kernel. @@ -545,7 +569,7 @@ impl<'a> ControlMessage<'a> { // Unsafe: start and end of buffer must be cmsg_align'd. Updates // the provided slice; panics if the buffer is too small. - unsafe fn encode_into<'b>(&self, buf: &mut &'b mut [u8]) { + unsafe fn encode_into(&self, buf: &mut [u8]) { match *self { ControlMessage::ScmRights(fds) => { let cmsg = cmsghdr { @@ -554,17 +578,16 @@ impl<'a> ControlMessage<'a> { cmsg_type: libc::SCM_RIGHTS, ..mem::uninitialized() }; - copy_bytes(&cmsg, buf); + let buf = copy_bytes(&cmsg, buf); let padlen = cmsg_align(mem::size_of_val(&cmsg)) - mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); - let mut tmpbuf = &mut [][..]; - mem::swap(&mut tmpbuf, buf); - let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen); - mem::swap(buf, &mut remainder); + let buf = copy_bytes(fds, buf); - copy_bytes(fds, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); }, ControlMessage::ScmTimestamp(t) => { let cmsg = cmsghdr { @@ -573,21 +596,28 @@ impl<'a> ControlMessage<'a> { cmsg_type: libc::SCM_TIMESTAMP, ..mem::uninitialized() }; - copy_bytes(&cmsg, buf); + let buf = copy_bytes(&cmsg, buf); let padlen = cmsg_align(mem::size_of_val(&cmsg)) - mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); - let mut tmpbuf = &mut [][..]; - mem::swap(&mut tmpbuf, buf); - let (_padding, mut remainder) = tmpbuf.split_at_mut(padlen); - mem::swap(buf, &mut remainder); + let buf = copy_bytes(t, buf); - copy_bytes(t, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); }, ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { - copy_bytes(orig_cmsg, buf); - copy_bytes(bytes, buf); + let buf = copy_bytes(orig_cmsg, buf); + + let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - + mem::size_of_val(&orig_cmsg); + let buf = pad_bytes(padlen, buf); + + let buf = copy_bytes(bytes, buf); + + let padlen = self.space() - self.len(); + pad_bytes(padlen, buf); } } } @@ -600,23 +630,25 @@ impl<'a> ControlMessage<'a> { /// /// Allocates if cmsgs is nonempty. pub fn sendmsg<'a>(fd: RawFd, iov: &[IoVec<&'a [u8]>], cmsgs: &[ControlMessage<'a>], flags: MsgFlags, addr: Option<&'a SockAddr>) -> Result { - let mut len = 0; let mut capacity = 0; for cmsg in cmsgs { - len += cmsg.len(); capacity += cmsg.space(); } // Note that the resulting vector claims to have length == capacity, // so it's presently uninitialized. let mut cmsg_buffer = unsafe { - let mut vec = Vec::::with_capacity(len); - vec.set_len(len); + let mut vec = Vec::::with_capacity(capacity); + vec.set_len(capacity); vec }; { - let mut ptr = &mut cmsg_buffer[..]; + let mut ofs = 0; for cmsg in cmsgs { - unsafe { cmsg.encode_into(&mut ptr) }; + let mut ptr = &mut cmsg_buffer[ofs..]; + unsafe { + cmsg.encode_into(ptr); + } + ofs += cmsg.space(); } } @@ -669,10 +701,23 @@ pub fn recvmsg<'a, T>(fd: RawFd, iov: &[IoVec<&mut [u8]>], cmsg_buffer: Option<& }; let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) }; + let cmsg_buffer = if msg_controllen > 0 { + // got control message(s) + debug_assert!(!mhdr.msg_control.is_null()); + unsafe { + // Safe: The pointer is not null and the length is correct as part of `recvmsg`s + // contract. + slice::from_raw_parts(mhdr.msg_control as *const u8, + mhdr.msg_controllen as usize) + } + } else { + // No control message, create an empty buffer to avoid creating a slice from a null pointer + &[] + }; + Ok(unsafe { RecvMsg { bytes: try!(Errno::result(ret)) as usize, - cmsg_buffer: slice::from_raw_parts(mhdr.msg_control as *const u8, - mhdr.msg_controllen as usize), + cmsg_buffer, address: sockaddr_storage_to_addr(&address, mhdr.msg_namelen as usize).ok(), flags: MsgFlags::from_bits_truncate(mhdr.msg_flags), -- cgit v1.2.3