summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2022-04-13 13:41:13 +0100
committerAlex Orlenko <zxteam@protonmail.com>2022-04-13 13:44:12 +0100
commit0215c31a3a05737a6cd0d985d8d8bea966d63325 (patch)
tree26e28e492850bbb8a726fda187119bb7e368adb8 /src
parent5cd82d0f6b6c462d05a2016028348c7fdcbae460 (diff)
downloadmlua-0215c31a3a05737a6cd0d985d8d8bea966d63325.zip
Refactor Lua instance structure.
The idea is to keep same Lua instance across all calls and only change context inside callbacks. This should solve #104.
Diffstat (limited to 'src')
-rw-r--r--src/lua.rs259
-rw-r--r--src/types.rs9
2 files changed, 130 insertions, 138 deletions
diff --git a/src/lua.rs b/src/lua.rs
index 81ac37d..15af0a7 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -5,6 +5,8 @@ use std::collections::HashMap;
use std::ffi::CString;
use std::fmt;
use std::marker::PhantomData;
+use std::mem::ManuallyDrop;
+use std::ops::{Deref, DerefMut};
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location};
use std::sync::{Arc, Mutex};
@@ -66,19 +68,25 @@ use {
#[cfg(feature = "serialize")]
use serde::Serialize;
-/// Top level Lua struct which holds the Lua state itself.
-pub struct Lua {
+/// Top level Lua struct which represents an instance of Lua VM.
+#[repr(transparent)]
+pub struct Lua(Arc<UnsafeCell<LuaInner>>);
+
+/// An inner Lua struct which holds a raw Lua state.
+pub struct LuaInner {
pub(crate) state: *mut ffi::lua_State,
main_state: *mut ffi::lua_State,
extra: Arc<UnsafeCell<ExtraData>>,
- ephemeral: bool,
safe: bool,
// Lua has lots of interior mutability, should not be RefUnwindSafe
_no_ref_unwind_safe: PhantomData<UnsafeCell<()>>,
}
// Data associated with the Lua.
-struct ExtraData {
+pub(crate) struct ExtraData {
+ // Same layout as `Lua`
+ inner: Option<ManuallyDrop<Arc<UnsafeCell<LuaInner>>>>,
+
registered_userdata: FxHashMap<TypeId, c_int>,
registered_userdata_mt: FxHashMap<*const c_void, Option<TypeId>>,
registry_unref_list: Arc<Mutex<Option<Vec<c_int>>>>,
@@ -90,7 +98,6 @@ struct ExtraData {
libs: StdLib,
mem_info: Option<ptr::NonNull<MemoryInfo>>,
- safe: bool, // Same as in the Lua struct
ref_thread: *mut ffi::lua_State,
ref_stack_size: c_int,
@@ -221,48 +228,52 @@ const MULTIVALUE_CACHE_SIZE: usize = 32;
/// Requires `feature = "send"`
#[cfg(feature = "send")]
#[cfg_attr(docsrs, doc(cfg(feature = "send")))]
-unsafe impl Send for Lua {}
+unsafe impl Send for LuaInner {}
-impl Drop for Lua {
+#[cfg(not(feature = "module"))]
+impl Drop for LuaInner {
fn drop(&mut self) {
unsafe {
- if !self.ephemeral {
- let extra = &mut *self.extra.get();
- let drain_iter = extra.wrapped_failures_cache.drain(..);
- #[cfg(feature = "async")]
- let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..));
- for index in drain_iter {
- ffi::lua_pushnil(extra.ref_thread);
- ffi::lua_replace(extra.ref_thread, index);
- extra.ref_free.push(index);
- }
- #[cfg(feature = "async")]
- {
- // Destroy Waker slot
- ffi::lua_pushnil(extra.ref_thread);
- ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
- extra.ref_free.push(extra.ref_waker_idx);
- }
- #[cfg(feature = "luau")]
- {
- let callbacks = ffi::lua_callbacks(self.state);
- let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
- drop(Box::from_raw(extra_ptr));
- (*callbacks).userdata = ptr::null_mut();
- }
- mlua_debug_assert!(
- ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
- && extra.ref_stack_top as usize == extra.ref_free.len(),
- "reference leak detected"
- );
- ffi::lua_close(self.main_state);
+ let extra = &mut *self.extra.get();
+ let drain_iter = extra.wrapped_failures_cache.drain(..);
+ #[cfg(feature = "async")]
+ let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..));
+ for index in drain_iter {
+ ffi::lua_pushnil(extra.ref_thread);
+ ffi::lua_replace(extra.ref_thread, index);
+ extra.ref_free.push(index);
+ }
+ #[cfg(feature = "async")]
+ {
+ // Destroy Waker slot
+ ffi::lua_pushnil(extra.ref_thread);
+ ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
+ extra.ref_free.push(extra.ref_waker_idx);
}
+ #[cfg(feature = "luau")]
+ {
+ let callbacks = ffi::lua_callbacks(self.state);
+ let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
+ drop(Box::from_raw(extra_ptr));
+ (*callbacks).userdata = ptr::null_mut();
+ }
+ mlua_debug_assert!(
+ ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
+ && extra.ref_stack_top as usize == extra.ref_free.len(),
+ "reference leak detected"
+ );
+ ffi::lua_close(self.main_state);
}
}
}
impl Drop for ExtraData {
fn drop(&mut self) {
+ #[cfg(feature = "module")]
+ unsafe {
+ ManuallyDrop::drop(&mut self.inner.take().unwrap())
+ };
+
*mlua_expect!(self.registry_unref_list.lock(), "unref list poisoned") = None;
if let Some(mem_info) = self.mem_info {
drop(unsafe { Box::from_raw(mem_info.as_ptr()) });
@@ -276,6 +287,20 @@ impl fmt::Debug for Lua {
}
}
+impl Deref for Lua {
+ type Target = LuaInner;
+
+ fn deref(&self) -> &Self::Target {
+ unsafe { &*(*self.0).get() }
+ }
+}
+
+impl DerefMut for Lua {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ unsafe { &mut *(*self.0).get() }
+ }
+}
+
impl Lua {
/// Creates a new Lua state and loads the **safe** subset of the standard libraries.
///
@@ -336,7 +361,6 @@ impl Lua {
mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules");
}
lua.safe = true;
- unsafe { (*lua.extra.get()).safe = true };
Ok(lua)
}
@@ -430,9 +454,7 @@ impl Lua {
ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1);
ffi::lua_pop(state, 1);
- let mut lua = Lua::init_from_ptr(state);
- lua.ephemeral = false;
-
+ let lua = Lua::init_from_ptr(state);
let extra = &mut *lua.extra.get();
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
@@ -544,6 +566,7 @@ impl Lua {
// Create ExtraData
let extra = Arc::new(UnsafeCell::new(ExtraData {
+ inner: None,
registered_userdata: FxHashMap::default(),
registered_userdata_mt: FxHashMap::default(),
registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))),
@@ -551,7 +574,6 @@ impl Lua {
ref_thread,
libs: StdLib::NONE,
mem_info: None,
- safe: false,
// We need 1 extra stack space to move values in and out of the ref stack.
ref_stack_size: ffi::LUA_MINSTACK - 1,
ref_stack_top,
@@ -606,14 +628,19 @@ impl Lua {
(*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void;
}
- Lua {
+ let inner = Arc::new(UnsafeCell::new(LuaInner {
state,
main_state,
- extra,
- ephemeral: true,
+ extra: Arc::clone(&extra),
safe: false,
_no_ref_unwind_safe: PhantomData,
- }
+ }));
+
+ (*extra.get()).inner = Some(ManuallyDrop::new(Arc::clone(&inner)));
+ #[cfg(not(feature = "module"))]
+ Arc::decrement_strong_count(Arc::as_ptr(&inner));
+
+ Lua(inner)
}
/// Loads the specified subset of the standard libraries into an existing Lua state.
@@ -1476,12 +1503,11 @@ impl Lua {
///
/// [`ToLua`]: crate::ToLua
/// [`ToLuaMulti`]: crate::ToLuaMulti
- pub fn create_function<'lua, 'callback, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
+ pub fn create_function<'lua, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
where
- 'lua: 'callback,
- A: FromLuaMulti<'callback>,
- R: ToLuaMulti<'callback>,
- F: 'static + MaybeSend + Fn(&'callback Lua, A) -> Result<R>,
+ A: FromLuaMulti<'lua>,
+ R: ToLuaMulti<'lua>,
+ F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result<R>,
{
self.create_callback(Box::new(move |lua, args| {
func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
@@ -1494,15 +1520,11 @@ impl Lua {
/// [`create_function`] for more information about the implementation.
///
/// [`create_function`]: #method.create_function
- pub fn create_function_mut<'lua, 'callback, A, R, F>(
- &'lua self,
- func: F,
- ) -> Result<Function<'lua>>
+ pub fn create_function_mut<'lua, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
where
- 'lua: 'callback,
- A: FromLuaMulti<'callback>,
- R: ToLuaMulti<'callback>,
- F: 'static + MaybeSend + FnMut(&'callback Lua, A) -> Result<R>,
+ A: FromLuaMulti<'lua>,
+ R: ToLuaMulti<'lua>,
+ F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result<R>,
{
let func = RefCell::new(func);
self.create_function(move |lua, args| {
@@ -1564,15 +1586,11 @@ impl Lua {
/// [`AsyncThread`]: crate::AsyncThread
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
- pub fn create_async_function<'lua, 'callback, A, R, F, FR>(
- &'lua self,
- func: F,
- ) -> Result<Function<'lua>>
+ pub fn create_async_function<'lua, A, R, F, FR>(&'lua self, func: F) -> Result<Function<'lua>>
where
- 'lua: 'callback,
- A: FromLuaMulti<'callback>,
- R: ToLuaMulti<'callback>,
- F: 'static + MaybeSend + Fn(&'callback Lua, A) -> FR,
+ A: FromLuaMulti<'lua>,
+ R: ToLuaMulti<'lua>,
+ F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR,
FR: 'lua + Future<Output = Result<R>>,
{
self.create_async_callback(Box::new(move |lua, args| {
@@ -2459,25 +2477,22 @@ impl Lua {
}
// Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the
- // Fn is 'static, otherwise it could capture 'callback arguments improperly. Without ATCs, we
+ // Fn is 'static, otherwise it could capture 'lua arguments improperly. Without ATCs, we
// cannot easily deal with the "correct" callback type of:
//
// Box<for<'lua> Fn(&'lua Lua, MultiValue<'lua>) -> Result<MultiValue<'lua>>)>
//
// So we instead use a caller provided lifetime, which without the 'static requirement would be
// unsafe.
- pub(crate) fn create_callback<'lua, 'callback>(
+ pub(crate) fn create_callback<'lua>(
&'lua self,
- func: Callback<'callback, 'static>,
- ) -> Result<Function<'lua>>
- where
- 'lua: 'callback,
- {
+ func: Callback<'lua, 'static>,
+ ) -> Result<Function<'lua>> {
unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue = get_userdata::<CallbackUpvalue>(state, ffi::lua_upvalueindex(1));
- (*upvalue).lua.extra.get()
+ (*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@@ -2492,10 +2507,10 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
- let mut lua = (*upvalue).lua.clone();
- lua.state = state;
+ let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
+ let _guard = StateGuard::new(&mut *lua.0.get(), state);
- let mut args = MultiValue::new_or_cached(&lua);
+ let mut args = MultiValue::new_or_cached(lua);
args.reserve(nargs as usize);
for _ in 0..nargs {
args.push_front(lua.pop_value());
@@ -2518,9 +2533,9 @@ impl Lua {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 4)?;
- let lua = self.clone();
let func = mem::transmute(func);
- push_gc_userdata(self.state, CallbackUpvalue { lua, func })?;
+ let extra = Arc::clone(&self.extra);
+ push_gc_userdata(self.state, CallbackUpvalue { extra, func })?;
protect_lua!(self.state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;
@@ -2530,13 +2545,10 @@ impl Lua {
}
#[cfg(feature = "async")]
- pub(crate) fn create_async_callback<'lua, 'callback>(
+ pub(crate) fn create_async_callback<'lua>(
&'lua self,
- func: AsyncCallback<'callback, 'static>,
- ) -> Result<Function<'lua>>
- where
- 'lua: 'callback,
- {
+ func: AsyncCallback<'lua, 'static>,
+ ) -> Result<Function<'lua>> {
#[cfg(any(
feature = "lua54",
feature = "lua53",
@@ -2550,28 +2562,12 @@ impl Lua {
}
}
- struct StateGuard(*mut Lua, *mut ffi::lua_State);
-
- impl StateGuard {
- unsafe fn new(lua: *mut Lua, state: *mut ffi::lua_State) -> Self {
- let orig_state = (*lua).state;
- (*lua).state = state;
- Self(lua, orig_state)
- }
- }
-
- impl Drop for StateGuard {
- fn drop(&mut self) {
- unsafe { (*self.0).state = self.1 }
- }
- }
-
unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue =
get_userdata::<AsyncCallbackUpvalue>(state, ffi::lua_upvalueindex(1));
- (*upvalue).lua.extra.get()
+ (*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@@ -2586,8 +2582,8 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
- let lua = &mut (*upvalue).lua;
- let _guard = StateGuard::new(lua, state);
+ let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
+ let _guard = StateGuard::new(&mut *lua.0.get(), state);
let mut args = MultiValue::new_or_cached(lua);
args.reserve(nargs as usize);
@@ -2596,8 +2592,8 @@ impl Lua {
}
let fut = ((*upvalue).func)(lua, args);
- let lua = lua.clone();
- push_gc_userdata(state, AsyncPollUpvalue { lua, fut })?;
+ let extra = Arc::clone(&(*upvalue).extra);
+ push_gc_userdata(state, AsyncPollUpvalue { extra, fut })?;
protect_lua!(state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, poll_future, 1);
})?;
@@ -2610,7 +2606,7 @@ impl Lua {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1));
- (*upvalue).lua.extra.get()
+ (*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@@ -2625,8 +2621,8 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
- let lua = &mut (*upvalue).lua;
- lua.state = state;
+ let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
+ let _guard = StateGuard::new(&mut *lua.0.get(), state);
// Try to get an outer poll waker
let waker = lua.waker().unwrap_or_else(noop_waker);
@@ -2657,9 +2653,9 @@ impl Lua {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 4)?;
- let lua = self.clone();
let func = mem::transmute(func);
- push_gc_userdata(self.state, AsyncCallbackUpvalue { lua, func })?;
+ let extra = Arc::clone(&self.extra);
+ push_gc_userdata(self.state, AsyncCallbackUpvalue { extra, func })?;
protect_lua!(self.state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;
@@ -2752,18 +2748,6 @@ impl Lua {
Ok(AnyUserData(self.pop_ref()))
}
- #[inline]
- pub(crate) fn clone(&self) -> Self {
- Lua {
- state: self.state,
- main_state: self.main_state,
- extra: Arc::clone(&self.extra),
- ephemeral: true,
- safe: self.safe,
- _no_ref_unwind_safe: PhantomData,
- }
- }
-
#[cfg(not(feature = "luau"))]
fn disable_c_modules(&self) -> Result<()> {
let package: Table = self.globals().get("package")?;
@@ -2794,17 +2778,9 @@ impl Lua {
pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Option<Self> {
let _sg = StackGuard::new(state);
assert_stack(state, 1);
-
let extra = extra_data(state)?;
- let safe = (*extra.get()).safe;
- Some(Lua {
- state,
- main_state: get_main_state(state).unwrap_or(state),
- extra,
- ephemeral: true,
- safe,
- _no_ref_unwind_safe: PhantomData,
- })
+ let inner = &*(*extra.get()).inner.as_ref().unwrap();
+ Some(Lua(Arc::clone(inner)))
}
#[inline]
@@ -2827,6 +2803,21 @@ impl Lua {
}
}
+struct StateGuard<'a>(&'a mut LuaInner, *mut ffi::lua_State);
+
+impl<'a> StateGuard<'a> {
+ fn new(inner: &'a mut LuaInner, mut state: *mut ffi::lua_State) -> Self {
+ mem::swap(&mut (*inner).state, &mut state);
+ Self(inner, state)
+ }
+}
+
+impl<'a> Drop for StateGuard<'a> {
+ fn drop(&mut self) {
+ mem::swap(&mut (*self.0).state, &mut self.1);
+ }
+}
+
#[cfg(feature = "luau")]
unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc<UnsafeCell<ExtraData>>;
diff --git a/src/types.rs b/src/types.rs
index a14ff3f..2d88f72 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -1,3 +1,4 @@
+use std::cell::UnsafeCell;
use std::hash::{Hash, Hasher};
use std::os::raw::{c_int, c_void};
use std::sync::{Arc, Mutex};
@@ -13,7 +14,7 @@ use crate::error::Result;
use crate::ffi;
#[cfg(not(feature = "luau"))]
use crate::hook::Debug;
-use crate::lua::Lua;
+use crate::lua::{ExtraData, Lua};
use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue;
@@ -30,7 +31,7 @@ pub(crate) type Callback<'lua, 'a> =
Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> Result<MultiValue<'lua>> + 'a>;
pub(crate) struct CallbackUpvalue<'lua> {
- pub(crate) lua: Lua,
+ pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) func: Callback<'lua, 'static>,
}
@@ -40,13 +41,13 @@ pub(crate) type AsyncCallback<'lua, 'a> =
#[cfg(feature = "async")]
pub(crate) struct AsyncCallbackUpvalue<'lua> {
- pub(crate) lua: Lua,
+ pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) func: AsyncCallback<'lua, 'static>,
}
#[cfg(feature = "async")]
pub(crate) struct AsyncPollUpvalue<'lua> {
- pub(crate) lua: Lua,
+ pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) fut: LocalBoxFuture<'lua, Result<MultiValue<'lua>>>,
}