diff options
Diffstat (limited to 'examples/async_tcp_server.rs')
-rw-r--r-- | examples/async_tcp_server.rs | 179 |
1 files changed, 89 insertions, 90 deletions
diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs index 21c81c2..edfc114 100644 --- a/examples/async_tcp_server.rs +++ b/examples/async_tcp_server.rs @@ -1,122 +1,121 @@ -use std::sync::Arc; +use std::io; +use std::net::SocketAddr; +use std::rc::Rc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Mutex; use tokio::task; -use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods}; +use mlua::{ + chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods, +}; -struct LuaTcp; +struct LuaTcpStream(TcpStream); -#[derive(Clone)] -struct LuaTcpListener(Arc<Mutex<TcpListener>>); - -#[derive(Clone)] -struct LuaTcpStream(Arc<Mutex<TcpStream>>); - -impl UserData for LuaTcp { +impl UserData for LuaTcpStream { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_function("bind", |_, addr: String| async move { - let listener = TcpListener::bind(addr).await?; - Ok(LuaTcpListener(Arc::new(Mutex::new(listener)))) + methods.add_method("peer_addr", |_, this, ()| { + Ok(this.0.peer_addr()?.to_string()) }); - } -} -impl UserData for LuaTcpListener { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("accept", |_, listener, ()| async move { - let (stream, _) = listener.0.lock().await.accept().await?; - Ok(LuaTcpStream(Arc::new(Mutex::new(stream)))) + methods.add_async_function( + "read", + |lua, (this, size): (AnyUserData, usize)| async move { + let mut this = this.borrow_mut::<Self>()?; + let mut buf = vec![0; size]; + let n = this.0.read(&mut buf).await?; + buf.truncate(n); + lua.create_string(&buf) + }, + ); + + methods.add_async_function( + "write", + |_, (this, data): (AnyUserData, LuaString)| async move { + let mut this = this.borrow_mut::<Self>()?; + let n = this.0.write(&data.as_bytes()).await?; + Ok(n) + }, + ); + + methods.add_async_function("close", |_, this: AnyUserData| async move { + let mut this = this.borrow_mut::<Self>()?; + this.0.shutdown().await?; + Ok(()) }); } } -impl UserData for LuaTcpStream { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("peer_addr", |_, stream, ()| async move { - Ok(stream.0.lock().await.peer_addr()?.to_string()) - }); - - methods.add_async_method("read", |lua, stream, size: usize| async move { - let mut buf = vec![0; size]; - let n = stream.0.lock().await.read(&mut buf).await?; - buf.truncate(n); - lua.create_string(&buf) - }); - - methods.add_async_method("write", |_, stream, data: LuaString| async move { - let n = stream.0.lock().await.write(&data.as_bytes()).await?; - Ok(n) - }); - - methods.add_async_method("close", |_, stream, ()| async move { - stream.0.lock().await.shutdown().await?; - Ok(()) +async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> { + let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); + let listener = TcpListener::bind(addr).await.expect("cannot bind addr"); + + println!("Listening on {}", addr); + + let lua = Rc::new(lua); + let handler = Rc::new(handler); + loop { + let (stream, _) = match listener.accept().await { + Ok(res) => res, + Err(err) if is_transient_error(&err) => continue, + Err(err) => return Err(err), + }; + + let lua = lua.clone(); + let handler = handler.clone(); + task::spawn_local(async move { + let handler: Function = lua + .registry_value(&handler) + .expect("cannot get Lua handler"); + + let stream = LuaTcpStream(stream); + if let Err(err) = handler.call_async::<_, ()>(stream).await { + eprintln!("{}", err); + } }); } } -async fn run_server(lua: &'static Lua) -> Result<()> { - let spawn = lua.create_function(move |_, func: Function| { - task::spawn_local(async move { func.call_async::<_, ()>(()).await }); - Ok(()) - })?; - - let tcp = LuaTcp; +#[tokio::main(flavor = "current_thread")] +async fn main() { + let lua = Lua::new(); - let server = lua + // Create Lua handler function + let handler_fn = lua .load(chunk! { - local addr = ... - local listener = $tcp.bind(addr) - print("listening on "..addr) - - local accept_new = true - while true do - local stream = listener:accept() + function(stream) local peer_addr = stream:peer_addr() print("connected from "..peer_addr) - if not accept_new then - return - end - - $spawn(function() - while true do - local data = stream:read(100) - data = data:match("^%s*(.-)%s*$") -- trim - print("["..peer_addr.."] "..data) - if data == "bye" then - stream:write("bye bye\n") - stream:close() - return - end - if data == "exit" then - stream:close() - accept_new = false - return - end - stream:write("echo: "..data.."\n") + while true do + local data = stream:read(100) + data = data:match("^%s*(.-)%s*$") // trim + print("["..peer_addr.."] "..data) + if data == "bye" then + stream:write("bye bye\n") + stream:close() + return end - end) + stream:write("echo: "..data.."\n") + end end }) - .into_function()?; + .eval::<Function>() + .expect("cannot create Lua handler"); + + // Store it in the Registry + let handler = lua + .create_registry_value(handler_fn) + .expect("cannot store Lua handler"); task::LocalSet::new() - .run_until(server.call_async::<_, ()>("0.0.0.0:1234")) + .run_until(run_server(lua, handler)) .await + .expect("cannot run server") } -#[tokio::main] -async fn main() { - let lua = Lua::new().into_static(); - - run_server(lua).await.unwrap(); - - // Consume the static reference and drop it. - // This is safe as long as we don't hold any other references to Lua - // or alive resources. - unsafe { Lua::from_static(lua) }; +fn is_transient_error(e: &io::Error) -> bool { + e.kind() == io::ErrorKind::ConnectionRefused + || e.kind() == io::ErrorKind::ConnectionAborted + || e.kind() == io::ErrorKind::ConnectionReset } |