Creating state
State must implementClone to be shared across handlers:
use axum::{Router, routing::get, extract::State};
#[derive(Clone)]
struct AppState {
api_key: String,
request_count: u64,
}
async fn handler(State(state): State<AppState>) -> String {
format!("API Key: {}", state.api_key)
}
let state = AppState {
api_key: "secret".to_string(),
request_count: 0,
};
let app = Router::new()
.route("/", get(handler))
.with_state(state);
Using with_state
Provide state to your router using the with_state method:
Define your state type
Create a struct that implements
Clone:#[derive(Clone)]
struct AppState {
db_pool: DbPool,
config: Config,
}
Create handlers that use State
Use the
State extractor in your handlers:use axum::extract::State;
async fn handler(State(state): State<AppState>) -> String {
// Use state...
format!("Connected to: {}", state.config.db_url)
}
Shared mutable state
Since state is cloned for each request, useArc<Mutex<_>> or Arc<RwLock<_>> for mutable state:
- Using Arc<Mutex<_>>
- Using Arc<RwLock<_>>
- Using Tokio Mutex
use axum::{Router, routing::get, extract::State};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct AppState {
counter: Arc<Mutex<u64>>,
}
async fn increment(State(state): State<AppState>) -> String {
let mut counter = state.counter.lock().unwrap();
*counter += 1;
format!("Counter: {}", *counter)
}
let state = AppState {
counter: Arc::new(Mutex::new(0)),
};
let app = Router::new()
.route("/increment", get(increment))
.with_state(state);
Don’t hold a
std::sync::Mutex across .await points. Use tokio::sync::Mutex instead.use axum::{Router, routing::get, extract::State};
use std::sync::{Arc, RwLock};
use std::collections::HashMap;
#[derive(Clone)]
struct AppState {
data: Arc<RwLock<HashMap<String, String>>>,
}
async fn read_data(State(state): State<AppState>) -> String {
let data = state.data.read().unwrap();
format!("Data: {:?}", *data)
}
async fn write_data(State(state): State<AppState>) {
let mut data = state.data.write().unwrap();
data.insert("key".to_string(), "value".to_string());
}
let state = AppState {
data: Arc::new(RwLock::new(HashMap::new())),
};
let app = Router::new()
.route("/read", get(read_data))
.route("/write", get(write_data))
.with_state(state);
use axum::{Router, routing::get, extract::State};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct AppState {
counter: Arc<Mutex<u64>>,
}
async fn increment(State(state): State<AppState>) -> String {
// Safe to hold across .await
let mut counter = state.counter.lock().await;
*counter += 1;
// Can await here while holding the lock
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
format!("Counter: {}", *counter)
}
let state = AppState {
counter: Arc::new(Mutex::new(0)),
};
let app = Router::new()
.route("/increment", get(increment))
.with_state(state);
Real-world example
A complete example with database connection pool:use axum::{
Router,
routing::{get, post},
extract::State,
http::StatusCode,
};
use std::sync::Arc;
// Simplified database connection pool
struct DbPool {
// In real apps, this would be sqlx::PgPool or similar
}
impl DbPool {
fn get_connection(&self) -> DbConnection {
DbConnection {}
}
}
struct DbConnection {}
impl DbConnection {
async fn query(&self, _sql: &str) -> Vec<String> {
vec!["result".to_string()]
}
}
#[derive(Clone)]
struct AppState {
db: Arc<DbPool>,
api_key: String,
}
async fn list_users(State(state): State<AppState>) -> Result<String, StatusCode> {
let conn = state.db.get_connection();
let users = conn.query("SELECT * FROM users").await;
Ok(format!("Users: {:?}", users))
}
async fn create_user(State(state): State<AppState>) -> StatusCode {
let conn = state.db.get_connection();
// Insert user...
StatusCode::CREATED
}
let state = AppState {
db: Arc::new(DbPool {}),
api_key: "secret".to_string(),
};
let app = Router::new()
.route("/users", get(list_users).post(create_user))
.with_state(state);
Combining routers with state
When combining routers, they must have the same state type:use axum::{Router, routing::get, extract::State};
#[derive(Clone)]
struct AppState {
config: String,
}
// Both routers use AppState
fn user_routes() -> Router<AppState> {
Router::new()
.route("/users", get(list_users))
}
fn post_routes() -> Router<AppState> {
Router::new()
.route("/posts", get(list_posts))
}
async fn list_users(State(state): State<AppState>) -> String {
format!("Config: {}", state.config)
}
async fn list_posts(State(state): State<AppState>) -> String {
format!("Config: {}", state.config)
}
let state = AppState {
config: "production".to_string(),
};
let app = Router::new()
.merge(user_routes())
.merge(post_routes())
.with_state(state);
Substates with FromRef
Extract parts of your state using theFromRef trait:
use axum::{
Router,
routing::get,
extract::{State, FromRef},
};
// Main application state
#[derive(Clone)]
struct AppState {
api_state: ApiState,
db_state: DbState,
}
// API-specific state
#[derive(Clone)]
struct ApiState {
api_key: String,
}
// Database-specific state
#[derive(Clone)]
struct DbState {
connection_string: String,
}
// Allow extracting ApiState from AppState
impl FromRef<AppState> for ApiState {
fn from_ref(app_state: &AppState) -> ApiState {
app_state.api_state.clone()
}
}
// Allow extracting DbState from AppState
impl FromRef<AppState> for DbState {
fn from_ref(app_state: &AppState) -> DbState {
app_state.db_state.clone()
}
}
// Handler only needs ApiState
async fn api_handler(State(api): State<ApiState>) -> String {
format!("API Key: {}", api.api_key)
}
// Handler only needs DbState
async fn db_handler(State(db): State<DbState>) -> String {
format!("DB: {}", db.connection_string)
}
// Handler needs full state
async fn full_handler(State(state): State<AppState>) -> String {
format!("API: {}, DB: {}",
state.api_state.api_key,
state.db_state.connection_string
)
}
let state = AppState {
api_state: ApiState {
api_key: "secret".to_string(),
},
db_state: DbState {
connection_string: "postgres://...".to_string(),
},
};
let app = Router::new()
.route("/api", get(api_handler))
.route("/db", get(db_handler))
.route("/full", get(full_handler))
.with_state(state);
You can also use
#[derive(FromRef)] to automatically implement FromRef for each field:use axum::extract::FromRef;
#[derive(Clone, FromRef)]
struct AppState {
api_state: ApiState,
db_state: DbState,
}
State in middleware
Access state in middleware usingfrom_fn_with_state:
use axum::{
Router,
routing::get,
extract::{State, Request},
middleware::{self, Next},
response::Response,
http::StatusCode,
};
#[derive(Clone)]
struct AppState {
api_key: String,
}
async fn auth_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Use state to validate request
if request.headers().get("api-key").map(|v| v.as_bytes())
!= Some(state.api_key.as_bytes()) {
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
let state = AppState {
api_key: "secret".to_string(),
};
let app = Router::new()
.route("/protected", get(|| async { "Protected" }))
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware))
.with_state(state);
Next steps
Middleware
Learn how to use middleware with state
Extractors
See all available extractors including State