summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWez Furlong <wez@wezfurlong.org>2019-07-23 17:12:18 -0700
committerWez Furlong <wez@wezfurlong.org>2019-07-29 08:55:06 -0700
commit1bbdfca88957747a9ad2d6010c84abe4189b2439 (patch)
tree7e09eae032a5cf49a5f57d6bfcff1f0be5223f08
parent19c95e8c79723e22a35de302cb8f768f9d133dbc (diff)
downloadssh2-rs-1bbdfca88957747a9ad2d6010c84abe4189b2439.zip
Session::handshake now takes ownership of TcpStream
Refs: https://github.com/alexcrichton/ssh2-rs/issues/17
-rw-r--r--src/channel.rs2
-rw-r--r--src/error.rs1
-rw-r--r--src/session.rs133
-rw-r--r--tests/all/channel.rs24
-rw-r--r--tests/all/main.rs6
-rw-r--r--tests/all/session.rs8
-rw-r--r--tests/all/sftp.rs4
7 files changed, 99 insertions, 79 deletions
diff --git a/src/channel.rs b/src/channel.rs
index 6328938..85c6576 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -4,7 +4,7 @@ use std::io::prelude::*;
use std::io::{self, ErrorKind};
use std::slice;
-use util::{Binding, SessionBinding};
+use util::SessionBinding;
use {raw, Error, Session};
/// A channel represents a portion of an SSH connection on which data can be
diff --git a/src/error.rs b/src/error.rs
index 1d2c5da..5eb92f9 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -4,7 +4,6 @@ use std::ffi::NulError;
use std::fmt;
use std::str;
-use util::Binding;
use {raw, Session};
/// Representation of an error that can occur within libssh2
diff --git a/src/session.rs b/src/session.rs
index 10a0703..117af4f 100644
--- a/src/session.rs
+++ b/src/session.rs
@@ -1,28 +1,35 @@
#[cfg(unix)]
use libc::size_t;
use libc::{self, c_int, c_long, c_uint, c_void};
+use std::cell::{Ref, RefCell};
use std::ffi::CString;
use std::mem;
use std::net::TcpStream;
use std::path::Path;
+use std::rc::Rc;
use std::slice;
use std::str;
-use util::{self, Binding, SessionBinding};
+use util::{self, SessionBinding};
use {raw, ByApplication, DisconnectCode, Error, HostKeyType};
use {Agent, Channel, HashType, KnownHosts, Listener, MethodType, Sftp};
+struct SessionInner {
+ raw: *mut raw::LIBSSH2_SESSION,
+ tcp: RefCell<Option<TcpStream>>,
+}
+
+unsafe impl Send for SessionInner {}
+
/// An SSH session, typically representing one TCP connection.
///
/// All other structures are based on an SSH session and cannot outlive a
/// session. Sessions are created and then have the TCP socket handed to them
/// (via the `handshake` method).
pub struct Session {
- raw: *mut raw::LIBSSH2_SESSION,
+ inner: Rc<SessionInner>,
}
-unsafe impl Send for Session {}
-
/// Metadata returned about a remote file when received via `scp`.
pub struct ScpFileStat {
stat: libc::stat,
@@ -43,11 +50,21 @@ impl Session {
if ret.is_null() {
Err(Error::unknown())
} else {
- Ok(Binding::from_raw(ret))
+ Ok(Session {
+ inner: Rc::new(SessionInner {
+ raw: ret,
+ tcp: RefCell::new(None),
+ }),
+ })
}
}
}
+ #[doc(hidden)]
+ pub fn raw(&self) -> *mut raw::LIBSSH2_SESSION {
+ self.inner.raw
+ }
+
/// Set the SSH protocol banner for the local client
///
/// Set the banner that will be sent to the remote host when the SSH session
@@ -56,7 +73,12 @@ impl Session {
/// default.
pub fn set_banner(&self, banner: &str) -> Result<(), Error> {
let banner = try!(CString::new(banner));
- unsafe { self.rc(raw::libssh2_session_banner_set(self.raw, banner.as_ptr())) }
+ unsafe {
+ self.rc(raw::libssh2_session_banner_set(
+ self.inner.raw,
+ banner.as_ptr(),
+ ))
+ }
}
/// Flag indicating whether SIGPIPE signals will be allowed or blocked.
@@ -68,7 +90,7 @@ impl Session {
pub fn set_allow_sigpipe(&self, block: bool) {
let res = unsafe {
self.rc(raw::libssh2_session_flag(
- self.raw,
+ self.inner.raw,
raw::LIBSSH2_FLAG_SIGPIPE as c_int,
block as c_int,
))
@@ -85,7 +107,7 @@ impl Session {
pub fn set_compress(&self, compress: bool) {
let res = unsafe {
self.rc(raw::libssh2_session_flag(
- self.raw,
+ self.inner.raw,
raw::LIBSSH2_FLAG_COMPRESS as c_int,
compress as c_int,
))
@@ -103,12 +125,12 @@ impl Session {
/// a blocking session will wait for room. A non-blocking session will
/// return immediately without writing anything.
pub fn set_blocking(&self, blocking: bool) {
- unsafe { raw::libssh2_session_set_blocking(self.raw, blocking as c_int) }
+ unsafe { raw::libssh2_session_set_blocking(self.inner.raw, blocking as c_int) }
}
/// Returns whether the session was previously set to nonblocking.
pub fn is_blocking(&self) -> bool {
- unsafe { raw::libssh2_session_get_blocking(self.raw) != 0 }
+ unsafe { raw::libssh2_session_get_blocking(self.inner.raw) != 0 }
}
/// Set timeout for blocking functions.
@@ -121,7 +143,7 @@ impl Session {
/// for blocking functions.
pub fn set_timeout(&self, timeout_ms: u32) {
let timeout_ms = timeout_ms as c_long;
- unsafe { raw::libssh2_session_set_timeout(self.raw, timeout_ms) }
+ unsafe { raw::libssh2_session_set_timeout(self.inner.raw, timeout_ms) }
}
/// Returns the timeout, in milliseconds, for how long blocking calls may
@@ -129,23 +151,19 @@ impl Session {
///
/// A timeout of 0 signifies no timeout.
pub fn timeout(&self) -> u32 {
- unsafe { raw::libssh2_session_get_timeout(self.raw) as u32 }
+ unsafe { raw::libssh2_session_get_timeout(self.inner.raw) as u32 }
}
/// Begin transport layer protocol negotiation with the connected host.
///
- /// This session does *not* take ownership of the socket provided, it is
- /// recommended to ensure that the socket persists the lifetime of this
- /// session to ensure that communication is correctly performed.
+ /// This session takes ownership of the socket provided.
+ /// You may use the tcp_stream() method to obtain a reference
+ /// to it later.
///
/// It is also highly recommended that the stream provided is not used
/// concurrently elsewhere for the duration of this session as it may
/// interfere with the protocol.
- pub fn handshake(&mut self, stream: &TcpStream) -> Result<(), Error> {
- unsafe {
- return self.rc(handshake(self.raw, stream));
- }
-
+ pub fn handshake(&mut self, stream: TcpStream) -> Result<(), Error> {
#[cfg(windows)]
unsafe fn handshake(raw: *mut raw::LIBSSH2_SESSION, stream: &TcpStream) -> libc::c_int {
use std::os::windows::prelude::*;
@@ -157,6 +175,18 @@ impl Session {
use std::os::unix::prelude::*;
raw::libssh2_session_handshake(raw, stream.as_raw_fd())
}
+
+ unsafe {
+ self.rc(handshake(self.inner.raw, &stream))?;
+ *self.inner.tcp.borrow_mut() = Some(stream);
+ Ok(())
+ }
+ }
+
+ /// Returns a reference to the stream that was associated with the Session
+ /// by the Session::handshake method.
+ pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> {
+ self.inner.tcp.borrow()
}
/// Attempt basic password authentication.
@@ -168,7 +198,7 @@ impl Session {
pub fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> {
self.rc(unsafe {
raw::libssh2_userauth_password_ex(
- self.raw,
+ self.inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
password.as_ptr() as *const _,
@@ -220,7 +250,7 @@ impl Session {
};
self.rc(unsafe {
raw::libssh2_userauth_publickey_fromfile_ex(
- self.raw,
+ self.inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
pubkey.as_ref().map(|s| s.as_ptr()).unwrap_or(0 as *const _),
@@ -258,7 +288,7 @@ impl Session {
};
self.rc(unsafe {
raw::libssh2_userauth_publickey_frommemory(
- self.raw,
+ self.inner.raw,
username.as_ptr() as *const _,
username.len() as size_t,
pubkeydata
@@ -299,7 +329,7 @@ impl Session {
};
self.rc(unsafe {
raw::libssh2_userauth_hostbased_fromfile_ex(
- self.raw,
+ self.inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
publickey.as_ptr(),
@@ -319,7 +349,7 @@ impl Session {
/// Indicates whether or not the named session has been successfully
/// authenticated.
pub fn authenticated(&self) -> bool {
- unsafe { raw::libssh2_userauth_authenticated(self.raw) != 0 }
+ unsafe { raw::libssh2_userauth_authenticated(self.inner.raw) != 0 }
}
/// Send a SSH_USERAUTH_NONE request to the remote host.
@@ -336,7 +366,7 @@ impl Session {
let len = username.len();
let username = try!(CString::new(username));
unsafe {
- let ret = raw::libssh2_userauth_list(self.raw, username.as_ptr(), len as c_uint);
+ let ret = raw::libssh2_userauth_list(self.inner.raw, username.as_ptr(), len as c_uint);
if ret.is_null() {
Err(Error::last_error(self).unwrap())
} else {
@@ -356,7 +386,7 @@ impl Session {
let prefs = try!(CString::new(prefs));
unsafe {
self.rc(raw::libssh2_session_method_pref(
- self.raw,
+ self.inner.raw,
method_type as c_int,
prefs.as_ptr(),
))
@@ -369,7 +399,7 @@ impl Session {
/// parameter. May return `None` if the session has not yet been started.
pub fn methods(&self, method_type: MethodType) -> Option<&str> {
unsafe {
- let ptr = raw::libssh2_session_methods(self.raw, method_type as c_int);
+ let ptr = raw::libssh2_session_methods(self.inner.raw, method_type as c_int);
::opt_bytes(self, ptr).and_then(|s| str::from_utf8(s).ok())
}
}
@@ -381,7 +411,7 @@ impl Session {
let mut ret = Vec::new();
unsafe {
let mut ptr = 0 as *mut _;
- let rc = raw::libssh2_session_supported_algs(self.raw, method_type, &mut ptr);
+ let rc = raw::libssh2_session_supported_algs(self.inner.raw, method_type, &mut ptr);
if rc <= 0 {
try!(self.rc(rc))
}
@@ -390,7 +420,7 @@ impl Session {
let s = str::from_utf8(s).unwrap();
ret.push(s);
}
- raw::libssh2_free(self.raw, ptr as *mut c_void);
+ raw::libssh2_free(self.inner.raw, ptr as *mut c_void);
}
Ok(ret)
}
@@ -399,7 +429,7 @@ impl Session {
///
/// The returned agent will still need to be connected manually before use.
pub fn agent(&self) -> Result<Agent, Error> {
- unsafe { SessionBinding::from_raw_opt(self, raw::libssh2_agent_init(self.raw)) }
+ unsafe { SessionBinding::from_raw_opt(self, raw::libssh2_agent_init(self.inner.raw)) }
}
/// Init a collection of known hosts for this session.
@@ -408,7 +438,7 @@ impl Session {
/// collection.
pub fn known_hosts(&self) -> Result<KnownHosts, Error> {
unsafe {
- let ptr = raw::libssh2_knownhost_init(self.raw);
+ let ptr = raw::libssh2_knownhost_init(self.inner.raw);
SessionBinding::from_raw_opt(self, ptr)
}
}
@@ -449,7 +479,7 @@ impl Session {
let shost = try!(CString::new(shost));
unsafe {
let ret = raw::libssh2_channel_direct_tcpip_ex(
- self.raw,
+ self.inner.raw,
host.as_ptr(),
port as c_int,
shost.as_ptr(),
@@ -473,7 +503,7 @@ impl Session {
let mut bound_port = 0;
unsafe {
let ret = raw::libssh2_channel_forward_listen_ex(
- self.raw,
+ self.inner.raw,
host.map(|s| s.as_ptr()).unwrap_or(0 as *const _) as *mut _,
remote_port as c_int,
&mut bound_port,
@@ -492,7 +522,7 @@ impl Session {
let path = try!(CString::new(try!(util::path2bytes(path))));
unsafe {
let mut sb: libc::stat = mem::zeroed();
- let ret = raw::libssh2_scp_recv(self.raw, path.as_ptr(), &mut sb);
+ let ret = raw::libssh2_scp_recv(self.inner.raw, path.as_ptr(), &mut sb);
let mut c: Channel = try!(SessionBinding::from_raw_opt(self, ret));
// Hm, apparently when we scp_recv() a file the actual channel
@@ -524,7 +554,7 @@ impl Session {
let (mtime, atime) = times.unwrap_or((0, 0));
unsafe {
let ret = raw::libssh2_scp_send64(
- self.raw,
+ self.inner.raw,
path.as_ptr(),
mode as c_int,
size as i64,
@@ -543,7 +573,7 @@ impl Session {
/// methods on `Sftp`.
pub fn sftp(&self) -> Result<Sftp, Error> {
unsafe {
- let ret = raw::libssh2_sftp_init(self.raw);
+ let ret = raw::libssh2_sftp_init(self.inner.raw);
SessionBinding::from_raw_opt(self, ret)
}
}
@@ -562,7 +592,7 @@ impl Session {
let message_len = message.map(|s| s.len()).unwrap_or(0);
unsafe {
let ret = raw::libssh2_channel_open_ex(
- self.raw,
+ self.inner.raw,
channel_type.as_ptr() as *const _,
channel_type.len() as c_uint,
window_size as c_uint,
@@ -592,7 +622,7 @@ impl Session {
///
/// Will only return `None` if an error has ocurred.
pub fn banner_bytes(&self) -> Option<&[u8]> {
- unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.raw)) }
+ unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.inner.raw)) }
}
/// Get the remote key.
@@ -602,7 +632,7 @@ impl Session {
let mut len = 0;
let mut kind = 0;
unsafe {
- let ret = raw::libssh2_session_hostkey(self.raw, &mut len, &mut kind);
+ let ret = raw::libssh2_session_hostkey(self.inner.raw, &mut len, &mut kind);
if ret.is_null() {
return None;
}
@@ -632,7 +662,7 @@ impl Session {
HashType::Sha256 => 32,
};
unsafe {
- let ret = raw::libssh2_hostkey_hash(self.raw, hash as c_int);
+ let ret = raw::libssh2_hostkey_hash(self.inner.raw, hash as c_int);
if ret.is_null() {
None
} else {
@@ -651,7 +681,9 @@ impl Session {
/// I/O, use 0 (the default) to disable keepalives. To avoid some busy-loop
/// corner-cases, if you specify an interval of 1 it will be treated as 2.
pub fn set_keepalive(&self, want_reply: bool, interval: u32) {
- unsafe { raw::libssh2_keepalive_config(self.raw, want_reply as c_int, interval as c_uint) }
+ unsafe {
+ raw::libssh2_keepalive_config(self.inner.raw, want_reply as c_int, interval as c_uint)
+ }
}
/// Send a keepalive message if needed.
@@ -660,7 +692,7 @@ impl Session {
/// to call it again.
pub fn keepalive_send(&self) -> Result<u32, Error> {
let mut ret = 0;
- let rc = unsafe { raw::libssh2_keepalive_send(self.raw, &mut ret) };
+ let rc = unsafe { raw::libssh2_keepalive_send(self.inner.raw, &mut ret) };
try!(self.rc(rc));
Ok(ret as u32)
}
@@ -682,7 +714,7 @@ impl Session {
let lang = try!(CString::new(lang.unwrap_or("")));
unsafe {
self.rc(raw::libssh2_session_disconnect_ex(
- self.raw,
+ self.inner.raw,
reason,
description.as_ptr(),
lang.as_ptr(),
@@ -703,18 +735,7 @@ impl Session {
}
}
-impl Binding for Session {
- type Raw = *mut raw::LIBSSH2_SESSION;
-
- unsafe fn from_raw(raw: *mut raw::LIBSSH2_SESSION) -> Session {
- Session { raw: raw }
- }
- fn raw(&self) -> *mut raw::LIBSSH2_SESSION {
- self.raw
- }
-}
-
-impl Drop for Session {
+impl Drop for SessionInner {
fn drop(&mut self) {
unsafe {
let _rc = raw::libssh2_session_free(self.raw);
diff --git a/tests/all/channel.rs b/tests/all/channel.rs
index 294eecf..2a9dfdf 100644
--- a/tests/all/channel.rs
+++ b/tests/all/channel.rs
@@ -4,7 +4,7 @@ use std::thread;
#[test]
fn smoke() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.flush().unwrap();
channel.exec("true").unwrap();
@@ -18,7 +18,7 @@ fn smoke() {
#[test]
fn bad_smoke() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.flush().unwrap();
channel.exec("false").unwrap();
@@ -32,7 +32,7 @@ fn bad_smoke() {
#[test]
fn reading_data() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.exec("echo foo").unwrap();
let mut output = String::new();
@@ -42,7 +42,7 @@ fn reading_data() {
#[test]
fn writing_data() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.exec("read foo && echo $foo").unwrap();
channel.write_all(b"foo\n").unwrap();
@@ -53,7 +53,7 @@ fn writing_data() {
#[test]
fn eof() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.adjust_receive_window(10, false).unwrap();
channel.exec("read foo").unwrap();
@@ -65,7 +65,7 @@ fn eof() {
#[test]
fn shell() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.request_pty("xterm", None, None).unwrap();
channel.shell().unwrap();
@@ -74,7 +74,7 @@ fn shell() {
#[test]
fn setenv() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
let _ = channel.setenv("FOO", "BAR");
channel.close().unwrap();
@@ -91,7 +91,7 @@ fn direct() {
assert_eq!(b, [1, 2, 3]);
s.write_all(&[4, 5, 6]).unwrap();
});
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess
.channel_direct_tcpip("127.0.0.1", addr.port(), None)
.unwrap();
@@ -104,7 +104,7 @@ fn direct() {
#[test]
fn forward() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let (mut listen, port) = sess.channel_forward_listen(39249, None, None).unwrap();
let t = thread::spawn(move || {
let mut s = TcpStream::connect(&("127.0.0.1", port)).unwrap();
@@ -127,7 +127,7 @@ fn drop_nonblocking() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
sess.set_blocking(false);
thread::spawn(move || {
@@ -140,7 +140,7 @@ fn drop_nonblocking() {
#[test]
fn nonblocking_before_exit_code() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.send_eof().unwrap();
let mut output = String::new();
@@ -165,7 +165,7 @@ fn nonblocking_before_exit_code() {
#[test]
fn exit_code_ignores_other_errors() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut channel = sess.channel_session().unwrap();
channel.exec("true").unwrap();
channel.wait_eof().unwrap();
diff --git a/tests/all/main.rs b/tests/all/main.rs
index 7f150bc..6f4335d 100644
--- a/tests/all/main.rs
+++ b/tests/all/main.rs
@@ -16,11 +16,11 @@ pub fn socket() -> TcpStream {
TcpStream::connect("127.0.0.1:22").unwrap()
}
-pub fn authed_session() -> (TcpStream, ssh2::Session) {
+pub fn authed_session() -> ssh2::Session {
let user = env::var("USER").unwrap();
let socket = socket();
let mut sess = ssh2::Session::new().unwrap();
- sess.handshake(&socket).unwrap();
+ sess.handshake(socket).unwrap();
assert!(!sess.authenticated());
{
@@ -31,5 +31,5 @@ pub fn authed_session() -> (TcpStream, ssh2::Session) {
agent.userauth(&user, &identity).unwrap();
}
assert!(sess.authenticated());
- (socket, sess)
+ sess
}
diff --git a/tests/all/session.rs b/tests/all/session.rs
index 42f1dd5..77fb824 100644
--- a/tests/all/session.rs
+++ b/tests/all/session.rs
@@ -30,7 +30,7 @@ fn smoke_handshake() {
let user = env::var("USER").unwrap();
let socket = ::socket();
let mut sess = Session::new().unwrap();
- sess.handshake(&socket).unwrap();
+ sess.handshake(socket).unwrap();
sess.host_key().unwrap();
let methods = sess.auth_methods(&user).unwrap();
assert!(methods.contains("publickey"), "{}", methods);
@@ -49,14 +49,14 @@ fn smoke_handshake() {
#[test]
fn keepalive() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
sess.set_keepalive(false, 10);
sess.keepalive_send().unwrap();
}
#[test]
fn scp_recv() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let (mut ch, _) = sess.scp_recv(Path::new(".ssh/authorized_keys")).unwrap();
let mut data = String::new();
ch.read_to_string(&mut data).unwrap();
@@ -72,7 +72,7 @@ fn scp_recv() {
#[test]
fn scp_send() {
let td = TempDir::new("test").unwrap();
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let mut ch = sess
.scp_send(&td.path().join("foo"), 0o644, 6, None)
.unwrap();
diff --git a/tests/all/sftp.rs b/tests/all/sftp.rs
index 415bcae..ad9eed5 100644
--- a/tests/all/sftp.rs
+++ b/tests/all/sftp.rs
@@ -4,7 +4,7 @@ use tempdir::TempDir;
#[test]
fn smoke() {
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
sess.sftp().unwrap();
}
@@ -14,7 +14,7 @@ fn ops() {
File::create(&td.path().join("foo")).unwrap();
fs::create_dir(&td.path().join("bar")).unwrap();
- let (_tcp, sess) = ::authed_session();
+ let sess = ::authed_session();
let sftp = sess.sftp().unwrap();
sftp.opendir(&td.path().join("bar")).unwrap();
let mut foo = sftp.open(&td.path().join("foo")).unwrap();