From 8200bee467bfe3f6c9f55813080d4d4dedd46d0b Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 21 Jan 2024 23:38:02 +0000 Subject: Implement IntoLua for ref to String/Table/Function/AnyUserData This would prevent cloning plus has better performance when pushing values to Lua stack (`IntoLua::push_into_stack` method) --- src/conversion.rs | 102 ++++++++++++++++++++++++++++++ src/lua.rs | 9 +++ tests/async.rs | 2 +- tests/conversion.rs | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++- tests/serde.rs | 2 +- tests/tests.rs | 37 ++--------- tests/thread.rs | 2 +- tests/userdata.rs | 6 +- 8 files changed, 295 insertions(+), 40 deletions(-) diff --git a/src/conversion.rs b/src/conversion.rs index 1321b42..74d532e 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -44,6 +44,18 @@ impl<'lua> IntoLua<'lua> for String<'lua> { } } +impl<'lua> IntoLua<'lua> for &String<'lua> { + #[inline] + fn into_lua(self, _: &'lua Lua) -> Result> { + Ok(Value::String(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_ref(&self.0)) + } +} + impl<'lua> FromLua<'lua> for String<'lua> { #[inline] fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result> { @@ -64,6 +76,18 @@ impl<'lua> IntoLua<'lua> for Table<'lua> { } } +impl<'lua> IntoLua<'lua> for &Table<'lua> { + #[inline] + fn into_lua(self, _: &'lua Lua) -> Result> { + Ok(Value::Table(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_ref(&self.0)) + } +} + impl<'lua> FromLua<'lua> for Table<'lua> { #[inline] fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { @@ -87,6 +111,20 @@ impl<'lua> IntoLua<'lua> for OwnedTable { } } +#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] +impl<'lua> IntoLua<'lua> for &OwnedTable { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + OwnedTable::into_lua(self.clone(), lua) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_owned_ref(&self.0)) + } +} + #[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] #[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] impl<'lua> FromLua<'lua> for OwnedTable { @@ -103,6 +141,18 @@ impl<'lua> IntoLua<'lua> for Function<'lua> { } } +impl<'lua> IntoLua<'lua> for &Function<'lua> { + #[inline] + fn into_lua(self, _: &'lua Lua) -> Result> { + Ok(Value::Function(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_ref(&self.0)) + } +} + impl<'lua> FromLua<'lua> for Function<'lua> { #[inline] fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { @@ -126,6 +176,20 @@ impl<'lua> IntoLua<'lua> for OwnedFunction { } } +#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] +impl<'lua> IntoLua<'lua> for &OwnedFunction { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + OwnedFunction::into_lua(self.clone(), lua) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_owned_ref(&self.0)) + } +} + #[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] #[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] impl<'lua> FromLua<'lua> for OwnedFunction { @@ -142,6 +206,18 @@ impl<'lua> IntoLua<'lua> for Thread<'lua> { } } +impl<'lua> IntoLua<'lua> for &Thread<'lua> { + #[inline] + fn into_lua(self, _: &'lua Lua) -> Result> { + Ok(Value::Thread(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_ref(&self.0)) + } +} + impl<'lua> FromLua<'lua> for Thread<'lua> { #[inline] fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { @@ -163,6 +239,18 @@ impl<'lua> IntoLua<'lua> for AnyUserData<'lua> { } } +impl<'lua> IntoLua<'lua> for &AnyUserData<'lua> { + #[inline] + fn into_lua(self, _: &'lua Lua) -> Result> { + Ok(Value::UserData(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_ref(&self.0)) + } +} + impl<'lua> FromLua<'lua> for AnyUserData<'lua> { #[inline] fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { @@ -189,6 +277,20 @@ impl<'lua> IntoLua<'lua> for OwnedAnyUserData { } } +#[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] +impl<'lua> IntoLua<'lua> for &OwnedAnyUserData { + #[inline] + fn into_lua(self, lua: &'lua Lua) -> Result> { + OwnedAnyUserData::into_lua(self.clone(), lua) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &'lua Lua) -> Result<()> { + Ok(lua.push_owned_ref(&self.0)) + } +} + #[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))] #[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))] impl<'lua> FromLua<'lua> for OwnedAnyUserData { diff --git a/src/lua.rs b/src/lua.rs index 6d22e05..358a34e 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -2558,6 +2558,15 @@ impl Lua { ffi::lua_xpush(self.ref_thread(), self.state(), lref.index); } + #[cfg(all(feature = "unstable", not(feature = "send")))] + pub(crate) unsafe fn push_owned_ref(&self, loref: &crate::types::LuaOwnedRef) { + assert!( + Arc::ptr_eq(&loref.inner, &self.0), + "Lua instance passed Value created from a different main Lua state" + ); + ffi::lua_xpush(self.ref_thread(), self.state(), loref.index); + } + // Pops the topmost element of the stack and stores a reference to it. This pins the object, // preventing garbage collection until the returned `LuaRef` is dropped. // diff --git a/tests/async.rs b/tests/async.rs index 7055914..9c4c865 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -443,7 +443,7 @@ async fn test_async_userdata() -> Result<()> { let globals = lua.globals(); let userdata = lua.create_userdata(MyUserData(11))?; - globals.set("userdata", userdata.clone())?; + globals.set("userdata", &userdata)?; lua.load( r#" diff --git a/tests/conversion.rs b/tests/conversion.rs index 6d17d0a..8dc0ad2 100644 --- a/tests/conversion.rs +++ b/tests/conversion.rs @@ -3,7 +3,180 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::ffi::{CStr, CString}; use maplit::{btreemap, btreeset, hashmap, hashset}; -use mlua::{Error, Lua, Result}; +use mlua::{AnyUserData, Error, Function, IntoLua, Lua, Result, Table, Thread, UserDataRef, Value}; + +#[test] +fn test_string_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let s2 = (&s).into_lua(&lua)?; + assert_eq!(s, s2.as_string().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("s", &s)?; + assert_eq!(s, table.get::<_, String>("s")?); + + Ok(()) +} + +#[test] +fn test_table_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let t = lua.create_table()?; + let t2 = (&t).into_lua(&lua)?; + assert_eq!(&t, t2.as_table().unwrap()); + + // Push into stack + let f = lua.create_function(|_, (t, s): (Table, String)| t.set("s", s))?; + f.call((&t, "hello"))?; + assert_eq!("hello", t.get::<_, String>("s")?); + + Ok(()) +} + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_table_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let t = lua.create_table()?.into_owned(); + let t2 = (&t).into_lua(&lua)?; + assert_eq!(t.to_ref(), *t2.as_table().unwrap()); + + // Push into stack + let f = lua.create_function(|_, (t, s): (Table, String)| t.set("s", s))?; + f.call((&t, "hello"))?; + assert_eq!("hello", t.to_ref().get::<_, String>("s")?); + + Ok(()) +} + +#[test] +fn test_function_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?; + let f2 = (&f).into_lua(&lua)?; + assert_eq!(&f, f2.as_function().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("f", &f)?; + assert_eq!(f, table.get::<_, Function>("f")?); + + Ok(()) +} + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_function_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let f = lua + .create_function(|_, ()| Ok::<_, Error>(()))? + .into_owned(); + let f2 = (&f).into_lua(&lua)?; + assert_eq!(f.to_ref(), *f2.as_function().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("f", &f)?; + assert_eq!(f.to_ref(), table.get::<_, Function>("f")?); + + Ok(()) +} + +#[test] +fn test_thread_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?; + let th = lua.create_thread(f)?; + let th2 = (&th).into_lua(&lua)?; + assert_eq!(&th, th2.as_thread().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("th", &th)?; + assert_eq!(th, table.get::<_, Thread>("th")?); + + Ok(()) +} + +#[test] +fn test_anyuserdata_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let ud = lua.create_any_userdata(String::from("hello"))?; + let ud2 = (&ud).into_lua(&lua)?; + assert_eq!(&ud, ud2.as_userdata().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("ud", &ud)?; + assert_eq!(ud, table.get::<_, AnyUserData>("ud")?); + assert_eq!("hello", *table.get::<_, UserDataRef>("ud")?); + + Ok(()) +} + +#[cfg(all(feature = "unstable", not(feature = "send")))] +#[test] +fn test_owned_anyuserdata_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let ud = lua.create_any_userdata(String::from("hello"))?.into_owned(); + let ud2 = (&ud).into_lua(&lua)?; + assert_eq!(ud.to_ref(), *ud2.as_userdata().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("ud", &ud)?; + assert_eq!(ud.to_ref(), table.get::<_, AnyUserData>("ud")?); + assert_eq!("hello", *table.get::<_, UserDataRef>("ud")?); + + Ok(()) +} + +#[test] +fn test_registry_value_into_lua() -> Result<()> { + let lua = Lua::new(); + + let t = lua.create_table()?; + let r = lua.create_registry_value(t)?; + let f = lua.create_function(|_, t: Table| t.raw_set("hello", "world"))?; + + f.call(&r)?; + let v = r.into_lua(&lua)?; + let t = v.as_table().unwrap(); + assert_eq!(t.get::<_, String>("hello")?, "world"); + + // Try to set nil registry key + let r_nil = lua.create_registry_value(Value::Nil)?; + t.set("hello", &r_nil)?; + assert_eq!(t.get::<_, Value>("hello")?, Value::Nil); + + // Check non-owned registry key + let lua2 = Lua::new(); + let r2 = lua2.create_registry_value("abc")?; + assert!(matches!( + f.call::<_, ()>(&r2), + Err(Error::MismatchedRegistryKey) + )); + + Ok(()) +} #[test] fn test_conv_vec() -> Result<()> { diff --git a/tests/serde.rs b/tests/serde.rs index 1fdd709..168407f 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -598,7 +598,7 @@ fn test_from_value_with_options() -> Result<(), Box> { // Check recursion when using `Serialize` impl let t = lua.create_table()?; - t.set("t", t.clone())?; + t.set("t", &t)?; assert!(serde_json::to_string(&t).is_err()); // Serialize Lua globals table diff --git a/tests/tests.rs b/tests/tests.rs index 55f283d..0310a82 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -8,8 +8,8 @@ use std::sync::Arc; use std::{error, f32, f64, fmt}; use mlua::{ - ChunkMode, Error, ExternalError, Function, IntoLua, Lua, LuaOptions, Nil, Result, StdLib, - String, Table, UserData, Value, Variadic, + ChunkMode, Error, ExternalError, Function, Lua, LuaOptions, Nil, Result, StdLib, String, Table, + UserData, Value, Variadic, }; #[cfg(not(feature = "luau"))] @@ -779,35 +779,6 @@ fn test_registry_value() -> Result<()> { Ok(()) } -#[test] -fn test_registry_value_into_lua() -> Result<()> { - let lua = Lua::new(); - - let t = lua.create_table()?; - let r = lua.create_registry_value(t)?; - let f = lua.create_function(|_, t: Table| t.raw_set("hello", "world"))?; - - f.call(&r)?; - let v = r.into_lua(&lua)?; - let t = v.as_table().unwrap(); - assert_eq!(t.get::<_, String>("hello")?, "world"); - - // Try to set nil registry key - let r_nil = lua.create_registry_value(Value::Nil)?; - t.set("hello", &r_nil)?; - assert_eq!(t.get::<_, Value>("hello")?, Value::Nil); - - // Check non-owned registry key - let lua2 = Lua::new(); - let r2 = lua2.create_registry_value("abc")?; - assert!(matches!( - f.call::<_, ()>(&r2), - Err(Error::MismatchedRegistryKey) - )); - - Ok(()) -} - #[test] fn test_drop_registry_value() -> Result<()> { struct MyUserdata(Arc<()>); @@ -994,7 +965,7 @@ fn test_recursion() -> Result<()> { Ok(()) })?; - lua.globals().set("f", f.clone())?; + lua.globals().set("f", &f)?; f.call::<_, ()>(1)?; Ok(()) @@ -1032,7 +1003,7 @@ fn test_too_many_recursions() -> Result<()> { let f = lua .create_function(move |lua, ()| lua.globals().get::<_, Function>("f")?.call::<_, ()>(()))?; - lua.globals().set("f", f.clone())?; + lua.globals().set("f", &f)?; assert!(f.call::<_, ()>(()).is_err()); Ok(()) diff --git a/tests/thread.rs b/tests/thread.rs index 9a84a2c..0473e88 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -182,7 +182,7 @@ fn test_coroutine_panic() { let thrd_main = lua.create_function(|_, ()| -> Result<()> { panic!("test_panic"); })?; - lua.globals().set("main", thrd_main.clone())?; + lua.globals().set("main", &thrd_main)?; let thrd: Thread = lua.create_thread(thrd_main)?; thrd.resume(()) }) { diff --git a/tests/userdata.rs b/tests/userdata.rs index da6c40d..141e329 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -58,7 +58,7 @@ fn test_methods() -> Result<()> { fn check_methods(lua: &Lua, userdata: AnyUserData) -> Result<()> { let globals = lua.globals(); - globals.set("userdata", userdata.clone())?; + globals.set("userdata", &userdata)?; lua.load( r#" function get_it() @@ -342,7 +342,7 @@ fn test_userdata_take() -> Result<()> { } fn check_userdata_take(lua: &Lua, userdata: AnyUserData, rc: Arc) -> Result<()> { - lua.globals().set("userdata", userdata.clone())?; + lua.globals().set("userdata", &userdata)?; assert_eq!(Arc::strong_count(&rc), 3); { let _value = userdata.borrow::()?; @@ -474,7 +474,7 @@ fn test_functions() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); let userdata = lua.create_userdata(MyUserData(42))?; - globals.set("userdata", userdata.clone())?; + globals.set("userdata", &userdata)?; lua.load( r#" function get_it() -- cgit v1.2.3