summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2022-03-30 22:01:06 +0100
committerAlex Orlenko <zxteam@protonmail.com>2022-03-30 22:01:06 +0100
commit595bc3a2b3bd44323d7bfdb0afbadcdb9b854868 (patch)
tree43bd7d1efcf523702823df77dbbe4125488bd138
parent87c10ca93dec7189330bdb0aad2daed5a1cd78fe (diff)
downloadmlua-595bc3a2b3bd44323d7bfdb0afbadcdb9b854868.zip
Support Luau interrupts (closes #138)
-rw-r--r--src/lib.rs2
-rw-r--r--src/lua.rs148
-rw-r--r--src/prelude.rs4
-rw-r--r--src/types.rs14
-rw-r--r--tests/luau.rs74
5 files changed, 232 insertions, 10 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 0ea54ac..a591fd6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -126,7 +126,7 @@ pub use crate::hook::HookTriggers;
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
-pub use crate::chunk::Compiler;
+pub use crate::{chunk::Compiler, types::VmState};
#[cfg(feature = "async")]
pub use crate::thread::AsyncThread;
diff --git a/src/lua.rs b/src/lua.rs
index 9bfbebb..e82a4ca 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -47,6 +47,9 @@ use {
#[cfg(not(feature = "luau"))]
use crate::{hook::HookTriggers, types::HookCallback};
+#[cfg(feature = "luau")]
+use crate::types::{InterruptCallback, VmState};
+
#[cfg(feature = "async")]
use {
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
@@ -108,6 +111,8 @@ struct ExtraData {
hook_callback: Option<HookCallback>,
#[cfg(feature = "lua54")]
warn_callback: Option<WarnCallback>,
+ #[cfg(feature = "luau")]
+ interrupt_callback: Option<InterruptCallback>,
#[cfg(feature = "luau")]
sandboxed: bool,
@@ -235,6 +240,13 @@ impl Drop for Lua {
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
extra.ref_free.push(extra.ref_waker_idx);
}
+ #[cfg(feature = "luau")]
+ {
+ let callbacks = ffi::lua_callbacks(self.state);
+ let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
+ drop(Box::from_raw(extra_ptr));
+ (*callbacks).userdata = ptr::null_mut();
+ }
mlua_debug_assert!(
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
&& extra.ref_stack_top as usize == extra.ref_free.len(),
@@ -552,6 +564,8 @@ impl Lua {
#[cfg(feature = "lua54")]
warn_callback: None,
#[cfg(feature = "luau")]
+ interrupt_callback: None,
+ #[cfg(feature = "luau")]
sandboxed: false,
}));
@@ -581,6 +595,14 @@ impl Lua {
);
assert_stack(main_state, ffi::LUA_MINSTACK);
+ // Set Luau callbacks userdata to extra data
+ // We can use global callbacks userdata since we don't allow C modules in Luau
+ #[cfg(feature = "luau")]
+ {
+ let extra_raw = Box::into_raw(Box::new(Arc::clone(&extra)));
+ (*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void;
+ }
+
Lua {
state,
main_state: maybe_main_state,
@@ -895,6 +917,102 @@ impl Lua {
}
}
+ /// Sets an 'interrupt' function that will periodically be called by Luau VM.
+ ///
+ /// Any Luau code is guaranteed to call this handler "eventually"
+ /// (in practice this can happen at any function call or at any loop iteration).
+ ///
+ /// The provided interrupt function can error, and this error will be propagated through
+ /// the Luau code that was executing at the time the interrupt was triggered.
+ /// Also this can be used to implement continuous execution limits by instructing Luau VM to yield
+ /// by returning [`VmState::Yield`].
+ ///
+ /// This is similar to [`Lua::set_hook`] but in more simplified form.
+ ///
+ /// # Example
+ ///
+ /// Periodically yield Luau VM to suspend execution.
+ ///
+ /// ```
+ /// # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
+ /// # use mlua::{Lua, Result, ThreadStatus, VmState};
+ /// # fn main() -> Result<()> {
+ /// let lua = Lua::new();
+ /// let count = Arc::new(AtomicU64::new(0));
+ /// lua.set_interrupt(move |_lua| {
+ /// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 {
+ /// return Ok(VmState::Yield);
+ /// }
+ /// Ok(VmState::Continue)
+ /// });
+ ///
+ /// let co = lua.create_thread(
+ /// lua.load(r#"
+ /// local b = 0
+ /// for _, x in ipairs({1, 2, 3}) do b += x end
+ /// "#)
+ /// .into_function()?,
+ /// )?;
+ /// while co.status() == ThreadStatus::Resumable {
+ /// co.resume(())?;
+ /// }
+ /// # Ok(())
+ /// # }
+ /// ```
+ #[cfg(feature = "luau")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
+ pub fn set_interrupt<'lua, F>(&'lua self, callback: F)
+ where
+ F: 'static + MaybeSend + Fn(&Lua) -> Result<VmState>,
+ {
+ unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) {
+ if gc != -1 {
+ // We don't support GC interrupts since they cannot survive Lua exceptions
+ return;
+ }
+ // TODO: think about not using drop types here
+ let lua = match Lua::make_from_ptr(state) {
+ Some(lua) => lua,
+ None => return,
+ };
+ let extra = lua.extra.get();
+ let result = callback_error_ext(state, extra, move |_| {
+ let interrupt_cb = (*extra).interrupt_callback.clone();
+ let interrupt_cb =
+ mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
+ if Arc::strong_count(&interrupt_cb) > 2 {
+ return Ok(VmState::Continue); // Don't allow recursion
+ }
+ interrupt_cb(&lua)
+ });
+ match result {
+ VmState::Continue => {}
+ VmState::Yield => {
+ ffi::lua_yield(state, 0);
+ }
+ }
+ }
+
+ let state = mlua_expect!(self.main_state, "Luau should always has main state");
+ unsafe {
+ (*self.extra.get()).interrupt_callback = Some(Arc::new(callback));
+ (*ffi::lua_callbacks(state)).interrupt = Some(interrupt_proc);
+ }
+ }
+
+ /// Removes any 'interrupt' previously set by `set_interrupt`.
+ ///
+ /// This function has no effect if an 'interrupt' was not previously set.
+ #[cfg(feature = "luau")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
+ pub fn remove_interrupt(&self) {
+ let state = mlua_expect!(self.main_state, "Luau should always has main state");
+ unsafe {
+ (*self.extra.get()).interrupt_callback = None;
+ (*ffi::lua_callbacks(state)).interrupt = None;
+ }
+ }
+
/// Sets the warning function to be used by Lua to emit warnings.
///
/// Requires `feature = "lua54"`
@@ -2759,14 +2877,7 @@ impl Lua {
let _sg = StackGuard::new(state);
assert_stack(state, 1);
- let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void;
- if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA {
- return None;
- }
- let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc<UnsafeCell<ExtraData>>;
- let extra = Arc::clone(&*extra_ptr);
- ffi::lua_pop(state, 1);
-
+ let extra = extra_data(state)?;
let safe = (*extra.get()).safe;
Some(Lua {
state,
@@ -2798,6 +2909,27 @@ impl Lua {
}
}
+#[cfg(feature = "luau")]
+unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
+ let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc<UnsafeCell<ExtraData>>;
+ if extra_ptr.is_null() {
+ return None;
+ }
+ Some(Arc::clone(&*extra_ptr))
+}
+
+#[cfg(not(feature = "luau"))]
+unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
+ let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void;
+ if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA {
+ return None;
+ }
+ let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc<UnsafeCell<ExtraData>>;
+ let extra = Arc::clone(&*extra_ptr);
+ ffi::lua_pop(state, 1);
+ Some(extra)
+}
+
// Creates required entries in the metatable cache (see `util::METATABLE_CACHE`)
pub(crate) fn init_metatable_cache(cache: &mut FxHashMap<TypeId, u8>) {
cache.insert(TypeId::of::<Arc<UnsafeCell<ExtraData>>>(), 0);
diff --git a/src/prelude.rs b/src/prelude.rs
index 1abc7e7..28a90e5 100644
--- a/src/prelude.rs
+++ b/src/prelude.rs
@@ -15,6 +15,10 @@ pub use crate::{
Value as LuaValue,
};
+#[cfg(feature = "luau")]
+#[doc(no_inline)]
+pub use crate::VmState as LuaVmState;
+
#[cfg(feature = "async")]
#[doc(no_inline)]
pub use crate::AsyncThread as LuaAsyncThread;
diff --git a/src/types.rs b/src/types.rs
index ed7305b..6ff7a21 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -55,6 +55,20 @@ pub(crate) type HookCallback = Arc<Mutex<dyn FnMut(&Lua, Debug) -> Result<()> +
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
pub(crate) type HookCallback = Arc<Mutex<dyn FnMut(&Lua, Debug) -> Result<()>>>;
+/// Type to set next Lua VM action after executing interrupt function.
+#[cfg(any(feature = "luau", doc))]
+#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
+pub enum VmState {
+ Continue,
+ Yield,
+}
+
+#[cfg(all(feature = "luau", feature = "send"))]
+pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState> + Send>;
+
+#[cfg(all(feature = "luau", not(feature = "send")))]
+pub(crate) type InterruptCallback = Arc<dyn Fn(&Lua) -> Result<VmState>>;
+
#[cfg(all(feature = "send", feature = "lua54"))]
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &CStr, bool) -> Result<()> + Send>;
diff --git a/tests/luau.rs b/tests/luau.rs
index ac6c290..608ea98 100644
--- a/tests/luau.rs
+++ b/tests/luau.rs
@@ -2,8 +2,10 @@
use std::env;
use std::fs;
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::Arc;
-use mlua::{Error, Lua, Result, Table, Value};
+use mlua::{Error, Lua, Result, Table, ThreadStatus, Value, VmState};
#[test]
fn test_require() -> Result<()> {
@@ -125,3 +127,73 @@ fn test_sandbox_threads() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_interrupts() -> Result<()> {
+ let lua = Lua::new();
+
+ let interrupts_count = Arc::new(AtomicU64::new(0));
+ let interrupts_count2 = interrupts_count.clone();
+
+ lua.set_interrupt(move |_lua| {
+ interrupts_count2.fetch_add(1, Ordering::Relaxed);
+ Ok(VmState::Continue)
+ });
+ let f = lua
+ .load(
+ r#"
+ local x = 2 + 3
+ local y = x * 63
+ local z = string.len(x..", "..y)
+ "#,
+ )
+ .into_function()?;
+ f.call(())?;
+
+ assert!(interrupts_count.load(Ordering::Relaxed) > 0);
+
+ //
+ // Test yields from interrupt
+ //
+ let yield_count = Arc::new(AtomicU64::new(0));
+ let yield_count2 = yield_count.clone();
+ lua.set_interrupt(move |_lua| {
+ if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 {
+ return Ok(VmState::Yield);
+ }
+ Ok(VmState::Continue)
+ });
+ let co = lua.create_thread(
+ lua.load(
+ r#"
+ local a = {1, 2, 3}
+ local b = 0
+ for _, x in ipairs(a) do b += x end
+ return b
+ "#,
+ )
+ .into_function()?,
+ )?;
+ co.resume(())?;
+ assert_eq!(co.status(), ThreadStatus::Resumable);
+ let result: i32 = co.resume(())?;
+ assert_eq!(result, 6);
+ assert_eq!(yield_count.load(Ordering::Relaxed), 7);
+ assert_eq!(co.status(), ThreadStatus::Unresumable);
+
+ //
+ // Test errors in interrupts
+ //
+ lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into())));
+ match f.call::<_, ()>(()) {
+ Err(Error::CallbackError { cause, .. }) => match *cause {
+ Error::RuntimeError(ref m) if m == "error from interrupt" => {}
+ ref e => panic!("expected RuntimeError with a specific message, got {:?}", e),
+ },
+ r => panic!("expected CallbackError, got {:?}", r),
+ }
+
+ lua.remove_interrupt();
+
+ Ok(())
+}