summaryrefslogtreecommitdiff
path: root/tests/luau.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/luau.rs')
-rw-r--r--tests/luau.rs74
1 files changed, 73 insertions, 1 deletions
diff --git a/tests/luau.rs b/tests/luau.rs
index ac6c290..608ea98 100644
--- a/tests/luau.rs
+++ b/tests/luau.rs
@@ -2,8 +2,10 @@
use std::env;
use std::fs;
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::Arc;
-use mlua::{Error, Lua, Result, Table, Value};
+use mlua::{Error, Lua, Result, Table, ThreadStatus, Value, VmState};
#[test]
fn test_require() -> Result<()> {
@@ -125,3 +127,73 @@ fn test_sandbox_threads() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_interrupts() -> Result<()> {
+ let lua = Lua::new();
+
+ let interrupts_count = Arc::new(AtomicU64::new(0));
+ let interrupts_count2 = interrupts_count.clone();
+
+ lua.set_interrupt(move |_lua| {
+ interrupts_count2.fetch_add(1, Ordering::Relaxed);
+ Ok(VmState::Continue)
+ });
+ let f = lua
+ .load(
+ r#"
+ local x = 2 + 3
+ local y = x * 63
+ local z = string.len(x..", "..y)
+ "#,
+ )
+ .into_function()?;
+ f.call(())?;
+
+ assert!(interrupts_count.load(Ordering::Relaxed) > 0);
+
+ //
+ // Test yields from interrupt
+ //
+ let yield_count = Arc::new(AtomicU64::new(0));
+ let yield_count2 = yield_count.clone();
+ lua.set_interrupt(move |_lua| {
+ if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 {
+ return Ok(VmState::Yield);
+ }
+ Ok(VmState::Continue)
+ });
+ let co = lua.create_thread(
+ lua.load(
+ r#"
+ local a = {1, 2, 3}
+ local b = 0
+ for _, x in ipairs(a) do b += x end
+ return b
+ "#,
+ )
+ .into_function()?,
+ )?;
+ co.resume(())?;
+ assert_eq!(co.status(), ThreadStatus::Resumable);
+ let result: i32 = co.resume(())?;
+ assert_eq!(result, 6);
+ assert_eq!(yield_count.load(Ordering::Relaxed), 7);
+ assert_eq!(co.status(), ThreadStatus::Unresumable);
+
+ //
+ // Test errors in interrupts
+ //
+ lua.set_interrupt(|_| Err(Error::RuntimeError("error from interrupt".into())));
+ match f.call::<_, ()>(()) {
+ Err(Error::CallbackError { cause, .. }) => match *cause {
+ Error::RuntimeError(ref m) if m == "error from interrupt" => {}
+ ref e => panic!("expected RuntimeError with a specific message, got {:?}", e),
+ },
+ r => panic!("expected CallbackError, got {:?}", r),
+ }
+
+ lua.remove_interrupt();
+
+ Ok(())
+}