summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2023-03-25 16:30:31 +0000
committerAlex Orlenko <zxteam@protonmail.com>2023-03-25 16:30:31 +0000
commit742307a2676773e2908109d7c6f9e603593aa406 (patch)
tree97af42da0bf2c8532fbe009d047dcfbd56a50071
parent781ded573a8d8ecd36ab8c277d53f654b233e0ce (diff)
downloadmlua-742307a2676773e2908109d7c6f9e603593aa406.zip
Add &Lua to luau interrupt callback (fixes #197)
-rw-r--r--src/lua.rs7
-rw-r--r--src/types.rs4
-rw-r--r--tests/luau.rs6
3 files changed, 9 insertions, 8 deletions
diff --git a/src/lua.rs b/src/lua.rs
index 48d40b4..212bb7a 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -971,7 +971,7 @@ impl Lua {
/// # fn main() -> Result<()> {
/// let lua = Lua::new();
/// let count = Arc::new(AtomicU64::new(0));
- /// lua.set_interrupt(move || {
+ /// lua.set_interrupt(move |_| {
/// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 {
/// return Ok(VmState::Yield);
/// }
@@ -995,7 +995,7 @@ impl Lua {
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub fn set_interrupt<F>(&self, callback: F)
where
- F: 'static + MaybeSend + Fn() -> Result<VmState>,
+ F: Fn(&Lua) -> Result<VmState> + MaybeSend + 'static,
{
unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) {
if gc >= 0 {
@@ -1013,7 +1013,8 @@ impl Lua {
if Arc::strong_count(&interrupt_cb) > 2 {
return Ok(VmState::Continue); // Don't allow recursion
}
- interrupt_cb()
+ let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
+ interrupt_cb(lua)
});
match result {
VmState::Continue => {}
diff --git a/src/types.rs b/src/types.rs
index 1e81ea0..a0bf8e6 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -63,10 +63,10 @@ pub(crate) type HookCallback = Arc<dyn Fn(&Lua, Debug) -> Result<()> + Send>;
pub(crate) type HookCallback = Arc<dyn Fn(&Lua, Debug) -> Result<()>>;
#[cfg(all(feature = "luau", feature = "send"))]
-pub(crate) type InterruptCallback = Arc<dyn Fn() -> Result<VmState> + Send>;
+pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState> + Send>;
#[cfg(all(feature = "luau", not(feature = "send")))]
-pub(crate) type InterruptCallback = Arc<dyn Fn() -> Result<VmState>>;
+pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState>>;
#[cfg(all(feature = "send", feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &CStr, bool) -> Result<()> + Send>;
diff --git a/tests/luau.rs b/tests/luau.rs
index bb8bff0..05390c7 100644
--- a/tests/luau.rs
+++ b/tests/luau.rs
@@ -173,7 +173,7 @@ fn test_interrupts() -> Result<()> {
let interrupts_count = Arc::new(AtomicU64::new(0));
let interrupts_count2 = interrupts_count.clone();
- lua.set_interrupt(move || {
+ lua.set_interrupt(move |_| {
interrupts_count2.fetch_add(1, Ordering::Relaxed);
Ok(VmState::Continue)
});
@@ -195,7 +195,7 @@ fn test_interrupts() -> Result<()> {
//
let yield_count = Arc::new(AtomicU64::new(0));
let yield_count2 = yield_count.clone();
- lua.set_interrupt(move || {
+ lua.set_interrupt(move |_| {
if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 {
return Ok(VmState::Yield);
}
@@ -222,7 +222,7 @@ fn test_interrupts() -> Result<()> {
//
// Test errors in interrupts
//
- lua.set_interrupt(|| Err(Error::RuntimeError("error from interrupt".into())));
+ lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into())));
match f.call::<_, ()>(()) {
Err(Error::CallbackError { cause, .. }) => match *cause {
Error::RuntimeError(ref m) if m == "error from interrupt" => {}