diff options
-rw-r--r-- | Cargo.toml | 3 | ||||
-rw-r--r-- | src/agent.rs | 132 | ||||
-rw-r--r-- | src/channel.rs | 183 | ||||
-rw-r--r-- | src/error.rs | 12 | ||||
-rw-r--r-- | src/knownhosts.rs | 152 | ||||
-rw-r--r-- | src/lib.rs | 8 | ||||
-rw-r--r-- | src/listener.rs | 12 | ||||
-rw-r--r-- | src/session.rs | 219 | ||||
-rw-r--r-- | src/sftp.rs | 261 | ||||
-rw-r--r-- | tests/all/agent.rs | 5 | ||||
-rw-r--r-- | tests/all/knownhosts.rs | 22 | ||||
-rw-r--r-- | tests/all/main.rs | 3 | ||||
-rw-r--r-- | tests/all/session.rs | 2 |
13 files changed, 600 insertions, 414 deletions
@@ -1,6 +1,6 @@ [package] name = "ssh2" -version = "0.6.0" +version = "0.7.0" authors = ["Alex Crichton <alex@alexcrichton.com>", "Wez Furlong <wez@wezfurlong.org>"] license = "MIT/Apache-2.0" keywords = ["ssh"] @@ -20,6 +20,7 @@ vendored-openssl = ["libssh2-sys/vendored-openssl"] bitflags = "1.2" libc = "0.2" libssh2-sys = { path = "libssh2-sys", version = "0.2.13" } +parking_lot = "0.10" [dev-dependencies] tempdir = "0.3" diff --git a/src/agent.rs b/src/agent.rs index 5eb5bc4..dbafb1d 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,10 +1,9 @@ -use std::ffi::CString; -use std::marker; +use parking_lot::{Mutex, MutexGuard}; +use std::ffi::{CStr, CString}; use std::slice; use std::str; use std::sync::Arc; -use util::Binding; use {raw, Error, SessionInner}; /// A structure representing a connection to an SSH agent. @@ -12,28 +11,24 @@ use {raw, Error, SessionInner}; /// Agents can be used to authenticate a session. pub struct Agent { raw: *mut raw::LIBSSH2_AGENT, - sess: Arc<SessionInner>, -} - -/// An iterator over the identities found in an SSH agent. -pub struct Identities<'agent> { - prev: *mut raw::libssh2_agent_publickey, - agent: &'agent Agent, + sess: Arc<Mutex<SessionInner>>, } /// A public key which is extracted from an SSH agent. -pub struct PublicKey<'agent> { - raw: *mut raw::libssh2_agent_publickey, - _marker: marker::PhantomData<&'agent [u8]>, +#[derive(Debug, PartialEq, Eq)] +pub struct PublicKey { + blob: Vec<u8>, + comment: String, } impl Agent { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_AGENT, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { raw, @@ -44,12 +39,14 @@ impl Agent { /// Connect to an ssh-agent running on the system. pub fn connect(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_agent_connect(self.raw)) } + let sess = self.sess.lock(); + unsafe { sess.rc(raw::libssh2_agent_connect(self.raw)) } } /// Close a connection to an ssh-agent. pub fn disconnect(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_agent_disconnect(self.raw)) } + let sess = self.sess.lock(); + unsafe { sess.rc(raw::libssh2_agent_disconnect(self.raw)) } } /// Request an ssh-agent to list of public keys, and stores them in the @@ -57,25 +54,64 @@ impl Agent { /// /// Call `identities` to get the public keys. pub fn list_identities(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_agent_list_identities(self.raw)) } + let sess = self.sess.lock(); + unsafe { sess.rc(raw::libssh2_agent_list_identities(self.raw)) } + } + + /// Get list of the identities of this agent. + pub fn identities(&self) -> Result<Vec<PublicKey>, Error> { + let sess = self.sess.lock(); + let mut res = vec![]; + let mut prev = 0 as *mut _; + let mut next = 0 as *mut _; + loop { + match unsafe { raw::libssh2_agent_get_identity(self.raw, &mut next, prev) } { + 0 => { + prev = next; + res.push(unsafe { PublicKey::from_raw(next) }); + } + 1 => break, + rc => return Err(Error::from_session_error_raw(sess.raw, rc)), + } + } + Ok(res) } - /// Get an iterator over the identities of this agent. - pub fn identities(&self) -> Identities { - Identities { - prev: 0 as *mut _, - agent: self, + fn resolve_raw_identity( + &self, + sess: &MutexGuard<SessionInner>, + identity: &PublicKey, + ) -> Result<Option<*mut raw::libssh2_agent_publickey>, Error> { + let mut prev = 0 as *mut _; + let mut next = 0 as *mut _; + loop { + match unsafe { raw::libssh2_agent_get_identity(self.raw, &mut next, prev) } { + 0 => { + prev = next; + let this_ident = unsafe { PublicKey::from_raw(next) }; + if this_ident == *identity { + return Ok(Some(next)); + } + } + 1 => break, + rc => return Err(Error::from_session_error_raw(sess.raw, rc)), + } } + Ok(None) } /// Attempt public key authentication with the help of ssh-agent. pub fn userauth(&self, username: &str, identity: &PublicKey) -> Result<(), Error> { let username = CString::new(username)?; + let sess = self.sess.lock(); + let raw_ident = self + .resolve_raw_identity(&sess, identity)? + .ok_or_else(|| Error::new(raw::LIBSSH2_ERROR_BAD_USE, "Identity not found in agent"))?; unsafe { - self.sess.rc(raw::libssh2_agent_userauth( + sess.rc(raw::libssh2_agent_userauth( self.raw, username.as_ptr(), - identity.raw, + raw_ident, )) } } @@ -87,46 +123,28 @@ impl Drop for Agent { } } -impl<'agent> Iterator for Identities<'agent> { - type Item = Result<PublicKey<'agent>, Error>; - fn next(&mut self) -> Option<Result<PublicKey<'agent>, Error>> { - unsafe { - let mut next = 0 as *mut _; - match raw::libssh2_agent_get_identity(self.agent.raw, &mut next, self.prev) { - 0 => { - self.prev = next; - Some(Ok(Binding::from_raw(next))) - } - 1 => None, - rc => Some(Err(self.agent.sess.rc(rc).err().unwrap())), - } +impl PublicKey { + unsafe fn from_raw(raw: *mut raw::libssh2_agent_publickey) -> Self { + let blob = slice::from_raw_parts_mut((*raw).blob, (*raw).blob_len as usize); + let comment = (*raw).comment; + let comment = if comment.is_null() { + String::new() + } else { + CStr::from_ptr(comment).to_string_lossy().into_owned() + }; + Self { + blob: blob.to_vec(), + comment, } } -} -impl<'agent> PublicKey<'agent> { /// Return the data of this public key. pub fn blob(&self) -> &[u8] { - unsafe { slice::from_raw_parts_mut((*self.raw).blob, (*self.raw).blob_len as usize) } + &self.blob } /// Returns the comment in a printable format pub fn comment(&self) -> &str { - unsafe { str::from_utf8(::opt_bytes(self, (*self.raw).comment).unwrap()).unwrap() } - } -} - -impl<'agent> Binding for PublicKey<'agent> { - type Raw = *mut raw::libssh2_agent_publickey; - - unsafe fn from_raw(raw: *mut raw::libssh2_agent_publickey) -> PublicKey<'agent> { - PublicKey { - raw: raw, - _marker: marker::PhantomData, - } - } - - fn raw(&self) -> *mut raw::libssh2_agent_publickey { - self.raw + &self.comment } } diff --git a/src/channel.rs b/src/channel.rs index 5adc43b..1169a65 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use libc::{c_char, c_int, c_uchar, c_uint, c_ulong, c_void, size_t}; +use parking_lot::{Mutex, MutexGuard}; use std::cmp; use std::io; use std::io::prelude::*; @@ -7,6 +8,17 @@ use std::sync::Arc; use {raw, Error, ExtendedData, PtyModes, SessionInner}; +struct ChannelInner { + unsafe_raw: *mut raw::LIBSSH2_CHANNEL, + sess: Arc<Mutex<SessionInner>>, + read_limit: Mutex<Option<u64>>, +} + +struct LockedChannel<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, +} + /// A channel represents a portion of an SSH connection on which data can be /// read and written. /// @@ -16,35 +28,57 @@ use {raw, Error, ExtendedData, PtyModes, SessionInner}; /// Whether or not I/O operations are blocking is mandated by the `blocking` /// flag on a channel's corresponding `Session`. pub struct Channel { - raw: *mut raw::LIBSSH2_CHANNEL, - sess: Arc<SessionInner>, - read_limit: Option<u64>, + channel_inner: Arc<ChannelInner>, } impl Channel { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_CHANNEL, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { - raw, - sess: Arc::clone(sess), - read_limit: None, + channel_inner: Arc::new(ChannelInner { + unsafe_raw: raw, + sess: Arc::clone(sess), + read_limit: Mutex::new(None), + }), }) } } + + fn lock(&self) -> LockedChannel { + let sess = self.channel_inner.sess.lock(); + LockedChannel { + sess, + raw: self.channel_inner.unsafe_raw, + } + } } /// A channel can have a number of streams, each identified by an id, each of /// which implements the `Read` and `Write` traits. -pub struct Stream<'channel> { - channel: &'channel mut Channel, +pub struct Stream { + channel_inner: Arc<ChannelInner>, id: i32, } +struct LockedStream<'a> { + raw: *mut raw::LIBSSH2_CHANNEL, + sess: MutexGuard<'a, SessionInner>, + id: i32, + read_limit: MutexGuard<'a, Option<u64>>, +} + +impl<'a> LockedStream<'a> { + pub fn eof(&self) -> bool { + *self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + } +} + /// Data received from when a program exits with a signal. pub struct ExitSignal { /// The exit signal received, if the program did not exit cleanly. Does not @@ -84,9 +118,10 @@ impl Channel { /// Note that this does not make sense for all channel types and may be /// ignored by the server despite returning success. pub fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> { + let locked = self.lock(); unsafe { - self.sess.rc(raw::libssh2_channel_setenv_ex( - self.raw, + locked.sess.rc(raw::libssh2_channel_setenv_ex( + locked.raw, var.as_ptr() as *const _, var.len() as c_uint, val.as_ptr() as *const _, @@ -120,12 +155,13 @@ impl Channel { mode: Option<PtyModes>, dim: Option<(u32, u32, u32, u32)>, ) -> Result<(), Error> { + let locked = self.lock(); let mode = mode.map(PtyModes::finish); let mode = mode.as_ref().map(Vec::as_slice).unwrap_or(&[]); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { let (width, height, width_px, height_px) = dim.unwrap_or((80, 24, 0, 0)); raw::libssh2_channel_request_pty_ex( - self.raw, + locked.raw, term.as_ptr() as *const _, term.len() as c_uint, mode.as_ptr() as *const _, @@ -148,11 +184,12 @@ impl Channel { width_px: Option<u32>, height_px: Option<u32>, ) -> Result<(), Error> { + let locked = self.lock(); let width_px = width_px.unwrap_or(0); let height_px = height_px.unwrap_or(0); - self.sess.rc(unsafe { + locked.sess.rc(unsafe { raw::libssh2_channel_request_pty_size_ex( - self.raw, + locked.raw, width as c_int, height as c_int, width_px as c_int, @@ -205,22 +242,23 @@ impl Channel { pub fn process_startup(&mut self, request: &str, message: Option<&str>) -> Result<(), Error> { let message_len = message.map(|s| s.len()).unwrap_or(0); let message = message.map(|s| s.as_ptr()).unwrap_or(0 as *const _); + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_process_startup( - self.raw, + locked.raw, request.as_ptr() as *const _, request.len() as c_uint, message as *const _, message_len as c_uint, ); - self.sess.rc(rc) + locked.sess.rc(rc) } } /// Get a handle to the stderr stream of this channel. /// /// The returned handle implements the `Read` and `Write` traits. - pub fn stderr<'a>(&'a mut self) -> Stream<'a> { + pub fn stderr(&self) -> Stream { self.stream(::EXTENDED_DATA_STDERR) } @@ -233,18 +271,19 @@ impl Channel { /// /// * FLUSH_EXTENDED_DATA - Flush all extended data substreams /// * FLUSH_ALL - Flush all substreams - pub fn stream<'a>(&'a mut self, stream_id: i32) -> Stream<'a> { + pub fn stream(&self, stream_id: i32) -> Stream { Stream { - channel: self, + channel_inner: Arc::clone(&self.channel_inner), id: stream_id, } } /// Change how extended data (such as stderr) is handled pub fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_handle_extended_data2(self.raw, mode as c_int); - self.sess.rc(rc) + let rc = raw::libssh2_channel_handle_extended_data2(locked.raw, mode as c_int); + locked.sess.rc(rc) } } @@ -254,15 +293,17 @@ impl Channel { /// Note that the exit status may not be available if the remote end has not /// yet set its status to closed. pub fn exit_status(&self) -> Result<i32, Error> { + let locked = self.lock(); // Should really store existing error, call function, check for error // after and restore previous error if no new one...but the only error // condition right now is a NULL pointer check on self.raw, so let's // assume that's not the case. - Ok(unsafe { raw::libssh2_channel_get_exit_status(self.raw) }) + Ok(unsafe { raw::libssh2_channel_get_exit_status(locked.raw) }) } /// Get the remote exit signal. pub fn exit_signal(&self) -> Result<ExitSignal, Error> { + let locked = self.lock(); unsafe { let mut sig = 0 as *mut _; let mut siglen = 0; @@ -271,7 +312,7 @@ impl Channel { let mut lang = 0 as *mut _; let mut langlen = 0; let rc = raw::libssh2_channel_get_exit_signal( - self.raw, + locked.raw, &mut sig, &mut siglen, &mut msg, @@ -279,31 +320,32 @@ impl Channel { &mut lang, &mut langlen, ); - self.sess.rc(rc)?; + locked.sess.rc(rc)?; return Ok(ExitSignal { - exit_signal: convert(self, sig, siglen), - error_message: convert(self, msg, msglen), - lang_tag: convert(self, lang, langlen), + exit_signal: convert(&locked, sig, siglen), + error_message: convert(&locked, msg, msglen), + lang_tag: convert(&locked, lang, langlen), }); } - unsafe fn convert(chan: &Channel, ptr: *mut c_char, len: size_t) -> Option<String> { + unsafe fn convert(locked: &LockedChannel, ptr: *mut c_char, len: size_t) -> Option<String> { if ptr.is_null() { return None; } let slice = slice::from_raw_parts(ptr as *const u8, len as usize); let ret = slice.to_vec(); - raw::libssh2_free(chan.sess.raw, ptr as *mut c_void); + raw::libssh2_free(locked.sess.raw, ptr as *mut c_void); String::from_utf8(ret).ok() } } /// Check the status of the read window. pub fn read_window(&self) -> ReadWindow { + let locked = self.lock(); unsafe { let mut avail = 0; let mut init = 0; - let remaining = raw::libssh2_channel_window_read_ex(self.raw, &mut avail, &mut init); + let remaining = raw::libssh2_channel_window_read_ex(locked.raw, &mut avail, &mut init); ReadWindow { remaining: remaining as u32, available: avail as u32, @@ -314,9 +356,10 @@ impl Channel { /// Check the status of the write window. pub fn write_window(&self) -> WriteWindow { + let locked = self.lock(); unsafe { let mut init = 0; - let remaining = raw::libssh2_channel_window_write_ex(self.raw, &mut init); + let remaining = raw::libssh2_channel_window_write_ex(locked.raw, &mut init); WriteWindow { remaining: remaining as u32, window_size_initial: init as u32, @@ -332,24 +375,25 @@ impl Channel { /// This function returns the new size of the receive window (as understood /// by remote end) on success. pub fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> { + let locked = self.lock(); let mut ret = 0; let rc = unsafe { raw::libssh2_channel_receive_window_adjust2( - self.raw, + locked.raw, adjust as c_ulong, force as c_uchar, &mut ret, ) }; - self.sess.rc(rc)?; + locked.sess.rc(rc)?; Ok(ret as u64) } /// Artificially limit the number of bytes that will be read from this /// channel. Hack intended for use by scp_recv only. #[doc(hidden)] - pub fn limit_read(&mut self, limit: u64) { - self.read_limit = Some(limit); + pub(crate) fn limit_read(&mut self, limit: u64) { + *self.channel_inner.read_limit.lock() = Some(limit); } /// Check if the remote host has sent an EOF status for the channel. @@ -357,7 +401,9 @@ impl Channel { /// because the reading from the channel reads only the stdout stream. /// unread, buffered, stderr data will cause eof() to return false. pub fn eof(&self) -> bool { - self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 } + let locked = self.lock(); + *self.channel_inner.read_limit.lock() == Some(0) + || unsafe { raw::libssh2_channel_eof(locked.raw) != 0 } } /// Tell the remote host that no further data will be sent on the specified @@ -365,7 +411,8 @@ impl Channel { /// /// Processes typically interpret this as a closed stdin descriptor. pub fn send_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_send_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_send_eof(locked.raw)) } } /// Wait for the remote end to send EOF. @@ -374,7 +421,8 @@ impl Channel { /// You should call the eof() function after calling this to check the /// status of the channel. pub fn wait_eof(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_eof(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_eof(locked.raw)) } } /// Close an active data channel. @@ -387,7 +435,8 @@ impl Channel { /// To wait for the remote end to close its connection as well, follow this /// command with `wait_closed` pub fn close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_close(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_close(locked.raw)) } } /// Enter a temporary blocking state until the remote host closes the named @@ -395,7 +444,8 @@ impl Channel { /// /// Typically sent after `close` in order to examine the exit status. pub fn wait_close(&mut self) -> Result<(), Error> { - unsafe { self.sess.rc(raw::libssh2_channel_wait_closed(self.raw)) } + let locked = self.lock(); + unsafe { locked.sess.rc(raw::libssh2_channel_wait_closed(locked.raw)) } } } @@ -415,40 +465,53 @@ impl Read for Channel { } } -impl Drop for Channel { +impl Drop for ChannelInner { fn drop(&mut self) { unsafe { - let _ = raw::libssh2_channel_free(self.raw); + let _ = raw::libssh2_channel_free(self.unsafe_raw); + } + } +} + +impl Stream { + fn lock(&self) -> LockedStream { + let sess = self.channel_inner.sess.lock(); + LockedStream { + sess, + raw: self.channel_inner.unsafe_raw, + id: self.id, + read_limit: self.channel_inner.read_limit.lock(), } } } -impl<'channel> Read for Stream<'channel> { +impl Read for Stream { fn read(&mut self, data: &mut [u8]) -> io::Result<usize> { - if self.channel.eof() { + let mut locked = self.lock(); + if locked.eof() { return Ok(0); } - let data = match self.channel.read_limit { + let data = match locked.read_limit.as_mut() { Some(amt) => { let len = data.len(); - &mut data[..cmp::min(amt as usize, len)] + &mut data[..cmp::min(*amt as usize, len)] } None => data, }; let ret = unsafe { let rc = raw::libssh2_channel_read_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_mut_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) }; match ret { Ok(n) => { - if let Some(ref mut amt) = self.channel.read_limit { - *amt -= n as u64; + if let Some(ref mut amt) = locked.read_limit.as_mut() { + **amt -= n as u64; } Ok(n) } @@ -457,24 +520,26 @@ impl<'channel> Read for Stream<'channel> { } } -impl<'channel> Write for Stream<'channel> { +impl Write for Stream { fn write(&mut self, data: &[u8]) -> io::Result<usize> { + let locked = self.lock(); unsafe { let rc = raw::libssh2_channel_write_ex( - self.channel.raw, - self.id as c_int, + locked.raw, + locked.id as c_int, data.as_ptr() as *mut _, data.len() as size_t, ); - self.channel.sess.rc(rc as c_int).map(|()| rc as usize) + locked.sess.rc(rc as c_int).map(|()| rc as usize) } .map_err(Into::into) } fn flush(&mut self) -> io::Result<()> { + let locked = self.lock(); unsafe { - let rc = raw::libssh2_channel_flush_ex(self.channel.raw, self.id as c_int); - self.channel.sess.rc(rc) + let rc = raw::libssh2_channel_flush_ex(locked.raw, locked.id as c_int); + locked.sess.rc(rc) } .map_err(Into::into) } diff --git a/src/error.rs b/src/error.rs index 2499150..7ae0431 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,7 +38,7 @@ impl Error { /// If the error code doesn't match then an approximation of the error /// reason is used instead of the error message stored in the Session. pub fn from_session_error(sess: &Session, rc: libc::c_int) -> Error { - Self::from_session_error_raw(sess.raw(), rc) + Self::from_session_error_raw(&mut *sess.raw(), rc) } #[doc(hidden)] @@ -59,7 +59,7 @@ impl Error { /// /// Returns `None` if there was no last error. pub fn last_error(sess: &Session) -> Option<Error> { - Self::last_error_raw(sess.raw()) + Self::last_error_raw(&mut *sess.raw()) } /// Create a new error for the given code and message @@ -80,6 +80,14 @@ impl Error { Error::new(libc::c_int::min_value(), "no other error listed") } + pub(crate) fn rc(rc: libc::c_int) -> Result<(), Error> { + if rc == 0 { + Ok(()) + } else { + Err(Self::from_errno(rc)) + } + } + /// Construct an error from an error code from libssh2 pub fn from_errno(code: libc::c_int) -> Error { let msg = match code { diff --git a/src/knownhosts.rs b/src/knownhosts.rs index 5957369..2750fdd 100644 --- a/src/knownhosts.rs +++ b/src/knownhosts.rs @@ -1,11 +1,11 @@ use libc::{c_int, size_t}; +use parking_lot::{Mutex, MutexGuard}; use std::ffi::CString; -use std::marker; use std::path::Path; use std::str; use std::sync::Arc; -use util::{self, Binding}; +use util; use {raw, CheckResult, Error, KnownHostFileKind, SessionInner}; /// A set of known hosts which can be used to verify the identity of a remote @@ -46,28 +46,24 @@ use {raw, CheckResult, Error, KnownHostFileKind, SessionInner}; /// ``` pub struct KnownHosts { raw: *mut raw::LIBSSH2_KNOWNHOSTS, - sess: Arc<SessionInner>, -} - -/// Iterator over the hosts in a `KnownHosts` structure. -pub struct Hosts<'kh> { - prev: *mut raw::libssh2_knownhost, - hosts: &'kh KnownHosts, + sess: Arc<Mutex<SessionInner>>, } /// Structure representing a known host as part of a `KnownHosts` structure. -pub struct Host<'kh> { - raw: *mut raw::libssh2_knownhost, - _marker: marker::PhantomData<&'kh str>, +#[derive(Debug, PartialEq, Eq)] +pub struct Host { + name: Option<String>, + key: String, } impl KnownHosts { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_KNOWNHOSTS, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { raw, @@ -80,16 +76,18 @@ impl KnownHosts { /// the collection of known hosts. pub fn read_file(&mut self, file: &Path, kind: KnownHostFileKind) -> Result<u32, Error> { let file = CString::new(util::path2bytes(file)?)?; + let sess = self.sess.lock(); let n = unsafe { raw::libssh2_knownhost_readfile(self.raw, file.as_ptr(), kind as c_int) }; if n < 0 { - self.sess.rc(n)? + sess.rc(n)? } Ok(n as u32) } /// Read a line as if it were from a known hosts file. pub fn read_str(&mut self, s: &str, kind: KnownHostFileKind) -> Result<(), Error> { - self.sess.rc(unsafe { + let sess = self.sess.lock(); + sess.rc(unsafe { raw::libssh2_knownhost_readline( self.raw, s.as_ptr() as *const _, @@ -103,20 +101,28 @@ impl KnownHosts { /// file format. pub fn write_file(&self, file: &Path, kind: KnownHostFileKind) -> Result<(), Error> { let file = CString::new(util::path2bytes(file)?)?; + let sess = self.sess.lock(); let n = unsafe { raw::libssh2_knownhost_writefile(self.raw, file.as_ptr(), kind as c_int) }; - self.sess.rc(n) + sess.rc(n) } /// Converts a single known host to a single line of output for storage, /// using the 'type' output format. pub fn write_string(&self, host: &Host, kind: KnownHostFileKind) -> Result<String, Error> { let mut v = Vec::with_capacity(128); + let sess = self.sess.lock(); + let raw_host = self.resolve_to_raw_host(&sess, host)?.ok_or_else(|| { + Error::new( + raw::LIBSSH2_ERROR_BAD_USE, + "Host is not in the set of known hosts", + ) + })?; loop { let mut outlen = 0; unsafe { let rc = raw::libssh2_knownhost_writeline( self.raw, - host.raw, + raw_host, v.as_mut_ptr() as *mut _, v.capacity() as size_t, &mut outlen, @@ -126,7 +132,7 @@ impl KnownHosts { // + 1 for the trailing zero v.reserve(outlen as usize + 1); } else { - self.sess.rc(rc)?; + sess.rc(rc)?; v.set_len(outlen as usize); break; } @@ -136,17 +142,66 @@ impl KnownHosts { } /// Create an iterator over all of the known hosts in this structure. - pub fn iter(&self) -> Hosts { - Hosts { - prev: 0 as *mut _, - hosts: self, + pub fn iter(&self) -> Result<Vec<Host>, Error> { + self.hosts() + } + + /// Retrieves the list of known hosts + pub fn hosts(&self) -> Result<Vec<Host>, Error> { + let mut next = 0 as *mut _; + let mut prev = 0 as *mut _; + let sess = self.sess.lock(); + let mut hosts = vec![]; + + loop { + match unsafe { raw::libssh2_knownhost_get(self.raw, &mut next, prev) } { + 0 => { + prev = next; + hosts.push(unsafe { Host::from_raw(next) }); + } + 1 => break, + rc => return Err(Error::from_session_error_raw(sess.raw, rc)), + } + } + + Ok(hosts) + } + + /// Given a Host object, find the matching raw node in the internal list. + /// The returned value is only valid while the session is locked. + fn resolve_to_raw_host( + &self, + sess: &MutexGuard<SessionInner>, + host: &Host, + ) -> Result<Option<*mut raw::libssh2_knownhost>, Error> { + let mut next = 0 as *mut _; + let mut prev = 0 as *mut _; + + loop { + match unsafe { raw::libssh2_knownhost_get(self.raw, &mut next, prev) } { + 0 => { + prev = next; + let current = unsafe { Host::from_raw(next) }; + if current == *host { + return Ok(Some(next)); + } + } + 1 => break, + rc => return Err(Error::from_session_error_raw(sess.raw, rc)), + } } + Ok(None) } /// Delete a known host entry from the collection of known hosts. - pub fn remove(&self, host: Host) -> Result<(), Error> { - self.sess - .rc(unsafe { raw::libssh2_knownhost_del(self.raw, host.raw) }) + pub fn remove(&self, host: &Host) -> Result<(), Error> { + let sess = self.sess.lock(); + + if let Some(raw_host) = self.resolve_to_raw_host(&sess, host)? { + return sess.rc(unsafe { raw::libssh2_knownhost_del(self.raw, raw_host) }); + } else { + Ok(()) + } } /// Checks a host and its associated key against the collection of known @@ -205,6 +260,7 @@ impl KnownHosts { let host = CString::new(host)?; let flags = raw::LIBSSH2_KNOWNHOST_TYPE_PLAIN | raw::LIBSSH2_KNOWNHOST_KEYENC_RAW | (fmt as c_int); + let sess = self.sess.lock(); unsafe { let rc = raw::libssh2_knownhost_addc( self.raw, @@ -217,57 +273,33 @@ impl KnownHosts { flags, 0 as *mut _, ); - self.sess.rc(rc) + sess.rc(rc) } } } impl Drop for KnownHosts { fn drop(&mut self) { + let _sess = self.sess.lock(); unsafe { raw::libssh2_knownhost_free(self.raw) } } } -impl<'kh> Iterator for Hosts<'kh> { - type Item = Result<Host<'kh>, Error>; - fn next(&mut self) -> Option<Result<Host<'kh>, Error>> { - unsafe { - let mut next = 0 as *mut _; - match raw::libssh2_knownhost_get(self.hosts.raw, &mut next, self.prev) { - 0 => { - self.prev = next; - Some(Ok(Binding::from_raw(next))) - } - 1 => None, - rc => Some(Err(self.hosts.sess.rc(rc).err().unwrap())), - } - } - } -} - -impl<'kh> Host<'kh> { +impl Host { /// This is `None` if no plain text host name exists. pub fn name(&self) -> Option<&str> { - unsafe { ::opt_bytes(self, (*self.raw).name).and_then(|s| str::from_utf8(s).ok()) } + self.name.as_ref().map(String::as_str) } /// Returns the key in base64/printable format pub fn key(&self) -> &str { - let bytes = unsafe { ::opt_bytes(self, (*self.raw).key).unwrap() }; - str::from_utf8(bytes).unwrap() + &self.key } -} - -impl<'kh> Binding for Host<'kh> { - type Raw = *mut raw::libssh2_knownhost; - unsafe fn from_raw(raw: *mut raw::libssh2_knownhost) -> Host<'kh> { - Host { - raw: raw, - _marker: marker::PhantomData, - } - } - fn raw(&self) -> *mut raw::libssh2_knownhost { - self.raw + unsafe fn from_raw(raw: *mut raw::libssh2_knownhost) -> Self { + let name = ::opt_bytes(&raw, (*raw).name).and_then(|s| String::from_utf8(s.to_vec()).ok()); + let key = ::opt_bytes(&raw, (*raw).key).unwrap(); + let key = String::from_utf8(key.to_vec()).unwrap(); + Self { name, key } } } @@ -22,8 +22,7 @@ //! agent.connect().unwrap(); //! agent.list_identities().unwrap(); //! -//! for identity in agent.identities() { -//! let identity = identity.unwrap(); // assume no I/O errors +//! for identity in agent.identities().unwrap() { //! println!("{}", identity.comment()); //! let pubkey = identity.blob(); //! } @@ -138,13 +137,14 @@ extern crate libc; extern crate libssh2_sys as raw; #[macro_use] extern crate bitflags; +extern crate parking_lot; use std::ffi::CStr; -pub use agent::{Agent, Identities, PublicKey}; +pub use agent::{Agent, PublicKey}; pub use channel::{Channel, ExitSignal, ReadWindow, Stream, WriteWindow}; pub use error::Error; -pub use knownhosts::{Host, Hosts, KnownHosts}; +pub use knownhosts::{Host, KnownHosts}; pub use listener::Listener; use session::SessionInner; pub use session::{BlockDirections, KeyboardInteractivePrompt, Prompt, ScpFileStat, Session}; diff --git a/src/listener.rs b/src/listener.rs index afd29fd..a1a4c8e 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -1,3 +1,4 @@ +use parking_lot::Mutex; use std::sync::Arc; use {raw, Channel, Error, SessionInner}; @@ -7,24 +8,27 @@ use {raw, Channel, Error, SessionInner}; /// the remote server's port. pub struct Listener { raw: *mut raw::LIBSSH2_LISTENER, - sess: Arc<SessionInner>, + sess: Arc<Mutex<SessionInner>>, } impl Listener { /// Accept a queued connection from this listener. pub fn accept(&mut self) -> Result<Channel, Error> { + let sess = self.sess.lock(); unsafe { let chan = raw::libssh2_channel_forward_accept(self.raw); - Channel::from_raw_opt(chan, &self.sess) + let err = sess.last_error(); + Channel::from_raw_opt(chan, err, &self.sess) } } pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_LISTENER, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { raw, diff --git a/src/session.rs b/src/session.rs index efb4b25..90a0333 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,8 +1,8 @@ #[cfg(unix)] use libc::size_t; use libc::{self, c_char, c_int, c_long, c_uint, c_void}; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; use std::borrow::Cow; -use std::cell::{Ref, RefCell}; use std::ffi::CString; use std::mem; use std::net::TcpStream; @@ -65,14 +65,9 @@ unsafe fn with_abstract<R, F: FnOnce() -> R>( pub(crate) struct SessionInner { pub(crate) raw: *mut raw::LIBSSH2_SESSION, - tcp: RefCell<Option<TcpStream>>, + tcp: Option<TcpStream>, } -// The compiler doesn't know that it is Send safe because of the raw -// pointer inside. We know that the way that it is used by libssh2 -// and this crate is Send safe. -unsafe impl Send for SessionInner {} - /// An SSH session, typically representing one TCP connection. /// /// All other structures are based on an SSH session and cannot outlive a @@ -80,14 +75,9 @@ unsafe impl Send for SessionInner {} /// (via the `set_tcp_stream` method). #[derive(Clone)] pub struct Session { - inner: Arc<SessionInner>, + inner: Arc<Mutex<SessionInner>>, } -// The compiler doesn't know that it is Send safe because of the raw -// pointer inside. We know that the way that it is used by libssh2 -// and this crate is Send safe. -unsafe impl Send for Session {} - /// Metadata returned about a remote file when received via `scp`. pub struct ScpFileStat { stat: libc::stat, @@ -123,18 +113,19 @@ impl Session { Err(Error::unknown()) } else { Ok(Session { - inner: Arc::new(SessionInner { + inner: Arc::new(Mutex::new(SessionInner { raw: ret, - tcp: RefCell::new(None), - }), + tcp: None, + })), }) } } } #[doc(hidden)] - pub fn raw(&self) -> *mut raw::LIBSSH2_SESSION { - self.inner.raw + pub fn raw(&self) -> MappedMutexGuard<raw::LIBSSH2_SESSION> { + let inner = self.inner(); + MutexGuard::map(inner, |inner| unsafe { &mut *inner.raw }) } /// Set the SSH protocol banner for the local client @@ -145,12 +136,8 @@ impl Session { /// default. pub fn set_banner(&self, banner: &str) -> Result<(), Error> { let banner = CString::new(banner)?; - unsafe { - self.rc(raw::libssh2_session_banner_set( - self.inner.raw, - banner.as_ptr(), - )) - } + let inner = self.inner(); + unsafe { inner.rc(raw::libssh2_session_banner_set(inner.raw, banner.as_ptr())) } } /// Flag indicating whether SIGPIPE signals will be allowed or blocked. @@ -160,9 +147,10 @@ impl Session { /// the library to not attempt to block SIGPIPE from the underlying socket /// layer. pub fn set_allow_sigpipe(&self, block: bool) { + let inner = self.inner(); let res = unsafe { - self.rc(raw::libssh2_session_flag( - self.inner.raw, + inner.rc(raw::libssh2_session_flag( + inner.raw, raw::LIBSSH2_FLAG_SIGPIPE as c_int, block as c_int, )) @@ -177,9 +165,10 @@ impl Session { /// try to negotiate compression enabling for this connection. By default /// libssh2 will not attempt to use compression. pub fn set_compress(&self, compress: bool) { + let inner = self.inner(); let res = unsafe { - self.rc(raw::libssh2_session_flag( - self.inner.raw, + inner.rc(raw::libssh2_session_flag( + inner.raw, raw::LIBSSH2_FLAG_COMPRESS as c_int, compress as c_int, )) @@ -197,12 +186,12 @@ impl Session { /// a blocking session will wait for room. A non-blocking session will /// return immediately without writing anything. pub fn set_blocking(&self, blocking: bool) { - self.inner.set_blocking(blocking); + self.inner().set_blocking(blocking); } /// Returns whether the session was previously set to nonblocking. pub fn is_blocking(&self) -> bool { - self.inner.is_blocking() + self.inner().is_blocking() } /// Set timeout for blocking functions. @@ -215,7 +204,8 @@ impl Session { /// for blocking functions. pub fn set_timeout(&self, timeout_ms: u32) { let timeout_ms = timeout_ms as c_long; - unsafe { raw::libssh2_session_set_timeout(self.inner.raw, timeout_ms) } + let inner = self.inner(); + unsafe { raw::libssh2_session_set_timeout(inner.raw, timeout_ms) } } /// Returns the timeout, in milliseconds, for how long blocking calls may @@ -223,7 +213,8 @@ impl Session { /// /// A timeout of 0 signifies no timeout. pub fn timeout(&self) -> u32 { - unsafe { raw::libssh2_session_get_timeout(self.inner.raw) as u32 } + let inner = self.inner(); + unsafe { raw::libssh2_session_get_timeout(inner.raw) as u32 } } /// Begin transport layer protocol negotiation with the connected host. @@ -243,17 +234,17 @@ impl Session { raw::libssh2_session_handshake(raw, stream.as_raw_fd()) } - unsafe { - let stream = self.inner.tcp.borrow(); + let inner = self.inner(); - let stream = stream.as_ref().ok_or_else(|| { + unsafe { + let stream = inner.tcp.as_ref().ok_or_else(|| { Error::new( raw::LIBSSH2_ERROR_BAD_SOCKET, "use set_tcp_stream() to associate with a TcpStream", ) })?; - self.rc(handshake(self.inner.raw, stream)) + inner.rc(handshake(inner.raw, stream)) } } @@ -265,13 +256,19 @@ impl Session { /// concurrently elsewhere for the duration of this session as it may /// interfere with the protocol. pub fn set_tcp_stream(&mut self, stream: TcpStream) { - *self.inner.tcp.borrow_mut() = Some(stream); + let mut inner = self.inner(); + let _ = inner.tcp.replace(stream); } /// Returns a reference to the stream that was associated with the Session /// by the Session::handshake method. - pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> { - self.inner.tcp.borrow() + pub fn tcp_stream(&self) -> Option<MappedMutexGuard<TcpStream>> { + let inner = self.inner(); + if inner.tcp.is_some() { + Some(MutexGuard::map(inner, |inner| inner.tcp.as_mut().unwrap())) + } else { + None + } } /// Attempt basic password authentication. @@ -281,9 +278,10 @@ impl Session { /// authentication (routed via PAM or another authentication backed) /// instead. pub fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> { - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_password_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, password.as_ptr() as *const _, @@ -394,10 +392,11 @@ impl Session { })); } + let inner = self.inner(); unsafe { - with_abstract(self.inner.raw, prompter as *mut P as *mut c_void, || { - self.rc(raw::libssh2_userauth_keyboard_interactive_ex( - self.inner.raw, + with_abstract(inner.raw, prompter as *mut P as *mut c_void, || { + inner.rc(raw::libssh2_userauth_keyboard_interactive_ex( + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, Some(prompt::<P>), @@ -416,8 +415,9 @@ impl Session { let mut agent = self.agent()?; agent.connect()?; agent.list_identities()?; - let identity = match agent.identities().next() { - Some(identity) => identity?, + let identities = agent.identities()?; + let identity = match identities.get(0) { + Some(identity) => identity, None => { return Err(Error::new( raw::LIBSSH2_ERROR_INVAL as c_int, @@ -446,9 +446,10 @@ impl Session { Some(s) => Some(CString::new(s)?), None => None, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_publickey_fromfile_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, pubkey.as_ref().map(|s| s.as_ptr()).unwrap_or(0 as *const _), @@ -484,9 +485,10 @@ impl Session { Some(s) => Some(CString::new(s)?), None => None, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_publickey_frommemory( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as size_t, pubkeydata @@ -525,9 +527,10 @@ impl Session { Some(local) => local, None => username, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_hostbased_fromfile_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, publickey.as_ptr(), @@ -547,7 +550,8 @@ impl Session { /// Indicates whether or not the named session has been successfully /// authenticated. pub fn authenticated(&self) -> bool { - unsafe { raw::libssh2_userauth_authenticated(self.inner.raw) != 0 } + let inner = self.inner(); + unsafe { raw::libssh2_userauth_authenticated(inner.raw) != 0 } } /// Send a SSH_USERAUTH_NONE request to the remote host. @@ -564,10 +568,11 @@ impl Session { pub fn auth_methods(&self, username: &str) -> Result<&str, Error> { let len = username.len(); let username = CString::new(username)?; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_userauth_list(self.inner.raw, username.as_ptr(), len as c_uint); + let ret = raw::libssh2_userauth_list(inner.raw, username.as_ptr(), len as c_uint); if ret.is_null() { - match Error::last_error(self) { + match inner.last_error() { Some(err) => Err(err), None => Ok(""), } @@ -586,9 +591,10 @@ impl Session { /// negotiation. pub fn method_pref(&self, method_type: MethodType, prefs: &str) -> Result<(), Error> { let prefs = CString::new(prefs)?; + let inner = self.inner(); unsafe { - self.rc(raw::libssh2_session_method_pref( - self.inner.raw, + inner.rc(raw::libssh2_session_method_pref( + inner.raw, method_type as c_int, prefs.as_ptr(), )) @@ -600,8 +606,9 @@ impl Session { /// Returns the actual method negotiated for a particular transport /// parameter. May return `None` if the session has not yet been started. pub fn methods(&self, method_type: MethodType) -> Option<&str> { + let inner = self.inner(); unsafe { - let ptr = raw::libssh2_session_methods(self.inner.raw, method_type as c_int); + let ptr = raw::libssh2_session_methods(inner.raw, method_type as c_int); ::opt_bytes(self, ptr).and_then(|s| str::from_utf8(s).ok()) } } @@ -611,18 +618,19 @@ impl Session { static STATIC: () = (); let method_type = method_type as c_int; let mut ret = Vec::new(); + let inner = self.inner(); unsafe { let mut ptr = 0 as *mut _; - let rc = raw::libssh2_session_supported_algs(self.inner.raw, method_type, &mut ptr); + let rc = raw::libssh2_session_supported_algs(inner.raw, method_type, &mut ptr); if rc <= 0 { - self.rc(rc)?; + inner.rc(rc)?; } for i in 0..(rc as isize) { let s = ::opt_bytes(&STATIC, *ptr.offset(i)).unwrap(); let s = str::from_utf8(s).unwrap(); ret.push(s); } - raw::libssh2_free(self.inner.raw, ptr as *mut c_void); + raw::libssh2_free(inner.raw, ptr as *mut c_void); } Ok(ret) } @@ -631,7 +639,12 @@ impl Session { /// /// The returned agent will still need to be connected manually before use. pub fn agent(&self) -> Result<Agent, Error> { - unsafe { Agent::from_raw_opt(raw::libssh2_agent_init(self.inner.raw), &self.inner) } + let inner = self.inner(); + unsafe { + let agent = raw::libssh2_agent_init(inner.raw); + let err = inner.last_error(); + Agent::from_raw_opt(agent, err, &self.inner) + } } /// Init a collection of known hosts for this session. @@ -639,9 +652,11 @@ impl Session { /// Returns the handle to an internal representation of a known host /// collection. pub fn known_hosts(&self) -> Result<KnownHosts, Error> { + let inner = self.inner(); unsafe { - let ptr = raw::libssh2_knownhost_init(self.inner.raw); - KnownHosts::from_raw_opt(ptr, &self.inner) + let ptr = raw::libssh2_knownhost_init(inner.raw); + let err = inner.last_error(); + KnownHosts::from_raw_opt(ptr, err, &self.inner) } } @@ -679,15 +694,17 @@ impl Session { let (shost, sport) = src.unwrap_or(("127.0.0.1", 22)); let host = CString::new(host)?; let shost = CString::new(shost)?; + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_direct_tcpip_ex( - self.inner.raw, + inner.raw, host.as_ptr(), port as c_int, shost.as_ptr(), sport as c_int, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -703,15 +720,17 @@ impl Session { queue_maxsize: Option<u32>, ) -> Result<(Listener, u16), Error> { let mut bound_port = 0; + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_forward_listen_ex( - self.inner.raw, + inner.raw, host.map(|s| s.as_ptr()).unwrap_or(0 as *const _) as *mut _, remote_port as c_int, &mut bound_port, queue_maxsize.unwrap_or(0) as c_int, ); - Listener::from_raw_opt(ret, &self.inner).map(|l| (l, bound_port as u16)) + let err = inner.last_error(); + Listener::from_raw_opt(ret, err, &self.inner).map(|l| (l, bound_port as u16)) } } @@ -722,10 +741,12 @@ impl Session { /// about the remote file to prepare for receiving the file. pub fn scp_recv(&self, path: &Path) -> Result<(Channel, ScpFileStat), Error> { let path = CString::new(util::path2bytes(path)?)?; + let inner = self.inner(); unsafe { let mut sb: raw::libssh2_struct_stat = mem::zeroed(); - let ret = raw::libssh2_scp_recv2(self.inner.raw, path.as_ptr(), &mut sb); - let mut c = Channel::from_raw_opt(ret, &self.inner)?; + let ret = raw::libssh2_scp_recv2(inner.raw, path.as_ptr(), &mut sb); + let err = inner.last_error(); + let mut c = Channel::from_raw_opt(ret, err, &self.inner)?; // Hm, apparently when we scp_recv() a file the actual channel // itself does not respond well to read_to_end(), and it also sends @@ -754,16 +775,18 @@ impl Session { ) -> Result<Channel, Error> { let path = CString::new(util::path2bytes(remote_path)?)?; let (mtime, atime) = times.unwrap_or((0, 0)); + let inner = self.inner(); unsafe { let ret = raw::libssh2_scp_send64( - self.inner.raw, + inner.raw, path.as_ptr(), mode as c_int, size as i64, mtime as libc::time_t, atime as libc::time_t, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -774,9 +797,11 @@ impl Session { /// own unique binary packet protocol which must be managed with the /// methods on `Sftp`. pub fn sftp(&self) -> Result<Sftp, Error> { + let inner = self.inner(); unsafe { - let ret = raw::libssh2_sftp_init(self.inner.raw); - Sftp::from_raw_opt(ret, &self.inner) + let ret = raw::libssh2_sftp_init(inner.raw); + let err = inner.last_error(); + Sftp::from_raw_opt(ret, err, &self.inner) } } @@ -792,9 +817,10 @@ impl Session { message: Option<&str>, ) -> Result<Channel, Error> { let message_len = message.map(|s| s.len()).unwrap_or(0); + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_open_ex( - self.inner.raw, + inner.raw, channel_type.as_ptr() as *const _, channel_type.len() as c_uint, window_size as c_uint, @@ -805,7 +831,8 @@ impl Session { .unwrap_or(0 as *const _) as *const _, message_len as c_uint, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -824,7 +851,8 @@ impl Session { /// /// Will only return `None` if an error has ocurred. pub fn banner_bytes(&self) -> Option<&[u8]> { - unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.inner.raw)) } + let inner = self.inner(); + unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(inner.raw)) } } /// Get the remote key. @@ -833,8 +861,9 @@ impl Session { pub fn host_key(&self) -> Option<(&[u8], HostKeyType)> { let mut len = 0; let mut kind = 0; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_session_hostkey(self.inner.raw, &mut len, &mut kind); + let ret = raw::libssh2_session_hostkey(inner.raw, &mut len, &mut kind); if ret.is_null() { return None; } @@ -863,8 +892,9 @@ impl Session { HashType::Sha1 => 20, HashType::Sha256 => 32, }; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_hostkey_hash(self.inner.raw, hash as c_int); + let ret = raw::libssh2_hostkey_hash(inner.raw, hash as c_int); if ret.is_null() { None } else { @@ -883,9 +913,8 @@ impl Session { /// I/O, use 0 (the default) to disable keepalives. To avoid some busy-loop /// corner-cases, if you specify an interval of 1 it will be treated as 2. pub fn set_keepalive(&self, want_reply: bool, interval: u32) { - unsafe { - raw::libssh2_keepalive_config(self.inner.raw, want_reply as c_int, interval as c_uint) - } + let inner = self.inner(); + unsafe { raw::libssh2_keepalive_config(inner.raw, want_reply as c_int, interval as c_uint) } } /// Send a keepalive message if needed. @@ -894,8 +923,9 @@ impl Session { /// to call it again. pub fn keepalive_send(&self) -> Result<u32, Error> { let mut ret = 0; - let rc = unsafe { raw::libssh2_keepalive_send(self.inner.raw, &mut ret) }; - self.rc(rc)?; + let inner = self.inner(); + let rc = unsafe { raw::libssh2_keepalive_send(inner.raw, &mut ret) }; + inner.rc(rc)?; Ok(ret as u32) } @@ -914,9 +944,10 @@ impl Session { let reason = reason.unwrap_or(ByApplication) as c_int; let description = CString::new(description)?; let lang = CString::new(lang.unwrap_or(""))?; + let inner = self.inner(); unsafe { - self.rc(raw::libssh2_session_disconnect_ex( - self.inner.raw, + inner.rc(raw::libssh2_session_disconnect_ex( + inner.raw, reason, description.as_ptr(), lang.as_ptr(), @@ -924,17 +955,13 @@ impl Session { } } - /// Translate a return code into a Rust-`Result`. - pub fn rc(&self, rc: c_int) -> Result<(), Error> { - self.inner.rc(rc) - } - /// Returns the blocked io directions that the application needs to wait for. /// /// This function should be used after an error of type `WouldBlock` is returned to /// find out the socket events the application has to wait for. pub fn block_directions(&self) -> BlockDirections { - let dir = unsafe { raw::libssh2_session_block_directions(self.inner.raw) }; + let inner = self.inner(); + let dir = unsafe { raw::libssh2_session_block_directions(inner.raw) }; match dir { raw::LIBSSH2_SESSION_BLOCK_INBOUND => BlockDirections::Inbound, raw::LIBSSH2_SESSION_BLOCK_OUTBOUND => BlockDirections::Outbound, @@ -944,6 +971,10 @@ impl Session { _ => BlockDirections::None, } } + + fn inner(&self) -> MutexGuard<SessionInner> { + self.inner.lock() + } } impl SessionInner { @@ -956,6 +987,10 @@ impl SessionInner { } } + pub fn last_error(&self) -> Option<Error> { + Error::last_error_raw(self.raw) + } + /// Set or clear blocking mode on session pub fn set_blocking(&self, blocking: bool) { unsafe { raw::libssh2_session_set_blocking(self.raw, blocking as c_int) } diff --git a/src/sftp.rs b/src/sftp.rs index b7ecf77..025c897 100644 --- a/src/sftp.rs +++ b/src/sftp.rs @@ -1,4 +1,5 @@ use libc::{c_int, c_long, c_uint, c_ulong, size_t}; +use parking_lot::{Mutex, MutexGuard}; use std::io::prelude::*; use std::io::{self, ErrorKind, SeekFrom}; use std::mem; @@ -10,19 +11,29 @@ use {raw, Error, SessionInner}; struct SftpInner { raw: *mut raw::LIBSSH2_SFTP, - sess: Arc<SessionInner>, + sess: Arc<Mutex<SessionInner>>, } /// A handle to a remote filesystem over SFTP. /// /// Instances are created through the `sftp` method on a `Session`. pub struct Sftp { - inner: Option<SftpInner>, + inner: Option<Arc<SftpInner>>, } -struct FileInner<'sftp> { +struct LockedSftp<'sftp> { + sess: MutexGuard<'sftp, SessionInner>, + raw: *mut raw::LIBSSH2_SFTP, +} + +struct FileInner { + raw: *mut raw::LIBSSH2_SFTP_HANDLE, + sftp: Arc<SftpInner>, +} + +struct LockedFile<'file> { + sess: MutexGuard<'file, SessionInner>, raw: *mut raw::LIBSSH2_SFTP_HANDLE, - sftp: &'sftp Sftp, } /// A file handle to an SFTP connection. @@ -32,8 +43,8 @@ struct FileInner<'sftp> { /// /// Files are created through `open`, `create`, and `open_mode` on an instance /// of `Sftp`. -pub struct File<'sftp> { - inner: Option<FileInner<'sftp>>, +pub struct File { + inner: Option<FileInner>, } /// Metadata information about a remote file. @@ -110,19 +121,27 @@ pub enum OpenType { Dir = raw::LIBSSH2_SFTP_OPENDIR as isize, } +impl<'sftp> LockedSftp<'sftp> { + pub fn last_error(&self) -> Error { + let code = unsafe { raw::libssh2_sftp_last_error(self.raw) }; + Error::from_errno(code as c_int) + } +} + impl Sftp { pub(crate) fn from_raw_opt( raw: *mut raw::LIBSSH2_SFTP, - sess: &Arc<SessionInner>, + err: Option<Error>, + sess: &Arc<Mutex<SessionInner>>, ) -> Result<Self, Error> { if raw.is_null() { - Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown)) + Err(err.unwrap_or_else(Error::unknown)) } else { Ok(Self { - inner: Some(SftpInner { + inner: Some(Arc::new(SftpInner { raw, sess: Arc::clone(sess), - }), + })), }) } } @@ -136,9 +155,11 @@ impl Sftp { open_type: OpenType, ) -> Result<File, Error> { let filename = util::path2bytes(filename)?; + + let locked = self.lock()?; unsafe { let ret = raw::libssh2_sftp_open_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, flags.bits() as c_ulong, @@ -146,7 +167,7 @@ impl Sftp { open_type as c_int, ); if ret.is_null() { - Err(self.last_session_error()) + Err(locked.last_error()) } else { Ok(File::from_raw(self, ret)) } @@ -199,9 +220,10 @@ impl Sftp { /// Create a directory on the remote file system. pub fn mkdir(&self, filename: &Path, mode: i32) -> Result<(), Error> { let filename = util::path2bytes(filename)?; - self.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { raw::libssh2_sftp_mkdir_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, mode as c_long, @@ -212,9 +234,10 @@ impl Sftp { /// Remove a directory from the remote file system. pub fn rmdir(&self, filename: &Path) -> Result<(), Error> { let filename = util::path2bytes(filename)?; - self.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { raw::libssh2_sftp_rmdir_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, ) @@ -224,16 +247,17 @@ impl Sftp { /// Get the metadata for a file, performed by stat(2) pub fn stat(&self, filename: &Path) -> Result<FileStat, Error> { let filename = util::path2bytes(filename)?; + let locked = self.lock()?; unsafe { let mut ret = mem::zeroed(); let rc = raw::libssh2_sftp_stat_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, raw::LIBSSH2_SFTP_STAT, &mut ret, ); - self.rc(rc)?; + locked.sess.rc(rc)?; Ok(FileStat::from_raw(&ret)) } } @@ -241,16 +265,17 @@ impl Sftp { /// Get the metadata for a file, performed by lstat(2) pub fn lstat(&self, filename: &Path) -> Result<FileStat, Error> { let filename = util::path2bytes(filename)?; + let locked = self.lock()?; unsafe { let mut ret = mem::zeroed(); let rc = raw::libssh2_sftp_stat_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, raw::LIBSSH2_SFTP_LSTAT, &mut ret, ); - self.rc(rc)?; + locked.sess.rc(rc)?; Ok(FileStat::from_raw(&ret)) } } @@ -258,10 +283,11 @@ impl Sftp { /// Set the metadata for a file. pub fn setstat(&self, filename: &Path, stat: FileStat) -> Result<(), Error> { let filename = util::path2bytes(filename)?; - self.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { let mut raw = stat.raw(); raw::libssh2_sftp_stat_ex( - self.get_raw()?, + locked.raw, filename.as_ptr() as *const _, filename.len() as c_uint, raw::LIBSSH2_SFTP_SETSTAT, @@ -274,9 +300,10 @@ impl Sftp { pub fn symlink(&self, path: &Path, target: &Path) -> Result<(), Error> { let path = util::path2bytes(path)?; let target = util::path2bytes(target)?; - self.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { raw::libssh2_sftp_symlink_ex( - self.get_raw()?, + locked.raw, path.as_ptr() as *const _, path.len() as c_uint, target.as_ptr() as *mut _, @@ -300,10 +327,11 @@ impl Sftp { let path = util::path2bytes(path)?; let mut ret = Vec::<u8>::with_capacity(128); let mut rc; + let locked = self.lock()?; loop { rc = unsafe { raw::libssh2_sftp_symlink_ex( - self.get_raw()?, + locked.raw, path.as_ptr() as *const _, path.len() as c_uint, ret.as_ptr() as *mut _, @@ -319,7 +347,7 @@ impl Sftp { } } if rc < 0 { - Err(Error::from_errno(rc)) + Err(Error::from_session_error_raw(locked.sess.raw, rc)) } else { unsafe { ret.set_len(rc as usize) } Ok(mkpath(ret)) @@ -343,9 +371,10 @@ impl Sftp { flags.unwrap_or(RenameFlags::ATOMIC | RenameFlags::OVERWRITE | RenameFlags::NATIVE); let src = util::path2bytes(src)?; let dst = util::path2bytes(dst)?; - self.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { raw::libssh2_sftp_rename_ex( - self.get_raw()?, + locked.raw, src.as_ptr() as *const _, src.len() as c_uint, dst.as_ptr() as *const _, @@ -358,53 +387,34 @@ impl Sftp { /// Remove a file on the remote filesystem pub fn unlink(&self, file: &Path) -> Result<(), Error> { let file = util::path2bytes(file)?; - self.rc(unsafe { - raw::libssh2_sftp_unlink_ex( - self.get_raw()?, - file.as_ptr() as *const _, - file.len() as c_uint, - ) + let locked = self.lock()?; + locked.sess.rc(unsafe { + raw::libssh2_sftp_unlink_ex(locked.raw, file.as_ptr() as *const _, file.len() as c_uint) }) } - /// Peel off the last error to happen on this SFTP instance. - pub fn last_error(&self) -> Error { - let raw = match self.get_raw() { - Ok(raw) => raw, - Err(e) => return e, - }; - let code = unsafe { raw::libssh2_sftp_last_error(raw) }; - Error::from_errno(code as c_int) - } - - fn last_session_error(&self) -> Error { - if let Some(inner) = self.inner.as_ref() { - Error::last_error_raw(inner.sess.raw).unwrap_or_else(Error::unknown) - } else { - Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE) - } - } - - /// Translates a return code into a Rust-`Result` - pub fn rc(&self, rc: c_int) -> Result<(), Error> { - if rc == 0 { - Ok(()) - } else { - Err(Error::from_errno(rc)) - } - } - - fn get_raw(&self) -> Result<*mut raw::LIBSSH2_SFTP, Error> { + fn lock(&self) -> Result<LockedSftp, Error> { match self.inner.as_ref() { - Some(inner) => Ok(inner.raw), + Some(inner) => { + let sess = inner.sess.lock(); + Ok(LockedSftp { + sess, + raw: inner.raw, + }) + } None => Err(Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE)), } } + // This method is used by the async ssh crate #[doc(hidden)] pub fn shutdown(&mut self) -> Result<(), Error> { - let raw = self.get_raw()?; - self.rc(unsafe { raw::libssh2_sftp_shutdown(raw) })?; + { + let locked = self.lock()?; + locked + .sess + .rc(unsafe { raw::libssh2_sftp_shutdown(locked.raw) })?; + } let _ = self.inner.take(); Ok(()) } @@ -414,57 +424,62 @@ impl Drop for Sftp { fn drop(&mut self) { // Set ssh2 to blocking if sftp was not shutdown yet. if let Some(inner) = self.inner.take() { - let was_blocking = inner.sess.is_blocking(); - inner.sess.set_blocking(true); + let sess = inner.sess.lock(); + let was_blocking = sess.is_blocking(); + sess.set_blocking(true); assert_eq!(unsafe { raw::libssh2_sftp_shutdown(inner.raw) }, 0); - inner.sess.set_blocking(was_blocking); + sess.set_blocking(was_blocking); } } } -impl<'sftp> File<'sftp> { +impl File { /// Wraps a raw pointer in a new File structure tied to the lifetime of the /// given session. /// /// This consumes ownership of `raw`. - unsafe fn from_raw(sftp: &'sftp Sftp, raw: *mut raw::LIBSSH2_SFTP_HANDLE) -> File<'sftp> { + unsafe fn from_raw(sftp: &Sftp, raw: *mut raw::LIBSSH2_SFTP_HANDLE) -> File { File { inner: Some(FileInner { raw: raw, - sftp: sftp, + sftp: Arc::clone( + sftp.inner + .as_ref() + .expect("we have a live option during construction"), + ), }), } } /// Set the metadata for this handle. pub fn setstat(&mut self, stat: FileStat) -> Result<(), Error> { - let inner = self.get_inner()?; - inner.sftp.rc(unsafe { + let locked = self.lock()?; + locked.sess.rc(unsafe { let mut raw = stat.raw(); - raw::libssh2_sftp_fstat_ex(inner.raw, &mut raw, 1) + raw::libssh2_sftp_fstat_ex(locked.raw, &mut raw, 1) }) } /// Get the metadata for this handle. pub fn stat(&mut self) -> Result<FileStat, Error> { + let locked = self.lock()?; unsafe { - let inner = self.get_inner()?; let mut ret = mem::zeroed(); - inner - .sftp - .rc(raw::libssh2_sftp_fstat_ex(inner.raw, &mut ret, 0))?; + locked + .sess + .rc(raw::libssh2_sftp_fstat_ex(locked.raw, &mut ret, 0))?; Ok(FileStat::from_raw(&ret)) } } #[allow(missing_docs)] // sure wish I knew what this did... pub fn statvfs(&mut self) -> Result<raw::LIBSSH2_SFTP_STATVFS, Error> { + let locked = self.lock()?; unsafe { - let inner = self.get_inner()?; let mut ret = mem::zeroed(); - inner - .sftp - .rc(raw::libssh2_sftp_fstatvfs(inner.raw, &mut ret))?; + locked + .sess + .rc(raw::libssh2_sftp_fstatvfs(locked.raw, &mut ret))?; Ok(ret) } } @@ -480,14 +495,15 @@ impl<'sftp> File<'sftp> { /// Also note that the return paths will not be absolute paths, they are /// the filenames of the files in this directory. pub fn readdir(&mut self) -> Result<(PathBuf, FileStat), Error> { - let inner = self.get_inner()?; + let locked = self.lock()?; + let mut buf = Vec::<u8>::with_capacity(128); let mut stat = unsafe { mem::zeroed() }; let mut rc; loop { rc = unsafe { raw::libssh2_sftp_readdir_ex( - inner.raw, + locked.raw, buf.as_mut_ptr() as *mut _, buf.capacity() as size_t, 0 as *mut _, @@ -503,7 +519,7 @@ impl<'sftp> File<'sftp> { } } if rc < 0 { - return Err(Error::from_errno(rc)); + return Err(Error::from_session_error_raw(locked.sess.raw, rc)); } else if rc == 0 { return Err(Error::new(raw::LIBSSH2_ERROR_FILE, "no more files")); } else { @@ -519,50 +535,59 @@ impl<'sftp> File<'sftp> { /// /// For this to work requires fsync@openssh.com support on the server. pub fn fsync(&mut self) -> Result<(), Error> { - let inner = self.get_inner()?; - inner.sftp.rc(unsafe { raw::libssh2_sftp_fsync(inner.raw) }) + let locked = self.lock()?; + locked + .sess + .rc(unsafe { raw::libssh2_sftp_fsync(locked.raw) }) } - fn get_inner(&self) -> Result<&FileInner, Error> { + fn lock(&self) -> Result<LockedFile, Error> { match self.inner.as_ref() { - Some(inner) => Ok(inner), + Some(file_inner) => { + let sess = file_inner.sftp.sess.lock(); + Ok(LockedFile { + sess, + raw: file_inner.raw, + }) + } None => Err(Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE)), } } #[doc(hidden)] pub fn close(&mut self) -> Result<(), Error> { - let inner = self.get_inner()?; - inner - .sftp - .rc(unsafe { raw::libssh2_sftp_close_handle(inner.raw) })?; + { + let locked = self.lock()?; + Error::rc(unsafe { raw::libssh2_sftp_close_handle(locked.raw) })?; + } let _ = self.inner.take(); Ok(()) } } -impl<'sftp> Read for File<'sftp> { +impl Read for File { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let locked = self.lock()?; unsafe { - let inner = self.get_inner()?; let rc = - raw::libssh2_sftp_read(inner.raw, buf.as_mut_ptr() as *mut _, buf.len() as size_t); - match rc { - n if n < 0 => Err(io::Error::new(ErrorKind::Other, inner.sftp.last_error())), - n => Ok(n as usize), + raw::libssh2_sftp_read(locked.raw, buf.as_mut_ptr() as *mut _, buf.len() as size_t); + if rc < 0 { + Err(Error::from_session_error_raw(locked.sess.raw, rc as _).into()) + } else { + Ok(rc as usize) } } } } -impl<'sftp> Write for File<'sftp> { +impl Write for File { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - let inner = self.get_inner()?; + let locked = self.lock()?; let rc = unsafe { - raw::libssh2_sftp_write(inner.raw, buf.as_ptr() as *const _, buf.len() as size_t) + raw::libssh2_sftp_write(locked.raw, buf.as_ptr() as *const _, buf.len() as size_t) }; if rc < 0 { - Err(Error::from_errno(rc as _).into()) + Err(Error::from_session_error_raw(locked.sess.raw, rc as _).into()) } else { Ok(rc as usize) } @@ -572,7 +597,7 @@ impl<'sftp> Write for File<'sftp> { } } -impl<'sftp> Seek for File<'sftp> { +impl Seek for File { /// Move the file handle's internal pointer to an arbitrary location. /// /// libssh2 implements file pointers as a localized concept to make file @@ -587,8 +612,8 @@ impl<'sftp> Seek for File<'sftp> { let next = match how { SeekFrom::Start(pos) => pos, SeekFrom::Current(offset) => { - let inner = self.get_inner()?; - let cur = unsafe { raw::libssh2_sftp_tell64(inner.raw) }; + let locked = self.lock()?; + let cur = unsafe { raw::libssh2_sftp_tell64(locked.raw) }; (cur as i64 + offset) as u64 } SeekFrom::End(offset) => match self.stat() { @@ -599,27 +624,21 @@ impl<'sftp> Seek for File<'sftp> { Err(e) => return Err(io::Error::new(ErrorKind::Other, e)), }, }; - let inner = self.get_inner()?; - unsafe { raw::libssh2_sftp_seek64(inner.raw, next) } + let locked = self.lock()?; + unsafe { raw::libssh2_sftp_seek64(locked.raw, next) } Ok(next) } } -impl<'sftp> Drop for File<'sftp> { +impl Drop for File { fn drop(&mut self) { // Set ssh2 to blocking if the file was not closed yet. - if let Some(inner) = self.inner.take() { - // Normally sftp should still be available here. The only way it is `None` is when - // shutdown has been polled until success. In this case the async client should - // also properly poll `close` on `File` until success. - if let Some(sftp) = inner.sftp.inner.as_ref() { - let was_blocking = sftp.sess.is_blocking(); - sftp.sess.set_blocking(true); - assert_eq!(unsafe { raw::libssh2_sftp_close_handle(inner.raw) }, 0); - sftp.sess.set_blocking(was_blocking); - } else { - assert_eq!(unsafe { raw::libssh2_sftp_close_handle(inner.raw) }, 0); - } + if let Some(file_inner) = self.inner.take() { + let sess_inner = file_inner.sftp.sess.lock(); + let was_blocking = sess_inner.is_blocking(); + sess_inner.set_blocking(true); + assert_eq!(unsafe { raw::libssh2_sftp_close_handle(file_inner.raw) }, 0); + sess_inner.set_blocking(was_blocking); } } } diff --git a/tests/all/agent.rs b/tests/all/agent.rs index 8feee06..b3a7963 100644 --- a/tests/all/agent.rs +++ b/tests/all/agent.rs @@ -7,9 +7,8 @@ fn smoke() { agent.connect().unwrap(); agent.list_identities().unwrap(); { - let mut a = agent.identities(); - let i1 = a.next().unwrap().unwrap(); - a.count(); + let a = agent.identities().unwrap(); + let i1 = &a[0]; assert!(agent.userauth("foo", &i1).is_err()); } agent.disconnect().unwrap(); diff --git a/tests/all/knownhosts.rs b/tests/all/knownhosts.rs index 1847797..ab83577 100644 --- a/tests/all/knownhosts.rs +++ b/tests/all/knownhosts.rs @@ -3,8 +3,9 @@ use ssh2::{KnownHostFileKind, Session}; #[test] fn smoke() { let sess = Session::new().unwrap(); - let hosts = sess.known_hosts().unwrap(); - assert_eq!(hosts.iter().count(), 0); + let known_hosts = sess.known_hosts().unwrap(); + let hosts = known_hosts.hosts().unwrap(); + assert_eq!(hosts.len(), 0); } #[test] @@ -20,11 +21,14 @@ PW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi\ /w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ== "; let sess = Session::new().unwrap(); - let mut hosts = sess.known_hosts().unwrap(); - hosts.read_str(encoded, KnownHostFileKind::OpenSSH).unwrap(); + let mut known_hosts = sess.known_hosts().unwrap(); + known_hosts + .read_str(encoded, KnownHostFileKind::OpenSSH) + .unwrap(); - assert_eq!(hosts.iter().count(), 1); - let host = hosts.iter().next().unwrap().unwrap(); + let hosts = known_hosts.hosts().unwrap(); + assert_eq!(hosts.len(), 1); + let host = &hosts[0]; assert_eq!(host.name(), None); assert_eq!( host.key(), @@ -39,10 +43,10 @@ PW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi\ ); assert_eq!( - hosts - .write_string(&host, KnownHostFileKind::OpenSSH) + known_hosts + .write_string(host, KnownHostFileKind::OpenSSH) .unwrap(), encoded ); - hosts.remove(host).unwrap(); + known_hosts.remove(host).unwrap(); } diff --git a/tests/all/main.rs b/tests/all/main.rs index f74684a..87380eb 100644 --- a/tests/all/main.rs +++ b/tests/all/main.rs @@ -36,7 +36,8 @@ pub fn authed_session() -> ssh2::Session { let mut agent = sess.agent().unwrap(); agent.connect().unwrap(); agent.list_identities().unwrap(); - let identity = agent.identities().next().unwrap().unwrap(); + let identities = agent.identities().unwrap(); + let identity = &identities[0]; agent.userauth(&user, &identity).unwrap(); } assert!(sess.authenticated()); diff --git a/tests/all/session.rs b/tests/all/session.rs index 8894081..ee3eaa0 100644 --- a/tests/all/session.rs +++ b/tests/all/session.rs @@ -51,7 +51,7 @@ fn smoke_handshake() { agent.connect().unwrap(); agent.list_identities().unwrap(); { - let identity = agent.identities().next().unwrap().unwrap(); + let identity = &agent.identities().unwrap()[0]; agent.userauth(&user, &identity).unwrap(); } assert!(sess.authenticated()); |