diff options
author | Alex Orlenko <zxteam@protonmail.com> | 2023-03-25 16:30:31 +0000 |
---|---|---|
committer | Alex Orlenko <zxteam@protonmail.com> | 2023-03-25 16:30:31 +0000 |
commit | 742307a2676773e2908109d7c6f9e603593aa406 (patch) | |
tree | 97af42da0bf2c8532fbe009d047dcfbd56a50071 | |
parent | 781ded573a8d8ecd36ab8c277d53f654b233e0ce (diff) | |
download | mlua-742307a2676773e2908109d7c6f9e603593aa406.zip |
Add &Lua to luau interrupt callback (fixes #197)
-rw-r--r-- | src/lua.rs | 7 | ||||
-rw-r--r-- | src/types.rs | 4 | ||||
-rw-r--r-- | tests/luau.rs | 6 |
3 files changed, 9 insertions, 8 deletions
@@ -971,7 +971,7 @@ impl Lua { /// # fn main() -> Result<()> { /// let lua = Lua::new(); /// let count = Arc::new(AtomicU64::new(0)); - /// lua.set_interrupt(move || { + /// lua.set_interrupt(move |_| { /// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 { /// return Ok(VmState::Yield); /// } @@ -995,7 +995,7 @@ impl Lua { #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_interrupt<F>(&self, callback: F) where - F: 'static + MaybeSend + Fn() -> Result<VmState>, + F: Fn(&Lua) -> Result<VmState> + MaybeSend + 'static, { unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) { if gc >= 0 { @@ -1013,7 +1013,8 @@ impl Lua { if Arc::strong_count(&interrupt_cb) > 2 { return Ok(VmState::Continue); // Don't allow recursion } - interrupt_cb() + let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + interrupt_cb(lua) }); match result { VmState::Continue => {} diff --git a/src/types.rs b/src/types.rs index 1e81ea0..a0bf8e6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -63,10 +63,10 @@ pub(crate) type HookCallback = Arc<dyn Fn(&Lua, Debug) -> Result<()> + Send>; pub(crate) type HookCallback = Arc<dyn Fn(&Lua, Debug) -> Result<()>>; #[cfg(all(feature = "luau", feature = "send"))] -pub(crate) type InterruptCallback = Arc<dyn Fn() -> Result<VmState> + Send>; +pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState> + Send>; #[cfg(all(feature = "luau", not(feature = "send")))] -pub(crate) type InterruptCallback = Arc<dyn Fn() -> Result<VmState>>; +pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState>>; #[cfg(all(feature = "send", feature = "lua54"))] pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &CStr, bool) -> Result<()> + Send>; diff --git a/tests/luau.rs b/tests/luau.rs index bb8bff0..05390c7 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -173,7 +173,7 @@ fn test_interrupts() -> Result<()> { let interrupts_count = Arc::new(AtomicU64::new(0)); let interrupts_count2 = interrupts_count.clone(); - lua.set_interrupt(move || { + lua.set_interrupt(move |_| { interrupts_count2.fetch_add(1, Ordering::Relaxed); Ok(VmState::Continue) }); @@ -195,7 +195,7 @@ fn test_interrupts() -> Result<()> { // let yield_count = Arc::new(AtomicU64::new(0)); let yield_count2 = yield_count.clone(); - lua.set_interrupt(move || { + lua.set_interrupt(move |_| { if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 { return Ok(VmState::Yield); } @@ -222,7 +222,7 @@ fn test_interrupts() -> Result<()> { // // Test errors in interrupts // - lua.set_interrupt(|| Err(Error::RuntimeError("error from interrupt".into()))); + lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into()))); match f.call::<_, ()>(()) { Err(Error::CallbackError { cause, .. }) => match *cause { Error::RuntimeError(ref m) if m == "error from interrupt" => {} |