summaryrefslogtreecommitdiff
path: root/examples/async_http_server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'examples/async_http_server.rs')
-rw-r--r--examples/async_http_server.rs147
1 files changed, 92 insertions, 55 deletions
diff --git a/examples/async_http_server.rs b/examples/async_http_server.rs
index 17d60b0..43ae7a9 100644
--- a/examples/async_http_server.rs
+++ b/examples/async_http_server.rs
@@ -1,10 +1,16 @@
+use std::future::Future;
use std::net::SocketAddr;
+use std::pin::Pin;
+use std::rc::Rc;
+use std::task::{Context, Poll};
use hyper::server::conn::AddrStream;
-use hyper::service::{make_service_fn, service_fn};
+use hyper::service::Service;
use hyper::{Body, Request, Response, Server};
-use mlua::{Error, Function, Lua, Result, Table, UserData, UserDataMethods};
+use mlua::{
+ chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods,
+};
struct LuaRequest(SocketAddr, Request<Body>);
@@ -15,75 +21,106 @@ impl UserData for LuaRequest {
}
}
-async fn run_server(handler: Function<'static>) -> Result<()> {
- let make_svc = make_service_fn(|socket: &AddrStream| {
- let remote_addr = socket.remote_addr();
- let handler = handler.clone();
- async move {
- Ok::<_, Error>(service_fn(move |req: Request<Body>| {
- let handler = handler.clone();
- async move {
- let lua_req = LuaRequest(remote_addr, req);
- let lua_resp: Table = handler.call_async(lua_req).await?;
- let body = lua_resp
- .get::<_, Option<String>>("body")?
- .unwrap_or_default();
+pub struct Svc(Rc<Lua>, SocketAddr);
+
+impl Service<Request<Body>> for Svc {
+ type Response = Response<Body>;
+ type Error = LuaError;
+ type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
- let mut resp = Response::builder()
- .status(lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200));
+ fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ fn call(&mut self, req: Request<Body>) -> Self::Future {
+ // If handler returns an error then generate 5xx response
+ let lua = self.0.clone();
+ let lua_req = LuaRequest(self.1, req);
+ Box::pin(async move {
+ let handler: Function = lua.named_registry_value("http_handler")?;
+ match handler.call_async::<_, Table>(lua_req).await {
+ Ok(lua_resp) => {
+ let status = lua_resp.get::<_, Option<u16>>("status")?.unwrap_or(200);
+ let mut resp = Response::builder().status(status);
+
+ // Set headers
if let Some(headers) = lua_resp.get::<_, Option<Table>>("headers")? {
- for pair in headers.pairs::<String, String>() {
+ for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?;
- resp = resp.header(&h, v);
+ resp = resp.header(&h, v.as_bytes());
}
}
- Ok::<_, Error>(resp.body(Body::from(body)).unwrap())
+ let body = lua_resp
+ .get::<_, Option<LuaString>>("body")?
+ .map(|b| Body::from(b.as_bytes().to_vec()))
+ .unwrap_or_else(Body::empty);
+
+ Ok(resp.body(body).unwrap())
+ }
+ Err(err) => {
+ eprintln!("{}", err);
+ Ok(Response::builder()
+ .status(500)
+ .body(Body::from("Internal Server Error"))
+ .unwrap())
+ }
+ }
+ })
+ }
+}
+
+#[tokio::main(flavor = "current_thread")]
+async fn main() {
+ let lua = Rc::new(Lua::new());
+
+ // Create Lua handler function
+ let handler: Function = lua
+ .load(chunk! {
+ function(req)
+ return {
+ status = 200,
+ headers = {
+ ["X-Req-Method"] = req:method(),
+ ["X-Remote-Addr"] = req:remote_addr(),
+ },
+ body = "Hello from Lua!\n"
}
- }))
- }
- });
+ end
+ })
+ .eval()
+ .expect("cannot create Lua handler");
+
+ // Store it in the Registry
+ lua.set_named_registry_value("http_handler", handler)
+ .expect("cannot store Lua handler");
let addr = ([127, 0, 0, 1], 3000).into();
- let server = Server::bind(&addr).executor(LocalExec).serve(make_svc);
+ let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua));
println!("Listening on http://{}", addr);
- tokio::task::LocalSet::new()
- .run_until(server)
- .await
- .map_err(Error::external)
+ // Create `LocalSet` to spawn !Send futures
+ let local = tokio::task::LocalSet::new();
+ local.run_until(server).await.expect("cannot run server")
}
-#[tokio::main]
-async fn main() -> Result<()> {
- let lua = Lua::new().into_static();
+struct MakeSvc(Rc<Lua>);
- let handler: Function = lua
- .load(
- r#"
- function(req)
- return {
- status = 200,
- headers = {
- ["X-Req-Method"] = req:method(),
- ["X-Remote-Addr"] = req:remote_addr(),
- },
- body = "Hello, World!\n"
- }
- end
- "#,
- )
- .eval()?;
-
- run_server(handler).await?;
-
- // Consume the static reference and drop it.
- // This is safe as long as we don't hold any other references to Lua
- // or alive resources.
- unsafe { Lua::from_static(lua) };
- Ok(())
+impl Service<&AddrStream> for MakeSvc {
+ type Response = Svc;
+ type Error = hyper::Error;
+ type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
+
+ fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn call(&mut self, stream: &AddrStream) -> Self::Future {
+ let lua = self.0.clone();
+ let remote_addr = stream.remote_addr();
+ Box::pin(async move { Ok(Svc(lua, remote_addr)) })
+ }
}
#[derive(Clone, Copy, Debug)]