diff options
Diffstat (limited to 'src/thread.rs')
-rw-r--r-- | src/thread.rs | 73 |
1 files changed, 44 insertions, 29 deletions
diff --git a/src/thread.rs b/src/thread.rs index 811c056..4e5bd9e 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -22,7 +22,6 @@ use { }, futures_core::{future::Future, stream::Stream}, std::{ - cell::RefCell, marker::PhantomData, pin::Pin, task::{Context, Poll, Waker}, @@ -67,10 +66,9 @@ impl OwnedThread { /// [`Stream`]: futures_core::stream::Stream #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] -#[derive(Debug)] pub struct AsyncThread<'lua, R> { thread: Thread<'lua>, - args0: RefCell<Option<Result<MultiValue<'lua>>>>, + args0: Option<Result<MultiValue<'lua>>>, ret: PhantomData<R>, recycle: bool, } @@ -123,11 +121,13 @@ impl<'lua> Thread<'lua> { R: FromLuaMulti<'lua>, { let lua = self.0.lua; + let state = lua.state(); + let mut args = args.to_lua_multi(lua)?; let nargs = args.len() as c_int; let results = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, cmp::max(nargs + 1, 3))?; + let _sg = StackGuard::new(state); + check_stack(state, cmp::max(nargs + 1, 3))?; let thread_state = ffi::lua_tothread(lua.ref_thread(), self.0.index); @@ -140,23 +140,23 @@ impl<'lua> Thread<'lua> { for arg in args.drain_all() { lua.push_value(arg)?; } - ffi::lua_xmove(lua.state, thread_state, nargs); + ffi::lua_xmove(state, thread_state, nargs); let mut nresults = 0; - let ret = ffi::lua_resume(thread_state, lua.state, nargs, &mut nresults as *mut c_int); + let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - check_stack(lua.state, 3)?; - protect_lua!(lua.state, 0, 1, |state| error_traceback_thread( + check_stack(state, 3)?; + protect_lua!(state, 0, 1, |state| error_traceback_thread( state, thread_state ))?; - return Err(pop_error(lua.state, ret)); + return Err(pop_error(state, ret)); } let mut results = args; // Reuse MultiValue container - check_stack(lua.state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below - ffi::lua_xmove(thread_state, lua.state, nresults); + check_stack(state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below + ffi::lua_xmove(thread_state, state, nresults); for _ in 0..nresults { results.push_front(lua.pop_value()); @@ -205,12 +205,13 @@ impl<'lua> Thread<'lua> { ))] pub fn reset(&self, func: Function<'lua>) -> Result<()> { let lua = self.0.lua; + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; + let _sg = StackGuard::new(state); + check_stack(state, 2)?; lua.push_ref(&self.0); - let thread_state = ffi::lua_tothread(lua.state, -1); + let thread_state = ffi::lua_tothread(state, -1); #[cfg(feature = "lua54")] let status = ffi::lua_resetthread(thread_state); @@ -219,17 +220,17 @@ impl<'lua> Thread<'lua> { return Err(pop_error(thread_state, status)); } #[cfg(all(feature = "luajit", feature = "vendored"))] - ffi::lua_resetthread(lua.state, thread_state); + ffi::lua_resetthread(state, thread_state); #[cfg(feature = "luau")] ffi::lua_resetthread(thread_state); lua.push_ref(&func.0); - ffi::lua_xmove(lua.state, thread_state, 1); + ffi::lua_xmove(state, thread_state, 1); #[cfg(feature = "luau")] { // Inherit `LUA_GLOBALSINDEX` from the caller - ffi::lua_xpush(lua.state, thread_state, ffi::LUA_GLOBALSINDEX); + ffi::lua_xpush(state, thread_state, ffi::LUA_GLOBALSINDEX); ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); } @@ -292,7 +293,7 @@ impl<'lua> Thread<'lua> { let args = args.to_lua_multi(self.0.lua); AsyncThread { thread: self, - args0: RefCell::new(Some(args)), + args0: Some(args), ret: PhantomData, recycle: false, } @@ -334,14 +335,15 @@ impl<'lua> Thread<'lua> { #[doc(hidden)] pub fn sandbox(&self) -> Result<()> { let lua = self.0.lua; + let state = lua.state(); unsafe { let thread = ffi::lua_tothread(lua.ref_thread(), self.0.index); check_stack(thread, 1)?; - check_stack(lua.state, 3)?; + check_stack(state, 3)?; // Inherit `LUA_GLOBALSINDEX` from the caller - ffi::lua_xpush(lua.state, thread, ffi::LUA_GLOBALSINDEX); + ffi::lua_xpush(state, thread, ffi::LUA_GLOBALSINDEX); ffi::lua_replace(thread, ffi::LUA_GLOBALSINDEX); - protect_lua!(lua.state, 0, 0, |_| ffi::luaL_sandboxthread(thread)) + protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread)) } } @@ -406,10 +408,13 @@ where }; let _wg = WakerGuard::new(lua, cx.waker().clone()); - let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { - self.thread.resume(args?)? + + // This is safe as we are not moving the whole struct + let this = unsafe { self.get_unchecked_mut() }; + let ret: MultiValue = if let Some(args) = this.args0.take() { + this.thread.resume(args?)? } else { - self.thread.resume(())? + this.thread.resume(())? }; if is_poll_pending(&ret) { @@ -437,17 +442,20 @@ where }; let _wg = WakerGuard::new(lua, cx.waker().clone()); - let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { - self.thread.resume(args?)? + + // This is safe as we are not moving the whole struct + let this = unsafe { self.get_unchecked_mut() }; + let ret: MultiValue = if let Some(args) = this.args0.take() { + this.thread.resume(args?)? } else { - self.thread.resume(())? + this.thread.resume(())? }; if is_poll_pending(&ret) { return Poll::Pending; } - if let ThreadStatus::Resumable = self.thread.status() { + if let ThreadStatus::Resumable = this.thread.status() { // Ignore value returned via yield() cx.waker().wake_by_ref(); return Poll::Pending; @@ -493,3 +501,10 @@ impl<'lua> Drop for WakerGuard<'lua> { } } } + +#[cfg(test)] +mod assertions { + use super::*; + + static_assertions::assert_not_impl_any!(Thread: Send); +} |