summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/conversion.rs25
-rw-r--r--src/function.rs69
-rw-r--r--tests/async.rs14
-rw-r--r--tests/function.rs33
4 files changed, 140 insertions, 1 deletions
diff --git a/src/conversion.rs b/src/conversion.rs
index cc9227c..c4ff6c2 100644
--- a/src/conversion.rs
+++ b/src/conversion.rs
@@ -19,7 +19,14 @@ use crate::userdata::{AnyUserData, UserData};
use crate::value::{FromLua, IntoLua, Nil, Value};
#[cfg(feature = "unstable")]
-use crate::{function::OwnedFunction, table::OwnedTable, userdata::OwnedAnyUserData};
+use crate::{
+ function::{OwnedFunction, WrappedFunction},
+ table::OwnedTable,
+ userdata::OwnedAnyUserData,
+};
+
+#[cfg(all(feature = "async", feature = "unstable"))]
+use crate::function::WrappedAsyncFunction;
impl<'lua> IntoLua<'lua> for Value<'lua> {
#[inline]
@@ -129,6 +136,22 @@ impl<'lua> FromLua<'lua> for OwnedFunction {
}
}
+#[cfg(feature = "unstable")]
+impl<'lua> IntoLua<'lua> for WrappedFunction<'lua> {
+ #[inline]
+ fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
+ lua.create_callback(self.0).map(Value::Function)
+ }
+}
+
+#[cfg(all(feature = "async", feature = "unstable"))]
+impl<'lua> IntoLua<'lua> for WrappedAsyncFunction<'lua> {
+ #[inline]
+ fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
+ lua.create_async_callback(self.0).map(Value::Function)
+ }
+}
+
impl<'lua> IntoLua<'lua> for Thread<'lua> {
#[inline]
fn into_lua(self, _: &'lua Lua) -> Result<Value<'lua>> {
diff --git a/src/function.rs b/src/function.rs
index d282a50..3a0aa36 100644
--- a/src/function.rs
+++ b/src/function.rs
@@ -11,9 +11,20 @@ use crate::util::{
};
use crate::value::{FromLuaMulti, IntoLuaMulti};
+#[cfg(feature = "unstable")]
+use {
+ crate::lua::Lua,
+ crate::types::{Callback, MaybeSend},
+ crate::value::IntoLua,
+ std::cell::RefCell,
+};
+
#[cfg(feature = "async")]
use {futures_core::future::LocalBoxFuture, futures_util::future};
+#[cfg(all(feature = "async", feature = "unstable"))]
+use {crate::types::AsyncCallback, futures_core::Future, futures_util::TryFutureExt};
+
/// Handle to an internal Lua function.
#[derive(Clone, Debug)]
pub struct Function<'lua>(pub(crate) LuaRef<'lua>);
@@ -408,6 +419,64 @@ impl<'lua> PartialEq for Function<'lua> {
}
}
+#[cfg(feature = "unstable")]
+pub(crate) struct WrappedFunction<'lua>(pub(crate) Callback<'lua, 'static>);
+
+#[cfg(all(feature = "async", feature = "unstable"))]
+pub(crate) struct WrappedAsyncFunction<'lua>(pub(crate) AsyncCallback<'lua, 'static>);
+
+#[cfg(feature = "unstable")]
+#[cfg_attr(docsrs, doc(cfg(feature = "unstable")))]
+impl<'lua> Function<'lua> {
+ /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] trait.
+ #[inline]
+ pub fn wrap<F, A, R>(func: F) -> impl IntoLua<'lua>
+ where
+ F: Fn(&'lua Lua, A) -> Result<R> + MaybeSend + 'static,
+ A: FromLuaMulti<'lua>,
+ R: IntoLuaMulti<'lua>,
+ {
+ WrappedFunction(Box::new(move |lua, args| {
+ func(lua, A::from_lua_multi(args, lua)?)?.into_lua_multi(lua)
+ }))
+ }
+
+ /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait.
+ #[inline]
+ pub fn wrap_mut<F, A, R>(func: F) -> impl IntoLua<'lua>
+ where
+ F: FnMut(&'lua Lua, A) -> Result<R> + MaybeSend + 'static,
+ A: FromLuaMulti<'lua>,
+ R: IntoLuaMulti<'lua>,
+ {
+ let func = RefCell::new(func);
+ WrappedFunction(Box::new(move |lua, args| {
+ let mut func = func
+ .try_borrow_mut()
+ .map_err(|_| Error::RecursiveMutCallback)?;
+ func(lua, A::from_lua_multi(args, lua)?)?.into_lua_multi(lua)
+ }))
+ }
+
+ #[cfg(feature = "async")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
+ pub fn wrap_async<F, A, FR, R>(func: F) -> impl IntoLua<'lua>
+ where
+ F: Fn(&'lua Lua, A) -> FR + MaybeSend + 'static,
+ A: FromLuaMulti<'lua>,
+ FR: Future<Output = Result<R>> + 'lua,
+ R: IntoLuaMulti<'lua>,
+ {
+ WrappedAsyncFunction(Box::new(move |lua, args| {
+ let args = match A::from_lua_multi(args, lua) {
+ Ok(args) => args,
+ Err(e) => return Box::pin(future::err(e)),
+ };
+ Box::pin(func(lua, args).and_then(move |ret| future::ready(ret.into_lua_multi(lua))))
+ }))
+ }
+}
+
#[cfg(test)]
mod assertions {
use super::*;
diff --git a/tests/async.rs b/tests/async.rs
index fc7a631..48f3cb6 100644
--- a/tests/async.rs
+++ b/tests/async.rs
@@ -30,6 +30,20 @@ async fn test_async_function() -> Result<()> {
Ok(())
}
+#[cfg(feature = "unstable")]
+#[tokio::test]
+async fn test_async_function_wrap() -> Result<()> {
+ let lua = Lua::new();
+
+ let f = Function::wrap_async(|_, s: String| async move { Ok(s) });
+ lua.globals().set("f", f)?;
+
+ let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
+ assert_eq!(res, "hello");
+
+ Ok(())
+}
+
#[tokio::test]
async fn test_async_sleep() -> Result<()> {
let lua = Lua::new();
diff --git a/tests/function.rs b/tests/function.rs
index 6d4848b..b93fac8 100644
--- a/tests/function.rs
+++ b/tests/function.rs
@@ -167,3 +167,36 @@ fn test_function_info() -> Result<()> {
Ok(())
}
+
+#[cfg(feature = "unstable")]
+#[test]
+fn test_function_wrap() -> Result<()> {
+ use mlua::Error;
+
+ let lua = Lua::new();
+
+ lua.globals()
+ .set("f", Function::wrap(|_, s: String| Ok(s)))?;
+ lua.load(r#"assert(f("hello") == "hello")"#).exec().unwrap();
+
+ let mut _i = false;
+ lua.globals().set(
+ "f",
+ Function::wrap_mut(move |lua, ()| {
+ _i = true;
+ lua.globals().get::<_, Function>("f")?.call::<_, ()>(())
+ }),
+ )?;
+ match lua.globals().get::<_, Function>("f")?.call::<_, ()>(()) {
+ Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
+ Error::CallbackError { ref cause, .. } => match *cause.as_ref() {
+ Error::RecursiveMutCallback { .. } => {}
+ ref other => panic!("incorrect result: {other:?}"),
+ },
+ ref other => panic!("incorrect result: {other:?}"),
+ },
+ other => panic!("incorrect result: {other:?}"),
+ };
+
+ Ok(())
+}