Dec 02, 2024
3 min read
Rust,
Axum,

如何实现 rust axum 框架的中间件

axum 的独特之处在于它没有自己的定制中间件系统,而是与tower集成。这意味着tower的生态系统和 tower-http中间件都与 axum 一起工作。

axum 是一个 Rust Web 框架,它使用 tower 作为其内部抽象。因此,它没有自己的定制中间件系统。相反,它与 tower 集成。这意味着 tower 的生态系统和 tower-http 中间件都与 axum 一起工作。

虽然不需要完全理解 tower 来编写或使用 axum 中间件,但建议至少对 tower 的概念有基本的了解。有关一般介绍,请参阅tower的指南。还建议阅读tower::ServiceBuilder的文档。

如何实现 axum 中间件

事实上,axum 提供了多种编写中间件的方法,这些方法具有不同的抽象级别,并且具有不同的优缺点。简单来说,抽象级别越低,实现起来越复杂。

它有一些限制:

  1. 这样编写的中间件仅与 axum 兼容

axum::middleware::from_fnaxum::middleware::from_extractor_with_state

axum::middleware::from_fn 是最简单的一种方式,它接受一个函数,该函数接受一个 Request 和一个 Next,并返回一个 Response。它的签名看起来像这样:


async fn custom_middleware(
    request: Request,
    next: Next,
) -> Response {
    let response = next.run(request).await;
    response
}

事实上,它可以像Handler 一样使用提取器(Extractors)依赖注入,返回也可以是 Result,例如:

async fn my_middleware(
    headers: HeaderMap,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode>  {
    let response = next.run(request).await;
    Ok(response)
}

axum::middleware::from_fn_with_state 功能基本和 axum::middleware::from_fn 一样,只是多了一个 State 参数,这很方便我们把 State 注入进来,实现更多功能的中间件(如认证、日志等)。

应用中间件的区别如下:


//axum::middleware::from_fn
let app = Router::new()
    .route("/", get(index))
    .layer(middleware::from_fn(custom_middleware));

//axum::middleware::from_fn_with_state
let app = Router::new()
    .route("/", get(index))
    .route_layer(middleware::from_fn_with_state(state.clone(), custom_middleware))
    .with_state(state);

数据传递

有一个场景是,我们需要在中间件中获取一些数据,再在后续的流程中(特别是响应处理阶段)使用。以从请求头中获取用户身份为例,我们先定义一个结构体:

#[derive(Clone)]
pub struct CurrentUser{
    pub user_id: String,
    pub email: String,
}

再定义一个中间件:

pub(crate) async fn auth(
    State(ref state): State<AppState>,
    mut req: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    // 获取请求头中的api_key,在 authorization 中
    let api_key = req
        .headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|h| h.to_str().ok())
        .ok_or(StatusCode::UNAUTHORIZED)?
        .strip_prefix("Bearer ")
        .ok_or(StatusCode::UNAUTHORIZED)?.trim();
    
    let (user_id,email) = get_user_info(api_key).await?; /* */;

    let current_user = CurrentUser {
        user_id,
        email,
    };

    req.extensions_mut().insert(current_user); //current_user = 1

    Ok(next.run(req).await)
}

handler 中使用:


pub(crate) async fn index(
    Extension(current_user): Extension<CurrentUserFromApiKey>,
    State(ref db): State<DatabaseConnection>,
    State(ref config): State<Settings>,
) -> Result<impl IntoResponse> {

    // current_user = 1

    not_implemented!()
}

axum::middleware::from_extractoraxum::middleware::from_extractor_with_state

from_extractor 系列和 from_fn 类似,但是 from_extractorfrom_fn 不同的地方在于,你有一个类型,有时想用作为提取器使用,有时又想作为中间件使用。


struct RequireAuth;

#[async_trait]
impl<S> FromRequestParts<S> for RequireAuth
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let auth_header = parts
            .headers
            .get(header::AUTHORIZATION)
            .and_then(|value| value.to_str().ok());

        match auth_header {
            Some(auth_header) if token_is_valid(auth_header) => {
                Ok(Self)
            }
            _ => Err(StatusCode::UNAUTHORIZED),
        }
    }
}

使用方式:

async fn index() {
    //todo
}
let app = Router::new()
    .route("/", get(index))
    .route_layer(from_extractor::<RequireAuth>());

axum::middleware::from_extractor_with_state 功能和和 axum::middleware::from_extractor 一样,只是多了一个 State 参数,这很方便我们把 State 注入进来,实现更多功能的中间件(如认证、日志等)。

tower’s combinators

tower 有几个实用程序组合器,可用于对请求或响应执行简单的修改。最常用的是

ServiceBuilder::map_request ServiceBuilder::map_response ServiceBuilder::then ServiceBuilder::and_then

不过这些组合器并不能满足所有需求,只支持简单的修改。

tower::ServicePin<Box<dyn Future>>

这是一个更低级别的 API,它允许你最大程度地控制请求处理。通常情况下,tower::ServicePin<Box<dyn Future>> 适合以下情形:

  1. 你的中间件需要可配置
  2. 你想中间件作为 crates 发布以供其他人使用。
  3. 你对实现自己的未来感到不舒服。

下面是使用官网 tower::Service 的模板例子:

use axum::{
    response::Response,
    body::Body,
    extract::Request,
};
use futures_util::future::BoxFuture;
use tower::{Service, Layer};
use std::task::{Context, Poll};

#[derive(Clone)]
struct MyLayer;

impl<S> Layer<S> for MyLayer {
    type Service = MyMiddleware<S>;

    fn layer(&self, inner: S) -> Self::Service {
        MyMiddleware { inner }
    }
}

#[derive(Clone)]
struct MyMiddleware<S> {
    inner: S,
}

impl<S> Service<Request> for MyMiddleware<S>
where
    S: Service<Request, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let future = self.inner.call(request);
        Box::pin(async move {
            let response: Response = future.await?;
            Ok(response)
        })
    }
}

应用 中间件:


let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
        .layer(HandleErrorLayer::new(|_: BoxError| async {
            StatusCode::BAD_REQUEST
        }))
    )

总结

如果只是简单的需求,最好使用axum::middleware::from_fnaxum::middleware::from_extractor_with_state 来实现。

如果你需要更复杂的需求,比如中间件需要配置,或者需要处理 futures,那么使用 tower::Service 是更好的选择。