diff options
Diffstat (limited to 'src/session.rs')
-rw-r--r-- | src/session.rs | 219 |
1 files changed, 127 insertions, 92 deletions
diff --git a/src/session.rs b/src/session.rs index efb4b25..90a0333 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,8 +1,8 @@ #[cfg(unix)] use libc::size_t; use libc::{self, c_char, c_int, c_long, c_uint, c_void}; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; use std::borrow::Cow; -use std::cell::{Ref, RefCell}; use std::ffi::CString; use std::mem; use std::net::TcpStream; @@ -65,14 +65,9 @@ unsafe fn with_abstract<R, F: FnOnce() -> R>( pub(crate) struct SessionInner { pub(crate) raw: *mut raw::LIBSSH2_SESSION, - tcp: RefCell<Option<TcpStream>>, + tcp: Option<TcpStream>, } -// The compiler doesn't know that it is Send safe because of the raw -// pointer inside. We know that the way that it is used by libssh2 -// and this crate is Send safe. -unsafe impl Send for SessionInner {} - /// An SSH session, typically representing one TCP connection. /// /// All other structures are based on an SSH session and cannot outlive a @@ -80,14 +75,9 @@ unsafe impl Send for SessionInner {} /// (via the `set_tcp_stream` method). #[derive(Clone)] pub struct Session { - inner: Arc<SessionInner>, + inner: Arc<Mutex<SessionInner>>, } -// The compiler doesn't know that it is Send safe because of the raw -// pointer inside. We know that the way that it is used by libssh2 -// and this crate is Send safe. -unsafe impl Send for Session {} - /// Metadata returned about a remote file when received via `scp`. pub struct ScpFileStat { stat: libc::stat, @@ -123,18 +113,19 @@ impl Session { Err(Error::unknown()) } else { Ok(Session { - inner: Arc::new(SessionInner { + inner: Arc::new(Mutex::new(SessionInner { raw: ret, - tcp: RefCell::new(None), - }), + tcp: None, + })), }) } } } #[doc(hidden)] - pub fn raw(&self) -> *mut raw::LIBSSH2_SESSION { - self.inner.raw + pub fn raw(&self) -> MappedMutexGuard<raw::LIBSSH2_SESSION> { + let inner = self.inner(); + MutexGuard::map(inner, |inner| unsafe { &mut *inner.raw }) } /// Set the SSH protocol banner for the local client @@ -145,12 +136,8 @@ impl Session { /// default. pub fn set_banner(&self, banner: &str) -> Result<(), Error> { let banner = CString::new(banner)?; - unsafe { - self.rc(raw::libssh2_session_banner_set( - self.inner.raw, - banner.as_ptr(), - )) - } + let inner = self.inner(); + unsafe { inner.rc(raw::libssh2_session_banner_set(inner.raw, banner.as_ptr())) } } /// Flag indicating whether SIGPIPE signals will be allowed or blocked. @@ -160,9 +147,10 @@ impl Session { /// the library to not attempt to block SIGPIPE from the underlying socket /// layer. pub fn set_allow_sigpipe(&self, block: bool) { + let inner = self.inner(); let res = unsafe { - self.rc(raw::libssh2_session_flag( - self.inner.raw, + inner.rc(raw::libssh2_session_flag( + inner.raw, raw::LIBSSH2_FLAG_SIGPIPE as c_int, block as c_int, )) @@ -177,9 +165,10 @@ impl Session { /// try to negotiate compression enabling for this connection. By default /// libssh2 will not attempt to use compression. pub fn set_compress(&self, compress: bool) { + let inner = self.inner(); let res = unsafe { - self.rc(raw::libssh2_session_flag( - self.inner.raw, + inner.rc(raw::libssh2_session_flag( + inner.raw, raw::LIBSSH2_FLAG_COMPRESS as c_int, compress as c_int, )) @@ -197,12 +186,12 @@ impl Session { /// a blocking session will wait for room. A non-blocking session will /// return immediately without writing anything. pub fn set_blocking(&self, blocking: bool) { - self.inner.set_blocking(blocking); + self.inner().set_blocking(blocking); } /// Returns whether the session was previously set to nonblocking. pub fn is_blocking(&self) -> bool { - self.inner.is_blocking() + self.inner().is_blocking() } /// Set timeout for blocking functions. @@ -215,7 +204,8 @@ impl Session { /// for blocking functions. pub fn set_timeout(&self, timeout_ms: u32) { let timeout_ms = timeout_ms as c_long; - unsafe { raw::libssh2_session_set_timeout(self.inner.raw, timeout_ms) } + let inner = self.inner(); + unsafe { raw::libssh2_session_set_timeout(inner.raw, timeout_ms) } } /// Returns the timeout, in milliseconds, for how long blocking calls may @@ -223,7 +213,8 @@ impl Session { /// /// A timeout of 0 signifies no timeout. pub fn timeout(&self) -> u32 { - unsafe { raw::libssh2_session_get_timeout(self.inner.raw) as u32 } + let inner = self.inner(); + unsafe { raw::libssh2_session_get_timeout(inner.raw) as u32 } } /// Begin transport layer protocol negotiation with the connected host. @@ -243,17 +234,17 @@ impl Session { raw::libssh2_session_handshake(raw, stream.as_raw_fd()) } - unsafe { - let stream = self.inner.tcp.borrow(); + let inner = self.inner(); - let stream = stream.as_ref().ok_or_else(|| { + unsafe { + let stream = inner.tcp.as_ref().ok_or_else(|| { Error::new( raw::LIBSSH2_ERROR_BAD_SOCKET, "use set_tcp_stream() to associate with a TcpStream", ) })?; - self.rc(handshake(self.inner.raw, stream)) + inner.rc(handshake(inner.raw, stream)) } } @@ -265,13 +256,19 @@ impl Session { /// concurrently elsewhere for the duration of this session as it may /// interfere with the protocol. pub fn set_tcp_stream(&mut self, stream: TcpStream) { - *self.inner.tcp.borrow_mut() = Some(stream); + let mut inner = self.inner(); + let _ = inner.tcp.replace(stream); } /// Returns a reference to the stream that was associated with the Session /// by the Session::handshake method. - pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> { - self.inner.tcp.borrow() + pub fn tcp_stream(&self) -> Option<MappedMutexGuard<TcpStream>> { + let inner = self.inner(); + if inner.tcp.is_some() { + Some(MutexGuard::map(inner, |inner| inner.tcp.as_mut().unwrap())) + } else { + None + } } /// Attempt basic password authentication. @@ -281,9 +278,10 @@ impl Session { /// authentication (routed via PAM or another authentication backed) /// instead. pub fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> { - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_password_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, password.as_ptr() as *const _, @@ -394,10 +392,11 @@ impl Session { })); } + let inner = self.inner(); unsafe { - with_abstract(self.inner.raw, prompter as *mut P as *mut c_void, || { - self.rc(raw::libssh2_userauth_keyboard_interactive_ex( - self.inner.raw, + with_abstract(inner.raw, prompter as *mut P as *mut c_void, || { + inner.rc(raw::libssh2_userauth_keyboard_interactive_ex( + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, Some(prompt::<P>), @@ -416,8 +415,9 @@ impl Session { let mut agent = self.agent()?; agent.connect()?; agent.list_identities()?; - let identity = match agent.identities().next() { - Some(identity) => identity?, + let identities = agent.identities()?; + let identity = match identities.get(0) { + Some(identity) => identity, None => { return Err(Error::new( raw::LIBSSH2_ERROR_INVAL as c_int, @@ -446,9 +446,10 @@ impl Session { Some(s) => Some(CString::new(s)?), None => None, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_publickey_fromfile_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, pubkey.as_ref().map(|s| s.as_ptr()).unwrap_or(0 as *const _), @@ -484,9 +485,10 @@ impl Session { Some(s) => Some(CString::new(s)?), None => None, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_publickey_frommemory( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as size_t, pubkeydata @@ -525,9 +527,10 @@ impl Session { Some(local) => local, None => username, }; - self.rc(unsafe { + let inner = self.inner(); + inner.rc(unsafe { raw::libssh2_userauth_hostbased_fromfile_ex( - self.inner.raw, + inner.raw, username.as_ptr() as *const _, username.len() as c_uint, publickey.as_ptr(), @@ -547,7 +550,8 @@ impl Session { /// Indicates whether or not the named session has been successfully /// authenticated. pub fn authenticated(&self) -> bool { - unsafe { raw::libssh2_userauth_authenticated(self.inner.raw) != 0 } + let inner = self.inner(); + unsafe { raw::libssh2_userauth_authenticated(inner.raw) != 0 } } /// Send a SSH_USERAUTH_NONE request to the remote host. @@ -564,10 +568,11 @@ impl Session { pub fn auth_methods(&self, username: &str) -> Result<&str, Error> { let len = username.len(); let username = CString::new(username)?; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_userauth_list(self.inner.raw, username.as_ptr(), len as c_uint); + let ret = raw::libssh2_userauth_list(inner.raw, username.as_ptr(), len as c_uint); if ret.is_null() { - match Error::last_error(self) { + match inner.last_error() { Some(err) => Err(err), None => Ok(""), } @@ -586,9 +591,10 @@ impl Session { /// negotiation. pub fn method_pref(&self, method_type: MethodType, prefs: &str) -> Result<(), Error> { let prefs = CString::new(prefs)?; + let inner = self.inner(); unsafe { - self.rc(raw::libssh2_session_method_pref( - self.inner.raw, + inner.rc(raw::libssh2_session_method_pref( + inner.raw, method_type as c_int, prefs.as_ptr(), )) @@ -600,8 +606,9 @@ impl Session { /// Returns the actual method negotiated for a particular transport /// parameter. May return `None` if the session has not yet been started. pub fn methods(&self, method_type: MethodType) -> Option<&str> { + let inner = self.inner(); unsafe { - let ptr = raw::libssh2_session_methods(self.inner.raw, method_type as c_int); + let ptr = raw::libssh2_session_methods(inner.raw, method_type as c_int); ::opt_bytes(self, ptr).and_then(|s| str::from_utf8(s).ok()) } } @@ -611,18 +618,19 @@ impl Session { static STATIC: () = (); let method_type = method_type as c_int; let mut ret = Vec::new(); + let inner = self.inner(); unsafe { let mut ptr = 0 as *mut _; - let rc = raw::libssh2_session_supported_algs(self.inner.raw, method_type, &mut ptr); + let rc = raw::libssh2_session_supported_algs(inner.raw, method_type, &mut ptr); if rc <= 0 { - self.rc(rc)?; + inner.rc(rc)?; } for i in 0..(rc as isize) { let s = ::opt_bytes(&STATIC, *ptr.offset(i)).unwrap(); let s = str::from_utf8(s).unwrap(); ret.push(s); } - raw::libssh2_free(self.inner.raw, ptr as *mut c_void); + raw::libssh2_free(inner.raw, ptr as *mut c_void); } Ok(ret) } @@ -631,7 +639,12 @@ impl Session { /// /// The returned agent will still need to be connected manually before use. pub fn agent(&self) -> Result<Agent, Error> { - unsafe { Agent::from_raw_opt(raw::libssh2_agent_init(self.inner.raw), &self.inner) } + let inner = self.inner(); + unsafe { + let agent = raw::libssh2_agent_init(inner.raw); + let err = inner.last_error(); + Agent::from_raw_opt(agent, err, &self.inner) + } } /// Init a collection of known hosts for this session. @@ -639,9 +652,11 @@ impl Session { /// Returns the handle to an internal representation of a known host /// collection. pub fn known_hosts(&self) -> Result<KnownHosts, Error> { + let inner = self.inner(); unsafe { - let ptr = raw::libssh2_knownhost_init(self.inner.raw); - KnownHosts::from_raw_opt(ptr, &self.inner) + let ptr = raw::libssh2_knownhost_init(inner.raw); + let err = inner.last_error(); + KnownHosts::from_raw_opt(ptr, err, &self.inner) } } @@ -679,15 +694,17 @@ impl Session { let (shost, sport) = src.unwrap_or(("127.0.0.1", 22)); let host = CString::new(host)?; let shost = CString::new(shost)?; + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_direct_tcpip_ex( - self.inner.raw, + inner.raw, host.as_ptr(), port as c_int, shost.as_ptr(), sport as c_int, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -703,15 +720,17 @@ impl Session { queue_maxsize: Option<u32>, ) -> Result<(Listener, u16), Error> { let mut bound_port = 0; + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_forward_listen_ex( - self.inner.raw, + inner.raw, host.map(|s| s.as_ptr()).unwrap_or(0 as *const _) as *mut _, remote_port as c_int, &mut bound_port, queue_maxsize.unwrap_or(0) as c_int, ); - Listener::from_raw_opt(ret, &self.inner).map(|l| (l, bound_port as u16)) + let err = inner.last_error(); + Listener::from_raw_opt(ret, err, &self.inner).map(|l| (l, bound_port as u16)) } } @@ -722,10 +741,12 @@ impl Session { /// about the remote file to prepare for receiving the file. pub fn scp_recv(&self, path: &Path) -> Result<(Channel, ScpFileStat), Error> { let path = CString::new(util::path2bytes(path)?)?; + let inner = self.inner(); unsafe { let mut sb: raw::libssh2_struct_stat = mem::zeroed(); - let ret = raw::libssh2_scp_recv2(self.inner.raw, path.as_ptr(), &mut sb); - let mut c = Channel::from_raw_opt(ret, &self.inner)?; + let ret = raw::libssh2_scp_recv2(inner.raw, path.as_ptr(), &mut sb); + let err = inner.last_error(); + let mut c = Channel::from_raw_opt(ret, err, &self.inner)?; // Hm, apparently when we scp_recv() a file the actual channel // itself does not respond well to read_to_end(), and it also sends @@ -754,16 +775,18 @@ impl Session { ) -> Result<Channel, Error> { let path = CString::new(util::path2bytes(remote_path)?)?; let (mtime, atime) = times.unwrap_or((0, 0)); + let inner = self.inner(); unsafe { let ret = raw::libssh2_scp_send64( - self.inner.raw, + inner.raw, path.as_ptr(), mode as c_int, size as i64, mtime as libc::time_t, atime as libc::time_t, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -774,9 +797,11 @@ impl Session { /// own unique binary packet protocol which must be managed with the /// methods on `Sftp`. pub fn sftp(&self) -> Result<Sftp, Error> { + let inner = self.inner(); unsafe { - let ret = raw::libssh2_sftp_init(self.inner.raw); - Sftp::from_raw_opt(ret, &self.inner) + let ret = raw::libssh2_sftp_init(inner.raw); + let err = inner.last_error(); + Sftp::from_raw_opt(ret, err, &self.inner) } } @@ -792,9 +817,10 @@ impl Session { message: Option<&str>, ) -> Result<Channel, Error> { let message_len = message.map(|s| s.len()).unwrap_or(0); + let inner = self.inner(); unsafe { let ret = raw::libssh2_channel_open_ex( - self.inner.raw, + inner.raw, channel_type.as_ptr() as *const _, channel_type.len() as c_uint, window_size as c_uint, @@ -805,7 +831,8 @@ impl Session { .unwrap_or(0 as *const _) as *const _, message_len as c_uint, ); - Channel::from_raw_opt(ret, &self.inner) + let err = inner.last_error(); + Channel::from_raw_opt(ret, err, &self.inner) } } @@ -824,7 +851,8 @@ impl Session { /// /// Will only return `None` if an error has ocurred. pub fn banner_bytes(&self) -> Option<&[u8]> { - unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.inner.raw)) } + let inner = self.inner(); + unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(inner.raw)) } } /// Get the remote key. @@ -833,8 +861,9 @@ impl Session { pub fn host_key(&self) -> Option<(&[u8], HostKeyType)> { let mut len = 0; let mut kind = 0; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_session_hostkey(self.inner.raw, &mut len, &mut kind); + let ret = raw::libssh2_session_hostkey(inner.raw, &mut len, &mut kind); if ret.is_null() { return None; } @@ -863,8 +892,9 @@ impl Session { HashType::Sha1 => 20, HashType::Sha256 => 32, }; + let inner = self.inner(); unsafe { - let ret = raw::libssh2_hostkey_hash(self.inner.raw, hash as c_int); + let ret = raw::libssh2_hostkey_hash(inner.raw, hash as c_int); if ret.is_null() { None } else { @@ -883,9 +913,8 @@ impl Session { /// I/O, use 0 (the default) to disable keepalives. To avoid some busy-loop /// corner-cases, if you specify an interval of 1 it will be treated as 2. pub fn set_keepalive(&self, want_reply: bool, interval: u32) { - unsafe { - raw::libssh2_keepalive_config(self.inner.raw, want_reply as c_int, interval as c_uint) - } + let inner = self.inner(); + unsafe { raw::libssh2_keepalive_config(inner.raw, want_reply as c_int, interval as c_uint) } } /// Send a keepalive message if needed. @@ -894,8 +923,9 @@ impl Session { /// to call it again. pub fn keepalive_send(&self) -> Result<u32, Error> { let mut ret = 0; - let rc = unsafe { raw::libssh2_keepalive_send(self.inner.raw, &mut ret) }; - self.rc(rc)?; + let inner = self.inner(); + let rc = unsafe { raw::libssh2_keepalive_send(inner.raw, &mut ret) }; + inner.rc(rc)?; Ok(ret as u32) } @@ -914,9 +944,10 @@ impl Session { let reason = reason.unwrap_or(ByApplication) as c_int; let description = CString::new(description)?; let lang = CString::new(lang.unwrap_or(""))?; + let inner = self.inner(); unsafe { - self.rc(raw::libssh2_session_disconnect_ex( - self.inner.raw, + inner.rc(raw::libssh2_session_disconnect_ex( + inner.raw, reason, description.as_ptr(), lang.as_ptr(), @@ -924,17 +955,13 @@ impl Session { } } - /// Translate a return code into a Rust-`Result`. - pub fn rc(&self, rc: c_int) -> Result<(), Error> { - self.inner.rc(rc) - } - /// Returns the blocked io directions that the application needs to wait for. /// /// This function should be used after an error of type `WouldBlock` is returned to /// find out the socket events the application has to wait for. pub fn block_directions(&self) -> BlockDirections { - let dir = unsafe { raw::libssh2_session_block_directions(self.inner.raw) }; + let inner = self.inner(); + let dir = unsafe { raw::libssh2_session_block_directions(inner.raw) }; match dir { raw::LIBSSH2_SESSION_BLOCK_INBOUND => BlockDirections::Inbound, raw::LIBSSH2_SESSION_BLOCK_OUTBOUND => BlockDirections::Outbound, @@ -944,6 +971,10 @@ impl Session { _ => BlockDirections::None, } } + + fn inner(&self) -> MutexGuard<SessionInner> { + self.inner.lock() + } } impl SessionInner { @@ -956,6 +987,10 @@ impl SessionInner { } } + pub fn last_error(&self) -> Option<Error> { + Error::last_error_raw(self.raw) + } + /// Set or clear blocking mode on session pub fn set_blocking(&self, blocking: bool) { unsafe { raw::libssh2_session_set_blocking(self.raw, blocking as c_int) } |