diff options
m--------- | libssh2-sys/libssh2 | 0 | ||||
-rw-r--r-- | src/session.rs | 37 | ||||
-rw-r--r-- | tests/all/main.rs | 3 | ||||
-rw-r--r-- | tests/all/session.rs | 6 |
4 files changed, 31 insertions, 15 deletions
diff --git a/libssh2-sys/libssh2 b/libssh2-sys/libssh2 -Subproject 6d70b26ab602d112707890608455caf29ec9a35 +Subproject 1bbb96e41bf298fb525e5bcf64566f43e8fd4c8 diff --git a/src/session.rs b/src/session.rs index ab27a4f..acf9522 100644 --- a/src/session.rs +++ b/src/session.rs @@ -205,14 +205,9 @@ impl Session { /// Begin transport layer protocol negotiation with the connected host. /// - /// This session takes ownership of the socket provided. - /// You may use the tcp_stream() method to obtain a reference - /// to it later. - /// - /// It is also highly recommended that the stream provided is not used - /// concurrently elsewhere for the duration of this session as it may - /// interfere with the protocol. - pub fn handshake(&mut self, stream: TcpStream) -> Result<(), Error> { + /// You must call this after associating the session with a tcp stream + /// via the `set_tcp_stream` function. + pub fn handshake(&mut self) -> Result<(), Error> { #[cfg(windows)] unsafe fn handshake(raw: *mut raw::LIBSSH2_SESSION, stream: &TcpStream) -> libc::c_int { use std::os::windows::prelude::*; @@ -226,7 +221,16 @@ impl Session { } unsafe { - let res = handshake(self.inner.raw, &stream); + let stream = self.inner.tcp.borrow(); + + let stream = stream.as_ref().ok_or_else(|| { + Error::new( + raw::LIBSSH2_ERROR_BAD_SOCKET, + "use set_tcp_stream() to associate with a TcpStream", + ) + })?; + + let res = handshake(self.inner.raw, stream); self.rc(res)?; if res < 0 { // There are some kex related errors that don't set the @@ -235,11 +239,21 @@ impl Session { // Let's ensure that we indicate an error in this situation. return Err(Error::new(res, "Error during handshake")); } - *self.inner.tcp.borrow_mut() = Some(stream); Ok(()) } } + /// The session takes ownership of the socket provided. + /// You may use the tcp_stream() method to obtain a reference + /// to it later. + /// + /// It is also highly recommended that the stream provided is not used + /// concurrently elsewhere for the duration of this session as it may + /// interfere with the protocol. + pub fn set_tcp_stream(&mut self, stream: TcpStream) { + *self.inner.tcp.borrow_mut() = Some(stream); + } + /// Returns a reference to the stream that was associated with the Session /// by the Session::handshake method. pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> { @@ -308,8 +322,7 @@ impl Session { let instruction = String::from_utf8_lossy(instruction); let prompts = unsafe { slice::from_raw_parts(prompts, num_prompts as usize) }; - let responses = - unsafe { slice::from_raw_parts_mut(responses, num_prompts as usize) }; + let responses = unsafe { slice::from_raw_parts_mut(responses, num_prompts as usize) }; let prompts: Vec<Prompt> = prompts .iter() diff --git a/tests/all/main.rs b/tests/all/main.rs index 75f2a29..a021b4c 100644 --- a/tests/all/main.rs +++ b/tests/all/main.rs @@ -23,7 +23,8 @@ pub fn authed_session() -> ssh2::Session { let user = env::var("USER").unwrap(); let socket = socket(); let mut sess = ssh2::Session::new().unwrap(); - sess.handshake(socket).unwrap(); + sess.set_tcp_stream(socket); + sess.handshake().unwrap(); assert!(!sess.authenticated()); { diff --git a/tests/all/session.rs b/tests/all/session.rs index 5e45d53..cf64f1d 100644 --- a/tests/all/session.rs +++ b/tests/all/session.rs @@ -30,7 +30,8 @@ fn smoke_handshake() { let user = env::var("USER").unwrap(); let socket = ::socket(); let mut sess = Session::new().unwrap(); - sess.handshake(socket).unwrap(); + sess.set_tcp_stream(socket); + sess.handshake().unwrap(); sess.host_key().unwrap(); let methods = sess.auth_methods(&user).unwrap(); assert!(methods.contains("publickey"), "{}", methods); @@ -52,7 +53,8 @@ fn keyboard_interactive() { let user = env::var("USER").unwrap(); let socket = ::socket(); let mut sess = Session::new().unwrap(); - sess.handshake(socket).unwrap(); + sess.set_tcp_stream(socket); + sess.handshake().unwrap(); sess.host_key().unwrap(); let methods = sess.auth_methods(&user).unwrap(); assert!(methods.contains("keyboard-interactive"), "{}", methods); |