diff options
Diffstat (limited to 'examples/async_http_server.rs')
-rw-r--r-- | examples/async_http_server.rs | 147 |
1 files changed, 92 insertions, 55 deletions
diff --git a/examples/async_http_server.rs b/examples/async_http_server.rs index 17d60b0..43ae7a9 100644 --- a/examples/async_http_server.rs +++ b/examples/async_http_server.rs @@ -1,10 +1,16 @@ +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; +use hyper::service::Service; use hyper::{Body, Request, Response, Server}; -use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods}; +use mlua::{ + chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods, +}; struct LuaRequest(SocketAddr, Request<Body>); @@ -15,75 +21,106 @@ impl UserData for LuaRequest { } } -async fn run_server(handler: Function<'static>) -> Result<()> { - let make_svc = make_service_fn(|socket: &AddrStream| { - let remote_addr = socket.remote_addr(); - let handler = handler.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request<Body>| { - let handler = handler.clone(); - async move { - let lua_req = LuaRequest(remote_addr, req); - let lua_resp: Table = handler.call_async(lua_req).await?; - let body = lua_resp - .get::<_, Option<String>>("body")? - .unwrap_or_default(); +pub struct Svc(Rc<Lua>, SocketAddr); + +impl Service<Request<Body>> for Svc { + type Response = Response<Body>; + type Error = LuaError; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; - let mut resp = Response::builder() - .status(lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200)); + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: Request<Body>) -> Self::Future { + // If handler returns an error then generate 5xx response + let lua = self.0.clone(); + let lua_req = LuaRequest(self.1, req); + Box::pin(async move { + let handler: Function = lua.named_registry_value("http_handler")?; + match handler.call_async::<_, Table>(lua_req).await { + Ok(lua_resp) => { + let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200); + let mut resp = Response::builder().status(status); + + // Set headers if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? { - for pair in headers.pairs::<String, String>() { + for pair in headers.pairs::<String, LuaString>() { let (h, v) = pair?; - resp = resp.header(&h, v); + resp = resp.header(&h, v.as_bytes()); } } - Ok::<_, Error>(resp.body(Body::from(body)).unwrap()) + let body = lua_resp + .get::<_, Option<LuaString>>("body")? + .map(|b| Body::from(b.as_bytes().to_vec())) + .unwrap_or_else(Body::empty); + + Ok(resp.body(body).unwrap()) + } + Err(err) => { + eprintln!("{}", err); + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } + } + }) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let lua = Rc::new(Lua::new()); + + // Create Lua handler function + let handler: Function = lua + .load(chunk! { + function(req) + return { + status = 200, + headers = { + ["X-Req-Method"] = req:method(), + ["X-Remote-Addr"] = req:remote_addr(), + }, + body = "Hello from Lua!\n" } - })) - } - }); + end + }) + .eval() + .expect("cannot create Lua handler"); + + // Store it in the Registry + lua.set_named_registry_value("http_handler", handler) + .expect("cannot store Lua handler"); let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).executor(LocalExec).serve(make_svc); + let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua)); println!("Listening on http://{}", addr); - tokio::task::LocalSet::new() - .run_until(server) - .await - .map_err(Error::external) + // Create `LocalSet` to spawn !Send futures + let local = tokio::task::LocalSet::new(); + local.run_until(server).await.expect("cannot run server") } -#[tokio::main] -async fn main() -> Result<()> { - let lua = Lua::new().into_static(); +struct MakeSvc(Rc<Lua>); - let handler: Function = lua - .load( - r#" - function(req) - return { - status = 200, - headers = { - ["X-Req-Method"] = req:method(), - ["X-Remote-Addr"] = req:remote_addr(), - }, - body = "Hello, World!\n" - } - end - "#, - ) - .eval()?; - - run_server(handler).await?; - - // Consume the static reference and drop it. - // This is safe as long as we don't hold any other references to Lua - // or alive resources. - unsafe { Lua::from_static(lua) }; - Ok(()) +impl Service<&AddrStream> for MakeSvc { + type Response = Svc; + type Error = hyper::Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, stream: &AddrStream) -> Self::Future { + let lua = self.0.clone(); + let remote_addr = stream.remote_addr(); + Box::pin(async move { Ok(Svc(lua, remote_addr)) }) + } } #[derive(Clone, Copy, Debug)] |