summaryrefslogtreecommitdiff
path: root/src/sys/socket/mod.rs
diff options
context:
space:
mode:
authorJonas Schievink <jonasschievink@gmail.com>2018-07-06 02:23:22 +0200
committerJonas Schievink <jonasschievink@gmail.com>2018-07-27 19:50:37 +0200
commit9f0af4479742386c4ce30d05ad20e2450bbd0d54 (patch)
treec970e39408b21bc1f6994dddf2415267d1b41ad7 /src/sys/socket/mod.rs
parent237ec7bc13d045f21ae653c74bfd41fe411860f9 (diff)
downloadnix-9f0af4479742386c4ce30d05ad20e2450bbd0d54.zip
Fix *decoding* of cmsgs and add `ScmCredentials`.
Diffstat (limited to 'src/sys/socket/mod.rs')
-rw-r--r--src/sys/socket/mod.rs214
1 files changed, 124 insertions, 90 deletions
diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs
index ecef05d8..e9537c4b 100644
--- a/src/sys/socket/mod.rs
+++ b/src/sys/socket/mod.rs
@@ -205,6 +205,18 @@ cfg_if! {
}
impl Eq for UnixCredentials {}
+ impl From<libc::ucred> for UnixCredentials {
+ fn from(cred: libc::ucred) -> Self {
+ UnixCredentials(cred)
+ }
+ }
+
+ impl Into<libc::ucred> for UnixCredentials {
+ fn into(self) -> libc::ucred {
+ self.0
+ }
+ }
+
impl fmt::Debug for UnixCredentials {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UnixCredentials")
@@ -359,7 +371,7 @@ impl<T> CmsgSpace<T> {
}
}
-#[allow(missing_debug_implementations)]
+#[derive(Debug)]
pub struct RecvMsg<'a> {
// The number of bytes received.
pub bytes: usize,
@@ -374,15 +386,14 @@ impl<'a> RecvMsg<'a> {
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
buf: self.cmsg_buffer,
- next: 0
}
}
}
-#[allow(missing_debug_implementations)]
+#[derive(Debug)]
pub struct CmsgIterator<'a> {
+ /// Control message buffer to decode from. Must adhere to cmsg alignment.
buf: &'a [u8],
- next: usize,
}
impl<'a> Iterator for CmsgIterator<'a> {
@@ -392,53 +403,27 @@ impl<'a> Iterator for CmsgIterator<'a> {
// although we handle the invariants in slightly different places to
// get a better iterator interface.
fn next(&mut self) -> Option<ControlMessage<'a>> {
- let sizeof_cmsghdr = mem::size_of::<cmsghdr>();
- if self.buf.len() < sizeof_cmsghdr {
+ if self.buf.len() == 0 {
+ // The iterator assumes that `self.buf` always contains exactly the
+ // bytes we need, so we're at the end when the buffer is empty.
return None;
}
- let cmsg: &'a cmsghdr = unsafe { &*(self.buf.as_ptr() as *const cmsghdr) };
- // This check is only in the glibc implementation of CMSG_NXTHDR
- // (although it claims the kernel header checks this), but such
- // a structure is clearly invalid, either way.
- let cmsg_len = cmsg.cmsg_len as usize;
- if cmsg_len < sizeof_cmsghdr {
- return None;
- }
- let len = cmsg_len - sizeof_cmsghdr;
- let aligned_cmsg_len = if self.next == 0 {
- // CMSG_FIRSTHDR
- cmsg_len
- } else {
- // CMSG_NXTHDR
- cmsg_align(cmsg_len)
+ // Safe if: `self.buf` is `cmsghdr`-aligned.
+ let cmsg: &'a cmsghdr = unsafe {
+ &*(self.buf[..mem::size_of::<cmsghdr>()].as_ptr() as *const cmsghdr)
};
+ let cmsg_len = cmsg.cmsg_len as usize;
+
// Advance our internal pointer.
- if aligned_cmsg_len > self.buf.len() {
- return None;
- }
- let cmsg_data = &self.buf[cmsg_align(sizeof_cmsghdr)..cmsg_len];
- self.buf = &self.buf[aligned_cmsg_len..];
- self.next += 1;
-
- match (cmsg.cmsg_level, cmsg.cmsg_type) {
- (libc::SOL_SOCKET, libc::SCM_RIGHTS) => unsafe {
- Some(ControlMessage::ScmRights(
- slice::from_raw_parts(cmsg_data.as_ptr() as *const _,
- cmsg_data.len() / mem::size_of::<RawFd>())))
- },
- (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => unsafe {
- Some(ControlMessage::ScmTimestamp(
- &*(cmsg_data.as_ptr() as *const _)))
- },
- (_, _) => unsafe {
- Some(ControlMessage::Unknown(UnknownCmsg(
- cmsg,
- slice::from_raw_parts(
- cmsg_data.as_ptr() as *const _,
- len))))
- }
+ let cmsg_data = &self.buf[cmsg_align(mem::size_of::<cmsghdr>())..cmsg_len];
+ self.buf = &self.buf[cmsg_align(cmsg_len)..];
+
+ // Safe if: `cmsg_data` contains the expected (amount of) content. This
+ // is verified by the kernel.
+ unsafe {
+ Some(ControlMessage::decode_from(cmsg, cmsg_data))
}
}
}
@@ -459,6 +444,20 @@ pub enum ControlMessage<'a> {
/// or fail with `EINVAL`. Instead, you can put all fds to be passed into a single `ScmRights`
/// message.
ScmRights(&'a [RawFd]),
+ /// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of
+ /// a process connected to the socket.
+ ///
+ /// This is similar to the socket option `SO_PEERCRED`, but requires a
+ /// process to explicitly send its credentials. A process running as root is
+ /// allowed to specify any credentials, while credentials sent by other
+ /// processes are verified by the kernel.
+ ///
+ /// For further information, please refer to the
+ /// [`unix(7)`](http://man7.org/linux/man-pages/man7/unix.7.html) man page.
+ // FIXME: When `#[repr(transparent)]` is stable, use it on `UnixCredentials`
+ // and put that in here instead of a raw ucred.
+ #[cfg(any(target_os = "android", target_os = "linux"))]
+ ScmCredentials(&'a libc::ucred),
/// A message of type `SCM_TIMESTAMP`, containing the time the
/// packet was received by the kernel.
///
@@ -527,6 +526,7 @@ pub enum ControlMessage<'a> {
/// nix::unistd::close(in_socket).unwrap();
/// ```
ScmTimestamp(&'a TimeVal),
+ /// Catch-all variant for unimplemented cmsg types.
#[doc(hidden)]
Unknown(UnknownCmsg<'a>),
}
@@ -558,6 +558,10 @@ impl<'a> ControlMessage<'a> {
ControlMessage::ScmRights(fds) => {
mem::size_of_val(fds)
},
+ #[cfg(any(target_os = "android", target_os = "linux"))]
+ ControlMessage::ScmCredentials(creds) => {
+ mem::size_of_val(creds)
+ }
ControlMessage::ScmTimestamp(t) => {
mem::size_of_val(t)
},
@@ -567,57 +571,87 @@ impl<'a> ControlMessage<'a> {
}
}
+ /// Returns the value to put into the `cmsg_type` field of the header.
+ fn cmsg_type(&self) -> libc::c_int {
+ match *self {
+ ControlMessage::ScmRights(_) => libc::SCM_RIGHTS,
+ #[cfg(any(target_os = "android", target_os = "linux"))]
+ ControlMessage::ScmCredentials(_) => libc::SCM_CREDENTIALS,
+ ControlMessage::ScmTimestamp(_) => libc::SCM_TIMESTAMP,
+ ControlMessage::Unknown(ref cmsg) => cmsg.0.cmsg_type,
+ }
+ }
+
// Unsafe: start and end of buffer must be cmsg_align'd. Updates
// the provided slice; panics if the buffer is too small.
unsafe fn encode_into(&self, buf: &mut [u8]) {
- match *self {
- ControlMessage::ScmRights(fds) => {
- let cmsg = cmsghdr {
- cmsg_len: self.len() as _,
- cmsg_level: libc::SOL_SOCKET,
- cmsg_type: libc::SCM_RIGHTS,
- ..mem::uninitialized()
- };
- let buf = copy_bytes(&cmsg, buf);
-
- let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
- mem::size_of_val(&cmsg);
- let buf = pad_bytes(padlen, buf);
-
- let buf = copy_bytes(fds, buf);
-
- let padlen = self.space() - self.len();
- pad_bytes(padlen, buf);
- },
- ControlMessage::ScmTimestamp(t) => {
- let cmsg = cmsghdr {
- cmsg_len: self.len() as _,
- cmsg_level: libc::SOL_SOCKET,
- cmsg_type: libc::SCM_TIMESTAMP,
- ..mem::uninitialized()
- };
- let buf = copy_bytes(&cmsg, buf);
-
- let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
- mem::size_of_val(&cmsg);
- let buf = pad_bytes(padlen, buf);
-
- let buf = copy_bytes(t, buf);
-
- let padlen = self.space() - self.len();
- pad_bytes(padlen, buf);
- },
- ControlMessage::Unknown(UnknownCmsg(orig_cmsg, bytes)) => {
- let buf = copy_bytes(orig_cmsg, buf);
+ let final_buf = if let ControlMessage::Unknown(ref cmsg) = *self {
+ let &UnknownCmsg(orig_cmsg, bytes) = cmsg;
+
+ let buf = copy_bytes(orig_cmsg, buf);
- let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
- mem::size_of_val(&orig_cmsg);
- let buf = pad_bytes(padlen, buf);
+ let padlen = cmsg_align(mem::size_of_val(&orig_cmsg)) -
+ mem::size_of_val(&orig_cmsg);
+ let buf = pad_bytes(padlen, buf);
- let buf = copy_bytes(bytes, buf);
+ copy_bytes(bytes, buf)
+ } else {
+ let cmsg = cmsghdr {
+ cmsg_len: self.len() as _,
+ cmsg_level: libc::SOL_SOCKET,
+ cmsg_type: self.cmsg_type(),
+ ..mem::zeroed() // zero out platform-dependent padding fields
+ };
+ let buf = copy_bytes(&cmsg, buf);
+
+ let padlen = cmsg_align(mem::size_of_val(&cmsg)) -
+ mem::size_of_val(&cmsg);
+ let buf = pad_bytes(padlen, buf);
+
+ match *self {
+ ControlMessage::ScmRights(fds) => {
+ copy_bytes(fds, buf)
+ },
+ #[cfg(any(target_os = "android", target_os = "linux"))]
+ ControlMessage::ScmCredentials(creds) => {
+ copy_bytes(creds, buf)
+ }
+ ControlMessage::ScmTimestamp(t) => {
+ copy_bytes(t, buf)
+ },
+ ControlMessage::Unknown(_) => unreachable!(),
+ }
+ };
- let padlen = self.space() - self.len();
- pad_bytes(padlen, buf);
+ let padlen = self.space() - self.len();
+ pad_bytes(padlen, final_buf);
+ }
+
+ /// Decodes a `ControlMessage` from raw bytes.
+ ///
+ /// This is only safe to call if the data is correct for the message type
+ /// specified in the header. Normally, the kernel ensures that this is the
+ /// case. "Correct" in this case includes correct length, alignment and
+ /// actual content.
+ unsafe fn decode_from(header: &'a cmsghdr, data: &'a [u8]) -> ControlMessage<'a> {
+ match (header.cmsg_level, header.cmsg_type) {
+ (libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
+ ControlMessage::ScmRights(
+ slice::from_raw_parts(data.as_ptr() as *const _,
+ data.len() / mem::size_of::<RawFd>()))
+ },
+ #[cfg(any(target_os = "android", target_os = "linux"))]
+ (libc::SOL_SOCKET, libc::SCM_CREDENTIALS) => {
+ ControlMessage::ScmCredentials(
+ &*(data.as_ptr() as *const _)
+ )
+ }
+ (libc::SOL_SOCKET, libc::SCM_TIMESTAMP) => {
+ ControlMessage::ScmTimestamp(
+ &*(data.as_ptr() as *const _))
+ },
+ (_, _) => {
+ ControlMessage::Unknown(UnknownCmsg(header, data))
}
}
}