summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2021-09-28 16:10:09 +0100
committerAlex Orlenko <zxteam@protonmail.com>2021-09-28 16:41:39 +0100
commitbdd3c923ba5494e0733d0d7af7c952e2216d535e (patch)
tree2360406cb86aaa2dfd7f8e76fbb482c2445021c6 /src
parentd586eef0f54a29af61bbf799335c0aaeea69737e (diff)
downloadmlua-bdd3c923ba5494e0733d0d7af7c952e2216d535e.zip
Fix table traversal used in recursion detection.
This fixes serializing same table multiple times within a parent table.
Diffstat (limited to 'src')
-rw-r--r--src/lua.rs1
-rw-r--r--src/serde/de.rs46
2 files changed, 33 insertions, 14 deletions
diff --git a/src/lua.rs b/src/lua.rs
index 9d46912..f0bda85 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -1721,6 +1721,7 @@ impl Lua {
}
#[cfg(feature = "serialize")]
+ #[inline]
pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void {
ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index)
}
diff --git a/src/serde/de.rs b/src/serde/de.rs
index d0d99dd..fcbd1d2 100644
--- a/src/serde/de.rs
+++ b/src/serde/de.rs
@@ -7,7 +7,7 @@ use std::string::String as StdString;
use serde::de::{self, IntoDeserializer};
use crate::error::{Error, Result};
-use crate::table::{TablePairs, TableSequence};
+use crate::table::{Table, TablePairs, TableSequence};
use crate::value::Value;
/// A struct for deserializing Lua values into Rust values.
@@ -158,11 +158,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
where
V: de::Visitor<'de>,
{
- let (variant, value) = match self.value {
+ let (variant, value, _guard) = match self.value {
Value::Table(table) => {
- let lua = table.0.lua;
- let ptr = unsafe { lua.get_ref_ptr(&table.0) };
- self.visited.borrow_mut().insert(ptr);
+ let _guard = RecursionGuard::new(&table, &self.visited);
let mut iter = table.pairs::<StdString, Value>();
let (variant, value) = match iter.next() {
@@ -185,9 +183,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
return Err(de::Error::custom("bad enum value"));
}
- (variant, Some(value))
+ (variant, Some(value), Some(_guard))
}
- Value::String(variant) => (variant.to_str()?.to_owned(), None),
+ Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
_ => return Err(de::Error::custom("bad enum value")),
};
@@ -206,9 +204,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
{
match self.value {
Value::Table(t) => {
- let lua = t.0.lua;
- let ptr = unsafe { lua.get_ref_ptr(&t.0) };
- self.visited.borrow_mut().insert(ptr);
+ let _guard = RecursionGuard::new(&t, &self.visited);
let len = t.raw_len() as usize;
let mut deserializer = SeqDeserializer {
@@ -261,9 +257,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
{
match self.value {
Value::Table(t) => {
- let lua = t.0.lua;
- let ptr = unsafe { lua.get_ref_ptr(&t.0) };
- self.visited.borrow_mut().insert(ptr);
+ let _guard = RecursionGuard::new(&t, &self.visited);
let mut deserializer = MapDeserializer {
pairs: t.pairs(),
@@ -495,10 +489,34 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
}
}
+// Adds `ptr` to the `visited` map and removes on drop
+// Used to track recursive tables but allow to traverse same tables multiple times
+struct RecursionGuard {
+ ptr: *const c_void,
+ visited: Rc<RefCell<HashSet<*const c_void>>>,
+}
+
+impl RecursionGuard {
+ #[inline]
+ fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self {
+ let visited = Rc::clone(visited);
+ let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) };
+ visited.borrow_mut().insert(ptr);
+ RecursionGuard { ptr, visited }
+ }
+}
+
+impl Drop for RecursionGuard {
+ fn drop(&mut self) {
+ self.visited.borrow_mut().remove(&self.ptr);
+ }
+}
+
+// Checks `options` and decides should we emit an error or skip next element
fn check_value_if_skip(
value: &Value,
options: Options,
- visited: &Rc<RefCell<HashSet<*const c_void>>>,
+ visited: &RefCell<HashSet<*const c_void>>,
) -> Result<bool> {
match value {
Value::Table(table) => {