diff options
author | Wez Furlong <wez@wezfurlong.org> | 2020-01-18 15:50:36 -0800 |
---|---|---|
committer | Wez Furlong <wez@wezfurlong.org> | 2020-01-18 15:50:36 -0800 |
commit | ebea1cce2a4c550ad11dbf10c20ab2379a37a0e2 (patch) | |
tree | 0281e763942c262b95bec85fd2b736d3cef8c342 /src/channel.rs | |
parent | 4011e518e6455bb4f0a2de0aaba3763b964d0e58 (diff) | |
download | ssh2-rs-ebea1cce2a4c550ad11dbf10c20ab2379a37a0e2.zip |
Make properly Send safe
In earlier iterations I accidentally removed Send from Session and then
later restored it in an unsafe way. This commit restructures the
bindings so that each of the objects holds a reference to the
appropriate thing to keep everything alive safely, without awkward
lifetimes to deal with.
The key to this is that the underlying Session is tracked by an
Arc<Mutex<>>, with the related objects ensuring that they lock this
before they call into the underlying API.
In order to make this work, I've had to adjust the API around iterating
both known hosts and agent identities: previously these would iterate
over internal references but with this shift there isn't a reasonable
way to make that safe. The strategy is instead to return a copy of the
host/identity data and then later look up the associated raw pointer
when needed. The purist in me feels that the copy feels slightly
wasteful, but the realist justifies this with the observation that the
cardinality of both known hosts and identities is typically small enough
that the cost of this is in the noise compared to actually doing the
crypto+network ops.
I've removed a couple of error code related helpers from some of
the objects: those were really internal APIs and were redundant
with methods exported by the Error type anyway.
Fixes: https://github.com/alexcrichton/ssh2-rs/issues/154
Refs: https://github.com/alexcrichton/ssh2-rs/issues/137
Diffstat (limited to 'src/channel.rs')
-rw-r--r-- | src/channel.rs | 183 |
1 files changed, 124 insertions, 59 deletions
diff --git a/src/channel.rs b/src/channel.rs index 5adc43b..1169a65 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use libc::{c_char, c_int, c_uchar, c_uint, c_ulong, c_void, size_t}; +use parking_lot::{Mutex, MutexGuard}; use std::cmp; use std::io; use std::io::prelude::*; @@ -7,6 +8,17 @@ use std::sync::Arc; use {raw, Error, ExtendedData, PtyModes, SessionInner}; +struct ChannelInner { + unsafe_raw: *mut raw::LIBSSH2_CHANNEL, + sess: Arc<Mutex<SessionInner>>, + read_limit: Mutex<Option<u64>>, +} + +struct LockedChannel<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, +} + /// A channel represents a portion of an SSH connection on which data can be /// read and written. /// @@ -16,35 +28,57 @@ use {raw, Error, ExtendedData, PtyModes, SessionInner}; /// Whether or not I/O operations are blocking is mandated by the `blocking` /// flag on a channel's corresponding `Session`. pub struct Channel { - raw: *mut raw::LIBSSH2_CHANNEL, - sess: Arc<SessionInner>, - read_limit: Option<u64>, + channel_inner: Arc<ChannelInner>, } impl Channel { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_CHANNEL, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { - raw, - sess: Arc::clone(sess), - read_limit: None, + channel_inner: Arc::new(ChannelInner { + unsafe_raw: raw, + sess: Arc::clone(sess), + read_limit: Mutex::new(None), + }), }) } } + + fn lock(&self) -> LockedChannel { + let sess = self.channel_inner.sess.lock(); + LockedChannel { + sess, + raw: self.channel_inner.unsafe_raw, + } + } } /// A channel can have a number of streams, each identified by an id, each of /// which implements the `Read` and `Write` traits. -pub struct Stream<'channel> { - channel: &'channel mut Channel, +pub struct Stream { + channel_inner: Arc<ChannelInner>, id: i32, } +struct LockedStream<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, + id: i32, + read_limit: MutexGuard<'a, Option<u64>>, +} + +impl<'a> LockedStream<'a> { + pub fn eof(&self) -> bool { + *self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + } +} + /// Data received from when a program exits with a signal. pub struct ExitSignal { /// The exit signal received, if the program did not exit cleanly. Does not @@ -84,9 +118,10 @@ impl Channel { /// Note that this does not make sense for all channel types and may be /// ignored by the server despite returning success. pub fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> { + let locked = self.lock(); unsafe { - self.sess.rc(raw::libssh2_channel_setenv_ex( - self.raw, + locked.sess.rc(raw::libssh2_channel_setenv_ex( + locked.raw, var.as_ptr() as *const _, var.len() as c_uint, val.as_ptr() as *const _, @@ -120,12 +155,13 @@ impl Channel { mode: Option<PtyModes>, dim: Option<(u32, u32, u32, u32)>, ) -> Result<(), Error> { + let locked = self.lock(); let mode = mode.map(PtyModes::finish); let mode = mode.as_ref().map(Vec::as_slice).unwrap_or(&[]); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { let (width, height, width_px, height_px) = dim.unwrap_or((80, 24, 0, 0)); raw::libssh2_channel_request_pty_ex( - self.raw, + locked.raw, term.as_ptr() as *const _, term.len() as c_uint, mode.as_ptr() as *const _, @@ -148,11 +184,12 @@ impl Channel { width_px: Option<u32>, height_px: Option<u32>, ) -> Result<(), Error> { + let locked = self.lock(); let width_px = width_px.unwrap_or(0); let height_px = height_px.unwrap_or(0); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { raw::libssh2_channel_request_pty_size_ex( - self.raw, + locked.raw, width as c_int, height as c_int, width_px as c_int, @@ -205,22 +242,23 @@ impl Channel { pub fn process_startup(&mut self, request: &str, message: Option<&str>) -> Result<(), Error> { let message_len = message.map(|s| s.len()).unwrap_or(0); let message = message.map(|s| s.as_ptr()).unwrap_or(0 as *const _); + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_process_startup( - self.raw, + locked.raw, request.as_ptr() as *const _, request.len() as c_uint, message as *const _, message_len as c_uint, ); - self.sess.rc(rc) + locked.sess.rc(rc) } } /// Get a handle to the stderr stream of this channel. /// /// The returned handle implements the `Read` and `Write` traits. - pub fn stderr<'a>(&'a mut self) -> Stream<'a> { + pub fn stderr(&self) -> Stream { self.stream(::EXTENDED_DATA_STDERR) } @@ -233,18 +271,19 @@ impl Channel { /// /// * FLUSH_EXTENDED_DATA - Flush all extended data substreams /// * FLUSH_ALL - Flush all substreams - pub fn stream<'a>(&'a mut self, stream_id: i32) -> Stream<'a> { + pub fn stream(&self, stream_id: i32) -> Stream { Stream { - channel: self, + channel_inner: Arc::clone(&self.channel_inner), id: stream_id, } } /// Change how extended data (such as stderr) is handled pub fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_handle_extended_data2(self.raw, mode as c_int); - self.sess.rc(rc) + let rc = raw::libssh2_channel_handle_extended_data2(locked.raw, mode as c_int); + locked.sess.rc(rc) } } @@ -254,15 +293,17 @@ impl Channel { /// Note that the exit status may not be available if the remote end has not /// yet set its status to closed. pub fn exit_status(&self) -> Result<i32, Error> { + let locked = self.lock(); // Should really store existing error, call function, check for error // after and restore previous error if no new one...but the only error // condition right now is a NULL pointer check on self.raw, so let's // assume that's not the case. - Ok(unsafe { raw::libssh2_channel_get_exit_status(self.raw) }) + Ok(unsafe { raw::libssh2_channel_get_exit_status(locked.raw) }) } /// Get the remote exit signal. pub fn exit_signal(&self) -> Result<ExitSignal, Error> { + let locked = self.lock(); unsafe { let mut sig = 0 as *mut _; let mut siglen = 0; @@ -271,7 +312,7 @@ impl Channel { let mut lang = 0 as *mut _; let mut langlen = 0; let rc = raw::libssh2_channel_get_exit_signal( - self.raw, + locked.raw, &mut sig, &mut siglen, &mut msg, @@ -279,31 +320,32 @@ impl Channel { &mut lang, &mut langlen, ); - self.sess.rc(rc)?; + locked.sess.rc(rc)?; return Ok(ExitSignal { - exit_signal: convert(self, sig, siglen), - error_message: convert(self, msg, msglen), - lang_tag: convert(self, lang, langlen), + exit_signal: convert(&locked, sig, siglen), + error_message: convert(&locked, msg, msglen), + lang_tag: convert(&locked, lang, langlen), }); } - unsafe fn convert(chan: &Channel, ptr: *mut c_char, len: size_t) -> Option<String> { + unsafe fn convert(locked: &LockedChannel, ptr: *mut c_char, len: size_t) -> Option<String> { if ptr.is_null() { return None; } let slice = slice::from_raw_parts(ptr as *const u8, len as usize); let ret = slice.to_vec(); - raw::libssh2_free(chan.sess.raw, ptr as *mut c_void); + raw::libssh2_free(locked.sess.raw, ptr as *mut c_void); String::from_utf8(ret).ok() } } /// Check the status of the read window. pub fn read_window(&self) -> ReadWindow { + let locked = self.lock(); unsafe { let mut avail = 0; let mut init = 0; - let remaining = raw::libssh2_channel_window_read_ex(self.raw, &mut avail, &mut init); + let remaining = raw::libssh2_channel_window_read_ex(locked.raw, &mut avail, &mut init); ReadWindow { remaining: remaining as u32, available: avail as u32, @@ -314,9 +356,10 @@ impl Channel { /// Check the status of the write window. pub fn write_window(&self) -> WriteWindow { + let locked = self.lock(); unsafe { let mut init = 0; - let remaining = raw::libssh2_channel_window_write_ex(self.raw, &mut init); + let remaining = raw::libssh2_channel_window_write_ex(locked.raw, &mut init); WriteWindow { remaining: remaining as u32, window_size_initial: init as u32, @@ -332,24 +375,25 @@ impl Channel { /// This function returns the new size of the receive window (as understood /// by remote end) on success. pub fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> { + let locked = self.lock(); let mut ret = 0; let rc = unsafe { raw::libssh2_channel_receive_window_adjust2( - self.raw, + locked.raw, adjust as c_ulong, force as c_uchar, &mut ret, ) }; - self.sess.rc(rc)?; + locked.sess.rc(rc)?; Ok(ret as u64) } /// Artificially limit the number of bytes that will be read from this /// channel. Hack intended for use by scp_recv only. #[doc(hidden)] - pub fn limit_read(&mut self, limit: u64) { - self.read_limit = Some(limit); + pub(crate) fn limit_read(&mut self, limit: u64) { + *self.channel_inner.read_limit.lock() = Some(limit); } /// Check if the remote host has sent an EOF status for the channel. @@ -357,7 +401,9 @@ impl Channel { /// because the reading from the channel reads only the stdout stream. /// unread, buffered, stderr data will cause eof() to return false. pub fn eof(&self) -> bool { - self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + let locked = self.lock(); + *self.channel_inner.read_limit.lock() == Some(0) + || unsafe { raw::libssh2_channel_eof(locked.raw) != 0 } } /// Tell the remote host that no further data will be sent on the specified @@ -365,7 +411,8 @@ impl Channel { /// /// Processes typically interpret this as a closed stdin descriptor. pub fn send_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_send_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_send_eof(locked.raw)) } } /// Wait for the remote end to send EOF. @@ -374,7 +421,8 @@ impl Channel { /// You should call the eof() function after calling this to check the /// status of the channel. pub fn wait_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_eof(locked.raw)) } } /// Close an active data channel. @@ -387,7 +435,8 @@ impl Channel { /// To wait for the remote end to close its connection as well, follow this /// command with `wait_closed` pub fn close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_close(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_close(locked.raw)) } } /// Enter a temporary blocking state until the remote host closes the named @@ -395,7 +444,8 @@ impl Channel { /// /// Typically sent after `close` in order to examine the exit status. pub fn wait_close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_closed(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_closed(locked.raw)) } } } @@ -415,40 +465,53 @@ impl Read for Channel { } } -impl Drop for Channel { +impl Drop for ChannelInner { fn drop(&mut self) { unsafe { - let _ = raw::libssh2_channel_free(self.raw); + let _ = raw::libssh2_channel_free(self.unsafe_raw); + } + } +} + +impl Stream { + fn lock(&self) -> LockedStream { + let sess = self.channel_inner.sess.lock(); + LockedStream { + sess, + raw: self.channel_inner.unsafe_raw, + id: self.id, + read_limit: self.channel_inner.read_limit.lock(), } } } -impl<'channel> Read for Stream<'channel> { +impl Read for Stream { fn read(&mut self, data: &mut [u8]) -> io::Result<usize> { - if self.channel.eof() { + let mut locked = self.lock(); + if locked.eof() { return Ok(0); } - let data = match self.channel.read_limit { + let data = match locked.read_limit.as_mut() { Some(amt) => { let len = data.len(); - &mut data[..cmp::min(amt as usize, len)] + &mut data[..cmp::min(*amt as usize, len)] } None => data, }; let ret = unsafe { let rc = raw::libssh2_channel_read_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_mut_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) }; match ret { Ok(n) => { - if let Some(ref mut amt) = self.channel.read_limit { - *amt -= n as u64; + if let Some(ref mut amt) = locked.read_limit.as_mut() { + **amt -= n as u64; } Ok(n) } @@ -457,24 +520,26 @@ impl<'channel> Read for Stream<'channel> { } } -impl<'channel> Write for Stream<'channel> { +impl Write for Stream { fn write(&mut self, data: &[u8]) -> io::Result<usize> { + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_write_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) } .map_err(Into::into) } fn flush(&mut self) -> io::Result<()> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_flush_ex(self.channel.raw, self.id as c_int); - self.channel.sess.rc(rc) + let rc = raw::libssh2_channel_flush_ex(locked.raw, locked.id as c_int); + locked.sess.rc(rc) } .map_err(Into::into) } |