summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2021-07-04 23:51:51 +0100
committerAlex Orlenko <zxteam@protonmail.com>2021-07-05 00:03:18 +0100
commit1fe583027bce76a4b980242a61ad641e6df30a16 (patch)
treecc8043e021a7b148913aefb562a176beb74b7643
parent7b5b78fa3d187ea6346e7c532af255bf88c690fe (diff)
downloadmlua-1fe583027bce76a4b980242a61ad641e6df30a16.zip
Add new functions: `lua.load_from_function()` and `lua.create_c_function()`
This should be useful to register embedded C modules to Lua state. Provides a solution for #61
-rw-r--r--src/ffi/lua.rs5
-rw-r--r--src/lib.rs3
-rw-r--r--src/lua.rs59
-rw-r--r--tests/function.rs17
-rw-r--r--tests/tests.rs28
5 files changed, 108 insertions, 4 deletions
diff --git a/src/ffi/lua.rs b/src/ffi/lua.rs
index 82516b4..ac0ba33 100644
--- a/src/ffi/lua.rs
+++ b/src/ffi/lua.rs
@@ -80,6 +80,7 @@ pub const LUA_ERRERR: c_int = 5;
#[cfg(any(feature = "lua53", feature = "lua52"))]
pub const LUA_ERRERR: c_int = 6;
+/// A raw Lua Lua state associated with a thread.
pub type lua_State = c_void;
// basic types
@@ -121,14 +122,14 @@ pub type lua_Number = luaconf::LUA_NUMBER;
/// A Lua integer, usually equivalent to `i64`.
pub type lua_Integer = luaconf::LUA_INTEGER;
-// unsigned integer type
+/// A Lua unsigned integer, usually equivalent to `u64`.
pub type lua_Unsigned = luaconf::LUA_UNSIGNED;
// type for continuation-function contexts
#[cfg(any(feature = "lua54", feature = "lua53"))]
pub type lua_KContext = luaconf::LUA_KCONTEXT;
-/// Type for native functions that can be passed to Lua.
+/// Type for native C functions that can be passed to Lua.
pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int;
// Type for continuation functions
diff --git a/src/lib.rs b/src/lib.rs
index 80af7f4..91d002f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -98,8 +98,7 @@ mod userdata;
mod util;
mod value;
-#[doc(hidden)]
-pub use crate::ffi::lua_State;
+pub use crate::{ffi::lua_CFunction, ffi::lua_State};
pub use crate::error::{Error, ExternalError, ExternalResult, Result};
pub use crate::function::Function;
diff --git a/src/lua.rs b/src/lua.rs
index 735485d..b81e442 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -510,6 +510,52 @@ impl Lua {
res
}
+ /// Calls the Lua function `func` with the string `modname` as an argument, sets
+ /// the call result to `package.loaded[modname]` and returns copy of the result.
+ ///
+ /// If `package.loaded[modname]` value is not nil, returns copy of the value without
+ /// calling the function.
+ ///
+ /// If the function does not return a non-nil value then this method assigns true to
+ /// `package.loaded[modname]`.
+ ///
+ /// Behavior is similar to Lua's [`require`] function.
+ ///
+ /// [`require`]: https://www.lua.org/manual/5.3/manual.html#pdf-require
+ pub fn load_from_function<'lua, S, T>(
+ &'lua self,
+ modname: &S,
+ func: Function<'lua>,
+ ) -> Result<T>
+ where
+ S: AsRef<[u8]> + ?Sized,
+ T: FromLua<'lua>,
+ {
+ unsafe {
+ let _sg = StackGuard::new(self.state);
+ check_stack(self.state, 3)?;
+
+ protect_lua(self.state, 0, 1, |state| {
+ ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED"));
+ })?;
+ let loaded = Table(self.pop_ref());
+
+ let modname = self.create_string(modname)?;
+ let value = match loaded.raw_get(modname.clone())? {
+ Value::Nil => {
+ let result = match func.call(modname.clone())? {
+ Value::Nil => Value::Boolean(true),
+ res => res,
+ };
+ loaded.raw_set(modname, result.clone())?;
+ result
+ }
+ res => res,
+ };
+ T::from_lua(value, self)
+ }
+ }
+
/// Consumes and leaks `Lua` object, returning a static reference `&'static Lua`.
///
/// This function is useful when the `Lua` object is supposed to live for the remainder
@@ -1034,6 +1080,19 @@ impl Lua {
})
}
+ /// Wraps a C function, creating a callable Lua function handle to it.
+ ///
+ /// # Safety
+ /// This function is unsafe because provides a way to execute unsafe C function.
+ pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result<Function> {
+ let _sg = StackGuard::new(self.state);
+ check_stack(self.state, 3)?;
+ protect_lua(self.state, 0, 1, |state| {
+ ffi::lua_pushcfunction(state, func);
+ })?;
+ Ok(Function(self.pop_ref()))
+ }
+
/// Wraps a Rust async function or closure, creating a callable Lua function handle to it.
///
/// While executing the function Rust will poll Future and if the result is not ready, call
diff --git a/tests/function.rs b/tests/function.rs
index b29fb81..a7bdd9f 100644
--- a/tests/function.rs
+++ b/tests/function.rs
@@ -77,6 +77,23 @@ fn test_rust_function() -> Result<()> {
}
#[test]
+fn test_c_function() -> Result<()> {
+ let lua = Lua::new();
+
+ unsafe extern "C" fn c_function(state: *mut mlua::lua_State) -> std::os::raw::c_int {
+ let lua = Lua::init_from_ptr(state);
+ lua.globals().set("c_function", true).unwrap();
+ 0
+ }
+
+ let func = unsafe { lua.create_c_function(c_function)? };
+ func.call(())?;
+ assert_eq!(lua.globals().get::<_, bool>("c_function")?, true);
+
+ Ok(())
+}
+
+#[test]
fn test_dump() -> Result<()> {
let lua = unsafe { Lua::unsafe_new() };
diff --git a/tests/tests.rs b/tests/tests.rs
index d3c2d1a..9b450ae 100644
--- a/tests/tests.rs
+++ b/tests/tests.rs
@@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::iter::FromIterator;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::string::String as StdString;
+use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::{error, f32, f64, fmt};
@@ -1086,3 +1087,30 @@ fn test_jit_version() -> Result<()> {
.contains("LuaJIT"));
Ok(())
}
+
+#[test]
+fn test_load_from_function() -> Result<()> {
+ let lua = Lua::new();
+
+ let i = Arc::new(AtomicU32::new(0));
+ let i2 = i.clone();
+ let func = lua.create_function(move |lua, modname: String| {
+ i2.fetch_add(1, Ordering::Relaxed);
+ let t = lua.create_table()?;
+ t.set("__name", modname)?;
+ Ok(t)
+ })?;
+
+ let t: Table = lua.load_from_function("my_module", func.clone())?;
+ assert_eq!(t.get::<_, String>("__name")?, "my_module");
+ assert_eq!(i.load(Ordering::Relaxed), 1);
+
+ let _: Value = lua.load_from_function("my_module", func)?;
+ assert_eq!(i.load(Ordering::Relaxed), 1);
+
+ let func_nil = lua.create_function(move |_, _: String| Ok(Value::Nil))?;
+ let v: Value = lua.load_from_function("my_module2", func_nil)?;
+ assert_eq!(v, Value::Boolean(true));
+
+ Ok(())
+}