summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml3
-rw-r--r--src/agent.rs132
-rw-r--r--src/channel.rs183
-rw-r--r--src/error.rs12
-rw-r--r--src/knownhosts.rs152
-rw-r--r--src/lib.rs8
-rw-r--r--src/listener.rs12
-rw-r--r--src/session.rs219
-rw-r--r--src/sftp.rs261
-rw-r--r--tests/all/agent.rs5
-rw-r--r--tests/all/knownhosts.rs22
-rw-r--r--tests/all/main.rs3
-rw-r--r--tests/all/session.rs2
13 files changed, 600 insertions, 414 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 0d32a82..23959f5 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "ssh2"
-version = "0.6.0"
+version = "0.7.0"
authors = ["Alex Crichton <alex@alexcrichton.com>", "Wez Furlong <wez@wezfurlong.org>"]
license = "MIT/Apache-2.0"
keywords = ["ssh"]
@@ -20,6 +20,7 @@ vendored-openssl = ["libssh2-sys/vendored-openssl"]
bitflags = "1.2"
libc = "0.2"
libssh2-sys = { path = "libssh2-sys", version = "0.2.13" }
+parking_lot = "0.10"
[dev-dependencies]
tempdir = "0.3"
diff --git a/src/agent.rs b/src/agent.rs
index 5eb5bc4..dbafb1d 100644
--- a/src/agent.rs
+++ b/src/agent.rs
@@ -1,10 +1,9 @@
-use std::ffi::CString;
-use std::marker;
+use parking_lot::{Mutex, MutexGuard};
+use std::ffi::{CStr, CString};
use std::slice;
use std::str;
use std::sync::Arc;
-use util::Binding;
use {raw, Error, SessionInner};
/// A structure representing a connection to an SSH agent.
@@ -12,28 +11,24 @@ use {raw, Error, SessionInner};
/// Agents can be used to authenticate a session.
pub struct Agent {
raw: *mut raw::LIBSSH2_AGENT,
- sess: Arc<SessionInner>,
-}
-
-/// An iterator over the identities found in an SSH agent.
-pub struct Identities<'agent> {
- prev: *mut raw::libssh2_agent_publickey,
- agent: &'agent Agent,
+ sess: Arc<Mutex<SessionInner>>,
}
/// A public key which is extracted from an SSH agent.
-pub struct PublicKey<'agent> {
- raw: *mut raw::libssh2_agent_publickey,
- _marker: marker::PhantomData<&'agent [u8]>,
+#[derive(Debug, PartialEq, Eq)]
+pub struct PublicKey {
+ blob: Vec<u8>,
+ comment: String,
}
impl Agent {
pub(crate) fn from_raw_opt(
raw: *mut raw::LIBSSH2_AGENT,
- sess: &Arc<SessionInner>,
+ err: Option<Error>,
+ sess: &Arc<Mutex<SessionInner>>,
) -> Result<Self, Error> {
if raw.is_null() {
- Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown))
+ Err(err.unwrap_or_else(Error::unknown))
} else {
Ok(Self {
raw,
@@ -44,12 +39,14 @@ impl Agent {
/// Connect to an ssh-agent running on the system.
pub fn connect(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_agent_connect(self.raw)) }
+ let sess = self.sess.lock();
+ unsafe { sess.rc(raw::libssh2_agent_connect(self.raw)) }
}
/// Close a connection to an ssh-agent.
pub fn disconnect(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_agent_disconnect(self.raw)) }
+ let sess = self.sess.lock();
+ unsafe { sess.rc(raw::libssh2_agent_disconnect(self.raw)) }
}
/// Request an ssh-agent to list of public keys, and stores them in the
@@ -57,25 +54,64 @@ impl Agent {
///
/// Call `identities` to get the public keys.
pub fn list_identities(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_agent_list_identities(self.raw)) }
+ let sess = self.sess.lock();
+ unsafe { sess.rc(raw::libssh2_agent_list_identities(self.raw)) }
+ }
+
+ /// Get list of the identities of this agent.
+ pub fn identities(&self) -> Result<Vec<PublicKey>, Error> {
+ let sess = self.sess.lock();
+ let mut res = vec![];
+ let mut prev = 0 as *mut _;
+ let mut next = 0 as *mut _;
+ loop {
+ match unsafe { raw::libssh2_agent_get_identity(self.raw, &mut next, prev) } {
+ 0 => {
+ prev = next;
+ res.push(unsafe { PublicKey::from_raw(next) });
+ }
+ 1 => break,
+ rc => return Err(Error::from_session_error_raw(sess.raw, rc)),
+ }
+ }
+ Ok(res)
}
- /// Get an iterator over the identities of this agent.
- pub fn identities(&self) -> Identities {
- Identities {
- prev: 0 as *mut _,
- agent: self,
+ fn resolve_raw_identity(
+ &self,
+ sess: &MutexGuard<SessionInner>,
+ identity: &PublicKey,
+ ) -> Result<Option<*mut raw::libssh2_agent_publickey>, Error> {
+ let mut prev = 0 as *mut _;
+ let mut next = 0 as *mut _;
+ loop {
+ match unsafe { raw::libssh2_agent_get_identity(self.raw, &mut next, prev) } {
+ 0 => {
+ prev = next;
+ let this_ident = unsafe { PublicKey::from_raw(next) };
+ if this_ident == *identity {
+ return Ok(Some(next));
+ }
+ }
+ 1 => break,
+ rc => return Err(Error::from_session_error_raw(sess.raw, rc)),
+ }
}
+ Ok(None)
}
/// Attempt public key authentication with the help of ssh-agent.
pub fn userauth(&self, username: &str, identity: &PublicKey) -> Result<(), Error> {
let username = CString::new(username)?;
+ let sess = self.sess.lock();
+ let raw_ident = self
+ .resolve_raw_identity(&sess, identity)?
+ .ok_or_else(|| Error::new(raw::LIBSSH2_ERROR_BAD_USE, "Identity not found in agent"))?;
unsafe {
- self.sess.rc(raw::libssh2_agent_userauth(
+ sess.rc(raw::libssh2_agent_userauth(
self.raw,
username.as_ptr(),
- identity.raw,
+ raw_ident,
))
}
}
@@ -87,46 +123,28 @@ impl Drop for Agent {
}
}
-impl<'agent> Iterator for Identities<'agent> {
- type Item = Result<PublicKey<'agent>, Error>;
- fn next(&mut self) -> Option<Result<PublicKey<'agent>, Error>> {
- unsafe {
- let mut next = 0 as *mut _;
- match raw::libssh2_agent_get_identity(self.agent.raw, &mut next, self.prev) {
- 0 => {
- self.prev = next;
- Some(Ok(Binding::from_raw(next)))
- }
- 1 => None,
- rc => Some(Err(self.agent.sess.rc(rc).err().unwrap())),
- }
+impl PublicKey {
+ unsafe fn from_raw(raw: *mut raw::libssh2_agent_publickey) -> Self {
+ let blob = slice::from_raw_parts_mut((*raw).blob, (*raw).blob_len as usize);
+ let comment = (*raw).comment;
+ let comment = if comment.is_null() {
+ String::new()
+ } else {
+ CStr::from_ptr(comment).to_string_lossy().into_owned()
+ };
+ Self {
+ blob: blob.to_vec(),
+ comment,
}
}
-}
-impl<'agent> PublicKey<'agent> {
/// Return the data of this public key.
pub fn blob(&self) -> &[u8] {
- unsafe { slice::from_raw_parts_mut((*self.raw).blob, (*self.raw).blob_len as usize) }
+ &self.blob
}
/// Returns the comment in a printable format
pub fn comment(&self) -> &str {
- unsafe { str::from_utf8(::opt_bytes(self, (*self.raw).comment).unwrap()).unwrap() }
- }
-}
-
-impl<'agent> Binding for PublicKey<'agent> {
- type Raw = *mut raw::libssh2_agent_publickey;
-
- unsafe fn from_raw(raw: *mut raw::libssh2_agent_publickey) -> PublicKey<'agent> {
- PublicKey {
- raw: raw,
- _marker: marker::PhantomData,
- }
- }
-
- fn raw(&self) -> *mut raw::libssh2_agent_publickey {
- self.raw
+ &self.comment
}
}
diff --git a/src/channel.rs b/src/channel.rs
index 5adc43b..1169a65 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -1,4 +1,5 @@
use libc::{c_char, c_int, c_uchar, c_uint, c_ulong, c_void, size_t};
+use parking_lot::{Mutex, MutexGuard};
use std::cmp;
use std::io;
use std::io::prelude::*;
@@ -7,6 +8,17 @@ use std::sync::Arc;
use {raw, Error, ExtendedData, PtyModes, SessionInner};
+struct ChannelInner {
+ unsafe_raw: *mut raw::LIBSSH2_CHANNEL,
+ sess: Arc<Mutex<SessionInner>>,
+ read_limit: Mutex<Option<u64>>,
+}
+
+struct LockedChannel<'a> {
+ raw: *mut raw::LIBSSH2_CHANNEL,
+ sess: MutexGuard<'a, SessionInner>,
+}
+
/// A channel represents a portion of an SSH connection on which data can be
/// read and written.
///
@@ -16,35 +28,57 @@ use {raw, Error, ExtendedData, PtyModes, SessionInner};
/// Whether or not I/O operations are blocking is mandated by the `blocking`
/// flag on a channel's corresponding `Session`.
pub struct Channel {
- raw: *mut raw::LIBSSH2_CHANNEL,
- sess: Arc<SessionInner>,
- read_limit: Option<u64>,
+ channel_inner: Arc<ChannelInner>,
}
impl Channel {
pub(crate) fn from_raw_opt(
raw: *mut raw::LIBSSH2_CHANNEL,
- sess: &Arc<SessionInner>,
+ err: Option<Error>,
+ sess: &Arc<Mutex<SessionInner>>,
) -> Result<Self, Error> {
if raw.is_null() {
- Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown))
+ Err(err.unwrap_or_else(Error::unknown))
} else {
Ok(Self {
- raw,
- sess: Arc::clone(sess),
- read_limit: None,
+ channel_inner: Arc::new(ChannelInner {
+ unsafe_raw: raw,
+ sess: Arc::clone(sess),
+ read_limit: Mutex::new(None),
+ }),
})
}
}
+
+ fn lock(&self) -> LockedChannel {
+ let sess = self.channel_inner.sess.lock();
+ LockedChannel {
+ sess,
+ raw: self.channel_inner.unsafe_raw,
+ }
+ }
}
/// A channel can have a number of streams, each identified by an id, each of
/// which implements the `Read` and `Write` traits.
-pub struct Stream<'channel> {
- channel: &'channel mut Channel,
+pub struct Stream {
+ channel_inner: Arc<ChannelInner>,
id: i32,
}
+struct LockedStream<'a> {
+ raw: *mut raw::LIBSSH2_CHANNEL,
+ sess: MutexGuard<'a, SessionInner>,
+ id: i32,
+ read_limit: MutexGuard<'a, Option<u64>>,
+}
+
+impl<'a> LockedStream<'a> {
+ pub fn eof(&self) -> bool {
+ *self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 }
+ }
+}
+
/// Data received from when a program exits with a signal.
pub struct ExitSignal {
/// The exit signal received, if the program did not exit cleanly. Does not
@@ -84,9 +118,10 @@ impl Channel {
/// Note that this does not make sense for all channel types and may be
/// ignored by the server despite returning success.
pub fn setenv(&mut self, var: &str, val: &str) -> Result<(), Error> {
+ let locked = self.lock();
unsafe {
- self.sess.rc(raw::libssh2_channel_setenv_ex(
- self.raw,
+ locked.sess.rc(raw::libssh2_channel_setenv_ex(
+ locked.raw,
var.as_ptr() as *const _,
var.len() as c_uint,
val.as_ptr() as *const _,
@@ -120,12 +155,13 @@ impl Channel {
mode: Option<PtyModes>,
dim: Option<(u32, u32, u32, u32)>,
) -> Result<(), Error> {
+ let locked = self.lock();
let mode = mode.map(PtyModes::finish);
let mode = mode.as_ref().map(Vec::as_slice).unwrap_or(&[]);
- self.sess.rc(unsafe {
+ locked.sess.rc(unsafe {
let (width, height, width_px, height_px) = dim.unwrap_or((80, 24, 0, 0));
raw::libssh2_channel_request_pty_ex(
- self.raw,
+ locked.raw,
term.as_ptr() as *const _,
term.len() as c_uint,
mode.as_ptr() as *const _,
@@ -148,11 +184,12 @@ impl Channel {
width_px: Option<u32>,
height_px: Option<u32>,
) -> Result<(), Error> {
+ let locked = self.lock();
let width_px = width_px.unwrap_or(0);
let height_px = height_px.unwrap_or(0);
- self.sess.rc(unsafe {
+ locked.sess.rc(unsafe {
raw::libssh2_channel_request_pty_size_ex(
- self.raw,
+ locked.raw,
width as c_int,
height as c_int,
width_px as c_int,
@@ -205,22 +242,23 @@ impl Channel {
pub fn process_startup(&mut self, request: &str, message: Option<&str>) -> Result<(), Error> {
let message_len = message.map(|s| s.len()).unwrap_or(0);
let message = message.map(|s| s.as_ptr()).unwrap_or(0 as *const _);
+ let locked = self.lock();
unsafe {
let rc = raw::libssh2_channel_process_startup(
- self.raw,
+ locked.raw,
request.as_ptr() as *const _,
request.len() as c_uint,
message as *const _,
message_len as c_uint,
);
- self.sess.rc(rc)
+ locked.sess.rc(rc)
}
}
/// Get a handle to the stderr stream of this channel.
///
/// The returned handle implements the `Read` and `Write` traits.
- pub fn stderr<'a>(&'a mut self) -> Stream<'a> {
+ pub fn stderr(&self) -> Stream {
self.stream(::EXTENDED_DATA_STDERR)
}
@@ -233,18 +271,19 @@ impl Channel {
///
/// * FLUSH_EXTENDED_DATA - Flush all extended data substreams
/// * FLUSH_ALL - Flush all substreams
- pub fn stream<'a>(&'a mut self, stream_id: i32) -> Stream<'a> {
+ pub fn stream(&self, stream_id: i32) -> Stream {
Stream {
- channel: self,
+ channel_inner: Arc::clone(&self.channel_inner),
id: stream_id,
}
}
/// Change how extended data (such as stderr) is handled
pub fn handle_extended_data(&mut self, mode: ExtendedData) -> Result<(), Error> {
+ let locked = self.lock();
unsafe {
- let rc = raw::libssh2_channel_handle_extended_data2(self.raw, mode as c_int);
- self.sess.rc(rc)
+ let rc = raw::libssh2_channel_handle_extended_data2(locked.raw, mode as c_int);
+ locked.sess.rc(rc)
}
}
@@ -254,15 +293,17 @@ impl Channel {
/// Note that the exit status may not be available if the remote end has not
/// yet set its status to closed.
pub fn exit_status(&self) -> Result<i32, Error> {
+ let locked = self.lock();
// Should really store existing error, call function, check for error
// after and restore previous error if no new one...but the only error
// condition right now is a NULL pointer check on self.raw, so let's
// assume that's not the case.
- Ok(unsafe { raw::libssh2_channel_get_exit_status(self.raw) })
+ Ok(unsafe { raw::libssh2_channel_get_exit_status(locked.raw) })
}
/// Get the remote exit signal.
pub fn exit_signal(&self) -> Result<ExitSignal, Error> {
+ let locked = self.lock();
unsafe {
let mut sig = 0 as *mut _;
let mut siglen = 0;
@@ -271,7 +312,7 @@ impl Channel {
let mut lang = 0 as *mut _;
let mut langlen = 0;
let rc = raw::libssh2_channel_get_exit_signal(
- self.raw,
+ locked.raw,
&mut sig,
&mut siglen,
&mut msg,
@@ -279,31 +320,32 @@ impl Channel {
&mut lang,
&mut langlen,
);
- self.sess.rc(rc)?;
+ locked.sess.rc(rc)?;
return Ok(ExitSignal {
- exit_signal: convert(self, sig, siglen),
- error_message: convert(self, msg, msglen),
- lang_tag: convert(self, lang, langlen),
+ exit_signal: convert(&locked, sig, siglen),
+ error_message: convert(&locked, msg, msglen),
+ lang_tag: convert(&locked, lang, langlen),
});
}
- unsafe fn convert(chan: &Channel, ptr: *mut c_char, len: size_t) -> Option<String> {
+ unsafe fn convert(locked: &LockedChannel, ptr: *mut c_char, len: size_t) -> Option<String> {
if ptr.is_null() {
return None;
}
let slice = slice::from_raw_parts(ptr as *const u8, len as usize);
let ret = slice.to_vec();
- raw::libssh2_free(chan.sess.raw, ptr as *mut c_void);
+ raw::libssh2_free(locked.sess.raw, ptr as *mut c_void);
String::from_utf8(ret).ok()
}
}
/// Check the status of the read window.
pub fn read_window(&self) -> ReadWindow {
+ let locked = self.lock();
unsafe {
let mut avail = 0;
let mut init = 0;
- let remaining = raw::libssh2_channel_window_read_ex(self.raw, &mut avail, &mut init);
+ let remaining = raw::libssh2_channel_window_read_ex(locked.raw, &mut avail, &mut init);
ReadWindow {
remaining: remaining as u32,
available: avail as u32,
@@ -314,9 +356,10 @@ impl Channel {
/// Check the status of the write window.
pub fn write_window(&self) -> WriteWindow {
+ let locked = self.lock();
unsafe {
let mut init = 0;
- let remaining = raw::libssh2_channel_window_write_ex(self.raw, &mut init);
+ let remaining = raw::libssh2_channel_window_write_ex(locked.raw, &mut init);
WriteWindow {
remaining: remaining as u32,
window_size_initial: init as u32,
@@ -332,24 +375,25 @@ impl Channel {
/// This function returns the new size of the receive window (as understood
/// by remote end) on success.
pub fn adjust_receive_window(&mut self, adjust: u64, force: bool) -> Result<u64, Error> {
+ let locked = self.lock();
let mut ret = 0;
let rc = unsafe {
raw::libssh2_channel_receive_window_adjust2(
- self.raw,
+ locked.raw,
adjust as c_ulong,
force as c_uchar,
&mut ret,
)
};
- self.sess.rc(rc)?;
+ locked.sess.rc(rc)?;
Ok(ret as u64)
}
/// Artificially limit the number of bytes that will be read from this
/// channel. Hack intended for use by scp_recv only.
#[doc(hidden)]
- pub fn limit_read(&mut self, limit: u64) {
- self.read_limit = Some(limit);
+ pub(crate) fn limit_read(&mut self, limit: u64) {
+ *self.channel_inner.read_limit.lock() = Some(limit);
}
/// Check if the remote host has sent an EOF status for the channel.
@@ -357,7 +401,9 @@ impl Channel {
/// because the reading from the channel reads only the stdout stream.
/// unread, buffered, stderr data will cause eof() to return false.
pub fn eof(&self) -> bool {
- self.read_limit == Some(0) || unsafe { raw::libssh2_channel_eof(self.raw) != 0 }
+ let locked = self.lock();
+ *self.channel_inner.read_limit.lock() == Some(0)
+ || unsafe { raw::libssh2_channel_eof(locked.raw) != 0 }
}
/// Tell the remote host that no further data will be sent on the specified
@@ -365,7 +411,8 @@ impl Channel {
///
/// Processes typically interpret this as a closed stdin descriptor.
pub fn send_eof(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_channel_send_eof(self.raw)) }
+ let locked = self.lock();
+ unsafe { locked.sess.rc(raw::libssh2_channel_send_eof(locked.raw)) }
}
/// Wait for the remote end to send EOF.
@@ -374,7 +421,8 @@ impl Channel {
/// You should call the eof() function after calling this to check the
/// status of the channel.
pub fn wait_eof(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_channel_wait_eof(self.raw)) }
+ let locked = self.lock();
+ unsafe { locked.sess.rc(raw::libssh2_channel_wait_eof(locked.raw)) }
}
/// Close an active data channel.
@@ -387,7 +435,8 @@ impl Channel {
/// To wait for the remote end to close its connection as well, follow this
/// command with `wait_closed`
pub fn close(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_channel_close(self.raw)) }
+ let locked = self.lock();
+ unsafe { locked.sess.rc(raw::libssh2_channel_close(locked.raw)) }
}
/// Enter a temporary blocking state until the remote host closes the named
@@ -395,7 +444,8 @@ impl Channel {
///
/// Typically sent after `close` in order to examine the exit status.
pub fn wait_close(&mut self) -> Result<(), Error> {
- unsafe { self.sess.rc(raw::libssh2_channel_wait_closed(self.raw)) }
+ let locked = self.lock();
+ unsafe { locked.sess.rc(raw::libssh2_channel_wait_closed(locked.raw)) }
}
}
@@ -415,40 +465,53 @@ impl Read for Channel {
}
}
-impl Drop for Channel {
+impl Drop for ChannelInner {
fn drop(&mut self) {
unsafe {
- let _ = raw::libssh2_channel_free(self.raw);
+ let _ = raw::libssh2_channel_free(self.unsafe_raw);
+ }
+ }
+}
+
+impl Stream {
+ fn lock(&self) -> LockedStream {
+ let sess = self.channel_inner.sess.lock();
+ LockedStream {
+ sess,
+ raw: self.channel_inner.unsafe_raw,
+ id: self.id,
+ read_limit: self.channel_inner.read_limit.lock(),
}
}
}
-impl<'channel> Read for Stream<'channel> {
+impl Read for Stream {
fn read(&mut self, data: &mut [u8]) -> io::Result<usize> {
- if self.channel.eof() {
+ let mut locked = self.lock();
+ if locked.eof() {
return Ok(0);
}
- let data = match self.channel.read_limit {
+ let data = match locked.read_limit.as_mut() {
Some(amt) => {
let len = data.len();
- &mut data[..cmp::min(amt as usize, len)]
+ &mut data[..cmp::min(*amt as usize, len)]
}
None => data,
};
let ret = unsafe {
let rc = raw::libssh2_channel_read_ex(
- self.channel.raw,
- self.id as c_int,
+ locked.raw,
+ locked.id as c_int,
data.as_mut_ptr() as *mut _,
data.len() as size_t,
);
- self.channel.sess.rc(rc as c_int).map(|()| rc as usize)
+ locked.sess.rc(rc as c_int).map(|()| rc as usize)
};
match ret {
Ok(n) => {
- if let Some(ref mut amt) = self.channel.read_limit {
- *amt -= n as u64;
+ if let Some(ref mut amt) = locked.read_limit.as_mut() {
+ **amt -= n as u64;
}
Ok(n)
}
@@ -457,24 +520,26 @@ impl<'channel> Read for Stream<'channel> {
}
}
-impl<'channel> Write for Stream<'channel> {
+impl Write for Stream {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
+ let locked = self.lock();
unsafe {
let rc = raw::libssh2_channel_write_ex(
- self.channel.raw,
- self.id as c_int,
+ locked.raw,
+ locked.id as c_int,
data.as_ptr() as *mut _,
data.len() as size_t,
);
- self.channel.sess.rc(rc as c_int).map(|()| rc as usize)
+ locked.sess.rc(rc as c_int).map(|()| rc as usize)
}
.map_err(Into::into)
}
fn flush(&mut self) -> io::Result<()> {
+ let locked = self.lock();
unsafe {
- let rc = raw::libssh2_channel_flush_ex(self.channel.raw, self.id as c_int);
- self.channel.sess.rc(rc)
+ let rc = raw::libssh2_channel_flush_ex(locked.raw, locked.id as c_int);
+ locked.sess.rc(rc)
}
.map_err(Into::into)
}
diff --git a/src/error.rs b/src/error.rs
index 2499150..7ae0431 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -38,7 +38,7 @@ impl Error {
/// If the error code doesn't match then an approximation of the error
/// reason is used instead of the error message stored in the Session.
pub fn from_session_error(sess: &Session, rc: libc::c_int) -> Error {
- Self::from_session_error_raw(sess.raw(), rc)
+ Self::from_session_error_raw(&mut *sess.raw(), rc)
}
#[doc(hidden)]
@@ -59,7 +59,7 @@ impl Error {
///
/// Returns `None` if there was no last error.
pub fn last_error(sess: &Session) -> Option<Error> {
- Self::last_error_raw(sess.raw())
+ Self::last_error_raw(&mut *sess.raw())
}
/// Create a new error for the given code and message
@@ -80,6 +80,14 @@ impl Error {
Error::new(libc::c_int::min_value(), "no other error listed")
}
+ pub(crate) fn rc(rc: libc::c_int) -> Result<(), Error> {
+ if rc == 0 {
+ Ok(())
+ } else {
+ Err(Self::from_errno(rc))
+ }
+ }
+
/// Construct an error from an error code from libssh2
pub fn from_errno(code: libc::c_int) -> Error {
let msg = match code {
diff --git a/src/knownhosts.rs b/src/knownhosts.rs
index 5957369..2750fdd 100644
--- a/src/knownhosts.rs
+++ b/src/knownhosts.rs
@@ -1,11 +1,11 @@
use libc::{c_int, size_t};
+use parking_lot::{Mutex, MutexGuard};
use std::ffi::CString;
-use std::marker;
use std::path::Path;
use std::str;
use std::sync::Arc;
-use util::{self, Binding};
+use util;
use {raw, CheckResult, Error, KnownHostFileKind, SessionInner};
/// A set of known hosts which can be used to verify the identity of a remote
@@ -46,28 +46,24 @@ use {raw, CheckResult, Error, KnownHostFileKind, SessionInner};
/// ```
pub struct KnownHosts {
raw: *mut raw::LIBSSH2_KNOWNHOSTS,
- sess: Arc<SessionInner>,
-}
-
-/// Iterator over the hosts in a `KnownHosts` structure.
-pub struct Hosts<'kh> {
- prev: *mut raw::libssh2_knownhost,
- hosts: &'kh KnownHosts,
+ sess: Arc<Mutex<SessionInner>>,
}
/// Structure representing a known host as part of a `KnownHosts` structure.
-pub struct Host<'kh> {
- raw: *mut raw::libssh2_knownhost,
- _marker: marker::PhantomData<&'kh str>,
+#[derive(Debug, PartialEq, Eq)]
+pub struct Host {
+ name: Option<String>,
+ key: String,
}
impl KnownHosts {
pub(crate) fn from_raw_opt(
raw: *mut raw::LIBSSH2_KNOWNHOSTS,
- sess: &Arc<SessionInner>,
+ err: Option<Error>,
+ sess: &Arc<Mutex<SessionInner>>,
) -> Result<Self, Error> {
if raw.is_null() {
- Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown))
+ Err(err.unwrap_or_else(Error::unknown))
} else {
Ok(Self {
raw,
@@ -80,16 +76,18 @@ impl KnownHosts {
/// the collection of known hosts.
pub fn read_file(&mut self, file: &Path, kind: KnownHostFileKind) -> Result<u32, Error> {
let file = CString::new(util::path2bytes(file)?)?;
+ let sess = self.sess.lock();
let n = unsafe { raw::libssh2_knownhost_readfile(self.raw, file.as_ptr(), kind as c_int) };
if n < 0 {
- self.sess.rc(n)?
+ sess.rc(n)?
}
Ok(n as u32)
}
/// Read a line as if it were from a known hosts file.
pub fn read_str(&mut self, s: &str, kind: KnownHostFileKind) -> Result<(), Error> {
- self.sess.rc(unsafe {
+ let sess = self.sess.lock();
+ sess.rc(unsafe {
raw::libssh2_knownhost_readline(
self.raw,
s.as_ptr() as *const _,
@@ -103,20 +101,28 @@ impl KnownHosts {
/// file format.
pub fn write_file(&self, file: &Path, kind: KnownHostFileKind) -> Result<(), Error> {
let file = CString::new(util::path2bytes(file)?)?;
+ let sess = self.sess.lock();
let n = unsafe { raw::libssh2_knownhost_writefile(self.raw, file.as_ptr(), kind as c_int) };
- self.sess.rc(n)
+ sess.rc(n)
}
/// Converts a single known host to a single line of output for storage,
/// using the 'type' output format.
pub fn write_string(&self, host: &Host, kind: KnownHostFileKind) -> Result<String, Error> {
let mut v = Vec::with_capacity(128);
+ let sess = self.sess.lock();
+ let raw_host = self.resolve_to_raw_host(&sess, host)?.ok_or_else(|| {
+ Error::new(
+ raw::LIBSSH2_ERROR_BAD_USE,
+ "Host is not in the set of known hosts",
+ )
+ })?;
loop {
let mut outlen = 0;
unsafe {
let rc = raw::libssh2_knownhost_writeline(
self.raw,
- host.raw,
+ raw_host,
v.as_mut_ptr() as *mut _,
v.capacity() as size_t,
&mut outlen,
@@ -126,7 +132,7 @@ impl KnownHosts {
// + 1 for the trailing zero
v.reserve(outlen as usize + 1);
} else {
- self.sess.rc(rc)?;
+ sess.rc(rc)?;
v.set_len(outlen as usize);
break;
}
@@ -136,17 +142,66 @@ impl KnownHosts {
}
/// Create an iterator over all of the known hosts in this structure.
- pub fn iter(&self) -> Hosts {
- Hosts {
- prev: 0 as *mut _,
- hosts: self,
+ pub fn iter(&self) -> Result<Vec<Host>, Error> {
+ self.hosts()
+ }
+
+ /// Retrieves the list of known hosts
+ pub fn hosts(&self) -> Result<Vec<Host>, Error> {
+ let mut next = 0 as *mut _;
+ let mut prev = 0 as *mut _;
+ let sess = self.sess.lock();
+ let mut hosts = vec![];
+
+ loop {
+ match unsafe { raw::libssh2_knownhost_get(self.raw, &mut next, prev) } {
+ 0 => {
+ prev = next;
+ hosts.push(unsafe { Host::from_raw(next) });
+ }
+ 1 => break,
+ rc => return Err(Error::from_session_error_raw(sess.raw, rc)),
+ }
+ }
+
+ Ok(hosts)
+ }
+
+ /// Given a Host object, find the matching raw node in the internal list.
+ /// The returned value is only valid while the session is locked.
+ fn resolve_to_raw_host(
+ &self,
+ sess: &MutexGuard<SessionInner>,
+ host: &Host,
+ ) -> Result<Option<*mut raw::libssh2_knownhost>, Error> {
+ let mut next = 0 as *mut _;
+ let mut prev = 0 as *mut _;
+
+ loop {
+ match unsafe { raw::libssh2_knownhost_get(self.raw, &mut next, prev) } {
+ 0 => {
+ prev = next;
+ let current = unsafe { Host::from_raw(next) };
+ if current == *host {
+ return Ok(Some(next));
+ }
+ }
+ 1 => break,
+ rc => return Err(Error::from_session_error_raw(sess.raw, rc)),
+ }
}
+ Ok(None)
}
/// Delete a known host entry from the collection of known hosts.
- pub fn remove(&self, host: Host) -> Result<(), Error> {
- self.sess
- .rc(unsafe { raw::libssh2_knownhost_del(self.raw, host.raw) })
+ pub fn remove(&self, host: &Host) -> Result<(), Error> {
+ let sess = self.sess.lock();
+
+ if let Some(raw_host) = self.resolve_to_raw_host(&sess, host)? {
+ return sess.rc(unsafe { raw::libssh2_knownhost_del(self.raw, raw_host) });
+ } else {
+ Ok(())
+ }
}
/// Checks a host and its associated key against the collection of known
@@ -205,6 +260,7 @@ impl KnownHosts {
let host = CString::new(host)?;
let flags =
raw::LIBSSH2_KNOWNHOST_TYPE_PLAIN | raw::LIBSSH2_KNOWNHOST_KEYENC_RAW | (fmt as c_int);
+ let sess = self.sess.lock();
unsafe {
let rc = raw::libssh2_knownhost_addc(
self.raw,
@@ -217,57 +273,33 @@ impl KnownHosts {
flags,
0 as *mut _,
);
- self.sess.rc(rc)
+ sess.rc(rc)
}
}
}
impl Drop for KnownHosts {
fn drop(&mut self) {
+ let _sess = self.sess.lock();
unsafe { raw::libssh2_knownhost_free(self.raw) }
}
}
-impl<'kh> Iterator for Hosts<'kh> {
- type Item = Result<Host<'kh>, Error>;
- fn next(&mut self) -> Option<Result<Host<'kh>, Error>> {
- unsafe {
- let mut next = 0 as *mut _;
- match raw::libssh2_knownhost_get(self.hosts.raw, &mut next, self.prev) {
- 0 => {
- self.prev = next;
- Some(Ok(Binding::from_raw(next)))
- }
- 1 => None,
- rc => Some(Err(self.hosts.sess.rc(rc).err().unwrap())),
- }
- }
- }
-}
-
-impl<'kh> Host<'kh> {
+impl Host {
/// This is `None` if no plain text host name exists.
pub fn name(&self) -> Option<&str> {
- unsafe { ::opt_bytes(self, (*self.raw).name).and_then(|s| str::from_utf8(s).ok()) }
+ self.name.as_ref().map(String::as_str)
}
/// Returns the key in base64/printable format
pub fn key(&self) -> &str {
- let bytes = unsafe { ::opt_bytes(self, (*self.raw).key).unwrap() };
- str::from_utf8(bytes).unwrap()
+ &self.key
}
-}
-
-impl<'kh> Binding for Host<'kh> {
- type Raw = *mut raw::libssh2_knownhost;
- unsafe fn from_raw(raw: *mut raw::libssh2_knownhost) -> Host<'kh> {
- Host {
- raw: raw,
- _marker: marker::PhantomData,
- }
- }
- fn raw(&self) -> *mut raw::libssh2_knownhost {
- self.raw
+ unsafe fn from_raw(raw: *mut raw::libssh2_knownhost) -> Self {
+ let name = ::opt_bytes(&raw, (*raw).name).and_then(|s| String::from_utf8(s.to_vec()).ok());
+ let key = ::opt_bytes(&raw, (*raw).key).unwrap();
+ let key = String::from_utf8(key.to_vec()).unwrap();
+ Self { name, key }
}
}
diff --git a/src/lib.rs b/src/lib.rs
index 23da2f7..036a769 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -22,8 +22,7 @@
//! agent.connect().unwrap();
//! agent.list_identities().unwrap();
//!
-//! for identity in agent.identities() {
-//! let identity = identity.unwrap(); // assume no I/O errors
+//! for identity in agent.identities().unwrap() {
//! println!("{}", identity.comment());
//! let pubkey = identity.blob();
//! }
@@ -138,13 +137,14 @@ extern crate libc;
extern crate libssh2_sys as raw;
#[macro_use]
extern crate bitflags;
+extern crate parking_lot;
use std::ffi::CStr;
-pub use agent::{Agent, Identities, PublicKey};
+pub use agent::{Agent, PublicKey};
pub use channel::{Channel, ExitSignal, ReadWindow, Stream, WriteWindow};
pub use error::Error;
-pub use knownhosts::{Host, Hosts, KnownHosts};
+pub use knownhosts::{Host, KnownHosts};
pub use listener::Listener;
use session::SessionInner;
pub use session::{BlockDirections, KeyboardInteractivePrompt, Prompt, ScpFileStat, Session};
diff --git a/src/listener.rs b/src/listener.rs
index afd29fd..a1a4c8e 100644
--- a/src/listener.rs
+++ b/src/listener.rs
@@ -1,3 +1,4 @@
+use parking_lot::Mutex;
use std::sync::Arc;
use {raw, Channel, Error, SessionInner};
@@ -7,24 +8,27 @@ use {raw, Channel, Error, SessionInner};
/// the remote server's port.
pub struct Listener {
raw: *mut raw::LIBSSH2_LISTENER,
- sess: Arc<SessionInner>,
+ sess: Arc<Mutex<SessionInner>>,
}
impl Listener {
/// Accept a queued connection from this listener.
pub fn accept(&mut self) -> Result<Channel, Error> {
+ let sess = self.sess.lock();
unsafe {
let chan = raw::libssh2_channel_forward_accept(self.raw);
- Channel::from_raw_opt(chan, &self.sess)
+ let err = sess.last_error();
+ Channel::from_raw_opt(chan, err, &self.sess)
}
}
pub(crate) fn from_raw_opt(
raw: *mut raw::LIBSSH2_LISTENER,
- sess: &Arc<SessionInner>,
+ err: Option<Error>,
+ sess: &Arc<Mutex<SessionInner>>,
) -> Result<Self, Error> {
if raw.is_null() {
- Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown))
+ Err(err.unwrap_or_else(Error::unknown))
} else {
Ok(Self {
raw,
diff --git a/src/session.rs b/src/session.rs
index efb4b25..90a0333 100644
--- a/src/session.rs
+++ b/src/session.rs
@@ -1,8 +1,8 @@
#[cfg(unix)]
use libc::size_t;
use libc::{self, c_char, c_int, c_long, c_uint, c_void};
+use parking_lot::{MappedMutexGuard, Mutex, MutexGuard};
use std::borrow::Cow;
-use std::cell::{Ref, RefCell};
use std::ffi::CString;
use std::mem;
use std::net::TcpStream;
@@ -65,14 +65,9 @@ unsafe fn with_abstract<R, F: FnOnce() -> R>(
pub(crate) struct SessionInner {
pub(crate) raw: *mut raw::LIBSSH2_SESSION,
- tcp: RefCell<Option<TcpStream>>,
+ tcp: Option<TcpStream>,
}
-// The compiler doesn't know that it is Send safe because of the raw
-// pointer inside. We know that the way that it is used by libssh2
-// and this crate is Send safe.
-unsafe impl Send for SessionInner {}
-
/// An SSH session, typically representing one TCP connection.
///
/// All other structures are based on an SSH session and cannot outlive a
@@ -80,14 +75,9 @@ unsafe impl Send for SessionInner {}
/// (via the `set_tcp_stream` method).
#[derive(Clone)]
pub struct Session {
- inner: Arc<SessionInner>,
+ inner: Arc<Mutex<SessionInner>>,
}
-// The compiler doesn't know that it is Send safe because of the raw
-// pointer inside. We know that the way that it is used by libssh2
-// and this crate is Send safe.
-unsafe impl Send for Session {}
-
/// Metadata returned about a remote file when received via `scp`.
pub struct ScpFileStat {
stat: libc::stat,
@@ -123,18 +113,19 @@ impl Session {
Err(Error::unknown())
} else {
Ok(Session {
- inner: Arc::new(SessionInner {
+ inner: Arc::new(Mutex::new(SessionInner {
raw: ret,
- tcp: RefCell::new(None),
- }),
+ tcp: None,
+ })),
})
}
}
}
#[doc(hidden)]
- pub fn raw(&self) -> *mut raw::LIBSSH2_SESSION {
- self.inner.raw
+ pub fn raw(&self) -> MappedMutexGuard<raw::LIBSSH2_SESSION> {
+ let inner = self.inner();
+ MutexGuard::map(inner, |inner| unsafe { &mut *inner.raw })
}
/// Set the SSH protocol banner for the local client
@@ -145,12 +136,8 @@ impl Session {
/// default.
pub fn set_banner(&self, banner: &str) -> Result<(), Error> {
let banner = CString::new(banner)?;
- unsafe {
- self.rc(raw::libssh2_session_banner_set(
- self.inner.raw,
- banner.as_ptr(),
- ))
- }
+ let inner = self.inner();
+ unsafe { inner.rc(raw::libssh2_session_banner_set(inner.raw, banner.as_ptr())) }
}
/// Flag indicating whether SIGPIPE signals will be allowed or blocked.
@@ -160,9 +147,10 @@ impl Session {
/// the library to not attempt to block SIGPIPE from the underlying socket
/// layer.
pub fn set_allow_sigpipe(&self, block: bool) {
+ let inner = self.inner();
let res = unsafe {
- self.rc(raw::libssh2_session_flag(
- self.inner.raw,
+ inner.rc(raw::libssh2_session_flag(
+ inner.raw,
raw::LIBSSH2_FLAG_SIGPIPE as c_int,
block as c_int,
))
@@ -177,9 +165,10 @@ impl Session {
/// try to negotiate compression enabling for this connection. By default
/// libssh2 will not attempt to use compression.
pub fn set_compress(&self, compress: bool) {
+ let inner = self.inner();
let res = unsafe {
- self.rc(raw::libssh2_session_flag(
- self.inner.raw,
+ inner.rc(raw::libssh2_session_flag(
+ inner.raw,
raw::LIBSSH2_FLAG_COMPRESS as c_int,
compress as c_int,
))
@@ -197,12 +186,12 @@ impl Session {
/// a blocking session will wait for room. A non-blocking session will
/// return immediately without writing anything.
pub fn set_blocking(&self, blocking: bool) {
- self.inner.set_blocking(blocking);
+ self.inner().set_blocking(blocking);
}
/// Returns whether the session was previously set to nonblocking.
pub fn is_blocking(&self) -> bool {
- self.inner.is_blocking()
+ self.inner().is_blocking()
}
/// Set timeout for blocking functions.
@@ -215,7 +204,8 @@ impl Session {
/// for blocking functions.
pub fn set_timeout(&self, timeout_ms: u32) {
let timeout_ms = timeout_ms as c_long;
- unsafe { raw::libssh2_session_set_timeout(self.inner.raw, timeout_ms) }
+ let inner = self.inner();
+ unsafe { raw::libssh2_session_set_timeout(inner.raw, timeout_ms) }
}
/// Returns the timeout, in milliseconds, for how long blocking calls may
@@ -223,7 +213,8 @@ impl Session {
///
/// A timeout of 0 signifies no timeout.
pub fn timeout(&self) -> u32 {
- unsafe { raw::libssh2_session_get_timeout(self.inner.raw) as u32 }
+ let inner = self.inner();
+ unsafe { raw::libssh2_session_get_timeout(inner.raw) as u32 }
}
/// Begin transport layer protocol negotiation with the connected host.
@@ -243,17 +234,17 @@ impl Session {
raw::libssh2_session_handshake(raw, stream.as_raw_fd())
}
- unsafe {
- let stream = self.inner.tcp.borrow();
+ let inner = self.inner();
- let stream = stream.as_ref().ok_or_else(|| {
+ unsafe {
+ let stream = inner.tcp.as_ref().ok_or_else(|| {
Error::new(
raw::LIBSSH2_ERROR_BAD_SOCKET,
"use set_tcp_stream() to associate with a TcpStream",
)
})?;
- self.rc(handshake(self.inner.raw, stream))
+ inner.rc(handshake(inner.raw, stream))
}
}
@@ -265,13 +256,19 @@ impl Session {
/// concurrently elsewhere for the duration of this session as it may
/// interfere with the protocol.
pub fn set_tcp_stream(&mut self, stream: TcpStream) {
- *self.inner.tcp.borrow_mut() = Some(stream);
+ let mut inner = self.inner();
+ let _ = inner.tcp.replace(stream);
}
/// Returns a reference to the stream that was associated with the Session
/// by the Session::handshake method.
- pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> {
- self.inner.tcp.borrow()
+ pub fn tcp_stream(&self) -> Option<MappedMutexGuard<TcpStream>> {
+ let inner = self.inner();
+ if inner.tcp.is_some() {
+ Some(MutexGuard::map(inner, |inner| inner.tcp.as_mut().unwrap()))
+ } else {
+ None
+ }
}
/// Attempt basic password authentication.
@@ -281,9 +278,10 @@ impl Session {
/// authentication (routed via PAM or another authentication backed)
/// instead.
pub fn userauth_password(&self, username: &str, password: &str) -> Result<(), Error> {
- self.rc(unsafe {
+ let inner = self.inner();
+ inner.rc(unsafe {
raw::libssh2_userauth_password_ex(
- self.inner.raw,
+ inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
password.as_ptr() as *const _,
@@ -394,10 +392,11 @@ impl Session {
}));
}
+ let inner = self.inner();
unsafe {
- with_abstract(self.inner.raw, prompter as *mut P as *mut c_void, || {
- self.rc(raw::libssh2_userauth_keyboard_interactive_ex(
- self.inner.raw,
+ with_abstract(inner.raw, prompter as *mut P as *mut c_void, || {
+ inner.rc(raw::libssh2_userauth_keyboard_interactive_ex(
+ inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
Some(prompt::<P>),
@@ -416,8 +415,9 @@ impl Session {
let mut agent = self.agent()?;
agent.connect()?;
agent.list_identities()?;
- let identity = match agent.identities().next() {
- Some(identity) => identity?,
+ let identities = agent.identities()?;
+ let identity = match identities.get(0) {
+ Some(identity) => identity,
None => {
return Err(Error::new(
raw::LIBSSH2_ERROR_INVAL as c_int,
@@ -446,9 +446,10 @@ impl Session {
Some(s) => Some(CString::new(s)?),
None => None,
};
- self.rc(unsafe {
+ let inner = self.inner();
+ inner.rc(unsafe {
raw::libssh2_userauth_publickey_fromfile_ex(
- self.inner.raw,
+ inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
pubkey.as_ref().map(|s| s.as_ptr()).unwrap_or(0 as *const _),
@@ -484,9 +485,10 @@ impl Session {
Some(s) => Some(CString::new(s)?),
None => None,
};
- self.rc(unsafe {
+ let inner = self.inner();
+ inner.rc(unsafe {
raw::libssh2_userauth_publickey_frommemory(
- self.inner.raw,
+ inner.raw,
username.as_ptr() as *const _,
username.len() as size_t,
pubkeydata
@@ -525,9 +527,10 @@ impl Session {
Some(local) => local,
None => username,
};
- self.rc(unsafe {
+ let inner = self.inner();
+ inner.rc(unsafe {
raw::libssh2_userauth_hostbased_fromfile_ex(
- self.inner.raw,
+ inner.raw,
username.as_ptr() as *const _,
username.len() as c_uint,
publickey.as_ptr(),
@@ -547,7 +550,8 @@ impl Session {
/// Indicates whether or not the named session has been successfully
/// authenticated.
pub fn authenticated(&self) -> bool {
- unsafe { raw::libssh2_userauth_authenticated(self.inner.raw) != 0 }
+ let inner = self.inner();
+ unsafe { raw::libssh2_userauth_authenticated(inner.raw) != 0 }
}
/// Send a SSH_USERAUTH_NONE request to the remote host.
@@ -564,10 +568,11 @@ impl Session {
pub fn auth_methods(&self, username: &str) -> Result<&str, Error> {
let len = username.len();
let username = CString::new(username)?;
+ let inner = self.inner();
unsafe {
- let ret = raw::libssh2_userauth_list(self.inner.raw, username.as_ptr(), len as c_uint);
+ let ret = raw::libssh2_userauth_list(inner.raw, username.as_ptr(), len as c_uint);
if ret.is_null() {
- match Error::last_error(self) {
+ match inner.last_error() {
Some(err) => Err(err),
None => Ok(""),
}
@@ -586,9 +591,10 @@ impl Session {
/// negotiation.
pub fn method_pref(&self, method_type: MethodType, prefs: &str) -> Result<(), Error> {
let prefs = CString::new(prefs)?;
+ let inner = self.inner();
unsafe {
- self.rc(raw::libssh2_session_method_pref(
- self.inner.raw,
+ inner.rc(raw::libssh2_session_method_pref(
+ inner.raw,
method_type as c_int,
prefs.as_ptr(),
))
@@ -600,8 +606,9 @@ impl Session {
/// Returns the actual method negotiated for a particular transport
/// parameter. May return `None` if the session has not yet been started.
pub fn methods(&self, method_type: MethodType) -> Option<&str> {
+ let inner = self.inner();
unsafe {
- let ptr = raw::libssh2_session_methods(self.inner.raw, method_type as c_int);
+ let ptr = raw::libssh2_session_methods(inner.raw, method_type as c_int);
::opt_bytes(self, ptr).and_then(|s| str::from_utf8(s).ok())
}
}
@@ -611,18 +618,19 @@ impl Session {
static STATIC: () = ();
let method_type = method_type as c_int;
let mut ret = Vec::new();
+ let inner = self.inner();
unsafe {
let mut ptr = 0 as *mut _;
- let rc = raw::libssh2_session_supported_algs(self.inner.raw, method_type, &mut ptr);
+ let rc = raw::libssh2_session_supported_algs(inner.raw, method_type, &mut ptr);
if rc <= 0 {
- self.rc(rc)?;
+ inner.rc(rc)?;
}
for i in 0..(rc as isize) {
let s = ::opt_bytes(&STATIC, *ptr.offset(i)).unwrap();
let s = str::from_utf8(s).unwrap();
ret.push(s);
}
- raw::libssh2_free(self.inner.raw, ptr as *mut c_void);
+ raw::libssh2_free(inner.raw, ptr as *mut c_void);
}
Ok(ret)
}
@@ -631,7 +639,12 @@ impl Session {
///
/// The returned agent will still need to be connected manually before use.
pub fn agent(&self) -> Result<Agent, Error> {
- unsafe { Agent::from_raw_opt(raw::libssh2_agent_init(self.inner.raw), &self.inner) }
+ let inner = self.inner();
+ unsafe {
+ let agent = raw::libssh2_agent_init(inner.raw);
+ let err = inner.last_error();
+ Agent::from_raw_opt(agent, err, &self.inner)
+ }
}
/// Init a collection of known hosts for this session.
@@ -639,9 +652,11 @@ impl Session {
/// Returns the handle to an internal representation of a known host
/// collection.
pub fn known_hosts(&self) -> Result<KnownHosts, Error> {
+ let inner = self.inner();
unsafe {
- let ptr = raw::libssh2_knownhost_init(self.inner.raw);
- KnownHosts::from_raw_opt(ptr, &self.inner)
+ let ptr = raw::libssh2_knownhost_init(inner.raw);
+ let err = inner.last_error();
+ KnownHosts::from_raw_opt(ptr, err, &self.inner)
}
}
@@ -679,15 +694,17 @@ impl Session {
let (shost, sport) = src.unwrap_or(("127.0.0.1", 22));
let host = CString::new(host)?;
let shost = CString::new(shost)?;
+ let inner = self.inner();
unsafe {
let ret = raw::libssh2_channel_direct_tcpip_ex(
- self.inner.raw,
+ inner.raw,
host.as_ptr(),
port as c_int,
shost.as_ptr(),
sport as c_int,
);
- Channel::from_raw_opt(ret, &self.inner)
+ let err = inner.last_error();
+ Channel::from_raw_opt(ret, err, &self.inner)
}
}
@@ -703,15 +720,17 @@ impl Session {
queue_maxsize: Option<u32>,
) -> Result<(Listener, u16), Error> {
let mut bound_port = 0;
+ let inner = self.inner();
unsafe {
let ret = raw::libssh2_channel_forward_listen_ex(
- self.inner.raw,
+ inner.raw,
host.map(|s| s.as_ptr()).unwrap_or(0 as *const _) as *mut _,
remote_port as c_int,
&mut bound_port,
queue_maxsize.unwrap_or(0) as c_int,
);
- Listener::from_raw_opt(ret, &self.inner).map(|l| (l, bound_port as u16))
+ let err = inner.last_error();
+ Listener::from_raw_opt(ret, err, &self.inner).map(|l| (l, bound_port as u16))
}
}
@@ -722,10 +741,12 @@ impl Session {
/// about the remote file to prepare for receiving the file.
pub fn scp_recv(&self, path: &Path) -> Result<(Channel, ScpFileStat), Error> {
let path = CString::new(util::path2bytes(path)?)?;
+ let inner = self.inner();
unsafe {
let mut sb: raw::libssh2_struct_stat = mem::zeroed();
- let ret = raw::libssh2_scp_recv2(self.inner.raw, path.as_ptr(), &mut sb);
- let mut c = Channel::from_raw_opt(ret, &self.inner)?;
+ let ret = raw::libssh2_scp_recv2(inner.raw, path.as_ptr(), &mut sb);
+ let err = inner.last_error();
+ let mut c = Channel::from_raw_opt(ret, err, &self.inner)?;
// Hm, apparently when we scp_recv() a file the actual channel
// itself does not respond well to read_to_end(), and it also sends
@@ -754,16 +775,18 @@ impl Session {
) -> Result<Channel, Error> {
let path = CString::new(util::path2bytes(remote_path)?)?;
let (mtime, atime) = times.unwrap_or((0, 0));
+ let inner = self.inner();
unsafe {
let ret = raw::libssh2_scp_send64(
- self.inner.raw,
+ inner.raw,
path.as_ptr(),
mode as c_int,
size as i64,
mtime as libc::time_t,
atime as libc::time_t,
);
- Channel::from_raw_opt(ret, &self.inner)
+ let err = inner.last_error();
+ Channel::from_raw_opt(ret, err, &self.inner)
}
}
@@ -774,9 +797,11 @@ impl Session {
/// own unique binary packet protocol which must be managed with the
/// methods on `Sftp`.
pub fn sftp(&self) -> Result<Sftp, Error> {
+ let inner = self.inner();
unsafe {
- let ret = raw::libssh2_sftp_init(self.inner.raw);
- Sftp::from_raw_opt(ret, &self.inner)
+ let ret = raw::libssh2_sftp_init(inner.raw);
+ let err = inner.last_error();
+ Sftp::from_raw_opt(ret, err, &self.inner)
}
}
@@ -792,9 +817,10 @@ impl Session {
message: Option<&str>,
) -> Result<Channel, Error> {
let message_len = message.map(|s| s.len()).unwrap_or(0);
+ let inner = self.inner();
unsafe {
let ret = raw::libssh2_channel_open_ex(
- self.inner.raw,
+ inner.raw,
channel_type.as_ptr() as *const _,
channel_type.len() as c_uint,
window_size as c_uint,
@@ -805,7 +831,8 @@ impl Session {
.unwrap_or(0 as *const _) as *const _,
message_len as c_uint,
);
- Channel::from_raw_opt(ret, &self.inner)
+ let err = inner.last_error();
+ Channel::from_raw_opt(ret, err, &self.inner)
}
}
@@ -824,7 +851,8 @@ impl Session {
///
/// Will only return `None` if an error has ocurred.
pub fn banner_bytes(&self) -> Option<&[u8]> {
- unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(self.inner.raw)) }
+ let inner = self.inner();
+ unsafe { ::opt_bytes(self, raw::libssh2_session_banner_get(inner.raw)) }
}
/// Get the remote key.
@@ -833,8 +861,9 @@ impl Session {
pub fn host_key(&self) -> Option<(&[u8], HostKeyType)> {
let mut len = 0;
let mut kind = 0;
+ let inner = self.inner();
unsafe {
- let ret = raw::libssh2_session_hostkey(self.inner.raw, &mut len, &mut kind);
+ let ret = raw::libssh2_session_hostkey(inner.raw, &mut len, &mut kind);
if ret.is_null() {
return None;
}
@@ -863,8 +892,9 @@ impl Session {
HashType::Sha1 => 20,
HashType::Sha256 => 32,
};
+ let inner = self.inner();
unsafe {
- let ret = raw::libssh2_hostkey_hash(self.inner.raw, hash as c_int);
+ let ret = raw::libssh2_hostkey_hash(inner.raw, hash as c_int);
if ret.is_null() {
None
} else {
@@ -883,9 +913,8 @@ impl Session {
/// I/O, use 0 (the default) to disable keepalives. To avoid some busy-loop
/// corner-cases, if you specify an interval of 1 it will be treated as 2.
pub fn set_keepalive(&self, want_reply: bool, interval: u32) {
- unsafe {
- raw::libssh2_keepalive_config(self.inner.raw, want_reply as c_int, interval as c_uint)
- }
+ let inner = self.inner();
+ unsafe { raw::libssh2_keepalive_config(inner.raw, want_reply as c_int, interval as c_uint) }
}
/// Send a keepalive message if needed.
@@ -894,8 +923,9 @@ impl Session {
/// to call it again.
pub fn keepalive_send(&self) -> Result<u32, Error> {
let mut ret = 0;
- let rc = unsafe { raw::libssh2_keepalive_send(self.inner.raw, &mut ret) };
- self.rc(rc)?;
+ let inner = self.inner();
+ let rc = unsafe { raw::libssh2_keepalive_send(inner.raw, &mut ret) };
+ inner.rc(rc)?;
Ok(ret as u32)
}
@@ -914,9 +944,10 @@ impl Session {
let reason = reason.unwrap_or(ByApplication) as c_int;
let description = CString::new(description)?;
let lang = CString::new(lang.unwrap_or(""))?;
+ let inner = self.inner();
unsafe {
- self.rc(raw::libssh2_session_disconnect_ex(
- self.inner.raw,
+ inner.rc(raw::libssh2_session_disconnect_ex(
+ inner.raw,
reason,
description.as_ptr(),
lang.as_ptr(),
@@ -924,17 +955,13 @@ impl Session {
}
}
- /// Translate a return code into a Rust-`Result`.
- pub fn rc(&self, rc: c_int) -> Result<(), Error> {
- self.inner.rc(rc)
- }
-
/// Returns the blocked io directions that the application needs to wait for.
///
/// This function should be used after an error of type `WouldBlock` is returned to
/// find out the socket events the application has to wait for.
pub fn block_directions(&self) -> BlockDirections {
- let dir = unsafe { raw::libssh2_session_block_directions(self.inner.raw) };
+ let inner = self.inner();
+ let dir = unsafe { raw::libssh2_session_block_directions(inner.raw) };
match dir {
raw::LIBSSH2_SESSION_BLOCK_INBOUND => BlockDirections::Inbound,
raw::LIBSSH2_SESSION_BLOCK_OUTBOUND => BlockDirections::Outbound,
@@ -944,6 +971,10 @@ impl Session {
_ => BlockDirections::None,
}
}
+
+ fn inner(&self) -> MutexGuard<SessionInner> {
+ self.inner.lock()
+ }
}
impl SessionInner {
@@ -956,6 +987,10 @@ impl SessionInner {
}
}
+ pub fn last_error(&self) -> Option<Error> {
+ Error::last_error_raw(self.raw)
+ }
+
/// Set or clear blocking mode on session
pub fn set_blocking(&self, blocking: bool) {
unsafe { raw::libssh2_session_set_blocking(self.raw, blocking as c_int) }
diff --git a/src/sftp.rs b/src/sftp.rs
index b7ecf77..025c897 100644
--- a/src/sftp.rs
+++ b/src/sftp.rs
@@ -1,4 +1,5 @@
use libc::{c_int, c_long, c_uint, c_ulong, size_t};
+use parking_lot::{Mutex, MutexGuard};
use std::io::prelude::*;
use std::io::{self, ErrorKind, SeekFrom};
use std::mem;
@@ -10,19 +11,29 @@ use {raw, Error, SessionInner};
struct SftpInner {
raw: *mut raw::LIBSSH2_SFTP,
- sess: Arc<SessionInner>,
+ sess: Arc<Mutex<SessionInner>>,
}
/// A handle to a remote filesystem over SFTP.
///
/// Instances are created through the `sftp` method on a `Session`.
pub struct Sftp {
- inner: Option<SftpInner>,
+ inner: Option<Arc<SftpInner>>,
}
-struct FileInner<'sftp> {
+struct LockedSftp<'sftp> {
+ sess: MutexGuard<'sftp, SessionInner>,
+ raw: *mut raw::LIBSSH2_SFTP,
+}
+
+struct FileInner {
+ raw: *mut raw::LIBSSH2_SFTP_HANDLE,
+ sftp: Arc<SftpInner>,
+}
+
+struct LockedFile<'file> {
+ sess: MutexGuard<'file, SessionInner>,
raw: *mut raw::LIBSSH2_SFTP_HANDLE,
- sftp: &'sftp Sftp,
}
/// A file handle to an SFTP connection.
@@ -32,8 +43,8 @@ struct FileInner<'sftp> {
///
/// Files are created through `open`, `create`, and `open_mode` on an instance
/// of `Sftp`.
-pub struct File<'sftp> {
- inner: Option<FileInner<'sftp>>,
+pub struct File {
+ inner: Option<FileInner>,
}
/// Metadata information about a remote file.
@@ -110,19 +121,27 @@ pub enum OpenType {
Dir = raw::LIBSSH2_SFTP_OPENDIR as isize,
}
+impl<'sftp> LockedSftp<'sftp> {
+ pub fn last_error(&self) -> Error {
+ let code = unsafe { raw::libssh2_sftp_last_error(self.raw) };
+ Error::from_errno(code as c_int)
+ }
+}
+
impl Sftp {
pub(crate) fn from_raw_opt(
raw: *mut raw::LIBSSH2_SFTP,
- sess: &Arc<SessionInner>,
+ err: Option<Error>,
+ sess: &Arc<Mutex<SessionInner>>,
) -> Result<Self, Error> {
if raw.is_null() {
- Err(Error::last_error_raw(sess.raw).unwrap_or_else(Error::unknown))
+ Err(err.unwrap_or_else(Error::unknown))
} else {
Ok(Self {
- inner: Some(SftpInner {
+ inner: Some(Arc::new(SftpInner {
raw,
sess: Arc::clone(sess),
- }),
+ })),
})
}
}
@@ -136,9 +155,11 @@ impl Sftp {
open_type: OpenType,
) -> Result<File, Error> {
let filename = util::path2bytes(filename)?;
+
+ let locked = self.lock()?;
unsafe {
let ret = raw::libssh2_sftp_open_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
flags.bits() as c_ulong,
@@ -146,7 +167,7 @@ impl Sftp {
open_type as c_int,
);
if ret.is_null() {
- Err(self.last_session_error())
+ Err(locked.last_error())
} else {
Ok(File::from_raw(self, ret))
}
@@ -199,9 +220,10 @@ impl Sftp {
/// Create a directory on the remote file system.
pub fn mkdir(&self, filename: &Path, mode: i32) -> Result<(), Error> {
let filename = util::path2bytes(filename)?;
- self.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
raw::libssh2_sftp_mkdir_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
mode as c_long,
@@ -212,9 +234,10 @@ impl Sftp {
/// Remove a directory from the remote file system.
pub fn rmdir(&self, filename: &Path) -> Result<(), Error> {
let filename = util::path2bytes(filename)?;
- self.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
raw::libssh2_sftp_rmdir_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
)
@@ -224,16 +247,17 @@ impl Sftp {
/// Get the metadata for a file, performed by stat(2)
pub fn stat(&self, filename: &Path) -> Result<FileStat, Error> {
let filename = util::path2bytes(filename)?;
+ let locked = self.lock()?;
unsafe {
let mut ret = mem::zeroed();
let rc = raw::libssh2_sftp_stat_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
raw::LIBSSH2_SFTP_STAT,
&mut ret,
);
- self.rc(rc)?;
+ locked.sess.rc(rc)?;
Ok(FileStat::from_raw(&ret))
}
}
@@ -241,16 +265,17 @@ impl Sftp {
/// Get the metadata for a file, performed by lstat(2)
pub fn lstat(&self, filename: &Path) -> Result<FileStat, Error> {
let filename = util::path2bytes(filename)?;
+ let locked = self.lock()?;
unsafe {
let mut ret = mem::zeroed();
let rc = raw::libssh2_sftp_stat_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
raw::LIBSSH2_SFTP_LSTAT,
&mut ret,
);
- self.rc(rc)?;
+ locked.sess.rc(rc)?;
Ok(FileStat::from_raw(&ret))
}
}
@@ -258,10 +283,11 @@ impl Sftp {
/// Set the metadata for a file.
pub fn setstat(&self, filename: &Path, stat: FileStat) -> Result<(), Error> {
let filename = util::path2bytes(filename)?;
- self.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
let mut raw = stat.raw();
raw::libssh2_sftp_stat_ex(
- self.get_raw()?,
+ locked.raw,
filename.as_ptr() as *const _,
filename.len() as c_uint,
raw::LIBSSH2_SFTP_SETSTAT,
@@ -274,9 +300,10 @@ impl Sftp {
pub fn symlink(&self, path: &Path, target: &Path) -> Result<(), Error> {
let path = util::path2bytes(path)?;
let target = util::path2bytes(target)?;
- self.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
raw::libssh2_sftp_symlink_ex(
- self.get_raw()?,
+ locked.raw,
path.as_ptr() as *const _,
path.len() as c_uint,
target.as_ptr() as *mut _,
@@ -300,10 +327,11 @@ impl Sftp {
let path = util::path2bytes(path)?;
let mut ret = Vec::<u8>::with_capacity(128);
let mut rc;
+ let locked = self.lock()?;
loop {
rc = unsafe {
raw::libssh2_sftp_symlink_ex(
- self.get_raw()?,
+ locked.raw,
path.as_ptr() as *const _,
path.len() as c_uint,
ret.as_ptr() as *mut _,
@@ -319,7 +347,7 @@ impl Sftp {
}
}
if rc < 0 {
- Err(Error::from_errno(rc))
+ Err(Error::from_session_error_raw(locked.sess.raw, rc))
} else {
unsafe { ret.set_len(rc as usize) }
Ok(mkpath(ret))
@@ -343,9 +371,10 @@ impl Sftp {
flags.unwrap_or(RenameFlags::ATOMIC | RenameFlags::OVERWRITE | RenameFlags::NATIVE);
let src = util::path2bytes(src)?;
let dst = util::path2bytes(dst)?;
- self.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
raw::libssh2_sftp_rename_ex(
- self.get_raw()?,
+ locked.raw,
src.as_ptr() as *const _,
src.len() as c_uint,
dst.as_ptr() as *const _,
@@ -358,53 +387,34 @@ impl Sftp {
/// Remove a file on the remote filesystem
pub fn unlink(&self, file: &Path) -> Result<(), Error> {
let file = util::path2bytes(file)?;
- self.rc(unsafe {
- raw::libssh2_sftp_unlink_ex(
- self.get_raw()?,
- file.as_ptr() as *const _,
- file.len() as c_uint,
- )
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
+ raw::libssh2_sftp_unlink_ex(locked.raw, file.as_ptr() as *const _, file.len() as c_uint)
})
}
- /// Peel off the last error to happen on this SFTP instance.
- pub fn last_error(&self) -> Error {
- let raw = match self.get_raw() {
- Ok(raw) => raw,
- Err(e) => return e,
- };
- let code = unsafe { raw::libssh2_sftp_last_error(raw) };
- Error::from_errno(code as c_int)
- }
-
- fn last_session_error(&self) -> Error {
- if let Some(inner) = self.inner.as_ref() {
- Error::last_error_raw(inner.sess.raw).unwrap_or_else(Error::unknown)
- } else {
- Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE)
- }
- }
-
- /// Translates a return code into a Rust-`Result`
- pub fn rc(&self, rc: c_int) -> Result<(), Error> {
- if rc == 0 {
- Ok(())
- } else {
- Err(Error::from_errno(rc))
- }
- }
-
- fn get_raw(&self) -> Result<*mut raw::LIBSSH2_SFTP, Error> {
+ fn lock(&self) -> Result<LockedSftp, Error> {
match self.inner.as_ref() {
- Some(inner) => Ok(inner.raw),
+ Some(inner) => {
+ let sess = inner.sess.lock();
+ Ok(LockedSftp {
+ sess,
+ raw: inner.raw,
+ })
+ }
None => Err(Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE)),
}
}
+ // This method is used by the async ssh crate
#[doc(hidden)]
pub fn shutdown(&mut self) -> Result<(), Error> {
- let raw = self.get_raw()?;
- self.rc(unsafe { raw::libssh2_sftp_shutdown(raw) })?;
+ {
+ let locked = self.lock()?;
+ locked
+ .sess
+ .rc(unsafe { raw::libssh2_sftp_shutdown(locked.raw) })?;
+ }
let _ = self.inner.take();
Ok(())
}
@@ -414,57 +424,62 @@ impl Drop for Sftp {
fn drop(&mut self) {
// Set ssh2 to blocking if sftp was not shutdown yet.
if let Some(inner) = self.inner.take() {
- let was_blocking = inner.sess.is_blocking();
- inner.sess.set_blocking(true);
+ let sess = inner.sess.lock();
+ let was_blocking = sess.is_blocking();
+ sess.set_blocking(true);
assert_eq!(unsafe { raw::libssh2_sftp_shutdown(inner.raw) }, 0);
- inner.sess.set_blocking(was_blocking);
+ sess.set_blocking(was_blocking);
}
}
}
-impl<'sftp> File<'sftp> {
+impl File {
/// Wraps a raw pointer in a new File structure tied to the lifetime of the
/// given session.
///
/// This consumes ownership of `raw`.
- unsafe fn from_raw(sftp: &'sftp Sftp, raw: *mut raw::LIBSSH2_SFTP_HANDLE) -> File<'sftp> {
+ unsafe fn from_raw(sftp: &Sftp, raw: *mut raw::LIBSSH2_SFTP_HANDLE) -> File {
File {
inner: Some(FileInner {
raw: raw,
- sftp: sftp,
+ sftp: Arc::clone(
+ sftp.inner
+ .as_ref()
+ .expect("we have a live option during construction"),
+ ),
}),
}
}
/// Set the metadata for this handle.
pub fn setstat(&mut self, stat: FileStat) -> Result<(), Error> {
- let inner = self.get_inner()?;
- inner.sftp.rc(unsafe {
+ let locked = self.lock()?;
+ locked.sess.rc(unsafe {
let mut raw = stat.raw();
- raw::libssh2_sftp_fstat_ex(inner.raw, &mut raw, 1)
+ raw::libssh2_sftp_fstat_ex(locked.raw, &mut raw, 1)
})
}
/// Get the metadata for this handle.
pub fn stat(&mut self) -> Result<FileStat, Error> {
+ let locked = self.lock()?;
unsafe {
- let inner = self.get_inner()?;
let mut ret = mem::zeroed();
- inner
- .sftp
- .rc(raw::libssh2_sftp_fstat_ex(inner.raw, &mut ret, 0))?;
+ locked
+ .sess
+ .rc(raw::libssh2_sftp_fstat_ex(locked.raw, &mut ret, 0))?;
Ok(FileStat::from_raw(&ret))
}
}
#[allow(missing_docs)] // sure wish I knew what this did...
pub fn statvfs(&mut self) -> Result<raw::LIBSSH2_SFTP_STATVFS, Error> {
+ let locked = self.lock()?;
unsafe {
- let inner = self.get_inner()?;
let mut ret = mem::zeroed();
- inner
- .sftp
- .rc(raw::libssh2_sftp_fstatvfs(inner.raw, &mut ret))?;
+ locked
+ .sess
+ .rc(raw::libssh2_sftp_fstatvfs(locked.raw, &mut ret))?;
Ok(ret)
}
}
@@ -480,14 +495,15 @@ impl<'sftp> File<'sftp> {
/// Also note that the return paths will not be absolute paths, they are
/// the filenames of the files in this directory.
pub fn readdir(&mut self) -> Result<(PathBuf, FileStat), Error> {
- let inner = self.get_inner()?;
+ let locked = self.lock()?;
+
let mut buf = Vec::<u8>::with_capacity(128);
let mut stat = unsafe { mem::zeroed() };
let mut rc;
loop {
rc = unsafe {
raw::libssh2_sftp_readdir_ex(
- inner.raw,
+ locked.raw,
buf.as_mut_ptr() as *mut _,
buf.capacity() as size_t,
0 as *mut _,
@@ -503,7 +519,7 @@ impl<'sftp> File<'sftp> {
}
}
if rc < 0 {
- return Err(Error::from_errno(rc));
+ return Err(Error::from_session_error_raw(locked.sess.raw, rc));
} else if rc == 0 {
return Err(Error::new(raw::LIBSSH2_ERROR_FILE, "no more files"));
} else {
@@ -519,50 +535,59 @@ impl<'sftp> File<'sftp> {
///
/// For this to work requires fsync@openssh.com support on the server.
pub fn fsync(&mut self) -> Result<(), Error> {
- let inner = self.get_inner()?;
- inner.sftp.rc(unsafe { raw::libssh2_sftp_fsync(inner.raw) })
+ let locked = self.lock()?;
+ locked
+ .sess
+ .rc(unsafe { raw::libssh2_sftp_fsync(locked.raw) })
}
- fn get_inner(&self) -> Result<&FileInner, Error> {
+ fn lock(&self) -> Result<LockedFile, Error> {
match self.inner.as_ref() {
- Some(inner) => Ok(inner),
+ Some(file_inner) => {
+ let sess = file_inner.sftp.sess.lock();
+ Ok(LockedFile {
+ sess,
+ raw: file_inner.raw,
+ })
+ }
None => Err(Error::from_errno(raw::LIBSSH2_ERROR_BAD_USE)),
}
}
#[doc(hidden)]
pub fn close(&mut self) -> Result<(), Error> {
- let inner = self.get_inner()?;
- inner
- .sftp
- .rc(unsafe { raw::libssh2_sftp_close_handle(inner.raw) })?;
+ {
+ let locked = self.lock()?;
+ Error::rc(unsafe { raw::libssh2_sftp_close_handle(locked.raw) })?;
+ }
let _ = self.inner.take();
Ok(())
}
}
-impl<'sftp> Read for File<'sftp> {
+impl Read for File {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let locked = self.lock()?;
unsafe {
- let inner = self.get_inner()?;
let rc =
- raw::libssh2_sftp_read(inner.raw, buf.as_mut_ptr() as *mut _, buf.len() as size_t);
- match rc {
- n if n < 0 => Err(io::Error::new(ErrorKind::Other, inner.sftp.last_error())),
- n => Ok(n as usize),
+ raw::libssh2_sftp_read(locked.raw, buf.as_mut_ptr() as *mut _, buf.len() as size_t);
+ if rc < 0 {
+ Err(Error::from_session_error_raw(locked.sess.raw, rc as _).into())
+ } else {
+ Ok(rc as usize)
}
}
}
}
-impl<'sftp> Write for File<'sftp> {
+impl Write for File {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
- let inner = self.get_inner()?;
+ let locked = self.lock()?;
let rc = unsafe {
- raw::libssh2_sftp_write(inner.raw, buf.as_ptr() as *const _, buf.len() as size_t)
+ raw::libssh2_sftp_write(locked.raw, buf.as_ptr() as *const _, buf.len() as size_t)
};
if rc < 0 {
- Err(Error::from_errno(rc as _).into())
+ Err(Error::from_session_error_raw(locked.sess.raw, rc as _).into())
} else {
Ok(rc as usize)
}
@@ -572,7 +597,7 @@ impl<'sftp> Write for File<'sftp> {
}
}
-impl<'sftp> Seek for File<'sftp> {
+impl Seek for File {
/// Move the file handle's internal pointer to an arbitrary location.
///
/// libssh2 implements file pointers as a localized concept to make file
@@ -587,8 +612,8 @@ impl<'sftp> Seek for File<'sftp> {
let next = match how {
SeekFrom::Start(pos) => pos,
SeekFrom::Current(offset) => {
- let inner = self.get_inner()?;
- let cur = unsafe { raw::libssh2_sftp_tell64(inner.raw) };
+ let locked = self.lock()?;
+ let cur = unsafe { raw::libssh2_sftp_tell64(locked.raw) };
(cur as i64 + offset) as u64
}
SeekFrom::End(offset) => match self.stat() {
@@ -599,27 +624,21 @@ impl<'sftp> Seek for File<'sftp> {
Err(e) => return Err(io::Error::new(ErrorKind::Other, e)),
},
};
- let inner = self.get_inner()?;
- unsafe { raw::libssh2_sftp_seek64(inner.raw, next) }
+ let locked = self.lock()?;
+ unsafe { raw::libssh2_sftp_seek64(locked.raw, next) }
Ok(next)
}
}
-impl<'sftp> Drop for File<'sftp> {
+impl Drop for File {
fn drop(&mut self) {
// Set ssh2 to blocking if the file was not closed yet.
- if let Some(inner) = self.inner.take() {
- // Normally sftp should still be available here. The only way it is `None` is when
- // shutdown has been polled until success. In this case the async client should
- // also properly poll `close` on `File` until success.
- if let Some(sftp) = inner.sftp.inner.as_ref() {
- let was_blocking = sftp.sess.is_blocking();
- sftp.sess.set_blocking(true);
- assert_eq!(unsafe { raw::libssh2_sftp_close_handle(inner.raw) }, 0);
- sftp.sess.set_blocking(was_blocking);
- } else {
- assert_eq!(unsafe { raw::libssh2_sftp_close_handle(inner.raw) }, 0);
- }
+ if let Some(file_inner) = self.inner.take() {
+ let sess_inner = file_inner.sftp.sess.lock();
+ let was_blocking = sess_inner.is_blocking();
+ sess_inner.set_blocking(true);
+ assert_eq!(unsafe { raw::libssh2_sftp_close_handle(file_inner.raw) }, 0);
+ sess_inner.set_blocking(was_blocking);
}
}
}
diff --git a/tests/all/agent.rs b/tests/all/agent.rs
index 8feee06..b3a7963 100644
--- a/tests/all/agent.rs
+++ b/tests/all/agent.rs
@@ -7,9 +7,8 @@ fn smoke() {
agent.connect().unwrap();
agent.list_identities().unwrap();
{
- let mut a = agent.identities();
- let i1 = a.next().unwrap().unwrap();
- a.count();
+ let a = agent.identities().unwrap();
+ let i1 = &a[0];
assert!(agent.userauth("foo", &i1).is_err());
}
agent.disconnect().unwrap();
diff --git a/tests/all/knownhosts.rs b/tests/all/knownhosts.rs
index 1847797..ab83577 100644
--- a/tests/all/knownhosts.rs
+++ b/tests/all/knownhosts.rs
@@ -3,8 +3,9 @@ use ssh2::{KnownHostFileKind, Session};
#[test]
fn smoke() {
let sess = Session::new().unwrap();
- let hosts = sess.known_hosts().unwrap();
- assert_eq!(hosts.iter().count(), 0);
+ let known_hosts = sess.known_hosts().unwrap();
+ let hosts = known_hosts.hosts().unwrap();
+ assert_eq!(hosts.len(), 0);
}
#[test]
@@ -20,11 +21,14 @@ PW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi\
/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ==
";
let sess = Session::new().unwrap();
- let mut hosts = sess.known_hosts().unwrap();
- hosts.read_str(encoded, KnownHostFileKind::OpenSSH).unwrap();
+ let mut known_hosts = sess.known_hosts().unwrap();
+ known_hosts
+ .read_str(encoded, KnownHostFileKind::OpenSSH)
+ .unwrap();
- assert_eq!(hosts.iter().count(), 1);
- let host = hosts.iter().next().unwrap().unwrap();
+ let hosts = known_hosts.hosts().unwrap();
+ assert_eq!(hosts.len(), 1);
+ let host = &hosts[0];
assert_eq!(host.name(), None);
assert_eq!(
host.key(),
@@ -39,10 +43,10 @@ PW3RcT0eOzQgqlJL3RKrTJvdsjE3JEAvGq3lGHSZXy28G3skua2SmVi\
);
assert_eq!(
- hosts
- .write_string(&host, KnownHostFileKind::OpenSSH)
+ known_hosts
+ .write_string(host, KnownHostFileKind::OpenSSH)
.unwrap(),
encoded
);
- hosts.remove(host).unwrap();
+ known_hosts.remove(host).unwrap();
}
diff --git a/tests/all/main.rs b/tests/all/main.rs
index f74684a..87380eb 100644
--- a/tests/all/main.rs
+++ b/tests/all/main.rs
@@ -36,7 +36,8 @@ pub fn authed_session() -> ssh2::Session {
let mut agent = sess.agent().unwrap();
agent.connect().unwrap();
agent.list_identities().unwrap();
- let identity = agent.identities().next().unwrap().unwrap();
+ let identities = agent.identities().unwrap();
+ let identity = &identities[0];
agent.userauth(&user, &identity).unwrap();
}
assert!(sess.authenticated());
diff --git a/tests/all/session.rs b/tests/all/session.rs
index 8894081..ee3eaa0 100644
--- a/tests/all/session.rs
+++ b/tests/all/session.rs
@@ -51,7 +51,7 @@ fn smoke_handshake() {
agent.connect().unwrap();
agent.list_identities().unwrap();
{
- let identity = agent.identities().next().unwrap().unwrap();
+ let identity = &agent.identities().unwrap()[0];
agent.userauth(&user, &identity).unwrap();
}
assert!(sess.authenticated());