summaryrefslogtreecommitdiff
path: root/src/thread.rs
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2023-08-03 00:56:17 +0100
committerAlex Orlenko <zxteam@protonmail.com>2023-08-03 00:56:17 +0100
commitcd0c8a4584401a68dc1141fe3b654eb647be27d0 (patch)
tree8045cda444dfdec6c898f563ae14a146960ea7d3 /src/thread.rs
parent4fff14a14467c5cd95f85d5e6980e808ab82cffd (diff)
downloadmlua-cd0c8a4584401a68dc1141fe3b654eb647be27d0.zip
Optimize async functionality:
Rewrite using the new `push_into_stack()`/`from_stack()` methods. Also store thread state (pointer) in `Thread` struct to avoid getting it every time. Async userdata methods still need to have arguments stored in ref thread as stack is empty on every poll().
Diffstat (limited to 'src/thread.rs')
-rw-r--r--src/thread.rs245
1 files changed, 126 insertions, 119 deletions
diff --git a/src/thread.rs b/src/thread.rs
index f062957..cdce77c 100644
--- a/src/thread.rs
+++ b/src/thread.rs
@@ -1,4 +1,3 @@
-use std::cmp;
use std::os::raw::c_int;
use crate::error::{Error, Result};
@@ -16,10 +15,7 @@ use crate::{
#[cfg(feature = "async")]
use {
- crate::{
- lua::ASYNC_POLL_PENDING,
- value::{MultiValue, Value},
- },
+ crate::{lua::ASYNC_POLL_PENDING, value::MultiValue},
futures_util::stream::Stream,
std::{
future::Future,
@@ -47,7 +43,7 @@ pub enum ThreadStatus {
/// Handle to an internal Lua thread (or coroutine).
#[derive(Clone, Debug)]
-pub struct Thread<'lua>(pub(crate) LuaRef<'lua>);
+pub struct Thread<'lua>(pub(crate) LuaRef<'lua>, pub(crate) *mut ffi::lua_State);
/// Thread (coroutine) representation as an async [`Future`] or [`Stream`].
///
@@ -66,6 +62,16 @@ pub struct AsyncThread<'lua, R> {
}
impl<'lua> Thread<'lua> {
+ #[inline(always)]
+ pub(crate) fn new(r#ref: LuaRef<'lua>) -> Self {
+ let state = unsafe { ffi::lua_tothread(r#ref.lua.ref_thread(), r#ref.index) };
+ Thread(r#ref, state)
+ }
+
+ const fn state(&self) -> *mut ffi::lua_State {
+ self.1
+ }
+
/// Resumes execution of this thread.
///
/// Equivalent to `coroutine.resume`.
@@ -114,60 +120,59 @@ impl<'lua> Thread<'lua> {
{
let lua = self.0.lua;
let state = lua.state();
-
- let mut args = args.into_lua_multi(lua)?;
- let nargs = args.len() as c_int;
- let results = unsafe {
+ let thread_state = self.state();
+ unsafe {
let _sg = StackGuard::new(state);
- check_stack(state, cmp::max(nargs + 1, 3))?;
+ let _thread_sg = StackGuard::with_top(thread_state, 0);
- let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index);
+ let nresults = self.resume_inner(args)?;
+ check_stack(state, nresults + 1)?;
+ ffi::lua_xmove(thread_state, state, nresults);
- let status = ffi::lua_status(thread_state);
- if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
- return Err(Error::CoroutineInactive);
- }
+ R::from_stack_multi(nresults, lua)
+ }
+ }
+ /// Resumes execution of this thread.
+ ///
+ /// It's similar to `resume()` but leaves `nresults` values on the thread stack.
+ unsafe fn resume_inner<A: IntoLuaMulti<'lua>>(&self, args: A) -> Result<c_int> {
+ let lua = self.0.lua;
+ let state = lua.state();
+ let thread_state = self.state();
+
+ if self.status() != ThreadStatus::Resumable {
+ return Err(Error::CoroutineInactive);
+ }
+
+ let nargs = args.push_into_stack_multi(lua)?;
+ if nargs > 0 {
check_stack(thread_state, nargs)?;
- for arg in args.drain_all() {
- lua.push_value(arg)?;
- }
ffi::lua_xmove(state, thread_state, nargs);
+ }
- let mut nresults = 0;
-
- let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
- if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
- if ret == ffi::LUA_ERRMEM {
- // Don't call error handler for memory errors
- return Err(pop_error(thread_state, ret));
- }
- check_stack(state, 3)?;
- protect_lua!(state, 0, 1, |state| error_traceback_thread(
- state,
- thread_state
- ))?;
- return Err(pop_error(state, ret));
+ let mut nresults = 0;
+ let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
+ if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
+ if ret == ffi::LUA_ERRMEM {
+ // Don't call error handler for memory errors
+ return Err(pop_error(thread_state, ret));
}
+ check_stack(state, 3)?;
+ protect_lua!(state, 0, 1, |state| error_traceback_thread(
+ state,
+ thread_state
+ ))?;
+ return Err(pop_error(state, ret));
+ }
- let mut results = args; // Reuse MultiValue container
- check_stack(state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below
- ffi::lua_xmove(thread_state, state, nresults);
-
- for _ in 0..nresults {
- results.push_front(lua.pop_value());
- }
- results
- };
- R::from_lua_multi(results, lua)
+ Ok(nresults)
}
/// Gets the status of the thread.
pub fn status(&self) -> ThreadStatus {
- let lua = self.0.lua;
+ let thread_state = self.state();
unsafe {
- let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index);
-
let status = ffi::lua_status(thread_state);
if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
ThreadStatus::Error
@@ -191,8 +196,7 @@ impl<'lua> Thread<'lua> {
{
let lua = self.0.lua;
unsafe {
- let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index);
- lua.set_thread_hook(thread_state, triggers, callback);
+ lua.set_thread_hook(self.state(), triggers, callback);
}
}
@@ -214,18 +218,12 @@ impl<'lua> Thread<'lua> {
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn reset(&self, func: crate::function::Function<'lua>) -> Result<()> {
let lua = self.0.lua;
- let state = lua.state();
+ let thread_state = self.state();
unsafe {
- let _sg = StackGuard::new(state);
- check_stack(state, 2)?;
-
- lua.push_ref(&self.0);
- let thread_state = ffi::lua_tothread(state, -1);
-
#[cfg(all(feature = "lua54", not(feature = "vendored")))]
let status = ffi::lua_resetthread(thread_state);
#[cfg(all(feature = "lua54", feature = "vendored"))]
- let status = ffi::lua_closethread(thread_state, state);
+ let status = ffi::lua_closethread(thread_state, lua.state());
#[cfg(feature = "lua54")]
if status != ffi::LUA_OK {
return Err(pop_error(thread_state, status));
@@ -233,8 +231,8 @@ impl<'lua> Thread<'lua> {
#[cfg(feature = "luau")]
ffi::lua_resetthread(thread_state);
- lua.push_ref(&func.0);
- ffi::lua_xmove(state, thread_state, 1);
+ // Push function to the top of the thread stack
+ ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
#[cfg(feature = "luau")]
{
@@ -345,11 +343,11 @@ impl<'lua> Thread<'lua> {
pub fn sandbox(&self) -> Result<()> {
let lua = self.0.lua;
let state = lua.state();
+ let thread_state = self.state();
unsafe {
- let thread = ffi::lua_tothread(lua.ref_thread(), self.0.index);
- check_stack(thread, 3)?;
+ check_stack(thread_state, 3)?;
check_stack(state, 3)?;
- protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread))
+ protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state))
}
}
}
@@ -379,11 +377,10 @@ impl<'lua, R> Drop for AsyncThread<'lua, R> {
if !lua.recycle_thread(&mut self.thread) {
#[cfg(feature = "lua54")]
if self.thread.status() == ThreadStatus::Error {
- let thread_state = ffi::lua_tothread(lua.ref_thread(), self.thread.0.index);
#[cfg(not(feature = "vendored"))]
- ffi::lua_resetthread(thread_state);
+ ffi::lua_resetthread(self.thread.state());
#[cfg(feature = "vendored")]
- ffi::lua_closethread(thread_state, lua.state());
+ ffi::lua_closethread(self.thread.state(), lua.state());
}
}
}
@@ -399,29 +396,36 @@ where
type Item = Result<R>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ if self.thread.status() != ThreadStatus::Resumable {
+ return Poll::Ready(None);
+ }
+
let lua = self.thread.0.lua;
+ let state = lua.state();
+ let thread_state = self.thread.state();
+ unsafe {
+ let _sg = StackGuard::new(state);
+ let _thread_sg = StackGuard::with_top(thread_state, 0);
+ let _wg = WakerGuard::new(lua, cx.waker());
- match self.thread.status() {
- ThreadStatus::Resumable => {}
- _ => return Poll::Ready(None),
- };
+ // This is safe as we are not moving the whole struct
+ let this = self.get_unchecked_mut();
+ let nresults = if let Some(args) = this.init_args.take() {
+ this.thread.resume_inner(args?)?
+ } else {
+ this.thread.resume_inner(())?
+ };
- let _wg = WakerGuard::new(lua, cx.waker());
+ if nresults == 1 && is_poll_pending(thread_state) {
+ return Poll::Pending;
+ }
- // This is safe as we are not moving the whole struct
- let this = unsafe { self.get_unchecked_mut() };
- let ret: MultiValue = if let Some(args) = this.init_args.take() {
- this.thread.resume(args?)?
- } else {
- this.thread.resume(())?
- };
+ check_stack(state, nresults + 1)?;
+ ffi::lua_xmove(thread_state, state, nresults);
- if is_poll_pending(&ret) {
- return Poll::Pending;
+ cx.waker().wake_by_ref();
+ Poll::Ready(Some(R::from_stack_multi(nresults, lua)))
}
-
- cx.waker().wake_by_ref();
- Poll::Ready(Some(R::from_lua_multi(ret, lua)))
}
}
@@ -433,46 +437,53 @@ where
type Output = Result<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ if self.thread.status() != ThreadStatus::Resumable {
+ return Poll::Ready(Err(Error::CoroutineInactive));
+ }
+
let lua = self.thread.0.lua;
+ let state = lua.state();
+ let thread_state = self.thread.state();
+ unsafe {
+ let _sg = StackGuard::new(state);
+ let _thread_sg = StackGuard::with_top(thread_state, 0);
+ let _wg = WakerGuard::new(lua, cx.waker());
- match self.thread.status() {
- ThreadStatus::Resumable => {}
- _ => return Poll::Ready(Err(Error::CoroutineInactive)),
- };
+ // This is safe as we are not moving the whole struct
+ let this = self.get_unchecked_mut();
+ let nresults = if let Some(args) = this.init_args.take() {
+ this.thread.resume_inner(args?)?
+ } else {
+ this.thread.resume_inner(())?
+ };
- let _wg = WakerGuard::new(lua, cx.waker());
+ if nresults == 1 && is_poll_pending(thread_state) {
+ return Poll::Pending;
+ }
- // This is safe as we are not moving the whole struct
- let this = unsafe { self.get_unchecked_mut() };
- let ret: MultiValue = if let Some(args) = this.init_args.take() {
- this.thread.resume(args?)?
- } else {
- this.thread.resume(())?
- };
+ if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
+ // Ignore value returned via yield()
+ cx.waker().wake_by_ref();
+ return Poll::Pending;
+ }
- if is_poll_pending(&ret) {
- return Poll::Pending;
- }
+ check_stack(state, nresults + 1)?;
+ ffi::lua_xmove(thread_state, state, nresults);
- if let ThreadStatus::Resumable = this.thread.status() {
- // Ignore value returned via yield()
- cx.waker().wake_by_ref();
- return Poll::Pending;
+ Poll::Ready(R::from_stack_multi(nresults, lua))
}
-
- Poll::Ready(R::from_lua_multi(ret, lua))
}
}
#[cfg(feature = "async")]
#[inline(always)]
-fn is_poll_pending(val: &MultiValue) -> bool {
- match val.iter().enumerate().last() {
- Some((0, Value::LightUserData(ud))) => {
- std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8)
- }
- _ => false,
+unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool {
+ if ffi::lua_islightuserdata(state, -1) != 0 {
+ let stack_ptr = ffi::lua_touserdata(state, -1) as *const u8;
+ let pending_ptr = &ASYNC_POLL_PENDING as *const u8;
+ return std::ptr::eq(stack_ptr, pending_ptr);
}
+ false
}
#[cfg(feature = "async")]
@@ -486,23 +497,19 @@ struct WakerGuard<'lua, 'a> {
impl<'lua, 'a> WakerGuard<'lua, 'a> {
#[inline]
pub fn new(lua: &'lua Lua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> {
- unsafe {
- let prev = lua.set_waker(NonNull::from(waker));
- Ok(WakerGuard {
- lua,
- prev,
- _phantom: PhantomData,
- })
- }
+ let prev = unsafe { lua.set_waker(NonNull::from(waker)) };
+ Ok(WakerGuard {
+ lua,
+ prev,
+ _phantom: PhantomData,
+ })
}
}
#[cfg(feature = "async")]
impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> {
fn drop(&mut self) {
- unsafe {
- self.lua.set_waker(self.prev);
- }
+ unsafe { self.lua.set_waker(self.prev) };
}
}