summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Orlenko <zxteam@protonmail.com>2022-12-13 23:23:58 +0000
committerAlex Orlenko <zxteam@protonmail.com>2022-12-18 00:35:41 +0000
commit0aa30226df129a7c7497723ea5f3c3985c7377f5 (patch)
tree9d54aec8da1174c4e00ace156773e2f0aa75544d
parentfdb5724053828bc935b4381f43973c12f30e2104 (diff)
downloadmlua-0aa30226df129a7c7497723ea5f3c3985c7377f5.zip
Check for invalid args when parsing `#[lua_module(...)]` proc macro
-rw-r--r--mlua_derive/src/lib.rs40
-rw-r--r--tests/module/src/lib.rs6
2 files changed, 27 insertions, 19 deletions
diff --git a/mlua_derive/src/lib.rs b/mlua_derive/src/lib.rs
index 76fb415..7931dc0 100644
--- a/mlua_derive/src/lib.rs
+++ b/mlua_derive/src/lib.rs
@@ -1,7 +1,7 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
-use syn::{parse_macro_input, AttributeArgs, ItemFn, Lit, Meta, NestedMeta};
+use syn::{parse_macro_input, AttributeArgs, Error, ItemFn, Lit, Meta, NestedMeta, Result};
#[cfg(feature = "macros")]
use {
@@ -13,41 +13,49 @@ use {
struct ModuleArgs {
name: Option<Ident>,
}
+
impl ModuleArgs {
- fn parse(attr: AttributeArgs) -> Self {
+ fn parse(args: AttributeArgs) -> Result<Self> {
let mut ret = Self::default();
- for arg in attr {
+ for arg in args {
match arg {
NestedMeta::Meta(Meta::NameValue(meta)) => {
- if meta.path.segments.last().unwrap().ident == "name" {
- if let Lit::Str(val) = meta.lit {
- if let Ok(val) = val.parse() {
- ret.name = Some(val);
+ if meta.path.is_ident("name") {
+ match meta.lit {
+ Lit::Str(val) => {
+ ret.name = Some(val.parse()?);
+ }
+ _ => {
+ return Err(Error::new_spanned(meta.lit, "expected string literal"))
}
}
+ } else {
+ return Err(Error::new_spanned(meta.path, "expected `name`"));
}
}
- _ => {}
+ _ => {
+ return Err(Error::new_spanned(arg, "invalid argument"));
+ }
}
}
- ret
+
+ Ok(ret)
}
}
#[proc_macro_attribute]
pub fn lua_module(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as AttributeArgs);
- let args = ModuleArgs::parse(args);
+ let args = match ModuleArgs::parse(args) {
+ Ok(args) => args,
+ Err(err) => return err.to_compile_error().into(),
+ };
let func = parse_macro_input!(item as ItemFn);
let func_name = func.sig.ident.clone();
- let module_name = if let Some(name) = args.name {
- name
- } else {
- func_name.clone()
- };
- let ext_entrypoint_name = Ident::new(&format!("luaopen_{}", module_name), Span::call_site());
+ let module_name = args.name.unwrap_or_else(|| func_name.clone());
+ let ext_entrypoint_name = Ident::new(&format!("luaopen_{module_name}"), Span::call_site());
let wrapped = quote! {
::mlua::require_module_feature!();
diff --git a/tests/module/src/lib.rs b/tests/module/src/lib.rs
index c624078..4e5aae5 100644
--- a/tests/module/src/lib.rs
+++ b/tests/module/src/lib.rs
@@ -12,7 +12,7 @@ fn check_userdata(_: &Lua, ud: MyUserData) -> LuaResult<i32> {
Ok(ud.0)
}
-#[mlua::lua_module(name = "rust_module_first")]
+#[mlua::lua_module]
fn rust_module(lua: &Lua) -> LuaResult<LuaTable> {
let exports = lua.create_table()?;
exports.set("sum", lua.create_function(sum)?)?;
@@ -26,8 +26,8 @@ struct MyUserData(i32);
impl LuaUserData for MyUserData {}
-#[mlua::lua_module]
-fn rust_module_second(lua: &Lua) -> LuaResult<LuaTable> {
+#[mlua::lua_module(name = "rust_module_second")]
+fn rust_module2(lua: &Lua) -> LuaResult<LuaTable> {
let exports = lua.create_table()?;
exports.set("userdata", lua.create_userdata(MyUserData(123))?)?;
Ok(exports)