diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lua.rs | 2 | ||||
-rw-r--r-- | src/luau.rs | 192 | ||||
-rw-r--r-- | src/util/mod.rs | 2 |
3 files changed, 143 insertions, 53 deletions
@@ -438,7 +438,7 @@ impl Lua { } #[cfg(feature = "luau")] - mlua_expect!(lua.prepare_luau_state(), "Error preparing Luau state"); + mlua_expect!(lua.prepare_luau_state(), "Error configuring Luau"); lua } diff --git a/src/luau.rs b/src/luau.rs index e17296e..61b846b 100644 --- a/src/luau.rs +++ b/src/luau.rs @@ -1,16 +1,22 @@ use std::ffi::CStr; +use std::fmt::Write; use std::os::raw::{c_float, c_int}; +use std::path::{PathBuf, MAIN_SEPARATOR_STR}; use std::string::String as StdString; +use std::{env, fs}; use crate::chunk::ChunkMode; -use crate::error::{Error, Result}; +use crate::error::Result; use crate::lua::Lua; use crate::table::Table; -use crate::util::{check_stack, StackGuard}; -use crate::value::Value; +use crate::types::RegistryKey; +use crate::value::{IntoLua, Value}; // Since Luau has some missing standard function, we re-implement them here +// We keep reference to the `package` table in registry under this key +struct PackageKey(RegistryKey); + impl Lua { pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> { let globals = self.globals(); @@ -19,7 +25,8 @@ impl Lua { "collectgarbage", self.create_c_function(lua_collectgarbage)?, )?; - globals.raw_set("require", self.create_function(lua_require)?)?; + globals.raw_set("require", self.create_c_function(lua_require)?)?; + globals.raw_set("package", create_package_table(self)?)?; globals.raw_set("vector", self.create_c_function(lua_vector)?)?; // Set `_VERSION` global to include version number @@ -69,56 +76,57 @@ unsafe extern "C-unwind" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_ } } -fn lua_require(lua: &Lua, name: Option<StdString>) -> Result<Value> { - let name = name.ok_or_else(|| Error::runtime("invalid module name"))?; - - // Find module in the cache - let state = lua.state(); - let loaded = unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - protect_lua!(state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(lua.pop_ref()) - }; - if let Some(v) = loaded.raw_get(name.clone())? { - return Ok(v); - } - - // Load file from filesystem - let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); - if search_path.is_empty() { - search_path = "?.luau;?.lua".into(); +unsafe extern "C-unwind" fn lua_require(state: *mut ffi::lua_State) -> c_int { + ffi::lua_settop(state, 1); + let name = ffi::luaL_checkstring(state, 1); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); // _LOADED is at index 2 + if ffi::lua_rawgetfield(state, 2, name) != ffi::LUA_TNIL { + return 1; // module is already loaded } - - let (mut source, mut source_name) = (None, String::new()); - for path in search_path.split(';') { - let file_path = path.replacen('?', &name, 1); - if let Ok(buf) = std::fs::read(&file_path) { - source = Some(buf); - source_name = file_path; - break; + ffi::lua_pop(state, 1); // remove nil + + // load the module + let err_buf = ffi::lua_newuserdata_t::<StdString>(state); + err_buf.write(StdString::new()); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADERS")); // _LOADERS is at index 3 + for i in 1.. { + if ffi::lua_rawgeti(state, -1, i) == ffi::LUA_TNIL { + // no more loaders? + if (*err_buf).is_empty() { + ffi::luaL_error(state, cstr!("module '%s' not found"), name); + } else { + let bytes = (*err_buf).as_bytes(); + let extra = ffi::lua_pushlstring(state, bytes.as_ptr() as *const _, bytes.len()); + ffi::luaL_error(state, cstr!("module '%s' not found:%s"), name, extra); + } } + ffi::lua_pushvalue(state, 1); // name arg + ffi::lua_call(state, 1, 2); // call loader + match ffi::lua_type(state, -2) { + ffi::LUA_TFUNCTION => break, // loader found + ffi::LUA_TSTRING => { + // error message + let msg = ffi::lua_tostring(state, -2); + let msg = CStr::from_ptr(msg).to_string_lossy(); + _ = write!(&mut *err_buf, "\n\t{msg}"); + } + _ => {} + } + ffi::lua_pop(state, 2); // remove both results + } + ffi::lua_pushvalue(state, 1); // name is 1st argument to module loader + ffi::lua_rotate(state, -2, 1); // loader data <-> name + + // stack: ...; loader function; module name; loader data + ffi::lua_call(state, 2, 1); + // stack: ...; result from loader function + if ffi::lua_isnil(state, -1) != 0 { + ffi::lua_pop(state, 1); + ffi::lua_pushboolean(state, 1); // use true as result } - let source = source.ok_or_else(|| Error::runtime(format!("cannot find '{name}'")))?; - - let value = lua - .load(&source) - .set_name(&format!("={source_name}")) - .set_mode(ChunkMode::Text) - .call::<_, Value>(())?; - - // Save in the cache - loaded.raw_set( - name, - match value.clone() { - Value::Nil => Value::Boolean(true), - v => v, - }, - )?; - - Ok(value) + ffi::lua_pushvalue(state, -1); // make copy of entrypoint result + ffi::lua_setfield(state, 2, name); /* _LOADED[name] = returned value */ + 1 } // Luau vector datatype constructor @@ -135,3 +143,85 @@ unsafe extern "C-unwind" fn lua_vector(state: *mut ffi::lua_State) -> c_int { ffi::lua_pushvector(state, x, y, z, w); 1 } + +// +// package module +// + +fn create_package_table(lua: &Lua) -> Result<Table> { + // Create the package table and store it in app_data for later use (bypassing globals lookup) + let package = lua.create_table()?; + lua.set_app_data(PackageKey(lua.create_registry_value(package.clone())?)); + + // Set `package.path` + let mut search_path = env::var("LUAU_PATH") + .or_else(|_| env::var("LUA_PATH")) + .unwrap_or_default(); + if search_path.is_empty() { + search_path = "?.luau;?.lua".to_string(); + } + package.raw_set("path", search_path)?; + + // Set `package.loaded` (table with a list of loaded modules) + let loaded = lua.create_table()?; + package.raw_set("loaded", loaded.clone())?; + lua.set_named_registry_value("_LOADED", loaded)?; + + // Set `package.loaders` + let loaders = lua.create_sequence_from([lua.create_function(lua_loader)?])?; + package.raw_set("loaders", loaders.clone())?; + lua.set_named_registry_value("_LOADERS", loaders)?; + + Ok(package) +} + +/// Searches for the given `name` in the given `path`. +/// +/// `path` is a string containing a sequence of templates separated by semicolons. +fn package_searchpath(name: &str, search_path: &str, try_prefix: bool) -> Option<PathBuf> { + let mut names = vec![name.replace('.', MAIN_SEPARATOR_STR)]; + if try_prefix && name.contains('.') { + let prefix = name.split_once('.').map(|(prefix, _)| prefix).unwrap(); + names.push(prefix.to_string()); + } + for path in search_path.split(';') { + for name in &names { + let file_path = PathBuf::from(path.replace('?', name)); + if let Ok(true) = fs::metadata(&file_path).map(|m| m.is_file()) { + return Some(file_path); + } + } + } + None +} + +// +// Module loaders +// + +/// Tries to load a lua (text) file +fn lua_loader(lua: &Lua, modname: StdString) -> Result<Value> { + let package = { + let key = lua.app_data_ref::<PackageKey>().unwrap(); + lua.registry_value::<Table>(&key.0) + }?; + let search_path = package.get::<_, StdString>("path").unwrap_or_default(); + + if let Some(file_path) = package_searchpath(&modname, &search_path, false) { + match fs::read(&file_path) { + Ok(buf) => { + return lua + .load(&buf) + .set_name(&format!("={}", file_path.display())) + .set_mode(ChunkMode::Text) + .into_function() + .map(Value::Function); + } + Err(err) => { + return format!("cannot open '{}': {err}", file_path.display()).into_lua(lua); + } + } + } + + Ok(Value::Nil) +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 6043c24..46960bd 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -997,7 +997,7 @@ impl WrappedFailure { #[cfg(feature = "luau")] let ud = ffi::lua_newuserdata_t::<Self>(state); #[cfg(not(feature = "luau"))] - let ud = ffi::lua_newuserdata(state, std::mem::size_of::<WrappedFailure>()) as *mut Self; + let ud = ffi::lua_newuserdata(state, std::mem::size_of::<Self>()) as *mut Self; ptr::write(ud, WrappedFailure::None); ud } |