summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSteve Lau <stevelauc@outlook.com>2022-12-11 14:04:13 +0800
committerSteve Lau <stevelauc@outlook.com>2022-12-11 14:04:13 +0800
commit47ecc9a1d0f26ef35f333d3f17304fa6e832d3f5 (patch)
treeebf809489b470257c06ad9704606f44a60af99d7
parent3d3e6b9fa0d182bc04866373ed0b077a5c77560b (diff)
downloadnix-47ecc9a1d0f26ef35f333d3f17304fa6e832d3f5.zip
feat: I/O safety for 'sys/poll'
-rw-r--r--src/poll.rs52
-rw-r--r--test/test_poll.rs22
2 files changed, 54 insertions, 20 deletions
diff --git a/src/poll.rs b/src/poll.rs
index 6f227fee..9181bf7f 100644
--- a/src/poll.rs
+++ b/src/poll.rs
@@ -1,5 +1,5 @@
//! Wait for events to trigger on specific file descriptors
-use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd};
use crate::errno::Errno;
use crate::Result;
@@ -14,20 +14,36 @@ use crate::Result;
/// retrieved by calling [`revents()`](#method.revents) on the `PollFd`.
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
-pub struct PollFd {
+pub struct PollFd<'fd> {
pollfd: libc::pollfd,
+ _fd: std::marker::PhantomData<BorrowedFd<'fd>>,
}
-impl PollFd {
+impl<'fd> PollFd<'fd> {
/// Creates a new `PollFd` specifying the events of interest
/// for a given file descriptor.
- pub const fn new(fd: RawFd, events: PollFlags) -> PollFd {
+ //
+ // Different from other I/O-safe interfaces, here, we have to take `AsFd`
+ // by reference to prevent the case where the `fd` is closed but it is
+ // still in use. For example:
+ //
+ // ```rust
+ // let (reader, _) = pipe().unwrap();
+ //
+ // // If `PollFd::new()` takes `AsFd` by value, then `reader` will be consumed,
+ // // but the file descriptor of `reader` will still be in use.
+ // let pollfd = PollFd::new(reader, flag);
+ //
+ // // Do something with `pollfd`, which uses the CLOSED fd.
+ // ```
+ pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd> {
PollFd {
pollfd: libc::pollfd {
- fd,
+ fd: fd.as_fd().as_raw_fd(),
events: events.bits(),
revents: PollFlags::empty().bits(),
},
+ _fd: std::marker::PhantomData,
}
}
@@ -68,9 +84,29 @@ impl PollFd {
}
}
-impl AsRawFd for PollFd {
- fn as_raw_fd(&self) -> RawFd {
- self.pollfd.fd
+impl<'fd> AsFd for PollFd<'fd> {
+ fn as_fd(&self) -> BorrowedFd<'_> {
+ // Safety:
+ //
+ // BorrowedFd::borrow_raw(RawFd) requires that the raw fd being passed
+ // must remain open for the duration of the returned BorrowedFd, this is
+ // guaranteed as the returned BorrowedFd has the lifetime parameter same
+ // as `self`:
+ // "fn as_fd<'self>(&'self self) -> BorrowedFd<'self>"
+ // which means that `self` (PollFd) is guaranteed to outlive the returned
+ // BorrowedFd. (Lifetime: PollFd > BorrowedFd)
+ //
+ // And the lifetime parameter of PollFd::new(fd, ...) ensures that `fd`
+ // (an owned file descriptor) must outlive the returned PollFd:
+ // "pub fn new<Fd: AsFd>(fd: &'fd Fd, events: PollFlags) -> PollFd<'fd>"
+ // (Lifetime: Owned fd > PollFd)
+ //
+ // With two above relationships, we can conclude that the `Owned file
+ // descriptor` will outlive the returned BorrowedFd,
+ // (Lifetime: Owned fd > BorrowedFd)
+ // i.e., the raw fd being passed will remain valid for the lifetime of
+ // the returned BorrowedFd.
+ unsafe { BorrowedFd::borrow_raw(self.pollfd.fd) }
}
}
diff --git a/test/test_poll.rs b/test/test_poll.rs
index 53964e26..045ccd3d 100644
--- a/test/test_poll.rs
+++ b/test/test_poll.rs
@@ -1,8 +1,9 @@
use nix::{
errno::Errno,
poll::{poll, PollFd, PollFlags},
- unistd::{pipe, write},
+ unistd::{close, pipe, write},
};
+use std::os::unix::io::{BorrowedFd, FromRawFd, OwnedFd};
macro_rules! loop_while_eintr {
($poll_expr: expr) => {
@@ -19,7 +20,8 @@ macro_rules! loop_while_eintr {
#[test]
fn test_poll() {
let (r, w) = pipe().unwrap();
- let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
+ let r = unsafe { OwnedFd::from_raw_fd(r) };
+ let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];
// Poll an idle pipe. Should timeout
let nfds = loop_while_eintr!(poll(&mut fds, 100));
@@ -32,6 +34,7 @@ fn test_poll() {
let nfds = poll(&mut fds, 100).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
+ close(w).unwrap();
}
// ppoll(2) is the same as poll except for how it handles timeouts and signals.
@@ -51,7 +54,8 @@ fn test_ppoll() {
let timeout = TimeSpec::milliseconds(1);
let (r, w) = pipe().unwrap();
- let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
+ let r = unsafe { OwnedFd::from_raw_fd(r) };
+ let mut fds = [PollFd::new(&r, PollFlags::POLLIN)];
// Poll an idle pipe. Should timeout
let sigset = SigSet::empty();
@@ -65,19 +69,13 @@ fn test_ppoll() {
let nfds = ppoll(&mut fds, Some(timeout), None).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
-}
-
-#[test]
-fn test_pollfd_fd() {
- use std::os::unix::io::AsRawFd;
-
- let pfd = PollFd::new(0x1234, PollFlags::empty());
- assert_eq!(pfd.as_raw_fd(), 0x1234);
+ close(w).unwrap();
}
#[test]
fn test_pollfd_events() {
- let mut pfd = PollFd::new(-1, PollFlags::POLLIN);
+ let fd_zero = unsafe { BorrowedFd::borrow_raw(0) };
+ let mut pfd = PollFd::new(&fd_zero, PollFlags::POLLIN);
assert_eq!(pfd.events(), PollFlags::POLLIN);
pfd.set_events(PollFlags::POLLOUT);
assert_eq!(pfd.events(), PollFlags::POLLOUT);