summaryrefslogtreecommitdiff
path: root/examples/async_tcp_server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'examples/async_tcp_server.rs')
-rw-r--r--examples/async_tcp_server.rs179
1 files changed, 89 insertions, 90 deletions
diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs
index 21c81c2..edfc114 100644
--- a/examples/async_tcp_server.rs
+++ b/examples/async_tcp_server.rs
@@ -1,122 +1,121 @@
-use std::sync::Arc;
+use std::io;
+use std::net::SocketAddr;
+use std::rc::Rc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::Mutex;
use tokio::task;
-use mlua::{chunk, Function, Lua, Result, String as LuaString, UserData, UserDataMethods};
+use mlua::{
+ chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods,
+};
-struct LuaTcp;
+struct LuaTcpStream(TcpStream);
-#[derive(Clone)]
-struct LuaTcpListener(Arc<Mutex<TcpListener>>);
-
-#[derive(Clone)]
-struct LuaTcpStream(Arc<Mutex<TcpStream>>);
-
-impl UserData for LuaTcp {
+impl UserData for LuaTcpStream {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
- methods.add_async_function("bind", |_, addr: String| async move {
- let listener = TcpListener::bind(addr).await?;
- Ok(LuaTcpListener(Arc::new(Mutex::new(listener))))
+ methods.add_method("peer_addr", |_, this, ()| {
+ Ok(this.0.peer_addr()?.to_string())
});
- }
-}
-impl UserData for LuaTcpListener {
- fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
- methods.add_async_method("accept", |_, listener, ()| async move {
- let (stream, _) = listener.0.lock().await.accept().await?;
- Ok(LuaTcpStream(Arc::new(Mutex::new(stream))))
+ methods.add_async_function(
+ "read",
+ |lua, (this, size): (AnyUserData, usize)| async move {
+ let mut this = this.borrow_mut::<Self>()?;
+ let mut buf = vec![0; size];
+ let n = this.0.read(&mut buf).await?;
+ buf.truncate(n);
+ lua.create_string(&buf)
+ },
+ );
+
+ methods.add_async_function(
+ "write",
+ |_, (this, data): (AnyUserData, LuaString)| async move {
+ let mut this = this.borrow_mut::<Self>()?;
+ let n = this.0.write(&data.as_bytes()).await?;
+ Ok(n)
+ },
+ );
+
+ methods.add_async_function("close", |_, this: AnyUserData| async move {
+ let mut this = this.borrow_mut::<Self>()?;
+ this.0.shutdown().await?;
+ Ok(())
});
}
}
-impl UserData for LuaTcpStream {
- fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
- methods.add_async_method("peer_addr", |_, stream, ()| async move {
- Ok(stream.0.lock().await.peer_addr()?.to_string())
- });
-
- methods.add_async_method("read", |lua, stream, size: usize| async move {
- let mut buf = vec![0; size];
- let n = stream.0.lock().await.read(&mut buf).await?;
- buf.truncate(n);
- lua.create_string(&buf)
- });
-
- methods.add_async_method("write", |_, stream, data: LuaString| async move {
- let n = stream.0.lock().await.write(&data.as_bytes()).await?;
- Ok(n)
- });
-
- methods.add_async_method("close", |_, stream, ()| async move {
- stream.0.lock().await.shutdown().await?;
- Ok(())
+async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
+ let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
+ let listener = TcpListener::bind(addr).await.expect("cannot bind addr");
+
+ println!("Listening on {}", addr);
+
+ let lua = Rc::new(lua);
+ let handler = Rc::new(handler);
+ loop {
+ let (stream, _) = match listener.accept().await {
+ Ok(res) => res,
+ Err(err) if is_transient_error(&err) => continue,
+ Err(err) => return Err(err),
+ };
+
+ let lua = lua.clone();
+ let handler = handler.clone();
+ task::spawn_local(async move {
+ let handler: Function = lua
+ .registry_value(&handler)
+ .expect("cannot get Lua handler");
+
+ let stream = LuaTcpStream(stream);
+ if let Err(err) = handler.call_async::<_, ()>(stream).await {
+ eprintln!("{}", err);
+ }
});
}
}
-async fn run_server(lua: &'static Lua) -> Result<()> {
- let spawn = lua.create_function(move |_, func: Function| {
- task::spawn_local(async move { func.call_async::<_, ()>(()).await });
- Ok(())
- })?;
-
- let tcp = LuaTcp;
+#[tokio::main(flavor = "current_thread")]
+async fn main() {
+ let lua = Lua::new();
- let server = lua
+ // Create Lua handler function
+ let handler_fn = lua
.load(chunk! {
- local addr = ...
- local listener = $tcp.bind(addr)
- print("listening on "..addr)
-
- local accept_new = true
- while true do
- local stream = listener:accept()
+ function(stream)
local peer_addr = stream:peer_addr()
print("connected from "..peer_addr)
- if not accept_new then
- return
- end
-
- $spawn(function()
- while true do
- local data = stream:read(100)
- data = data:match("^%s*(.-)%s*$") -- trim
- print("["..peer_addr.."] "..data)
- if data == "bye" then
- stream:write("bye bye\n")
- stream:close()
- return
- end
- if data == "exit" then
- stream:close()
- accept_new = false
- return
- end
- stream:write("echo: "..data.."\n")
+ while true do
+ local data = stream:read(100)
+ data = data:match("^%s*(.-)%s*$") // trim
+ print("["..peer_addr.."] "..data)
+ if data == "bye" then
+ stream:write("bye bye\n")
+ stream:close()
+ return
end
- end)
+ stream:write("echo: "..data.."\n")
+ end
end
})
- .into_function()?;
+ .eval::<Function>()
+ .expect("cannot create Lua handler");
+
+ // Store it in the Registry
+ let handler = lua
+ .create_registry_value(handler_fn)
+ .expect("cannot store Lua handler");
task::LocalSet::new()
- .run_until(server.call_async::<_, ()>("0.0.0.0:1234"))
+ .run_until(run_server(lua, handler))
.await
+ .expect("cannot run server")
}
-#[tokio::main]
-async fn main() {
- let lua = Lua::new().into_static();
-
- run_server(lua).await.unwrap();
-
- // Consume the static reference and drop it.
- // This is safe as long as we don't hold any other references to Lua
- // or alive resources.
- unsafe { Lua::from_static(lua) };
+fn is_transient_error(e: &io::Error) -> bool {
+ e.kind() == io::ErrorKind::ConnectionRefused
+ || e.kind() == io::ErrorKind::ConnectionAborted
+ || e.kind() == io::ErrorKind::ConnectionReset
}