diff options
-rw-r--r-- | src/error.rs | 38 | ||||
-rw-r--r-- | tests/error.rs | 17 |
2 files changed, 41 insertions, 14 deletions
diff --git a/src/error.rs b/src/error.rs index 3ccff0e..cfdd19a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -335,6 +335,10 @@ impl StdError for Error { // Given that we include source to fmt::Display implementation for `CallbackError`, this call returns nothing. Error::CallbackError { .. } => None, Error::ExternalError(ref err) => err.source(), + Error::WithContext { ref cause, .. } => match cause.as_ref() { + Error::ExternalError(err) => err.source(), + _ => None, + }, _ => None, } } @@ -353,6 +357,10 @@ impl Error { { match self { Error::ExternalError(err) => err.downcast_ref(), + Error::WithContext { cause, .. } => match cause.as_ref() { + Error::ExternalError(err) => err.downcast_ref(), + _ => None, + }, _ => None, } } @@ -414,33 +422,35 @@ pub trait ErrorContext: Sealed { impl ErrorContext for Error { fn context<C: fmt::Display>(self, context: C) -> Self { - Error::WithContext { - context: context.to_string(), - cause: Arc::new(self), + let context = context.to_string(); + match self { + Error::WithContext { cause, .. } => Error::WithContext { context, cause }, + _ => Error::WithContext { + context, + cause: Arc::new(self), + }, } } fn with_context<C: fmt::Display>(self, f: impl FnOnce(&Error) -> C) -> Self { - Error::WithContext { - context: f(&self).to_string(), - cause: Arc::new(self), + let context = f(&self).to_string(); + match self { + Error::WithContext { cause, .. } => Error::WithContext { context, cause }, + _ => Error::WithContext { + context, + cause: Arc::new(self), + }, } } } impl<T> ErrorContext for StdResult<T, Error> { fn context<C: fmt::Display>(self, context: C) -> Self { - self.map_err(|err| Error::WithContext { - context: context.to_string(), - cause: Arc::new(err), - }) + self.map_err(|err| err.context(context)) } fn with_context<C: fmt::Display>(self, f: impl FnOnce(&Error) -> C) -> Self { - self.map_err(|err| Error::WithContext { - context: f(&err).to_string(), - cause: Arc::new(err), - }) + self.map_err(|err| err.with_context(f)) } } diff --git a/tests/error.rs b/tests/error.rs index 2e88736..18c89c5 100644 --- a/tests/error.rs +++ b/tests/error.rs @@ -1,3 +1,5 @@ +use std::io; + use mlua::{Error, ErrorContext, Lua, Result}; #[test] @@ -29,5 +31,20 @@ fn test_error_context() -> Result<()> { println!("{msg2}"); assert!(msg2.contains("error converting Lua nil to String")); + // Rewrite context message and test `downcast_ref` + let func3 = lua.create_function(|_, ()| { + Err::<(), _>(Error::external(io::Error::new( + io::ErrorKind::Other, + "other", + ))) + .context("some context") + .context("some new context") + })?; + let res = func3.call::<_, ()>(()).err().unwrap(); + let Error::CallbackError { cause, .. } = &res else { unreachable!() }; + assert!(!res.to_string().contains("some context")); + assert!(res.to_string().contains("some new context")); + assert!(cause.downcast_ref::<io::Error>().is_some()); + Ok(()) } |