diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/session.rs | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/src/session.rs b/src/session.rs index ab27a4f..acf9522 100644 --- a/src/session.rs +++ b/src/session.rs @@ -205,14 +205,9 @@ impl Session { /// Begin transport layer protocol negotiation with the connected host. /// - /// This session takes ownership of the socket provided. - /// You may use the tcp_stream() method to obtain a reference - /// to it later. - /// - /// It is also highly recommended that the stream provided is not used - /// concurrently elsewhere for the duration of this session as it may - /// interfere with the protocol. - pub fn handshake(&mut self, stream: TcpStream) -> Result<(), Error> { + /// You must call this after associating the session with a tcp stream + /// via the `set_tcp_stream` function. + pub fn handshake(&mut self) -> Result<(), Error> { #[cfg(windows)] unsafe fn handshake(raw: *mut raw::LIBSSH2_SESSION, stream: &TcpStream) -> libc::c_int { use std::os::windows::prelude::*; @@ -226,7 +221,16 @@ impl Session { } unsafe { - let res = handshake(self.inner.raw, &stream); + let stream = self.inner.tcp.borrow(); + + let stream = stream.as_ref().ok_or_else(|| { + Error::new( + raw::LIBSSH2_ERROR_BAD_SOCKET, + "use set_tcp_stream() to associate with a TcpStream", + ) + })?; + + let res = handshake(self.inner.raw, stream); self.rc(res)?; if res < 0 { // There are some kex related errors that don't set the @@ -235,11 +239,21 @@ impl Session { // Let's ensure that we indicate an error in this situation. return Err(Error::new(res, "Error during handshake")); } - *self.inner.tcp.borrow_mut() = Some(stream); Ok(()) } } + /// The session takes ownership of the socket provided. + /// You may use the tcp_stream() method to obtain a reference + /// to it later. + /// + /// It is also highly recommended that the stream provided is not used + /// concurrently elsewhere for the duration of this session as it may + /// interfere with the protocol. + pub fn set_tcp_stream(&mut self, stream: TcpStream) { + *self.inner.tcp.borrow_mut() = Some(stream); + } + /// Returns a reference to the stream that was associated with the Session /// by the Session::handshake method. pub fn tcp_stream(&self) -> Ref<Option<TcpStream>> { @@ -308,8 +322,7 @@ impl Session { let instruction = String::from_utf8_lossy(instruction); let prompts = unsafe { slice::from_raw_parts(prompts, num_prompts as usize) }; - let responses = - unsafe { slice::from_raw_parts_mut(responses, num_prompts as usize) }; + let responses = unsafe { slice::from_raw_parts_mut(responses, num_prompts as usize) }; let prompts: Vec<Prompt> = prompts .iter() |