From 5465f4a5e8a50ad988a6c3cb4439d58e763df805 Mon Sep 17 00:00:00 2001 From: Michael Lazear Date: Wed, 2 Mar 2022 14:05:05 -0800 Subject: [PATCH] Call `handler.poll_ready()` before `handler.call()` According to the tower::Service documentation and API contract, `poll_ready` must be called and a `Poll:Ready` must be obtained before invoking `call` --- lambda-runtime/src/lib.rs | 73 +++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index a5be8fd1..a178fa3b 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -13,8 +13,8 @@ use serde::{Deserialize, Serialize}; use std::{convert::TryFrom, env, fmt, future::Future, panic}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; -use tower::util::ServiceFn; pub use tower::{self, service_fn, Service}; +use tower::{util::ServiceFn, ServiceExt}; use tracing::{error, trace}; mod requests; @@ -112,41 +112,56 @@ where env::set_var("_X_AMZN_TRACE_ID", xray_trace_id); let request_id = &ctx.request_id.clone(); - let task = panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(LambdaEvent::new(body, ctx)))); - - let req = match task { - Ok(response) => match response.await { - Ok(response) => { - trace!("Ok response from handler (run loop)"); - EventCompletionRequest { - request_id, - body: response, - } - .into_req() - } - Err(err) => { - error!("{:?}", err); // logs the error in CloudWatch - EventErrorRequest { - request_id, - diagnostic: Diagnostic { - error_type: type_name_of_val(&err).to_owned(), - error_message: format!("{}", err), // returns the error to the caller via Lambda API - }, + let req = match handler.ready().await { + Ok(handler) => { + let task = + panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(LambdaEvent::new(body, ctx)))); + match task { + Ok(response) => match response.await { + Ok(response) => { + trace!("Ok response from handler (run loop)"); + EventCompletionRequest { + request_id, + body: response, + } + .into_req() + } + Err(err) => { + error!("{:?}", err); // logs the error in CloudWatch + EventErrorRequest { + request_id, + diagnostic: Diagnostic { + error_type: type_name_of_val(&err).to_owned(), + error_message: format!("{}", err), // returns the error to the caller via Lambda API + }, + } + .into_req() + } + }, + Err(err) => { + error!("{:?}", err); + EventErrorRequest { + request_id, + diagnostic: Diagnostic { + error_type: type_name_of_val(&err).to_owned(), + error_message: if let Some(msg) = err.downcast_ref::<&str>() { + format!("Lambda panicked: {}", msg) + } else { + "Lambda panicked".to_string() + }, + }, + } + .into_req() } - .into_req() } - }, + } Err(err) => { - error!("{:?}", err); + error!("{:?}", err); // logs the error in CloudWatch EventErrorRequest { request_id, diagnostic: Diagnostic { error_type: type_name_of_val(&err).to_owned(), - error_message: if let Some(msg) = err.downcast_ref::<&str>() { - format!("Lambda panicked: {}", msg) - } else { - "Lambda panicked".to_string() - }, + error_message: format!("{}", err), // returns the error to the caller via Lambda API }, } .into_req()