summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/function.rs6
-rw-r--r--src/table.rs69
-rw-r--r--src/thread.rs6
-rw-r--r--src/types.rs14
-rw-r--r--src/userdata.rs51
-rw-r--r--src/value.rs46
-rw-r--r--tests/memory.rs18
-rw-r--r--tests/table.rs34
-rw-r--r--tests/tests.rs12
-rw-r--r--tests/thread.rs20
-rw-r--r--tests/userdata.rs37
-rw-r--r--tests/value.rs57
12 files changed, 332 insertions, 38 deletions
diff --git a/src/function.rs b/src/function.rs
index 1b7e6fe..141fcbe 100644
--- a/src/function.rs
+++ b/src/function.rs
@@ -160,3 +160,9 @@ impl<'lua> Function<'lua> {
}
}
}
+
+impl<'lua> PartialEq for Function<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+}
diff --git a/src/table.rs b/src/table.rs
index 0c3e143..bb45ea5 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -166,6 +166,62 @@ impl<'lua> Table<'lua> {
self.get::<_, Function>(key)?.call(args)
}
+ /// Compares two tables for equality.
+ ///
+ /// Tables are compared by reference first.
+ /// If they are not primitively equals, then mlua will try to invoke the `__eq` metamethod.
+ /// mlua will check `self` first for the metamethod, then `other` if not found.
+ ///
+ /// # Examples
+ ///
+ /// Compare two tables using `__eq` metamethod:
+ ///
+ /// ```
+ /// # use mlua::{Lua, Result, Table};
+ /// # fn main() -> Result<()> {
+ /// # let lua = Lua::new();
+ /// let table1 = lua.create_table()?;
+ /// table1.set(1, "value")?;
+ ///
+ /// let table2 = lua.create_table()?;
+ /// table2.set(2, "value")?;
+ ///
+ /// let always_equals_mt = lua.create_table()?;
+ /// always_equals_mt.set("__eq", lua.create_function(|_, (_t1, _t2): (Table, Table)| Ok(true))?)?;
+ /// table2.set_metatable(Some(always_equals_mt));
+ ///
+ /// assert!(table1.equals(&table1.clone())?);
+ /// assert!(table1.equals(&table2)?);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
+ let other = other.as_ref();
+ if self == other {
+ return Ok(true);
+ }
+
+ // Compare using __eq metamethod if exists
+ // First, check the self for the metamethod.
+ // If self does not define it, then check the other table.
+ if let Some(mt) = self.get_metatable() {
+ if mt.contains_key("__eq")? {
+ return mt
+ .get::<_, Function>("__eq")?
+ .call((self.clone(), other.clone()));
+ }
+ }
+ if let Some(mt) = other.get_metatable() {
+ if mt.contains_key("__eq")? {
+ return mt
+ .get::<_, Function>("__eq")?
+ .call((self.clone(), other.clone()));
+ }
+ }
+
+ Ok(false)
+ }
+
/// Removes a key from the table, returning the value at the key
/// if the key was previously in the table.
pub fn raw_remove<K: ToLua<'lua>>(&self, key: K) -> Result<()> {
@@ -368,6 +424,19 @@ impl<'lua> Table<'lua> {
}
}
+impl<'lua> PartialEq for Table<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+}
+
+impl<'lua> AsRef<Table<'lua>> for Table<'lua> {
+ #[inline]
+ fn as_ref(&self) -> &Self {
+ self
+ }
+}
+
/// An iterator over the pairs of a Lua table.
///
/// This struct is created by the [`Table::pairs`] method.
diff --git a/src/thread.rs b/src/thread.rs
index 97b0362..2452fec 100644
--- a/src/thread.rs
+++ b/src/thread.rs
@@ -143,3 +143,9 @@ impl<'lua> Thread<'lua> {
}
}
}
+
+impl<'lua> PartialEq for Thread<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+}
diff --git a/src/types.rs b/src/types.rs
index 29e55b1..e88d0c1 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -5,6 +5,7 @@ use std::{fmt, mem, ptr};
use crate::error::Result;
use crate::ffi;
use crate::lua::Lua;
+use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue;
/// Type of Lua integer numbers.
@@ -92,3 +93,16 @@ impl<'lua> Drop for LuaRef<'lua> {
self.lua.drop_ref(self)
}
}
+
+impl<'lua> PartialEq for LuaRef<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ let lua = self.lua;
+ unsafe {
+ let _sg = StackGuard::new(lua.state);
+ assert_stack(lua.state, 2);
+ lua.push_ref(&self);
+ lua.push_ref(&other);
+ ffi::lua_rawequal(lua.state, -1, -2) == 1
+ }
+ }
+}
diff --git a/src/userdata.rs b/src/userdata.rs
index 8177ad9..32817b6 100644
--- a/src/userdata.rs
+++ b/src/userdata.rs
@@ -2,7 +2,9 @@ use std::cell::{Ref, RefCell, RefMut};
use crate::error::{Error, Result};
use crate::ffi;
+use crate::function::Function;
use crate::lua::Lua;
+use crate::table::Table;
use crate::types::LuaRef;
use crate::util::{assert_stack, get_userdata, StackGuard};
use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti};
@@ -398,6 +400,42 @@ impl<'lua> AnyUserData<'lua> {
V::from_lua(res, lua)
}
+ fn get_metatable(&self) -> Result<Table<'lua>> {
+ unsafe {
+ let lua = self.0.lua;
+ let _sg = StackGuard::new(lua.state);
+ assert_stack(lua.state, 3);
+
+ lua.push_ref(&self.0);
+
+ if ffi::lua_getmetatable(lua.state, -1) == 0 {
+ return Err(Error::UserDataTypeMismatch);
+ }
+
+ Ok(Table(lua.pop_ref()))
+ }
+ }
+
+ pub(crate) fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
+ let other = other.as_ref();
+ if self == other {
+ return Ok(true);
+ }
+
+ let mt = self.get_metatable()?;
+ if mt != other.get_metatable()? {
+ return Ok(false);
+ }
+
+ if mt.contains_key("__eq")? {
+ return mt
+ .get::<_, Function>("__eq")?
+ .call((self.clone(), other.clone()));
+ }
+
+ Ok(false)
+ }
+
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
where
T: 'static + UserData,
@@ -428,3 +466,16 @@ impl<'lua> AnyUserData<'lua> {
}
}
}
+
+impl<'lua> PartialEq for AnyUserData<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+}
+
+impl<'lua> AsRef<AnyUserData<'lua>> for AnyUserData<'lua> {
+ #[inline]
+ fn as_ref(&self) -> &Self {
+ self
+ }
+}
diff --git a/src/value.rs b/src/value.rs
index c3eab89..1ab5399 100644
--- a/src/value.rs
+++ b/src/value.rs
@@ -2,6 +2,7 @@ use std::iter::{self, FromIterator};
use std::{slice, str, vec};
use crate::error::{Error, Result};
+use crate::ffi;
use crate::function::Function;
use crate::lua::Lua;
use crate::string::String;
@@ -61,6 +62,51 @@ impl<'lua> Value<'lua> {
Value::UserData(_) | Value::Error(_) => "userdata",
}
}
+
+ /// Compares two values for equality.
+ ///
+ /// Equality comparisons do not convert strings to numbers or vice versa.
+ /// Tables, Functions, Threads, and Userdata are compared by reference:
+ /// two objects are considered equal only if they are the same object.
+ ///
+ /// If Tables or Userdata have `__eq` metamethod then mlua will try to invoke it.
+ /// The first value is checked first. If that value does not define a metamethod
+ /// for `__eq`, then mlua will check the second value.
+ /// Then mlua calls the metamethod with the two values as arguments, if found.
+ pub fn equals<T: AsRef<Self>>(&self, other: T) -> Result<bool> {
+ match (self, other.as_ref()) {
+ (Value::Table(a), Value::Table(b)) => a.equals(b),
+ (Value::UserData(a), Value::UserData(b)) => a.equals(b),
+ _ => Ok(self == other.as_ref()),
+ }
+ }
+}
+
+impl<'lua> PartialEq for Value<'lua> {
+ fn eq(&self, other: &Self) -> bool {
+ match (self, other) {
+ (Value::Nil, Value::Nil) => true,
+ (Value::Boolean(a), Value::Boolean(b)) => a == b,
+ (Value::LightUserData(a), Value::LightUserData(b)) => a == b,
+ (Value::Integer(a), Value::Integer(b)) => *a == *b,
+ (Value::Integer(a), Value::Number(b)) => *a as ffi::lua_Number == *b,
+ (Value::Number(a), Value::Integer(b)) => *a == *b as ffi::lua_Number,
+ (Value::Number(a), Value::Number(b)) => *a == *b,
+ (Value::String(a), Value::String(b)) => a == b,
+ (Value::Table(a), Value::Table(b)) => a == b,
+ (Value::Function(a), Value::Function(b)) => a == b,
+ (Value::Thread(a), Value::Thread(b)) => a == b,
+ (Value::UserData(a), Value::UserData(b)) => a == b,
+ _ => false,
+ }
+ }
+}
+
+impl<'lua> AsRef<Value<'lua>> for Value<'lua> {
+ #[inline]
+ fn as_ref(&self) -> &Self {
+ self
+ }
}
/// Trait for types convertible to `Value`.
diff --git a/tests/memory.rs b/tests/memory.rs
index 7c141cd..0bcdeb6 100644
--- a/tests/memory.rs
+++ b/tests/memory.rs
@@ -51,15 +51,15 @@ fn test_gc_error() {
match lua
.load(
r#"
- val = nil
- table = {}
- setmetatable(table, {
- __gc = function()
- error("gcwascalled")
- end
- })
- table = nil
- collectgarbage("collect")
+ val = nil
+ table = {}
+ setmetatable(table, {
+ __gc = function()
+ error("gcwascalled")
+ end
+ })
+ table = nil
+ collectgarbage("collect")
"#,
)
.exec()
diff --git a/tests/table.rs b/tests/table.rs
index 14ccb87..1d788e2 100644
--- a/tests/table.rs
+++ b/tests/table.rs
@@ -149,6 +149,40 @@ fn test_metatable() -> Result<()> {
}
#[test]
+fn test_table_eq() -> Result<()> {
+ let lua = Lua::new();
+ let globals = lua.globals();
+
+ lua.load(
+ r#"
+ table1 = {1}
+ table2 = {1}
+ table3 = table1
+ table4 = {1}
+
+ setmetatable(table4, {
+ __eq = function(a, b) return a[1] == b[1] end
+ })
+ "#,
+ )
+ .exec()?;
+
+ let table1 = globals.get::<_, Table>("table1")?;
+ let table2 = globals.get::<_, Table>("table2")?;
+ let table3 = globals.get::<_, Table>("table3")?;
+ let table4 = globals.get::<_, Table>("table4")?;
+
+ assert!(table1 != table2);
+ assert!(!table1.equals(&table2)?);
+ assert!(table1 == table3);
+ assert!(table1.equals(&table3)?);
+ assert!(table1 != table4);
+ assert!(table1.equals(&table4)?);
+
+ Ok(())
+}
+
+#[test]
fn test_table_error() -> Result<()> {
let lua = Lua::new();
diff --git a/tests/tests.rs b/tests/tests.rs
index 2e71dc1..8ef9d48 100644
--- a/tests/tests.rs
+++ b/tests/tests.rs
@@ -92,13 +92,13 @@ fn test_lua_multi() -> Result<()> {
lua.load(
r#"
- function concat(arg1, arg2)
- return arg1 .. arg2
- end
+ function concat(arg1, arg2)
+ return arg1 .. arg2
+ end
- function mreturn()
- return 1, 2, 3, 4, 5, 6
- end
+ function mreturn()
+ return 1, 2, 3, 4, 5, 6
+ end
"#,
)
.exec()?;
diff --git a/tests/thread.rs b/tests/thread.rs
index 3344b7c..d3266ba 100644
--- a/tests/thread.rs
+++ b/tests/thread.rs
@@ -20,13 +20,13 @@ fn test_thread() -> Result<()> {
let thread = lua.create_thread(
lua.load(
r#"
- function (s)
- local sum = s
- for i = 1,4 do
- sum = sum + coroutine.yield(sum)
- end
- return sum
+ function (s)
+ local sum = s
+ for i = 1,4 do
+ sum = sum + coroutine.yield(sum)
end
+ return sum
+ end
"#,
)
.eval()?,
@@ -47,11 +47,11 @@ fn test_thread() -> Result<()> {
let accumulate = lua.create_thread(
lua.load(
r#"
- function (sum)
- while true do
- sum = sum + coroutine.yield(sum)
- end
+ function (sum)
+ while true do
+ sum = sum + coroutine.yield(sum)
end
+ end
"#,
)
.eval::<Function>()?,
diff --git a/tests/userdata.rs b/tests/userdata.rs
index 4f01669..34aa8b4 100644
--- a/tests/userdata.rs
+++ b/tests/userdata.rs
@@ -13,7 +13,7 @@ use std::sync::Arc;
use mlua::{
AnyUserData, ExternalError, Function, Lua, MetaMethod, Result, String, UserData,
- UserDataMethods,
+ UserDataMethods, Value,
};
#[test]
@@ -96,6 +96,9 @@ fn test_metamethods() -> Result<()> {
MetaMethod::Sub,
|_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)),
);
+ methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| {
+ Ok(lhs.0 == rhs.0)
+ });
methods.add_meta_method(MetaMethod::Index, |_, data, index: String| {
if index.to_str()? == "inner" {
Ok(data.0)
@@ -122,6 +125,7 @@ fn test_metamethods() -> Result<()> {
let globals = lua.globals();
globals.set("userdata1", MyUserData(7))?;
globals.set("userdata2", MyUserData(3))?;
+ globals.set("userdata3", MyUserData(3))?;
assert_eq!(
lua.load("userdata1 + userdata2").eval::<MyUserData>()?.0,
10
@@ -151,6 +155,13 @@ fn test_metamethods() -> Result<()> {
assert_eq!(ipairs_it.call::<_, i64>(())?, 28);
assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err());
+ let userdata2: Value = globals.get("userdata2")?;
+ let userdata3: Value = globals.get("userdata3")?;
+
+ assert!(lua.load("userdata2 == userdata3").eval::<bool>()?);
+ assert!(userdata2 != userdata3); // because references are differ
+ assert!(userdata2.equals(userdata3)?);
+
Ok(())
}
@@ -175,18 +186,18 @@ fn test_gc_userdata() -> Result<()> {
assert!(lua
.load(
r#"
- local tbl = setmetatable({
- userdata = userdata
- }, { __gc = function(self)
- -- resurrect userdata
- hatch = self.userdata
- end })
-
- tbl = nil
- userdata = nil -- make table and userdata collectable
- collectgarbage("collect")
- hatch:access()
- "#
+ local tbl = setmetatable({
+ userdata = userdata
+ }, { __gc = function(self)
+ -- resurrect userdata
+ hatch = self.userdata
+ end })
+
+ tbl = nil
+ userdata = nil -- make table and userdata collectable
+ collectgarbage("collect")
+ hatch:access()
+ "#
)
.exec()
.is_err());
diff --git a/tests/value.rs b/tests/value.rs
new file mode 100644
index 0000000..9beb68c
--- /dev/null
+++ b/tests/value.rs
@@ -0,0 +1,57 @@
+use mlua::{Lua, Result, Value};
+
+#[test]
+fn test_value_eq() -> Result<()> {
+ let lua = Lua::new();
+ let globals = lua.globals();
+
+ lua.load(
+ r#"
+ table1 = {1}
+ table2 = {1}
+ string1 = "hello"
+ string2 = "hello"
+ num1 = 1
+ num2 = 1.0
+ num3 = "1"
+ func1 = function() end
+ func2 = func1
+ func3 = function() end
+ thread1 = coroutine.create(function() end)
+ thread2 = thread1
+
+ setmetatable(table1, {
+ __eq = function(a, b) return a[1] == b[1] end
+ })
+ "#,
+ )
+ .exec()?;
+
+ let table1: Value = globals.get("table1")?;
+ let table2: Value = globals.get("table2")?;
+ let string1: Value = globals.get("string1")?;
+ let string2: Value = globals.get("string2")?;
+ let num1: Value = globals.get("num1")?;
+ let num2: Value = globals.get("num2")?;
+ let num3: Value = globals.get("num3")?;
+ let func1: Value = globals.get("func1")?;
+ let func2: Value = globals.get("func2")?;
+ let func3: Value = globals.get("func3")?;
+ let thread1: Value = globals.get("thread1")?;
+ let thread2: Value = globals.get("thread2")?;
+
+ assert!(table1 != table2);
+ assert!(table1.equals(table2)?);
+ assert!(string1 == string2);
+ assert!(string1.equals(string2)?);
+ assert!(num1 == num2);
+ assert!(num1.equals(num2)?);
+ assert!(num1 != num3);
+ assert!(func1 == func2);
+ assert!(func1 != func3);
+ assert!(!func1.equals(func3)?);
+ assert!(thread1 == thread2);
+ assert!(thread1.equals(thread2)?);
+
+ Ok(())
+}