diff options
author | Alex Orlenko <zxteam@protonmail.com> | 2023-08-03 00:56:17 +0100 |
---|---|---|
committer | Alex Orlenko <zxteam@protonmail.com> | 2023-08-03 00:56:17 +0100 |
commit | cd0c8a4584401a68dc1141fe3b654eb647be27d0 (patch) | |
tree | 8045cda444dfdec6c898f563ae14a146960ea7d3 /src/thread.rs | |
parent | 4fff14a14467c5cd95f85d5e6980e808ab82cffd (diff) | |
download | mlua-cd0c8a4584401a68dc1141fe3b654eb647be27d0.zip |
Optimize async functionality:
Rewrite using the new `push_into_stack()`/`from_stack()` methods.
Also store thread state (pointer) in `Thread` struct to avoid getting it every time.
Async userdata methods still need to have arguments stored in ref thread as stack is empty on every poll().
Diffstat (limited to 'src/thread.rs')
-rw-r--r-- | src/thread.rs | 245 |
1 files changed, 126 insertions, 119 deletions
diff --git a/src/thread.rs b/src/thread.rs index f062957..cdce77c 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -1,4 +1,3 @@ -use std::cmp; use std::os::raw::c_int; use crate::error::{Error, Result}; @@ -16,10 +15,7 @@ use crate::{ #[cfg(feature = "async")] use { - crate::{ - lua::ASYNC_POLL_PENDING, - value::{MultiValue, Value}, - }, + crate::{lua::ASYNC_POLL_PENDING, value::MultiValue}, futures_util::stream::Stream, std::{ future::Future, @@ -47,7 +43,7 @@ pub enum ThreadStatus { /// Handle to an internal Lua thread (or coroutine). #[derive(Clone, Debug)] -pub struct Thread<'lua>(pub(crate) LuaRef<'lua>); +pub struct Thread<'lua>(pub(crate) LuaRef<'lua>, pub(crate) *mut ffi::lua_State); /// Thread (coroutine) representation as an async [`Future`] or [`Stream`]. /// @@ -66,6 +62,16 @@ pub struct AsyncThread<'lua, R> { } impl<'lua> Thread<'lua> { + #[inline(always)] + pub(crate) fn new(r#ref: LuaRef<'lua>) -> Self { + let state = unsafe { ffi::lua_tothread(r#ref.lua.ref_thread(), r#ref.index) }; + Thread(r#ref, state) + } + + const fn state(&self) -> *mut ffi::lua_State { + self.1 + } + /// Resumes execution of this thread. /// /// Equivalent to `coroutine.resume`. @@ -114,60 +120,59 @@ impl<'lua> Thread<'lua> { { let lua = self.0.lua; let state = lua.state(); - - let mut args = args.into_lua_multi(lua)?; - let nargs = args.len() as c_int; - let results = unsafe { + let thread_state = self.state(); + unsafe { let _sg = StackGuard::new(state); - check_stack(state, cmp::max(nargs + 1, 3))?; + let _thread_sg = StackGuard::with_top(thread_state, 0); - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); + let nresults = self.resume_inner(args)?; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { - return Err(Error::CoroutineInactive); - } + R::from_stack_multi(nresults, lua) + } + } + /// Resumes execution of this thread. + /// + /// It's similar to `resume()` but leaves `nresults` values on the thread stack. + unsafe fn resume_inner<A: IntoLuaMulti<'lua>>(&self, args: A) -> Result<c_int> { + let lua = self.0.lua; + let state = lua.state(); + let thread_state = self.state(); + + if self.status() != ThreadStatus::Resumable { + return Err(Error::CoroutineInactive); + } + + let nargs = args.push_into_stack_multi(lua)?; + if nargs > 0 { check_stack(thread_state, nargs)?; - for arg in args.drain_all() { - lua.push_value(arg)?; - } ffi::lua_xmove(state, thread_state, nargs); + } - let mut nresults = 0; - - let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); - if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - if ret == ffi::LUA_ERRMEM { - // Don't call error handler for memory errors - return Err(pop_error(thread_state, ret)); - } - check_stack(state, 3)?; - protect_lua!(state, 0, 1, |state| error_traceback_thread( - state, - thread_state - ))?; - return Err(pop_error(state, ret)); + let mut nresults = 0; + let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); + if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { + if ret == ffi::LUA_ERRMEM { + // Don't call error handler for memory errors + return Err(pop_error(thread_state, ret)); } + check_stack(state, 3)?; + protect_lua!(state, 0, 1, |state| error_traceback_thread( + state, + thread_state + ))?; + return Err(pop_error(state, ret)); + } - let mut results = args; // Reuse MultiValue container - check_stack(state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below - ffi::lua_xmove(thread_state, state, nresults); - - for _ in 0..nresults { - results.push_front(lua.pop_value()); - } - results - }; - R::from_lua_multi(results, lua) + Ok(nresults) } /// Gets the status of the thread. pub fn status(&self) -> ThreadStatus { - let lua = self.0.lua; + let thread_state = self.state(); unsafe { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); - let status = ffi::lua_status(thread_state); if status != ffi::LUA_OK && status != ffi::LUA_YIELD { ThreadStatus::Error @@ -191,8 +196,7 @@ impl<'lua> Thread<'lua> { { let lua = self.0.lua; unsafe { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); - lua.set_thread_hook(thread_state, triggers, callback); + lua.set_thread_hook(self.state(), triggers, callback); } } @@ -214,18 +218,12 @@ impl<'lua> Thread<'lua> { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn reset(&self, func: crate::function::Function<'lua>) -> Result<()> { let lua = self.0.lua; - let state = lua.state(); + let thread_state = self.state(); unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - - lua.push_ref(&self.0); - let thread_state = ffi::lua_tothread(state, -1); - #[cfg(all(feature = "lua54", not(feature = "vendored")))] let status = ffi::lua_resetthread(thread_state); #[cfg(all(feature = "lua54", feature = "vendored"))] - let status = ffi::lua_closethread(thread_state, state); + let status = ffi::lua_closethread(thread_state, lua.state()); #[cfg(feature = "lua54")] if status != ffi::LUA_OK { return Err(pop_error(thread_state, status)); @@ -233,8 +231,8 @@ impl<'lua> Thread<'lua> { #[cfg(feature = "luau")] ffi::lua_resetthread(thread_state); - lua.push_ref(&func.0); - ffi::lua_xmove(state, thread_state, 1); + // Push function to the top of the thread stack + ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); #[cfg(feature = "luau")] { @@ -345,11 +343,11 @@ impl<'lua> Thread<'lua> { pub fn sandbox(&self) -> Result<()> { let lua = self.0.lua; let state = lua.state(); + let thread_state = self.state(); unsafe { - let thread = ffi::lua_tothread(lua.ref_thread(), self.0.index); - check_stack(thread, 3)?; + check_stack(thread_state, 3)?; check_stack(state, 3)?; - protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread)) + protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state)) } } } @@ -379,11 +377,10 @@ impl<'lua, R> Drop for AsyncThread<'lua, R> { if !lua.recycle_thread(&mut self.thread) { #[cfg(feature = "lua54")] if self.thread.status() == ThreadStatus::Error { - let thread_state = ffi::lua_tothread(lua.ref_thread(), self.thread.0.index); #[cfg(not(feature = "vendored"))] - ffi::lua_resetthread(thread_state); + ffi::lua_resetthread(self.thread.state()); #[cfg(feature = "vendored")] - ffi::lua_closethread(thread_state, lua.state()); + ffi::lua_closethread(self.thread.state(), lua.state()); } } } @@ -399,29 +396,36 @@ where type Item = Result<R>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + if self.thread.status() != ThreadStatus::Resumable { + return Poll::Ready(None); + } + let lua = self.thread.0.lua; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(lua, cx.waker()); - match self.thread.status() { - ThreadStatus::Resumable => {} - _ => return Poll::Ready(None), - }; + // This is safe as we are not moving the whole struct + let this = self.get_unchecked_mut(); + let nresults = if let Some(args) = this.init_args.take() { + this.thread.resume_inner(args?)? + } else { + this.thread.resume_inner(())? + }; - let _wg = WakerGuard::new(lua, cx.waker()); + if nresults == 1 && is_poll_pending(thread_state) { + return Poll::Pending; + } - // This is safe as we are not moving the whole struct - let this = unsafe { self.get_unchecked_mut() }; - let ret: MultiValue = if let Some(args) = this.init_args.take() { - this.thread.resume(args?)? - } else { - this.thread.resume(())? - }; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - if is_poll_pending(&ret) { - return Poll::Pending; + cx.waker().wake_by_ref(); + Poll::Ready(Some(R::from_stack_multi(nresults, lua))) } - - cx.waker().wake_by_ref(); - Poll::Ready(Some(R::from_lua_multi(ret, lua))) } } @@ -433,46 +437,53 @@ where type Output = Result<R>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + if self.thread.status() != ThreadStatus::Resumable { + return Poll::Ready(Err(Error::CoroutineInactive)); + } + let lua = self.thread.0.lua; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(lua, cx.waker()); - match self.thread.status() { - ThreadStatus::Resumable => {} - _ => return Poll::Ready(Err(Error::CoroutineInactive)), - }; + // This is safe as we are not moving the whole struct + let this = self.get_unchecked_mut(); + let nresults = if let Some(args) = this.init_args.take() { + this.thread.resume_inner(args?)? + } else { + this.thread.resume_inner(())? + }; - let _wg = WakerGuard::new(lua, cx.waker()); + if nresults == 1 && is_poll_pending(thread_state) { + return Poll::Pending; + } - // This is safe as we are not moving the whole struct - let this = unsafe { self.get_unchecked_mut() }; - let ret: MultiValue = if let Some(args) = this.init_args.take() { - this.thread.resume(args?)? - } else { - this.thread.resume(())? - }; + if ffi::lua_status(thread_state) == ffi::LUA_YIELD { + // Ignore value returned via yield() + cx.waker().wake_by_ref(); + return Poll::Pending; + } - if is_poll_pending(&ret) { - return Poll::Pending; - } + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - if let ThreadStatus::Resumable = this.thread.status() { - // Ignore value returned via yield() - cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Ready(R::from_stack_multi(nresults, lua)) } - - Poll::Ready(R::from_lua_multi(ret, lua)) } } #[cfg(feature = "async")] #[inline(always)] -fn is_poll_pending(val: &MultiValue) -> bool { - match val.iter().enumerate().last() { - Some((0, Value::LightUserData(ud))) => { - std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8) - } - _ => false, +unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool { + if ffi::lua_islightuserdata(state, -1) != 0 { + let stack_ptr = ffi::lua_touserdata(state, -1) as *const u8; + let pending_ptr = &ASYNC_POLL_PENDING as *const u8; + return std::ptr::eq(stack_ptr, pending_ptr); } + false } #[cfg(feature = "async")] @@ -486,23 +497,19 @@ struct WakerGuard<'lua, 'a> { impl<'lua, 'a> WakerGuard<'lua, 'a> { #[inline] pub fn new(lua: &'lua Lua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> { - unsafe { - let prev = lua.set_waker(NonNull::from(waker)); - Ok(WakerGuard { - lua, - prev, - _phantom: PhantomData, - }) - } + let prev = unsafe { lua.set_waker(NonNull::from(waker)) }; + Ok(WakerGuard { + lua, + prev, + _phantom: PhantomData, + }) } } #[cfg(feature = "async")] impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> { fn drop(&mut self) { - unsafe { - self.lua.set_waker(self.prev); - } + unsafe { self.lua.set_waker(self.prev) }; } } |