Skip to main content
Axum provides full WebSocket support through the ws extractor, enabling real-time bidirectional communication between clients and servers.

Basic WebSocket setup

1

Extract the WebSocket upgrade

Use the WebSocketUpgrade extractor in your handler to initiate the WebSocket handshake:
use axum::{
    extract::ws::{WebSocketUpgrade, WebSocket},
    response::Response,
    routing::any,
    Router,
};

async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    ws.on_upgrade(handle_socket)
}

let app = Router::new().route("/ws", any(ws_handler));
Use any() instead of get() to support both HTTP/1.1 (which uses GET) and HTTP/2+ (which uses CONNECT) for WebSocket upgrades.
2

Handle the WebSocket connection

Implement the socket handler to send and receive messages:
async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = if let Ok(msg) = msg {
            msg
        } else {
            // client disconnected
            return;
        };

        if socket.send(msg).await.is_err() {
            // client disconnected
            return;
        }
    }
}
3

Start the server

Run your application with WebSocket support:
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
    .await
    .unwrap();
axum::serve(listener, app).await;

Extracting connection metadata

You can extract HTTP headers, client IP addresses, and other metadata during the upgrade:
use axum::{
    extract::{
        ws::{WebSocketUpgrade, WebSocket},
        ConnectInfo,
    },
    response::IntoResponse,
};
use axum_extra::TypedHeader;
use std::net::SocketAddr;

async fn ws_handler(
    ws: WebSocketUpgrade,
    user_agent: Option<TypedHeader<headers::UserAgent>>,
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
    let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
        user_agent.to_string()
    } else {
        String::from("Unknown browser")
    };
    
    println!("`{user_agent}` at {addr} connected.");
    ws.on_upgrade(move |socket| handle_socket(socket, addr))
}

async fn handle_socket(socket: WebSocket, who: SocketAddr) {
    // Handle the connection with access to client info
}
Remember to use into_make_service_with_connect_info::<SocketAddr>() when serving to enable ConnectInfo extraction.

Message types

WebSocket messages come in several types:
use axum::extract::ws::Message;

// Send text
let msg = Message::Text("Hello, client!".into());
socket.send(msg).await.unwrap();

// Receive text
if let Message::Text(text) = msg {
    println!("Received: {text}");
}

Concurrent send and receive

Split the socket to send and receive simultaneously:
use futures_util::{sink::SinkExt, stream::StreamExt};

async fn handle_socket(socket: WebSocket) {
    let (mut sender, mut receiver) = socket.split();

    // Spawn a task to send messages
    let mut send_task = tokio::spawn(async move {
        for i in 0..10 {
            if sender
                .send(Message::Text(format!("Message {i}").into()))
                .await
                .is_err()
            {
                return;
            }
            tokio::time::sleep(std::time::Duration::from_millis(300)).await;
        }
    });

    // Receive messages in the current task
    let mut recv_task = tokio::spawn(async move {
        while let Some(Ok(msg)) = receiver.next().await {
            // Process incoming messages
            println!("Received: {msg:?}");
        }
    });

    // Wait for either task to finish
    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    }
}

WebSocket protocols

Negotiate sub-protocols with clients:
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    ws.protocols(["graphql-ws", "graphql-transport-ws"])
        .on_upgrade(|socket| async move {
            // Check which protocol was selected
            if let Some(protocol) = socket.protocol() {
                println!("Using protocol: {}", protocol.to_str().unwrap());
            }
            handle_socket(socket).await;
        })
}
You can also check what protocols the client requested:
async fn ws_handler(mut ws: WebSocketUpgrade) -> Response {
    for protocol in ws.requested_protocols() {
        println!("Client requested: {}", protocol.to_str().unwrap());
    }
    
    // Manually set the selected protocol
    ws.set_selected_protocol(HeaderValue::from_static("custom-protocol"));
    ws.on_upgrade(handle_socket)
}

Configuration options

Customize WebSocket behavior:
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    ws
        .max_message_size(10 * 1024 * 1024) // 10 MB
        .max_frame_size(2 * 1024 * 1024)    // 2 MB
        .write_buffer_size(256 * 1024)      // 256 KB
        .on_upgrade(handle_socket)
}
  • max_message_size(usize) - Maximum message size (default: 64 MB)
  • max_frame_size(usize) - Maximum frame size (default: 16 MB)
  • read_buffer_size(usize) - Read buffer capacity (default: 128 KB)
  • write_buffer_size(usize) - Write buffer size (default: 128 KB)
  • max_write_buffer_size(usize) - Maximum write buffer size (default: unlimited)
  • accept_unmasked_frames(bool) - Accept unmasked frames from clients (default: false)

Error handling

Handle upgrade failures gracefully:
async fn ws_handler(ws: WebSocketUpgrade) -> Response {
    ws.on_failed_upgrade(|error| {
        eprintln!("WebSocket upgrade failed: {error}");
        // Log to monitoring service, etc.
    })
    .on_upgrade(handle_socket)
}

Full example

Here’s a complete chat server example:
use axum::{
    body::Bytes,
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        ConnectInfo,
    },
    response::IntoResponse,
    routing::any,
    Router,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use std::{net::SocketAddr, ops::ControlFlow};
use tower_http::trace::TraceLayer;

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/ws", any(ws_handler))
        .layer(TraceLayer::new_for_http());

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();
    
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await;
}

async fn ws_handler(
    ws: WebSocketUpgrade,
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
    println!("Client {addr} connected");
    ws.on_upgrade(move |socket| handle_socket(socket, addr))
}

async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
    // Send a ping to kick things off
    if socket
        .send(Message::Ping(Bytes::from_static(&[1, 2, 3])))
        .await
        .is_ok()
    {
        println!("Pinged {who}...");
    } else {
        return;
    }

    // Split the socket for concurrent operations
    let (mut sender, mut receiver) = socket.split();

    // Spawn a task to send periodic messages
    let mut send_task = tokio::spawn(async move {
        for i in 0..10 {
            if sender
                .send(Message::Text(format!("Server message {i}").into()))
                .await
                .is_err()
            {
                return;
            }
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
        }
    });

    // Receive and echo messages
    let mut recv_task = tokio::spawn(async move {
        while let Some(Ok(msg)) = receiver.next().await {
            match msg {
                Message::Text(text) => println!("Received from {who}: {text}"),
                Message::Close(_) => break,
                _ => {}
            }
        }
    });

    // Wait for either task to complete
    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    }

    println!("WebSocket connection {who} closed");
}
Always handle errors when sending or receiving messages, as clients can disconnect at any time without notice.