From 1bbdfca88957747a9ad2d6010c84abe4189b2439 Mon Sep 17 00:00:00 2001 From: Wez Furlong Date: Tue, 23 Jul 2019 17:12:18 -0700 Subject: Session::handshake now takes ownership of TcpStream Refs: https://github.com/alexcrichton/ssh2-rs/issues/17 --- src/channel.rs | 2 +- src/error.rs | 1 - src/session.rs | 133 +++++++++++++++++++++++++++++++++------------------------ 3 files changed, 78 insertions(+), 58 deletions(-) (limited to 'src') diff --git a/src/channel.rs b/src/channel.rs index 6328938..85c6576 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -4,7 +4,7 @@ use std::io::prelude::*; use std::io::{self, ErrorKind}; use std::slice; -use util::{Binding, SessionBinding}; +use util::SessionBinding; use {raw, Error, Session}; /// A channel represents a portion of an SSH connection on which data can be diff --git a/src/error.rs b/src/error.rs index 1d2c5da..5eb92f9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,6 @@ use std::ffi::NulError; use std::fmt; use std::str; -use util::Binding; use {raw, Session}; /// Representation of an error that can occur within libssh2 diff --git a/src/session.rs b/src/session.rs index 10a0703..117af4f 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,28 +1,35 @@ #[cfg(unix)] use libc::size_t; use libc::{self, c_int, c_long, c_uint, c_void}; +use std::cell::{Ref, RefCell}; use std::ffi::CString; use std::mem; use std::net::TcpStream; use std::path::Path; +use std::rc::Rc; use std::slice; use std::str; -use util::{self, Binding, SessionBinding}; +use util::{self, SessionBinding}; use {raw, ByApplication, DisconnectCode, Error, HostKeyType}; use {Agent, Channel, HashType, KnownHosts, Listener, MethodType, Sftp}; +struct SessionInner { + raw: *mut raw::LIBSSH2_SESSION, + tcp: RefCell>, +} + +unsafe impl Send for SessionInner {} + /// An SSH session, typically representing one TCP connection. /// /// All other structures are based on an SSH session and cannot outlive a /// session. Sessions are created and then have the TCP socket handed to them /// (via the `handshake` method). pub struct Session { - raw: *mut raw::LIBSSH2_SESSION, + inner: Rc, } -unsafe impl Send for Session {} - /// Metadata returned about a remote file when received via `scp`. pub struct ScpFileStat { stat: libc::stat, @@ -43,11 +50,21 @@ impl Session { if ret.is_null() { Err(Error::unknown()) } else { - Ok(Binding::from_raw(ret)) + Ok(Session { + inner: Rc::new(SessionInner { + raw: ret, + tcp: RefCell::new(None), + }), + }) } } } + #[doc(hidden)] + pub fn raw(&self) -> *mut raw::LIBSSH2_SESSION { + self.inner.raw + } + /// Set the SSH protocol banner for the local client /// /// Set the banner that will be sent to the remote host when the SSH session @@ -56,7 +73,12 @@ impl Session { /// default. pub fn set_banner(&self, banner: &str) -> Result<(), Error> { let banner = try!(CString::new(banner)); - unsafe { self.rc(raw::libssh2_session_banner_set(self.raw, banner.as_ptr())) } + unsafe { + self.rc(raw::libssh2_session_banner_set( + self.inner.raw, + banner.as_ptr(), + )) + } } /// Flag indicating whether SIGPIPE signals will be allowed or blocked. @@ -68,7 +90,7 @@ impl Session { pub fn set_allow_sigpipe(&self, block: bool) { let res = unsafe { self.rc(raw::libssh2_session_flag( - self.raw, + self.inner.raw, raw::LIBSSH2_FLAG_SIGPIPE as c_int, block as c_int, )) @@ -85,7 +107,7 @@ impl Session { pub fn set_compress(&self, compress: bool) { let res = unsafe { self.rc(raw::libssh2_session_flag( - self.raw, + self.inner.raw, raw::LIBSSH2_FLAG_COMPRESS as c_int, compress as c_int, )) @@ -103,12 +125,12 @@ impl Session { /// a blocking session will wait for room. A non-blocking session will /// return immediately without writing anything. pub fn set_blocking(&self, blocking: bool) { - unsafe { raw::libssh2_session_set_blocking(self.raw, blocking as c_int) } + unsafe { raw::libssh2_session_set_blocking(self.inner.raw, blocking as c_int) } } /// Returns whether the session was previously set to nonblocking. pub fn is_blocking(&self) -> bool { - unsafe { raw::libssh2_session_get_blocking(self.raw) != 0 } + unsafe { raw::libssh2_session_get_blocking(self.inner.raw) != 0 } } /// Set timeout for blocking functions. @@ -121,7 +143,7 @@ impl Session { /// for blocking functions. pub fn set_timeout(&self, timeout_ms: u32) { let timeout_ms = timeout_ms as c_long; - unsafe { raw::libssh2_session_set_timeout(self.raw, timeout_ms) } + unsafe { raw::libssh2_session_set_timeout(self.inner.raw, timeout_ms) } } /// Returns the timeout, in milliseconds, for how long blocking calls may @@ -129,23 +151,19 @@ impl Session { /// /// A timeout of 0 signifies no timeout. pub fn timeout(&self) -> u32 { - unsafe { raw::libssh2_session_get_timeout(self.raw) as u32 } + unsafe { raw::libssh2_session_get_timeout(self.inner.raw) as u32 } } /// Begin transport layer protocol negotiation with the connected host. /// - /// This session does *not* take ownership of the socket provided, it is - /// recommended to ensure that the socket persists the lifetime of this - /// session to ensure that communication is correctly performed. + /// This session takes ownership of the socket provided. + /// You may use the tcp_stream() method to obtain a reference + /// to it later. /// /// It is also highly recommended that the stream provided is not used /// concurrently elsewhere for the duration of this session as it may /// interfere with the protocol. - pub fn handshake(&mut self, stream: &TcpStream) -> Result<(), Error> { - unsafe { - return self.rc(handshake(self.raw, stream)); - } - + pub fn handshake(&mut self, stream: TcpStream) -> Result<(), Error> { #[cfg(windows)] unsafe fn handshake(raw: *mut raw::LIBSSH2_SESSION, stream: &TcpStream) -> libc::c_int { use std::os::windows::prelude::*; @@ -157,6 +175,18 @@ impl Session { use std::os::unix::prelude::*; raw::libssh2_session_handshake(raw, stream.as_raw_fd()) } + + unsafe { + self.rc(handshake(self.inner.raw, &stream))?; + *self.inner.tcp.borrow_mut() = Some(stream); + Ok(()) + } + } + + /// Returns a reference to the stream that was associated with the Session + /// by the Session::handshake method. + pub fn tcp_stream(&self) -> Ref> { + self.inner.tcp.borrow() } /// Attempt basic password authentication. @@ -168,7 +198,7 @@ impl Session { pub fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> { self.rc(unsafe { raw::libssh2_userauth_password_ex( - self.raw, + self.inner.raw, username.as_ptr() as *const _, username.len() as c_uint, password.as_ptr() as *const _, @@ -220,7 +250,7 @@ impl Session { }; self.rc(unsafe { raw::libssh2_userauth_publickey_fromfile_ex( - self.raw, + self.inner.raw, username.as_ptr() as *const _, username.len() as c_uint, pubkey.as_ref().map(|s| s.as_ptr()).unwrap_or(0 as *const _), @@ -258,7 +288,7 @@ impl Session { }; self.rc(unsafe { raw::libssh2_userauth_publickey_frommemory( - self.raw, + self.inner.raw, username.as_ptr() as *const _, username.len() as size_t, pubkeydata @@ -299,7 +329,7 @@ impl Session { }; self.rc(unsafe { raw::libssh2_userauth_hostbased_fromfile_ex( - self.raw, + self.inner.raw, username.as_ptr() as *const _, username.len() as c_uint, publickey.as_ptr(), @@ -319,7 +349,7 @@ impl Session { /// Indicates whether or not the named session has been successfully /// authenticated. pub fn authenticated(&self) -> bool { - unsafe { raw::libssh2_userauth_authenticated(self.raw) != 0 } + unsafe { raw::libssh2_userauth_authenticated(self.inner.raw) != 0 } } /// Send a SSH_USERAUTH_NONE request to the remote host. @@ -336,7 +366,7 @@ impl Session { let len = username.len(); let username = try!(CString::new(username)); unsafe { - let ret = raw::libssh2_userauth_list(self.raw, username.as_ptr(), len as c_uint); + let ret = raw::libssh2_userauth_list(self.inner.raw, username.as_ptr(), len as c_uint); if ret.is_null() { Err(Error::last_error(self).unwrap()) } else { @@ -356,7 +386,7 @@ impl Session { let prefs = try!(CString::new(prefs)); unsafe { self.rc(raw::libssh2_session_method_pref( - self.raw, + self.inner.raw, method_type as c_int, prefs.as_ptr(), )) @@ -369,7 +399,7 @@ impl Session { /// parameter. May return `None` if the session has not yet been started. pub fn methods(&self, method_type: MethodType) -> Option<&str> { unsafe { - let ptr = raw::libssh2_session_methods(self.raw, method_type as c_int); + let ptr = raw::libssh2_session_methods(self.inner.raw, method_type as c_int); ::opt_bytes(self, ptr).and_then(|s| str::from_utf8(s).ok()) } } @@ -381,7 +411,7 @@ impl Session { let mut ret = Vec::new(); unsafe { let mut ptr = 0 as *mut _; - let rc = raw::libssh2_session_supported_algs(self.raw, method_type, &mut ptr); + let rc = raw::libssh2_session_supported_algs(self.inner.raw, method_type, &mut ptr); if rc <= 0 { try!(self.rc(rc)) } @@ -390,7 +420,7 @@ impl Session { let s = str::from_utf8(s).unwrap(); ret.push(s); } - raw::libssh2_free(self.raw, ptr as *mut c_void); + raw::libssh2_free(self.inner.raw, ptr as *mut c_void); } Ok(ret) } @@ -399,7 +429,7 @@ impl Session { /// /// The returned agent will still need to be connected manually before use. pub fn agent(&self) -> Result { - unsafe { SessionBinding::from_raw_opt(self, raw::libssh2_agent_init(self.raw)) } + unsafe { SessionBinding::from_raw_opt(self, raw::libssh2_agent_init(self.inner.raw)) } } /// Init a collection of known hosts for this session. @@ -408,7 +438,7 @@ impl Session { /// collection. pub fn known_hosts(&self) -> Result { unsafe { - let ptr = raw::libssh2_knownhost_init(self.raw); + let ptr = raw::libssh2_knownhost_init(self.inner.raw); SessionBinding::from_raw_opt(self, ptr) } } @@ -449,7 +479,7 @@ impl Session { let shost = try!(CString::new(shost)); unsafe { let ret = raw::libssh2_channel_direct_tcpip_ex( - self.raw, + self.inner.raw, host.as_ptr(), port as c_int, shost.as_ptr(), @@ -473,7 +503,7 @@ impl Session { let mut bound_port = 0; unsafe { let ret = raw::libssh2_channel_forward_listen_ex( - self.raw, + self.inner.raw, host.map(|s| s.as_ptr()).unwrap_or(0 as *const _) as *mut _, remote_port as c_int, &mut bound_port, @@ -492,7 +522,7 @@ impl Session { let path = try!(CString::new(try!(util::path2bytes(path)))); unsafe { let mut sb: libc::stat = mem::zeroed(); - let ret = raw::libssh2_scp_recv(self.raw, path.as_ptr(), &mut sb); + let ret = raw::libssh2_scp_recv(self.inner.raw, path.as_ptr(), &mut sb); let mut c: Channel = try!(SessionBinding::from_raw_opt(self, ret)); // Hm, apparently when we scp_recv() a file the actual channel @@ -524,7 +554,7 @@ impl Session { let (mtime, atime) = times.unwrap_or((0, 0)); unsafe { let ret = raw::libssh2_scp_send64( - self.raw, + self.inner.raw, path.as_ptr(), mode as c_int, size as i64, @@ -543,7 +573,7 @@ impl Session { /// methods on `Sftp`. pub fn sftp(&self) -> Result { unsafe { - let ret = raw::libssh2_sftp_init(self.raw); + let ret = raw::libssh2_sftp_init(self.inner.raw); SessionBinding::from_raw_opt(self, ret) } } @@ -562,7 +592,7 @@ impl Session { let message_len = message.map(|s| s.len()).unwrap_or(0); unsafe { let ret = raw::libssh2_channel_open_ex( - self.raw, + self.inner.raw, channel_type.as_ptr() as *const _, channel_type.len() as c_uint, window_size as c_uint, @@ -592,7 +622,7 @@ impl Session { /// /// Will only return `None` if an error has ocurred. pub fn banner_bytes(&self) -> Option<&[u8]> { - unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.raw)) } + unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.inner.raw)) } } /// Get the remote key. @@ -602,7 +632,7 @@ impl Session { let mut len = 0; let mut kind = 0; unsafe { - let ret = raw::libssh2_session_hostkey(self.raw, &mut len, &mut kind); + let ret = raw::libssh2_session_hostkey(self.inner.raw, &mut len, &mut kind); if ret.is_null() { return None; } @@ -632,7 +662,7 @@ impl Session { HashType::Sha256 => 32, }; unsafe { - let ret = raw::libssh2_hostkey_hash(self.raw, hash as c_int); + let ret = raw::libssh2_hostkey_hash(self.inner.raw, hash as c_int); if ret.is_null() { None } else { @@ -651,7 +681,9 @@ impl Session { /// I/O, use 0 (the default) to disable keepalives. To avoid some busy-loop /// corner-cases, if you specify an interval of 1 it will be treated as 2. pub fn set_keepalive(&self, want_reply: bool, interval: u32) { - unsafe { raw::libssh2_keepalive_config(self.raw, want_reply as c_int, interval as c_uint) } + unsafe { + raw::libssh2_keepalive_config(self.inner.raw, want_reply as c_int, interval as c_uint) + } } /// Send a keepalive message if needed. @@ -660,7 +692,7 @@ impl Session { /// to call it again. pub fn keepalive_send(&self) -> Result { let mut ret = 0; - let rc = unsafe { raw::libssh2_keepalive_send(self.raw, &mut ret) }; + let rc = unsafe { raw::libssh2_keepalive_send(self.inner.raw, &mut ret) }; try!(self.rc(rc)); Ok(ret as u32) } @@ -682,7 +714,7 @@ impl Session { let lang = try!(CString::new(lang.unwrap_or(""))); unsafe { self.rc(raw::libssh2_session_disconnect_ex( - self.raw, + self.inner.raw, reason, description.as_ptr(), lang.as_ptr(), @@ -703,18 +735,7 @@ impl Session { } } -impl Binding for Session { - type Raw = *mut raw::LIBSSH2_SESSION; - - unsafe fn from_raw(raw: *mut raw::LIBSSH2_SESSION) -> Session { - Session { raw: raw } - } - fn raw(&self) -> *mut raw::LIBSSH2_SESSION { - self.raw - } -} - -impl Drop for Session { +impl Drop for SessionInner { fn drop(&mut self) { unsafe { let _rc = raw::libssh2_session_free(self.raw); -- cgit v1.2.3