summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2024-03-28 13:05:01 +0000
committerAlex Orlenko <zxteam@protonmail.com>2024-03-28 13:05:01 +0000
commitfa217d3706ebebfd1519743c3e620a940c218967 (patch)
tree4efd3f3c714455fdcc21710d6fdb5462f25722a2
parentb62f2ee0f70dfa91f3a8f6ec5ae5be6c58b2f77f (diff)
downloadmlua-fa217d3706ebebfd1519743c3e620a940c218967.zip
Better Luau buffer type support.
- Add `Lua::create_buffer()` function - Support serializing buffer type as a byte slice - Support accessing copy of underlying bytes using `BString`
-rw-r--r--src/conversion.rs40
-rw-r--r--src/lua.rs21
-rw-r--r--src/serde/de.rs8
-rw-r--r--src/userdata.rs13
-rw-r--r--src/util/mod.rs14
-rw-r--r--tests/conversion.rs44
-rw-r--r--tests/serde.rs27
7 files changed, 162 insertions, 5 deletions
diff --git a/src/conversion.rs b/src/conversion.rs
index 2c57269..8ca8fe8 100644
--- a/src/conversion.rs
+++ b/src/conversion.rs
@@ -679,19 +679,49 @@ impl<'lua> IntoLua<'lua> for BString {
}
impl<'lua> FromLua<'lua> for BString {
- #[inline]
fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result<Self> {
let ty = value.type_name();
- Ok(BString::from(
- lua.coerce_string(value)?
+ match value {
+ Value::String(s) => Ok(s.as_bytes().into()),
+ #[cfg(feature = "luau")]
+ Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
+ let mut size = 0usize;
+ let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
+ mlua_assert!(!buf.is_null(), "invalid Luau buffer");
+ Ok(slice::from_raw_parts(buf as *const u8, size).into())
+ },
+ _ => Ok(lua
+ .coerce_string(value)?
.ok_or_else(|| Error::FromLuaConversionError {
from: ty,
to: "BString",
message: Some("expected string or number".to_string()),
})?
.as_bytes()
- .to_vec(),
- ))
+ .into()),
+ }
+ }
+
+ unsafe fn from_stack(idx: c_int, lua: &'lua Lua) -> Result<Self> {
+ let state = lua.state();
+ match ffi::lua_type(state, idx) {
+ ffi::LUA_TSTRING => {
+ let mut size = 0;
+ let data = ffi::lua_tolstring(state, idx, &mut size);
+ Ok(slice::from_raw_parts(data as *const u8, size).into())
+ }
+ #[cfg(feature = "luau")]
+ ffi::LUA_TBUFFER => {
+ let mut size = 0;
+ let buf = ffi::lua_tobuffer(state, idx, &mut size);
+ mlua_assert!(!buf.is_null(), "invalid Luau buffer");
+ Ok(slice::from_raw_parts(buf as *const u8, size).into())
+ }
+ _ => {
+ // Fallback to default
+ Self::from_lua(lua.stack_value(idx), lua)
+ }
+ }
}
}
diff --git a/src/lua.rs b/src/lua.rs
index 9938a74..6691bf0 100644
--- a/src/lua.rs
+++ b/src/lua.rs
@@ -1373,6 +1373,27 @@ impl Lua {
}
}
+ /// Create and return a Luau [buffer] object from a byte slice of data.
+ ///
+ /// Requires `feature = "luau"`
+ ///
+ /// [buffer]: https://luau-lang.org/library#buffer-library
+ #[cfg(feature = "luau")]
+ pub fn create_buffer(&self, buf: impl AsRef<[u8]>) -> Result<AnyUserData> {
+ let state = self.state();
+ unsafe {
+ if self.unlikely_memory_error() {
+ crate::util::push_buffer(self.ref_thread(), buf.as_ref(), false)?;
+ return Ok(AnyUserData(self.pop_ref_thread(), SubtypeId::Buffer));
+ }
+
+ let _sg = StackGuard::new(state);
+ check_stack(state, 4)?;
+ crate::util::push_buffer(state, buf.as_ref(), true)?;
+ Ok(AnyUserData(self.pop_ref(), SubtypeId::Buffer))
+ }
+ }
+
/// Creates and returns a new empty table.
pub fn create_table(&self) -> Result<Table> {
self.create_table_with_capacity(0, 0)
diff --git a/src/serde/de.rs b/src/serde/de.rs
index 5d3512c..6933e4e 100644
--- a/src/serde/de.rs
+++ b/src/serde/de.rs
@@ -148,6 +148,14 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
Value::UserData(ud) if ud.is_serializable() => {
serde_userdata(ud, |value| value.deserialize_any(visitor))
}
+ #[cfg(feature = "luau")]
+ Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
+ let mut size = 0usize;
+ let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
+ mlua_assert!(!buf.is_null(), "invalid Luau buffer");
+ let buf = std::slice::from_raw_parts(buf as *const u8, size);
+ visitor.visit_bytes(buf)
+ },
Value::Function(_)
| Value::Thread(_)
| Value::UserData(_)
diff --git a/src/userdata.rs b/src/userdata.rs
index e2c4a74..dbbbffa 100644
--- a/src/userdata.rs
+++ b/src/userdata.rs
@@ -1340,6 +1340,19 @@ impl<'lua> Serialize for AnyUserData<'lua> {
S: Serializer,
{
let lua = self.0.lua;
+
+ // Special case for Luau buffer type
+ #[cfg(feature = "luau")]
+ if self.1 == SubtypeId::Buffer {
+ let buf = unsafe {
+ let mut size = 0usize;
+ let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size);
+ mlua_assert!(!buf.is_null(), "invalid Luau buffer");
+ std::slice::from_raw_parts(buf as *const u8, size)
+ };
+ return serializer.serialize_bytes(buf);
+ }
+
let data = unsafe {
let _ = lua
.get_userdata_ref_type_id(&self.0)
diff --git a/src/util/mod.rs b/src/util/mod.rs
index 595407e..ff8b28e 100644
--- a/src/util/mod.rs
+++ b/src/util/mod.rs
@@ -253,6 +253,20 @@ pub unsafe fn push_string(state: *mut ffi::lua_State, s: &[u8], protect: bool) -
}
}
+// Uses 3 stack spaces (when protect), does not call checkstack.
+#[cfg(feature = "luau")]
+#[inline(always)]
+pub unsafe fn push_buffer(state: *mut ffi::lua_State, b: &[u8], protect: bool) -> Result<()> {
+ let data = if protect {
+ protect_lua!(state, 0, 1, |state| ffi::lua_newbuffer(state, b.len()))?
+ } else {
+ ffi::lua_newbuffer(state, b.len())
+ };
+ let buf = slice::from_raw_parts_mut(data as *mut u8, b.len());
+ buf.copy_from_slice(b);
+ Ok(())
+}
+
// Uses 3 stack spaces, does not call checkstack.
#[inline]
pub unsafe fn push_table(
diff --git a/tests/conversion.rs b/tests/conversion.rs
index a312f11..ad2c09a 100644
--- a/tests/conversion.rs
+++ b/tests/conversion.rs
@@ -2,6 +2,7 @@ use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::ffi::{CStr, CString};
+use bstr::BString;
use maplit::{btreemap, btreeset, hashmap, hashset};
use mlua::{
AnyUserData, Error, Function, IntoLua, Lua, RegistryKey, Result, Table, Thread, UserDataRef,
@@ -409,3 +410,46 @@ fn test_conv_array() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_bstring_from_lua() -> Result<()> {
+ let lua = Lua::new();
+
+ let s = lua.create_string("hello, world")?;
+ let bstr = lua.unpack::<BString>(Value::String(s))?;
+ assert_eq!(bstr, "hello, world");
+
+ let bstr = lua.unpack::<BString>(Value::Integer(123))?;
+ assert_eq!(bstr, "123");
+
+ let bstr = lua.unpack::<BString>(Value::Number(-123.55))?;
+ assert_eq!(bstr, "-123.55");
+
+ // Test from stack
+ let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
+ let bstr = f.call::<_, BString>("hello, world")?;
+ assert_eq!(bstr, "hello, world");
+
+ let bstr = f.call::<_, BString>(-43.22)?;
+ assert_eq!(bstr, "-43.22");
+
+ Ok(())
+}
+
+#[cfg(feature = "luau")]
+#[test]
+fn test_bstring_from_lua_buffer() -> Result<()> {
+ let lua = Lua::new();
+
+ let b = lua.create_buffer("hello, world")?;
+ let bstr = lua.unpack::<BString>(Value::UserData(b))?;
+ assert_eq!(bstr, "hello, world");
+
+ // Test from stack
+ let f = lua.create_function(|_, bstr: BString| Ok(bstr))?;
+ let buf = lua.create_buffer("hello, world")?;
+ let bstr = f.call::<_, BString>(buf)?;
+ assert_eq!(bstr, "hello, world");
+
+ Ok(())
+}
diff --git a/tests/serde.rs b/tests/serde.rs
index 7e2e251..9c287a4 100644
--- a/tests/serde.rs
+++ b/tests/serde.rs
@@ -728,3 +728,30 @@ fn test_arbitrary_precision() {
"{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}"
);
}
+
+#[cfg(feature = "luau")]
+#[test]
+fn test_buffer_serialize() {
+ let lua = Lua::new();
+
+ let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
+ let val = serde_value::to_value(&buf).unwrap();
+ assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));
+
+ // Try empty buffer
+ let buf = lua.create_buffer(&[]).unwrap();
+ let val = serde_value::to_value(&buf).unwrap();
+ assert_eq!(val, serde_value::Value::Bytes(vec![]));
+}
+
+#[cfg(feature = "luau")]
+#[test]
+fn test_buffer_from_value() {
+ let lua = Lua::new();
+
+ let buf = lua.create_buffer(&[1, 2, 3, 4]).unwrap();
+ let val = lua
+ .from_value::<serde_value::Value>(Value::UserData(buf))
+ .unwrap();
+ assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4]));
+}