diff options
Diffstat (limited to 'src/channel.rs')
-rw-r--r-- | src/channel.rs | 183 |
1 files changed, 124 insertions, 59 deletions
diff --git a/src/channel.rs b/src/channel.rs index 5adc43b..1169a65 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use libc::{c_char, c_int, c_uchar, c_uint, c_ulong, c_void, size_t}; +use parking_lot::{Mutex, MutexGuard}; use std::cmp; use std::io; use std::io::prelude::*; @@ -7,6 +8,17 @@ use std::sync::Arc; use {raw, Error, ExtendedData, PtyModes, SessionInner}; +struct ChannelInner { + unsafe_raw: *mut raw::LIBSSH2_CHANNEL, + sess: Arc<Mutex<SessionInner>>, + read_limit: Mutex<Option<u64>>, +} + +struct LockedChannel<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, +} + /// A channel represents a portion of an SSH connection on which data can be /// read and written. /// @@ -16,35 +28,57 @@ use {raw, Error, ExtendedData, PtyModes, SessionInner}; /// Whether or not I/O operations are blocking is mandated by the `blocking` /// flag on a channel's corresponding `Session`. pub struct Channel { - raw: *mut raw::LIBSSH2_CHANNEL, - sess: Arc<SessionInner>, - read_limit: Option<u64>, + channel_inner: Arc<ChannelInner>, } impl Channel { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_CHANNEL, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { - raw, - sess: Arc::clone(sess), - read_limit: None, + channel_inner: Arc::new(ChannelInner { + unsafe_raw: raw, + sess: Arc::clone(sess), + read_limit: Mutex::new(None), + }), }) } } + + fn lock(&self) -> LockedChannel { + let sess = self.channel_inner.sess.lock(); + LockedChannel { + sess, + raw: self.channel_inner.unsafe_raw, + } + } } /// A channel can have a number of streams, each identified by an id, each of /// which implements the `Read` and `Write` traits. -pub struct Stream<'channel> { - channel: &'channel mut Channel, +pub struct Stream { + channel_inner: Arc<ChannelInner>, id: i32, } +struct LockedStream<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, + id: i32, + read_limit: MutexGuard<'a, Option<u64>>, +} + +impl<'a> LockedStream<'a> { + pub fn eof(&self) -> bool { + *self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + } +} + /// Data received from when a program exits with a signal. pub struct ExitSignal { /// The exit signal received, if the program did not exit cleanly. Does not @@ -84,9 +118,10 @@ impl Channel { /// Note that this does not make sense for all channel types and may be /// ignored by the server despite returning success. pub fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> { + let locked = self.lock(); unsafe { - self.sess.rc(raw::libssh2_channel_setenv_ex( - self.raw, + locked.sess.rc(raw::libssh2_channel_setenv_ex( + locked.raw, var.as_ptr() as *const _, var.len() as c_uint, val.as_ptr() as *const _, @@ -120,12 +155,13 @@ impl Channel { mode: Option<PtyModes>, dim: Option<(u32, u32, u32, u32)>, ) -> Result<(), Error> { + let locked = self.lock(); let mode = mode.map(PtyModes::finish); let mode = mode.as_ref().map(Vec::as_slice).unwrap_or(&[]); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { let (width, height, width_px, height_px) = dim.unwrap_or((80, 24, 0, 0)); raw::libssh2_channel_request_pty_ex( - self.raw, + locked.raw, term.as_ptr() as *const _, term.len() as c_uint, mode.as_ptr() as *const _, @@ -148,11 +184,12 @@ impl Channel { width_px: Option<u32>, height_px: Option<u32>, ) -> Result<(), Error> { + let locked = self.lock(); let width_px = width_px.unwrap_or(0); let height_px = height_px.unwrap_or(0); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { raw::libssh2_channel_request_pty_size_ex( - self.raw, + locked.raw, width as c_int, height as c_int, width_px as c_int, @@ -205,22 +242,23 @@ impl Channel { pub fn process_startup(&mut self, request: &str, message: Option<&str>) -> Result<(), Error> { let message_len = message.map(|s| s.len()).unwrap_or(0); let message = message.map(|s| s.as_ptr()).unwrap_or(0 as *const _); + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_process_startup( - self.raw, + locked.raw, request.as_ptr() as *const _, request.len() as c_uint, message as *const _, message_len as c_uint, ); - self.sess.rc(rc) + locked.sess.rc(rc) } } /// Get a handle to the stderr stream of this channel. /// /// The returned handle implements the `Read` and `Write` traits. - pub fn stderr<'a>(&'a mut self) -> Stream<'a> { + pub fn stderr(&self) -> Stream { self.stream(::EXTENDED_DATA_STDERR) } @@ -233,18 +271,19 @@ impl Channel { /// /// * FLUSH_EXTENDED_DATA - Flush all extended data substreams /// * FLUSH_ALL - Flush all substreams - pub fn stream<'a>(&'a mut self, stream_id: i32) -> Stream<'a> { + pub fn stream(&self, stream_id: i32) -> Stream { Stream { - channel: self, + channel_inner: Arc::clone(&self.channel_inner), id: stream_id, } } /// Change how extended data (such as stderr) is handled pub fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_handle_extended_data2(self.raw, mode as c_int); - self.sess.rc(rc) + let rc = raw::libssh2_channel_handle_extended_data2(locked.raw, mode as c_int); + locked.sess.rc(rc) } } @@ -254,15 +293,17 @@ impl Channel { /// Note that the exit status may not be available if the remote end has not /// yet set its status to closed. pub fn exit_status(&self) -> Result<i32, Error> { + let locked = self.lock(); // Should really store existing error, call function, check for error // after and restore previous error if no new one...but the only error // condition right now is a NULL pointer check on self.raw, so let's // assume that's not the case. - Ok(unsafe { raw::libssh2_channel_get_exit_status(self.raw) }) + Ok(unsafe { raw::libssh2_channel_get_exit_status(locked.raw) }) } /// Get the remote exit signal. pub fn exit_signal(&self) -> Result<ExitSignal, Error> { + let locked = self.lock(); unsafe { let mut sig = 0 as *mut _; let mut siglen = 0; @@ -271,7 +312,7 @@ impl Channel { let mut lang = 0 as *mut _; let mut langlen = 0; let rc = raw::libssh2_channel_get_exit_signal( - self.raw, + locked.raw, &mut sig, &mut siglen, &mut msg, @@ -279,31 +320,32 @@ impl Channel { &mut lang, &mut langlen, ); - self.sess.rc(rc)?; + locked.sess.rc(rc)?; return Ok(ExitSignal { - exit_signal: convert(self, sig, siglen), - error_message: convert(self, msg, msglen), - lang_tag: convert(self, lang, langlen), + exit_signal: convert(&locked, sig, siglen), + error_message: convert(&locked, msg, msglen), + lang_tag: convert(&locked, lang, langlen), }); } - unsafe fn convert(chan: &Channel, ptr: *mut c_char, len: size_t) -> Option<String> { + unsafe fn convert(locked: &LockedChannel, ptr: *mut c_char, len: size_t) -> Option<String> { if ptr.is_null() { return None; } let slice = slice::from_raw_parts(ptr as *const u8, len as usize); let ret = slice.to_vec(); - raw::libssh2_free(chan.sess.raw, ptr as *mut c_void); + raw::libssh2_free(locked.sess.raw, ptr as *mut c_void); String::from_utf8(ret).ok() } } /// Check the status of the read window. pub fn read_window(&self) -> ReadWindow { + let locked = self.lock(); unsafe { let mut avail = 0; let mut init = 0; - let remaining = raw::libssh2_channel_window_read_ex(self.raw, &mut avail, &mut init); + let remaining = raw::libssh2_channel_window_read_ex(locked.raw, &mut avail, &mut init); ReadWindow { remaining: remaining as u32, available: avail as u32, @@ -314,9 +356,10 @@ impl Channel { /// Check the status of the write window. pub fn write_window(&self) -> WriteWindow { + let locked = self.lock(); unsafe { let mut init = 0; - let remaining = raw::libssh2_channel_window_write_ex(self.raw, &mut init); + let remaining = raw::libssh2_channel_window_write_ex(locked.raw, &mut init); WriteWindow { remaining: remaining as u32, window_size_initial: init as u32, @@ -332,24 +375,25 @@ impl Channel { /// This function returns the new size of the receive window (as understood /// by remote end) on success. pub fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> { + let locked = self.lock(); let mut ret = 0; let rc = unsafe { raw::libssh2_channel_receive_window_adjust2( - self.raw, + locked.raw, adjust as c_ulong, force as c_uchar, &mut ret, ) }; - self.sess.rc(rc)?; + locked.sess.rc(rc)?; Ok(ret as u64) } /// Artificially limit the number of bytes that will be read from this /// channel. Hack intended for use by scp_recv only. #[doc(hidden)] - pub fn limit_read(&mut self, limit: u64) { - self.read_limit = Some(limit); + pub(crate) fn limit_read(&mut self, limit: u64) { + *self.channel_inner.read_limit.lock() = Some(limit); } /// Check if the remote host has sent an EOF status for the channel. @@ -357,7 +401,9 @@ impl Channel { /// because the reading from the channel reads only the stdout stream. /// unread, buffered, stderr data will cause eof() to return false. pub fn eof(&self) -> bool { - self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + let locked = self.lock(); + *self.channel_inner.read_limit.lock() == Some(0) + || unsafe { raw::libssh2_channel_eof(locked.raw) != 0 } } /// Tell the remote host that no further data will be sent on the specified @@ -365,7 +411,8 @@ impl Channel { /// /// Processes typically interpret this as a closed stdin descriptor. pub fn send_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_send_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_send_eof(locked.raw)) } } /// Wait for the remote end to send EOF. @@ -374,7 +421,8 @@ impl Channel { /// You should call the eof() function after calling this to check the /// status of the channel. pub fn wait_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_eof(locked.raw)) } } /// Close an active data channel. @@ -387,7 +435,8 @@ impl Channel { /// To wait for the remote end to close its connection as well, follow this /// command with `wait_closed` pub fn close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_close(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_close(locked.raw)) } } /// Enter a temporary blocking state until the remote host closes the named @@ -395,7 +444,8 @@ impl Channel { /// /// Typically sent after `close` in order to examine the exit status. pub fn wait_close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_closed(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_closed(locked.raw)) } } } @@ -415,40 +465,53 @@ impl Read for Channel { } } -impl Drop for Channel { +impl Drop for ChannelInner { fn drop(&mut self) { unsafe { - let _ = raw::libssh2_channel_free(self.raw); + let _ = raw::libssh2_channel_free(self.unsafe_raw); + } + } +} + +impl Stream { + fn lock(&self) -> LockedStream { + let sess = self.channel_inner.sess.lock(); + LockedStream { + sess, + raw: self.channel_inner.unsafe_raw, + id: self.id, + read_limit: self.channel_inner.read_limit.lock(), } } } -impl<'channel> Read for Stream<'channel> { +impl Read for Stream { fn read(&mut self, data: &mut [u8]) -> io::Result<usize> { - if self.channel.eof() { + let mut locked = self.lock(); + if locked.eof() { return Ok(0); } - let data = match self.channel.read_limit { + let data = match locked.read_limit.as_mut() { Some(amt) => { let len = data.len(); - &mut data[..cmp::min(amt as usize, len)] + &mut data[..cmp::min(*amt as usize, len)] } None => data, }; let ret = unsafe { let rc = raw::libssh2_channel_read_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_mut_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) }; match ret { Ok(n) => { - if let Some(ref mut amt) = self.channel.read_limit { - *amt -= n as u64; + if let Some(ref mut amt) = locked.read_limit.as_mut() { + **amt -= n as u64; } Ok(n) } @@ -457,24 +520,26 @@ impl<'channel> Read for Stream<'channel> { } } -impl<'channel> Write for Stream<'channel> { +impl Write for Stream { fn write(&mut self, data: &[u8]) -> io::Result<usize> { + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_write_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) } .map_err(Into::into) } fn flush(&mut self) -> io::Result<()> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_flush_ex(self.channel.raw, self.id as c_int); - self.channel.sess.rc(rc) + let rc = raw::libssh2_channel_flush_ex(locked.raw, locked.id as c_int); + locked.sess.rc(rc) } .map_err(Into::into) } |