diff options
author | Alex Orlenko <zxteam@protonmail.com> | 2023-08-03 00:56:17 +0100 |
---|---|---|
committer | Alex Orlenko <zxteam@protonmail.com> | 2023-08-03 00:56:17 +0100 |
commit | cd0c8a4584401a68dc1141fe3b654eb647be27d0 (patch) | |
tree | 8045cda444dfdec6c898f563ae14a146960ea7d3 /src | |
parent | 4fff14a14467c5cd95f85d5e6980e808ab82cffd (diff) | |
download | mlua-cd0c8a4584401a68dc1141fe3b654eb647be27d0.zip |
Optimize async functionality:
Rewrite using the new `push_into_stack()`/`from_stack()` methods.
Also store thread state (pointer) in `Thread` struct to avoid getting it every time.
Async userdata methods still need to have arguments stored in ref thread as stack is empty on every poll().
Diffstat (limited to 'src')
-rw-r--r-- | src/function.rs | 6 | ||||
-rw-r--r-- | src/lua.rs | 52 | ||||
-rw-r--r-- | src/scope.rs | 8 | ||||
-rw-r--r-- | src/thread.rs | 245 | ||||
-rw-r--r-- | src/types.rs | 9 | ||||
-rw-r--r-- | src/userdata.rs | 10 | ||||
-rw-r--r-- | src/userdata_impl.rs | 199 | ||||
-rw-r--r-- | src/util/mod.rs | 5 | ||||
-rw-r--r-- | src/value.rs | 10 |
9 files changed, 259 insertions, 285 deletions
diff --git a/src/function.rs b/src/function.rs index b4c7965..ce6970b 100644 --- a/src/function.rs +++ b/src/function.rs @@ -598,13 +598,13 @@ impl<'lua> Function<'lua> { F: Fn(&'lua Lua, A) -> FR + MaybeSend + 'static, FR: Future<Output = Result<R>> + 'lua, { - WrappedAsyncFunction(Box::new(move |lua, args| { - let args = match A::from_lua_multi(args, lua) { + WrappedAsyncFunction(Box::new(move |lua, args| unsafe { + let args = match A::from_lua_args(args, 1, None, lua) { Ok(args) => args, Err(e) => return Box::pin(future::err(e)), }; let fut = func(lua, args); - Box::pin(async move { fut.await?.into_lua_multi(lua) }) + Box::pin(async move { fut.await?.push_into_stack_multi(lua) }) })) } } @@ -1600,13 +1600,13 @@ impl Lua { F: Fn(&'lua Lua, A) -> FR + MaybeSend + 'static, FR: Future<Output = Result<R>> + 'lua, { - self.create_async_callback(Box::new(move |lua, args| { + self.create_async_callback(Box::new(move |lua, args| unsafe { let args = match A::from_lua_args(args, 1, None, lua) { Ok(args) => args, Err(e) => return Box::pin(future::err(e)), }; let fut = func(lua, args); - Box::pin(async move { fut.await?.into_lua_multi(lua) }) + Box::pin(async move { fut.await?.push_into_stack_multi(lua) }) })) } @@ -1634,7 +1634,7 @@ impl Lua { self.push_ref(&func.0); ffi::lua_xmove(state, thread_state, 1); - Ok(Thread(self.pop_ref())) + Ok(Thread::new(self.pop_ref())) } } @@ -1662,7 +1662,7 @@ impl Lua { ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); } - return Ok(Thread(LuaRef::new(self, index))); + return Ok(Thread::new(LuaRef::new(self, index))); } }; self.create_thread_inner(func) @@ -1817,7 +1817,7 @@ impl Lua { let _sg = StackGuard::new(state); assert_stack(state, 1); ffi::lua_pushthread(state); - Thread(self.pop_ref()) + Thread::new(self.pop_ref()) } } @@ -2412,7 +2412,7 @@ impl Lua { } } - ffi::LUA_TTHREAD => Value::Thread(Thread(self.pop_ref())), + ffi::LUA_TTHREAD => Value::Thread(Thread::new(self.pop_ref())), #[cfg(feature = "luajit")] ffi::LUA_TCDATA => { @@ -2510,7 +2510,7 @@ impl Lua { ffi::LUA_TTHREAD => { ffi::lua_xpush(state, self.ref_thread(), idx); - Value::Thread(Thread(self.pop_ref_thread())) + Value::Thread(Thread::new(self.pop_ref_thread())) } #[cfg(feature = "luajit")] @@ -2763,16 +2763,15 @@ impl Lua { } } - // Returns `TypeId` for the LuaRef, checking that it's a registered - // and not destructed UserData. + // Returns `TypeId` for the `lref` userdata, checking that it's registered and not destructed. // // Returns `None` if the userdata is registered but non-static. - pub(crate) unsafe fn get_userdata_type_id(&self, lref: &LuaRef) -> Result<Option<TypeId>> { + pub(crate) unsafe fn get_userdata_ref_type_id(&self, lref: &LuaRef) -> Result<Option<TypeId>> { self.get_userdata_type_id_inner(self.ref_thread(), lref.index) } - // Same as `get_userdata_type_id` but assumes the value is already on the current stack - pub(crate) unsafe fn get_userdata_type_id_stack(&self, idx: c_int) -> Result<Option<TypeId>> { + // Same as `get_userdata_ref_type_id` but assumes the userdata is already on the stack. + pub(crate) unsafe fn get_userdata_type_id(&self, idx: c_int) -> Result<Option<TypeId>> { self.get_userdata_type_id_inner(self.state(), idx) } @@ -2808,7 +2807,7 @@ impl Lua { // Pushes a LuaRef (userdata) value onto the stack, returning their `TypeId`. // Uses 1 stack space, does not call checkstack. pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result<Option<TypeId>> { - let type_id = self.get_userdata_type_id(lref)?; + let type_id = self.get_userdata_type_id_inner(self.ref_thread(), lref.index)?; self.push_ref(lref); Ok(type_id) } @@ -2898,11 +2897,7 @@ impl Lua { let lua: &Lua = mem::transmute((*extra).inner.assume_init_ref()); let _guard = StateGuard::new(&lua.0, state); - let mut args = MultiValue::with_lua_and_capacity(lua, nargs as usize); - for _ in 0..nargs { - args.push_front(lua.pop_value()); - } - + let args = MultiValue::from_stack_multi(nargs, lua)?; let func = &*(*upvalue).data; let fut = func(lua, args); let extra = Arc::clone(&(*upvalue).extra); @@ -2924,6 +2919,7 @@ impl Lua { let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1)); let extra = (*upvalue).extra.get(); callback_error_ext(state, extra, |_| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) let lua: &Lua = mem::transmute((*extra).inner.assume_init_ref()); let _guard = StateGuard::new(&lua.0, state); @@ -2931,20 +2927,20 @@ impl Lua { let mut ctx = Context::from_waker(lua.waker()); match fut.as_mut().poll(&mut ctx) { Poll::Pending => Ok(0), - Poll::Ready(results) => { - let mut results = results?; - let nresults = results.len(); - lua.push_value(Value::Integer(nresults as _))?; + Poll::Ready(nresults) => { + let nresults = nresults?; match nresults { - 0 => Ok(1), - 1 | 2 => { - // Fast path for 1 or 2 results without creating a table - for r in results.drain_all() { - lua.push_value(r)?; + 0..=2 => { + // Fast path for up to 2 results without creating a table + ffi::lua_pushinteger(state, nresults as _); + if nresults > 0 { + ffi::lua_insert(state, -nresults - 1); } - Ok(nresults as c_int + 1) + Ok(nresults + 1) } _ => { + let results = MultiValue::from_stack_multi(nresults, lua)?; + ffi::lua_pushinteger(state, nresults as _); lua.push_value(Value::Table(lua.create_sequence_from(results)?))?; Ok(2) } diff --git a/src/scope.rs b/src/scope.rs index 2126f43..fa37877 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -338,13 +338,13 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { }); scope.create_callback(f) } - NonStaticMethod::Function(function) => unsafe { + NonStaticMethod::Function(function) => { scope.create_callback(Box::new(move |lua, nargs| { let args = MultiValue::from_stack_args(nargs, 1, None, lua)?; function(lua, args)?.push_into_stack_multi(lua) })) - }, - NonStaticMethod::FunctionMut(function) => unsafe { + } + NonStaticMethod::FunctionMut(function) => { let function = RefCell::new(function); let f = Box::new(move |lua, nargs| { let mut func = function @@ -354,7 +354,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { func(lua, args)?.push_into_stack_multi(lua) }); scope.create_callback(f) - }, + } } } diff --git a/src/thread.rs b/src/thread.rs index f062957..cdce77c 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -1,4 +1,3 @@ -use std::cmp; use std::os::raw::c_int; use crate::error::{Error, Result}; @@ -16,10 +15,7 @@ use crate::{ #[cfg(feature = "async")] use { - crate::{ - lua::ASYNC_POLL_PENDING, - value::{MultiValue, Value}, - }, + crate::{lua::ASYNC_POLL_PENDING, value::MultiValue}, futures_util::stream::Stream, std::{ future::Future, @@ -47,7 +43,7 @@ pub enum ThreadStatus { /// Handle to an internal Lua thread (or coroutine). #[derive(Clone, Debug)] -pub struct Thread<'lua>(pub(crate) LuaRef<'lua>); +pub struct Thread<'lua>(pub(crate) LuaRef<'lua>, pub(crate) *mut ffi::lua_State); /// Thread (coroutine) representation as an async [`Future`] or [`Stream`]. /// @@ -66,6 +62,16 @@ pub struct AsyncThread<'lua, R> { } impl<'lua> Thread<'lua> { + #[inline(always)] + pub(crate) fn new(r#ref: LuaRef<'lua>) -> Self { + let state = unsafe { ffi::lua_tothread(r#ref.lua.ref_thread(), r#ref.index) }; + Thread(r#ref, state) + } + + const fn state(&self) -> *mut ffi::lua_State { + self.1 + } + /// Resumes execution of this thread. /// /// Equivalent to `coroutine.resume`. @@ -114,60 +120,59 @@ impl<'lua> Thread<'lua> { { let lua = self.0.lua; let state = lua.state(); - - let mut args = args.into_lua_multi(lua)?; - let nargs = args.len() as c_int; - let results = unsafe { + let thread_state = self.state(); + unsafe { let _sg = StackGuard::new(state); - check_stack(state, cmp::max(nargs + 1, 3))?; + let _thread_sg = StackGuard::with_top(thread_state, 0); - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); + let nresults = self.resume_inner(args)?; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { - return Err(Error::CoroutineInactive); - } + R::from_stack_multi(nresults, lua) + } + } + /// Resumes execution of this thread. + /// + /// It's similar to `resume()` but leaves `nresults` values on the thread stack. + unsafe fn resume_inner<A: IntoLuaMulti<'lua>>(&self, args: A) -> Result<c_int> { + let lua = self.0.lua; + let state = lua.state(); + let thread_state = self.state(); + + if self.status() != ThreadStatus::Resumable { + return Err(Error::CoroutineInactive); + } + + let nargs = args.push_into_stack_multi(lua)?; + if nargs > 0 { check_stack(thread_state, nargs)?; - for arg in args.drain_all() { - lua.push_value(arg)?; - } ffi::lua_xmove(state, thread_state, nargs); + } - let mut nresults = 0; - - let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); - if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - if ret == ffi::LUA_ERRMEM { - // Don't call error handler for memory errors - return Err(pop_error(thread_state, ret)); - } - check_stack(state, 3)?; - protect_lua!(state, 0, 1, |state| error_traceback_thread( - state, - thread_state - ))?; - return Err(pop_error(state, ret)); + let mut nresults = 0; + let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); + if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { + if ret == ffi::LUA_ERRMEM { + // Don't call error handler for memory errors + return Err(pop_error(thread_state, ret)); } + check_stack(state, 3)?; + protect_lua!(state, 0, 1, |state| error_traceback_thread( + state, + thread_state + ))?; + return Err(pop_error(state, ret)); + } - let mut results = args; // Reuse MultiValue container - check_stack(state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below - ffi::lua_xmove(thread_state, state, nresults); - - for _ in 0..nresults { - results.push_front(lua.pop_value()); - } - results - }; - R::from_lua_multi(results, lua) + Ok(nresults) } /// Gets the status of the thread. pub fn status(&self) -> ThreadStatus { - let lua = self.0.lua; + let thread_state = self.state(); unsafe { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); - let status = ffi::lua_status(thread_state); if status != ffi::LUA_OK && status != ffi::LUA_YIELD { ThreadStatus::Error @@ -191,8 +196,7 @@ impl<'lua> Thread<'lua> { { let lua = self.0.lua; unsafe { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); - lua.set_thread_hook(thread_state, triggers, callback); + lua.set_thread_hook(self.state(), triggers, callback); } } @@ -214,18 +218,12 @@ impl<'lua> Thread<'lua> { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn reset(&self, func: crate::function::Function<'lua>) -> Result<()> { let lua = self.0.lua; - let state = lua.state(); + let thread_state = self.state(); unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - - lua.push_ref(&self.0); - let thread_state = ffi::lua_tothread(state, -1); - #[cfg(all(feature = "lua54", not(feature = "vendored")))] let status = ffi::lua_resetthread(thread_state); #[cfg(all(feature = "lua54", feature = "vendored"))] - let status = ffi::lua_closethread(thread_state, state); + let status = ffi::lua_closethread(thread_state, lua.state()); #[cfg(feature = "lua54")] if status != ffi::LUA_OK { return Err(pop_error(thread_state, status)); @@ -233,8 +231,8 @@ impl<'lua> Thread<'lua> { #[cfg(feature = "luau")] ffi::lua_resetthread(thread_state); - lua.push_ref(&func.0); - ffi::lua_xmove(state, thread_state, 1); + // Push function to the top of the thread stack + ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); #[cfg(feature = "luau")] { @@ -345,11 +343,11 @@ impl<'lua> Thread<'lua> { pub fn sandbox(&self) -> Result<()> { let lua = self.0.lua; let state = lua.state(); + let thread_state = self.state(); unsafe { - let thread = ffi::lua_tothread(lua.ref_thread(), self.0.index); - check_stack(thread, 3)?; + check_stack(thread_state, 3)?; check_stack(state, 3)?; - protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread)) + protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state)) } } } @@ -379,11 +377,10 @@ impl<'lua, R> Drop for AsyncThread<'lua, R> { if !lua.recycle_thread(&mut self.thread) { #[cfg(feature = "lua54")] if self.thread.status() == ThreadStatus::Error { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.thread.0.index); #[cfg(not(feature = "vendored"))] - ffi::lua_resetthread(thread_state); + ffi::lua_resetthread(self.thread.state()); #[cfg(feature = "vendored")] - ffi::lua_closethread(thread_state, lua.state()); + ffi::lua_closethread(self.thread.state(), lua.state()); } } } @@ -399,29 +396,36 @@ where type Item = Result<R>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + if self.thread.status() != ThreadStatus::Resumable { + return Poll::Ready(None); + } + let lua = self.thread.0.lua; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(lua, cx.waker()); - match self.thread.status() { - ThreadStatus::Resumable => {} - _ => return Poll::Ready(None), - }; + // This is safe as we are not moving the whole struct + let this = self.get_unchecked_mut(); + let nresults = if let Some(args) = this.init_args.take() { + this.thread.resume_inner(args?)? + } else { + this.thread.resume_inner(())? + }; - let _wg = WakerGuard::new(lua, cx.waker()); + if nresults == 1 && is_poll_pending(thread_state) { + return Poll::Pending; + } - // This is safe as we are not moving the whole struct - let this = unsafe { self.get_unchecked_mut() }; - let ret: MultiValue = if let Some(args) = this.init_args.take() { - this.thread.resume(args?)? - } else { - this.thread.resume(())? - }; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - if is_poll_pending(&ret) { - return Poll::Pending; + cx.waker().wake_by_ref(); + Poll::Ready(Some(R::from_stack_multi(nresults, lua))) } - - cx.waker().wake_by_ref(); - Poll::Ready(Some(R::from_lua_multi(ret, lua))) } } @@ -433,46 +437,53 @@ where type Output = Result<R>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + if self.thread.status() != ThreadStatus::Resumable { + return Poll::Ready(Err(Error::CoroutineInactive)); + } + let lua = self.thread.0.lua; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(lua, cx.waker()); - match self.thread.status() { - ThreadStatus::Resumable => {} - _ => return Poll::Ready(Err(Error::CoroutineInactive)), - }; + // This is safe as we are not moving the whole struct + let this = self.get_unchecked_mut(); + let nresults = if let Some(args) = this.init_args.take() { + this.thread.resume_inner(args?)? + } else { + this.thread.resume_inner(())? + }; - let _wg = WakerGuard::new(lua, cx.waker()); + if nresults == 1 && is_poll_pending(thread_state) { + return Poll::Pending; + } - // This is safe as we are not moving the whole struct - let this = unsafe { self.get_unchecked_mut() }; - let ret: MultiValue = if let Some(args) = this.init_args.take() { - this.thread.resume(args?)? - } else { - this.thread.resume(())? - }; + if ffi::lua_status(thread_state) == ffi::LUA_YIELD { + // Ignore value returned via yield() + cx.waker().wake_by_ref(); + return Poll::Pending; + } - if is_poll_pending(&ret) { - return Poll::Pending; - } + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - if let ThreadStatus::Resumable = this.thread.status() { - // Ignore value returned via yield() - cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Ready(R::from_stack_multi(nresults, lua)) } - - Poll::Ready(R::from_lua_multi(ret, lua)) } } #[cfg(feature = "async")] #[inline(always)] -fn is_poll_pending(val: &MultiValue) -> bool { - match val.iter().enumerate().last() { - Some((0, Value::LightUserData(ud))) => { - std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8) - } - _ => false, +unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool { + if ffi::lua_islightuserdata(state, -1) != 0 { + let stack_ptr = ffi::lua_touserdata(state, -1) as *const u8; + let pending_ptr = &ASYNC_POLL_PENDING as *const u8; + return std::ptr::eq(stack_ptr, pending_ptr); } + false } #[cfg(feature = "async")] @@ -486,23 +497,19 @@ struct WakerGuard<'lua, 'a> { impl<'lua, 'a> WakerGuard<'lua, 'a> { #[inline] pub fn new(lua: &'lua Lua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> { - unsafe { - let prev = lua.set_waker(NonNull::from(waker)); - Ok(WakerGuard { - lua, - prev, - _phantom: PhantomData, - }) - } + let prev = unsafe { lua.set_waker(NonNull::from(waker)) }; + Ok(WakerGuard { + lua, + prev, + _phantom: PhantomData, + }) } } #[cfg(feature = "async")] impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> { fn drop(&mut self) { - unsafe { - self.lua.set_waker(self.prev); - } + unsafe { self.lua.set_waker(self.prev) }; } } diff --git a/src/types.rs b/src/types.rs index baf31dc..a2b38e9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -10,16 +10,13 @@ use std::{fmt, mem, ptr}; use rustc_hash::FxHashMap; -#[cfg(feature = "async")] -use futures_util::future::LocalBoxFuture; - use crate::error::Result; #[cfg(not(feature = "luau"))] use crate::hook::Debug; use crate::lua::{ExtraData, Lua}; #[cfg(feature = "async")] -use crate::value::MultiValue; +use {crate::value::MultiValue, futures_util::future::LocalBoxFuture}; #[cfg(feature = "unstable")] use {crate::lua::LuaInner, std::marker::PhantomData}; @@ -47,13 +44,13 @@ pub(crate) type CallbackUpvalue = Upvalue<Callback<'static, 'static>>; #[cfg(feature = "async")] pub(crate) type AsyncCallback<'lua, 'a> = - Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> LocalBoxFuture<'lua, Result<MultiValue<'lua>>> + 'a>; + Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> LocalBoxFuture<'lua, Result<c_int>> + 'a>; #[cfg(feature = "async")] pub(crate) type AsyncCallbackUpvalue = Upvalue<AsyncCallback<'static, 'static>>; #[cfg(feature = "async")] -pub(crate) type AsyncPollUpvalue = Upvalue<LocalBoxFuture<'static, Result<MultiValue<'static>>>>; +pub(crate) type AsyncPollUpvalue = Upvalue<LocalBoxFuture<'static, Result<c_int>>>; /// Type to set next Luau VM action after executing interrupt function. #[cfg(any(feature = "luau", doc))] diff --git a/src/userdata.rs b/src/userdata.rs index 9d9c318..2a29d7a 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1105,9 +1105,9 @@ impl<'lua> AnyUserData<'lua> { } #[cfg(feature = "async")] - #[inline(always)] + #[inline] pub(crate) fn type_id(&self) -> Result<Option<TypeId>> { - unsafe { self.0.lua.get_userdata_type_id(&self.0) } + unsafe { self.0.lua.get_userdata_ref_type_id(&self.0) } } /// Returns a type name of this `UserData` (from a metatable field). @@ -1161,7 +1161,7 @@ impl<'lua> AnyUserData<'lua> { let lua = self.0.lua; let is_serializable = || unsafe { // Userdata can be unregistered or destructed - let _ = lua.get_userdata_type_id(&self.0)?; + let _ = lua.get_userdata_ref_type_id(&self.0)?; let ud = &*get_userdata::<UserDataCell<()>>(lua.ref_thread(), self.0.index); match &*ud.0.try_borrow().map_err(|_| Error::UserDataBorrowError)? { @@ -1179,7 +1179,7 @@ impl<'lua> AnyUserData<'lua> { { let lua = self.0.lua; unsafe { - let type_id = lua.get_userdata_type_id(&self.0)?; + let type_id = lua.get_userdata_ref_type_id(&self.0)?; match type_id { Some(type_id) if type_id == TypeId::of::<T>() => { let ref_thread = lua.ref_thread(); @@ -1328,7 +1328,7 @@ impl<'lua> Serialize for AnyUserData<'lua> { let lua = self.0.lua; let data = unsafe { let _ = lua - .get_userdata_type_id(&self.0) + .get_userdata_ref_type_id(&self.0) .map_err(ser::Error::custom)?; let ud = &*get_userdata::<UserDataCell<()>>(lua.ref_thread(), self.0.index); ud.0.try_borrow() diff --git a/src/userdata_impl.rs b/src/userdata_impl.rs index 565dafe..af1dbaf 100644 --- a/src/userdata_impl.rs +++ b/src/userdata_impl.rs @@ -20,11 +20,7 @@ use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Value}; use std::rc::Rc; #[cfg(feature = "async")] -use { - crate::types::AsyncCallback, - futures_util::future::{self, TryFutureExt}, - std::future::Future, -}; +use {crate::types::AsyncCallback, futures_util::future, std::future::Future}; /// Handle to registry for userdata methods and metamethods. pub struct UserDataRegistry<'lua, T: 'static> { @@ -80,62 +76,56 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { Box::new(move |lua, nargs| unsafe { if nargs == 0 { - try_self_arg!(Err(Error::from_lua_conversion( - "missing argument", - "userdata", - None, - ))); + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); } - let call = |ud| { - // Self was at index 1, so we pass 2 here - let args = A::from_stack_args(nargs - 1, 2, Some(&name), lua)?; - method(lua, ud, args)?.push_into_stack_multi(lua) - }; + // Self was at index 1, so we pass 2 here + let args = A::from_stack_args(nargs - 1, 2, Some(&name), lua); let (state, index) = (lua.state(), -nargs); - match try_self_arg!(lua.get_userdata_type_id_stack(index)) { + match try_self_arg!(lua.get_userdata_type_id(index)) { Some(id) if id == TypeId::of::<T>() => { let ud = try_self_arg!(get_userdata_ref::<T>(state, index)); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } #[cfg(not(feature = "send"))] Some(id) if id == TypeId::of::<Rc<T>>() => { let ud = try_self_arg!(get_userdata_ref::<Rc<T>>(state, index)); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } #[cfg(not(feature = "send"))] Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Rc<RefCell<T>>>(state, index)); let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } Some(id) if id == TypeId::of::<Arc<T>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<T>>(state, index)); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<Mutex<T>>>(state, index)); let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } #[cfg(feature = "parking_lot")] Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => { let ud = get_userdata_ref::<Arc<parking_lot::Mutex<T>>>(state, index); let ud = try_self_arg!(ud); let ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<RwLock<T>>>(state, index)); let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } #[cfg(feature = "parking_lot")] Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => { let ud = get_userdata_ref::<Arc<parking_lot::RwLock<T>>>(state, index); let ud = try_self_arg!(ud); let ud = try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); - call(&ud) + method(lua, &ud, args?)?.push_into_stack_multi(lua) } _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } @@ -164,23 +154,17 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { .try_borrow_mut() .map_err(|_| Error::RecursiveMutCallback)?; if nargs == 0 { - try_self_arg!(Err(Error::from_lua_conversion( - "missing argument", - "userdata", - None, - ))); + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); } - let mut call = |ud| { - // Self was at index 1, so we pass 2 here - let args = A::from_stack_args(nargs - 1, 2, Some(&name), lua)?; - method(lua, ud, args)?.push_into_stack_multi(lua) - }; + // Self was at index 1, so we pass 2 here + let args = A::from_stack_args(nargs - 1, 2, Some(&name), lua); let (state, index) = (lua.state(), -nargs); - match try_self_arg!(lua.get_userdata_type_id_stack(index)) { + match try_self_arg!(lua.get_userdata_type_id(index)) { Some(id) if id == TypeId::of::<T>() => { let mut ud = try_self_arg!(get_userdata_mut::<T>(state, index)); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } #[cfg(not(feature = "send"))] Some(id) if id == TypeId::of::<Rc<T>>() => Err(Error::UserDataBorrowMutError), @@ -188,32 +172,32 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Rc<RefCell<T>>>(state, index)); let mut ud = try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } Some(id) if id == TypeId::of::<Arc<T>>() => Err(Error::UserDataBorrowMutError), Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Arc<Mutex<T>>>(state, index)); let mut ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } #[cfg(feature = "parking_lot")] Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => { let ud = get_userdata_mut::<Arc<parking_lot::Mutex<T>>>(state, index); let ud = try_self_arg!(ud); let mut ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowMutError)); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Arc<RwLock<T>>>(state, index)); let mut ud = try_self_arg!(ud.try_write(), Error::UserDataBorrowMutError); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } #[cfg(feature = "parking_lot")] Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => { let ud = get_userdata_mut::<Arc<parking_lot::RwLock<T>>>(state, index); let ud = try_self_arg!(ud); let mut ud = try_self_arg!(ud.try_write().ok_or(Error::UserDataBorrowMutError)); - call(&mut ud) + method(lua, &mut ud, args?)?.push_into_stack_multi(lua) } _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } @@ -233,7 +217,7 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { let name = get_function_name::<T>(name); let method = Arc::new(method); - Box::new(move |lua, mut args| { + Box::new(move |lua, mut args| unsafe { let name = name.clone(); let method = method.clone(); macro_rules! try_self_arg { @@ -246,76 +230,68 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { } Box::pin(async move { - let front = args.pop_front().ok_or_else(|| { + let this = args.pop_front().ok_or_else(|| { Error::from_lua_conversion("missing argument", "userdata", None) }); - let front = try_self_arg!(front); - let userdata: AnyUserData = try_self_arg!(AnyUserData::from_lua(front, lua)); - let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); - match try_self_arg!(userdata.type_id()) { - Some(id) if id == TypeId::of::<T>() => unsafe { + let this = try_self_arg!(AnyUserData::from_lua(try_self_arg!(this), lua)); + let args = A::from_lua_args(args, 2, Some(&name), lua); + + let (ref_thread, index) = (lua.ref_thread(), this.0.index); + match try_self_arg!(this.type_id()) { + Some(id) if id == TypeId::of::<T>() => { let ud = try_self_arg!(get_userdata_ref::<T>(ref_thread, index)); let ud = std::mem::transmute::<&T, &T>(&ud); - // Self was at index 1, so we pass 2 here - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::<Rc<T>>() => unsafe { + Some(id) if id == TypeId::of::<Rc<T>>() => { let ud = try_self_arg!(get_userdata_ref::<Rc<T>>(ref_thread, index)); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Rc<RefCell<T>>>(ref_thread, index)); let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, - Some(id) if id == TypeId::of::<Arc<T>>() => unsafe { + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } + Some(id) if id == TypeId::of::<Arc<T>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<T>>(ref_thread, index)); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, - Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe { + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } + Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<Mutex<T>>>(ref_thread, index)); let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => { let ud = get_userdata_ref::<Arc<parking_lot::Mutex<T>>>(ref_thread, index); let ud = try_self_arg!(ud); let ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, - Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => unsafe { + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } + Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => { let ud = try_self_arg!(get_userdata_ref::<Arc<RwLock<T>>>(ref_thread, index)); let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => { let ud = get_userdata_ref::<Arc<parking_lot::RwLock<T>>>(ref_thread, index); let ud = try_self_arg!(ud); let ud = try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); let ud = std::mem::transmute::<&T, &T>(&ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) @@ -335,7 +311,7 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { let name = get_function_name::<T>(name); let method = Arc::new(method); - Box::new(move |lua, mut args| { + Box::new(move |lua, mut args| unsafe { let name = name.clone(); let method = method.clone(); macro_rules! try_self_arg { @@ -348,72 +324,66 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { } Box::pin(async move { - let front = args.pop_front().ok_or_else(|| { + let this = args.pop_front().ok_or_else(|| { Error::from_lua_conversion("missing argument", "userdata", None) }); - let front = try_self_arg!(front); - let userdata: AnyUserData = try_self_arg!(AnyUserData::from_lua(front, lua)); - let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); - match try_self_arg!(userdata.type_id()) { - Some(id) if id == TypeId::of::<T>() => unsafe { + let this = try_self_arg!(AnyUserData::from_lua(try_self_arg!(this), lua)); + let args = A::from_lua_args(args, 2, Some(&name), lua); + + let (ref_thread, index) = (lua.ref_thread(), this.0.index); + match try_self_arg!(this.type_id()) { + Some(id) if id == TypeId::of::<T>() => { let mut ud = try_self_arg!(get_userdata_mut::<T>(ref_thread, index)); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - // Self was at index 1, so we pass 2 here - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(not(feature = "send"))] Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => { Err(Error::UserDataBorrowMutError) } #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Rc<RefCell<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Rc<RefCell<T>>>(ref_thread, index)); let mut ud = try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(not(feature = "send"))] Some(id) if id == TypeId::of::<Arc<T>>() => Err(Error::UserDataBorrowMutError), - Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Arc<Mutex<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Arc<Mutex<T>>>(ref_thread, index)); let mut ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Arc<parking_lot::Mutex<T>>>() => { let ud = get_userdata_mut::<Arc<parking_lot::Mutex<T>>>(ref_thread, index); let ud = try_self_arg!(ud); let mut ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowMutError)); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, - Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => unsafe { + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } + Some(id) if id == TypeId::of::<Arc<RwLock<T>>>() => { let ud = try_self_arg!(get_userdata_mut::<Arc<RwLock<T>>>(ref_thread, index)); let mut ud = try_self_arg!(ud.try_write(), Error::UserDataBorrowMutError); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => unsafe { + Some(id) if id == TypeId::of::<Arc<parking_lot::RwLock<T>>>() => { let ud = get_userdata_mut::<Arc<parking_lot::RwLock<T>>>(ref_thread, index); let ud = try_self_arg!(ud); let mut ud = try_self_arg!(ud.try_write().ok_or(Error::UserDataBorrowMutError)); let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); - let args = A::from_lua_args(args, 2, Some(&name), lua)?; - method(lua, ud, args).await?.into_lua_multi(lua) - }, + method(lua, ud, args?).await?.push_into_stack_multi(lua) + } _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) @@ -459,14 +429,13 @@ impl<'lua, T: 'static> UserDataRegistry<'lua, T> { R: IntoLuaMulti<'lua>, { let name = get_function_name::<T>(name); - Box::new(move |lua, args| { + Box::new(move |lua, args| unsafe { let args = match A::from_lua_args(args, 1, Some(&name), lua) { Ok(args) => args, Err(e) => return Box::pin(future::err(e)), }; - Box::pin( - function(lua, args).and_then(move |ret| future::ready(ret.into_lua_multi(lua))), - ) + let fut = function(lua, args); + Box::pin(async move { fut.await?.push_into_stack_multi(lua) }) }) } diff --git a/src/util/mod.rs b/src/util/mod.rs index d52119f..1351fbf 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -63,6 +63,11 @@ impl StackGuard { top: ffi::lua_gettop(state), } } + + // Same as `new()`, but allows specifying the expected stack size at the end of the scope. + pub const fn with_top(state: *mut ffi::lua_State, top: c_int) -> StackGuard { + StackGuard { state, top } + } } impl Drop for StackGuard { diff --git a/src/value.rs b/src/value.rs index 2caa98e..0931351 100644 --- a/src/value.rs +++ b/src/value.rs @@ -121,7 +121,7 @@ impl<'lua> Value<'lua> { Value::String(String(r)) | Value::Table(Table(r)) | Value::Function(Function(r)) - | Value::Thread(Thread(r)) + | Value::Thread(Thread(r, ..)) | Value::UserData(AnyUserData(r)) => r.to_pointer(), _ => ptr::null(), } @@ -143,7 +143,7 @@ impl<'lua> Value<'lua> { Value::String(s) => Ok(s.to_str()?.to_string()), Value::Table(Table(r)) | Value::Function(Function(r)) - | Value::Thread(Thread(r)) + | Value::Thread(Thread(r, ..)) | Value::UserData(AnyUserData(r)) => unsafe { let state = r.lua.state(); let _guard = StackGuard::new(state); @@ -387,7 +387,7 @@ pub struct MultiValue<'lua> { impl Drop for MultiValue<'_> { fn drop(&mut self) { if let Some(lua) = self.lua { - let vec = mem::replace(&mut self.vec, Vec::new()); + let vec = mem::take(&mut self.vec); lua.push_multivalue_to_pool(vec); } } @@ -439,7 +439,7 @@ impl<'lua> IntoIterator for MultiValue<'lua> { #[inline] fn into_iter(mut self) -> Self::IntoIter { - let vec = mem::replace(&mut self.vec, Vec::new()); + let vec = mem::take(&mut self.vec); mem::forget(self); vec.into_iter().rev() } @@ -481,7 +481,7 @@ impl<'lua> MultiValue<'lua> { #[inline] pub fn into_vec(mut self) -> Vec<Value<'lua>> { - let mut vec = mem::replace(&mut self.vec, Vec::new()); + let mut vec = mem::take(&mut self.vec); mem::forget(self); vec.reverse(); vec |