diff options
author | Alex Orlenko <zxteam@protonmail.com> | 2021-09-28 16:10:09 +0100 |
---|---|---|
committer | Alex Orlenko <zxteam@protonmail.com> | 2021-09-28 16:41:39 +0100 |
commit | bdd3c923ba5494e0733d0d7af7c952e2216d535e (patch) | |
tree | 2360406cb86aaa2dfd7f8e76fbb482c2445021c6 /src | |
parent | d586eef0f54a29af61bbf799335c0aaeea69737e (diff) | |
download | mlua-bdd3c923ba5494e0733d0d7af7c952e2216d535e.zip |
Fix table traversal used in recursion detection.
This fixes serializing same table multiple times within a parent table.
Diffstat (limited to 'src')
-rw-r--r-- | src/lua.rs | 1 | ||||
-rw-r--r-- | src/serde/de.rs | 46 |
2 files changed, 33 insertions, 14 deletions
@@ -1721,6 +1721,7 @@ impl Lua { } #[cfg(feature = "serialize")] + #[inline] pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void { ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index) } diff --git a/src/serde/de.rs b/src/serde/de.rs index d0d99dd..fcbd1d2 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -7,7 +7,7 @@ use std::string::String as StdString; use serde::de::{self, IntoDeserializer}; use crate::error::{Error, Result}; -use crate::table::{TablePairs, TableSequence}; +use crate::table::{Table, TablePairs, TableSequence}; use crate::value::Value; /// A struct for deserializing Lua values into Rust values. @@ -158,11 +158,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { where V: de::Visitor<'de>, { - let (variant, value) = match self.value { + let (variant, value, _guard) = match self.value { Value::Table(table) => { - let lua = table.0.lua; - let ptr = unsafe { lua.get_ref_ptr(&table.0) }; - self.visited.borrow_mut().insert(ptr); + let _guard = RecursionGuard::new(&table, &self.visited); let mut iter = table.pairs::<StdString, Value>(); let (variant, value) = match iter.next() { @@ -185,9 +183,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { return Err(de::Error::custom("bad enum value")); } - (variant, Some(value)) + (variant, Some(value), Some(_guard)) } - Value::String(variant) => (variant.to_str()?.to_owned(), None), + Value::String(variant) => (variant.to_str()?.to_owned(), None, None), _ => return Err(de::Error::custom("bad enum value")), }; @@ -206,9 +204,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { { match self.value { Value::Table(t) => { - let lua = t.0.lua; - let ptr = unsafe { lua.get_ref_ptr(&t.0) }; - self.visited.borrow_mut().insert(ptr); + let _guard = RecursionGuard::new(&t, &self.visited); let len = t.raw_len() as usize; let mut deserializer = SeqDeserializer { @@ -261,9 +257,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { { match self.value { Value::Table(t) => { - let lua = t.0.lua; - let ptr = unsafe { lua.get_ref_ptr(&t.0) }; - self.visited.borrow_mut().insert(ptr); + let _guard = RecursionGuard::new(&t, &self.visited); let mut deserializer = MapDeserializer { pairs: t.pairs(), @@ -495,10 +489,34 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> { } } +// Adds `ptr` to the `visited` map and removes on drop +// Used to track recursive tables but allow to traverse same tables multiple times +struct RecursionGuard { + ptr: *const c_void, + visited: Rc<RefCell<HashSet<*const c_void>>>, +} + +impl RecursionGuard { + #[inline] + fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self { + let visited = Rc::clone(visited); + let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) }; + visited.borrow_mut().insert(ptr); + RecursionGuard { ptr, visited } + } +} + +impl Drop for RecursionGuard { + fn drop(&mut self) { + self.visited.borrow_mut().remove(&self.ptr); + } +} + +// Checks `options` and decides should we emit an error or skip next element fn check_value_if_skip( value: &Value, options: Options, - visited: &Rc<RefCell<HashSet<*const c_void>>>, + visited: &RefCell<HashSet<*const c_void>>, ) -> Result<bool> { match value { Value::Table(table) => { |