diff options
author | Jonas Schievink <jonasschievink@gmail.com> | 2018-07-06 02:23:22 +0200 |
---|---|---|
committer | Jonas Schievink <jonasschievink@gmail.com> | 2018-07-27 19:50:37 +0200 |
commit | 9f0af4479742386c4ce30d05ad20e2450bbd0d54 (patch) | |
tree | c970e39408b21bc1f6994dddf2415267d1b41ad7 /src/sys/socket/mod.rs | |
parent | 237ec7bc13d045f21ae653c74bfd41fe411860f9 (diff) | |
download | nix-9f0af4479742386c4ce30d05ad20e2450bbd0d54.zip |
Fix *decoding* of cmsgs and add `ScmCredentials`.
Diffstat (limited to 'src/sys/socket/mod.rs')
-rw-r--r-- | src/sys/socket/mod.rs | 214 |
1 files changed, 124 insertions, 90 deletions
diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index ecef05d8..e9537c4b 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -205,6 +205,18 @@ cfg_if! { } impl Eq for UnixCredentials {} + impl From<libc::ucred> for UnixCredentials { + fn from(cred: libc::ucred) -> Self { + UnixCredentials(cred) + } + } + + impl Into<libc::ucred> for UnixCredentials { + fn into(self) -> libc::ucred { + self.0 + } + } + impl fmt::Debug for UnixCredentials { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("UnixCredentials") @@ -359,7 +371,7 @@ impl<T> CmsgSpace<T> { } } -#[allow(missing_debug_implementations)] +#[derive(Debug)] pub struct RecvMsg<'a> { // The number of bytes received. pub bytes: usize, @@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> { pub fn cmsgs(&self) -> CmsgIterator { CmsgIterator { buf: self.cmsg_buffer, - next: 0 } } } -#[allow(missing_debug_implementations)] +#[derive(Debug)] pub struct CmsgIterator<'a> { + /// Control message buffer to decode from. Must adhere to cmsg alignment. buf: &'a [u8], - next: usize, } impl<'a> Iterator for CmsgIterator<'a> { @@ -392,53 +403,27 @@ impl<'a> Iterator for CmsgIterator<'a> { // although we handle the invariants in slightly different places to // get a better iterator interface. fn next(&mut self) -> Option<ControlMessage<'a>> { - let sizeof_cmsghdr = mem::size_of::<cmsghdr>(); - if self.buf.len() < sizeof_cmsghdr { + if self.buf.len() == 0 { + // The iterator assumes that `self.buf` always contains exactly the + // bytes we need, so we're at the end when the buffer is empty. return None; } - let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) }; - // This check is only in the glibc implementation of CMSG_NXTHDR - // (although it claims the kernel header checks this), but such - // a structure is clearly invalid, either way. - let cmsg_len = cmsg.cmsg_len as usize; - if cmsg_len < sizeof_cmsghdr { - return None; - } - let len = cmsg_len - sizeof_cmsghdr; - let aligned_cmsg_len = if self.next == 0 { - // CMSG_FIRSTHDR - cmsg_len - } else { - // CMSG_NXTHDR - cmsg_align(cmsg_len) + // Safe if: `self.buf` is `cmsghdr`-aligned. + let cmsg: &'a cmsghdr = unsafe { + &*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr) }; + let cmsg_len = cmsg.cmsg_len as usize; + // Advance our internal pointer. - if aligned_cmsg_len > self.buf.len() { - return None; - } - let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len]; - self.buf = &self.buf[aligned_cmsg_len..]; - self.next += 1; - - match (cmsg.cmsg_level, cmsg.cmsg_type) { - (libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe { - Some(ControlMessage::ScmRights( - slice::from_raw_parts(cmsg_data.as_ptr() as *const _, - cmsg_data.len() / mem::size_of::<RawFd>()))) - }, - (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe { - Some(ControlMessage::ScmTimestamp( - &*(cmsg_data.as_ptr() as *const _))) - }, - (_, _) => unsafe { - Some(ControlMessage::Unknown(UnknownCmsg( - cmsg, - slice::from_raw_parts( - cmsg_data.as_ptr() as *const _, - len)))) - } + let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len]; + self.buf = &self.buf[cmsg_align(cmsg_len)..]; + + // Safe if: `cmsg_data` contains the expected (amount of) content. This + // is verified by the kernel. + unsafe { + Some(ControlMessage::decode_from(cmsg, cmsg_data)) } } } @@ -459,6 +444,20 @@ pub enum ControlMessage<'a> { /// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights` /// message. ScmRights(&'a [RawFd]), + /// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of + /// a process connected to the socket. + /// + /// This is similar to the socket option `SO_PEERCRED`, but requires a + /// process to explicitly send its credentials. A process running as root is + /// allowed to specify any credentials, while credentials sent by other + /// processes are verified by the kernel. + /// + /// For further information, please refer to the + /// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page. + // FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials` + // and put that in here instead of a raw ucred. + #[cfg(any(target_os = "android", target_os = "linux"))] + ScmCredentials(&'a libc::ucred), /// A message of type `SCM_TIMESTAMP`, containing the time the /// packet was received by the kernel. /// @@ -527,6 +526,7 @@ pub enum ControlMessage<'a> { /// nix::unistd::close(in_socket).unwrap(); /// ``` ScmTimestamp(&'a TimeVal), + /// Catch-all variant for unimplemented cmsg types. #[doc(hidden)] Unknown(UnknownCmsg<'a>), } @@ -558,6 +558,10 @@ impl<'a> ControlMessage<'a> { ControlMessage::ScmRights(fds) => { mem::size_of_val(fds) }, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(creds) => { + mem::size_of_val(creds) + } ControlMessage::ScmTimestamp(t) => { mem::size_of_val(t) }, @@ -567,57 +571,87 @@ impl<'a> ControlMessage<'a> { } } + /// Returns the value to put into the `cmsg_type` field of the header. + fn cmsg_type(&self) -> libc::c_int { + match *self { + ControlMessage::ScmRights(_) => libc::SCM_RIGHTS, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS, + ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP, + ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type, + } + } + // Unsafe: start and end of buffer must be cmsg_align'd. Updates // the provided slice; panics if the buffer is too small. unsafe fn encode_into(&self, buf: &mut [u8]) { - match *self { - ControlMessage::ScmRights(fds) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_RIGHTS, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(fds, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::ScmTimestamp(t) => { - let cmsg = cmsghdr { - cmsg_len: self.len() as _, - cmsg_level: libc::SOL_SOCKET, - cmsg_type: libc::SCM_TIMESTAMP, - ..mem::uninitialized() - }; - let buf = copy_bytes(&cmsg, buf); - - let padlen = cmsg_align(mem::size_of_val(&cmsg)) - - mem::size_of_val(&cmsg); - let buf = pad_bytes(padlen, buf); - - let buf = copy_bytes(t, buf); - - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); - }, - ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => { - let buf = copy_bytes(orig_cmsg, buf); + let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self { + let &UnknownCmsg(orig_cmsg, bytes) = cmsg; + + let buf = copy_bytes(orig_cmsg, buf); - let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - - mem::size_of_val(&orig_cmsg); - let buf = pad_bytes(padlen, buf); + let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) - + mem::size_of_val(&orig_cmsg); + let buf = pad_bytes(padlen, buf); - let buf = copy_bytes(bytes, buf); + copy_bytes(bytes, buf) + } else { + let cmsg = cmsghdr { + cmsg_len: self.len() as _, + cmsg_level: libc::SOL_SOCKET, + cmsg_type: self.cmsg_type(), + ..mem::zeroed() // zero out platform-dependent padding fields + }; + let buf = copy_bytes(&cmsg, buf); + + let padlen = cmsg_align(mem::size_of_val(&cmsg)) - + mem::size_of_val(&cmsg); + let buf = pad_bytes(padlen, buf); + + match *self { + ControlMessage::ScmRights(fds) => { + copy_bytes(fds, buf) + }, + #[cfg(any(target_os = "android", target_os = "linux"))] + ControlMessage::ScmCredentials(creds) => { + copy_bytes(creds, buf) + } + ControlMessage::ScmTimestamp(t) => { + copy_bytes(t, buf) + }, + ControlMessage::Unknown(_) => unreachable!(), + } + }; - let padlen = self.space() - self.len(); - pad_bytes(padlen, buf); + let padlen = self.space() - self.len(); + pad_bytes(padlen, final_buf); + } + + /// Decodes a `ControlMessage` from raw bytes. + /// + /// This is only safe to call if the data is correct for the message type + /// specified in the header. Normally, the kernel ensures that this is the + /// case. "Correct" in this case includes correct length, alignment and + /// actual content. + unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> { + match (header.cmsg_level, header.cmsg_type) { + (libc::SOL_SOCKET, libc::SCM_RIGHTS) => { + ControlMessage::ScmRights( + slice::from_raw_parts(data.as_ptr() as *const _, + data.len() / mem::size_of::<RawFd>())) + }, + #[cfg(any(target_os = "android", target_os = "linux"))] + (libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => { + ControlMessage::ScmCredentials( + &*(data.as_ptr() as *const _) + ) + } + (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => { + ControlMessage::ScmTimestamp( + &*(data.as_ptr() as *const _)) + }, + (_, _) => { + ControlMessage::Unknown(UnknownCmsg(header, data)) } } } |