Axum provides full WebSocket support through the ws extractor, enabling real-time bidirectional communication between clients and servers.
Basic WebSocket setup
Extract the WebSocket upgrade
Use the WebSocketUpgrade extractor in your handler to initiate the WebSocket handshake: use axum :: {
extract :: ws :: { WebSocketUpgrade , WebSocket },
response :: Response ,
routing :: any,
Router ,
};
async fn ws_handler ( ws : WebSocketUpgrade ) -> Response {
ws . on_upgrade ( handle_socket )
}
let app = Router :: new () . route ( "/ws" , any ( ws_handler ));
Use any() instead of get() to support both HTTP/1.1 (which uses GET) and HTTP/2+ (which uses CONNECT) for WebSocket upgrades.
Handle the WebSocket connection
Implement the socket handler to send and receive messages: async fn handle_socket ( mut socket : WebSocket ) {
while let Some ( msg ) = socket . recv () . await {
let msg = if let Ok ( msg ) = msg {
msg
} else {
// client disconnected
return ;
};
if socket . send ( msg ) . await . is_err () {
// client disconnected
return ;
}
}
}
Start the server
Run your application with WebSocket support: let listener = tokio :: net :: TcpListener :: bind ( "127.0.0.1:3000" )
. await
. unwrap ();
axum :: serve ( listener , app ) . await ;
You can extract HTTP headers, client IP addresses, and other metadata during the upgrade:
use axum :: {
extract :: {
ws :: { WebSocketUpgrade , WebSocket },
ConnectInfo ,
},
response :: IntoResponse ,
};
use axum_extra :: TypedHeader ;
use std :: net :: SocketAddr ;
async fn ws_handler (
ws : WebSocketUpgrade ,
user_agent : Option < TypedHeader < headers :: UserAgent >>,
ConnectInfo ( addr ) : ConnectInfo < SocketAddr >,
) -> impl IntoResponse {
let user_agent = if let Some ( TypedHeader ( user_agent )) = user_agent {
user_agent . to_string ()
} else {
String :: from ( "Unknown browser" )
};
println! ( "`{user_agent}` at {addr} connected." );
ws . on_upgrade ( move | socket | handle_socket ( socket , addr ))
}
async fn handle_socket ( socket : WebSocket , who : SocketAddr ) {
// Handle the connection with access to client info
}
Remember to use into_make_service_with_connect_info::<SocketAddr>() when serving to enable ConnectInfo extraction.
Message types
WebSocket messages come in several types:
Text messages
Binary messages
Ping/Pong
Close frames
use axum :: extract :: ws :: Message ;
// Send text
let msg = Message :: Text ( "Hello, client!" . into ());
socket . send ( msg ) . await . unwrap ();
// Receive text
if let Message :: Text ( text ) = msg {
println! ( "Received: {text}" );
}
Concurrent send and receive
Split the socket to send and receive simultaneously:
use futures_util :: { sink :: SinkExt , stream :: StreamExt };
async fn handle_socket ( socket : WebSocket ) {
let ( mut sender , mut receiver ) = socket . split ();
// Spawn a task to send messages
let mut send_task = tokio :: spawn ( async move {
for i in 0 .. 10 {
if sender
. send ( Message :: Text ( format! ( "Message {i}" ) . into ()))
. await
. is_err ()
{
return ;
}
tokio :: time :: sleep ( std :: time :: Duration :: from_millis ( 300 )) . await ;
}
});
// Receive messages in the current task
let mut recv_task = tokio :: spawn ( async move {
while let Some ( Ok ( msg )) = receiver . next () . await {
// Process incoming messages
println! ( "Received: {msg:?}" );
}
});
// Wait for either task to finish
tokio :: select! {
_ = ( & mut send_task ) => recv_task . abort (),
_ = ( & mut recv_task ) => send_task . abort (),
}
}
WebSocket protocols
Negotiate sub-protocols with clients:
async fn ws_handler ( ws : WebSocketUpgrade ) -> Response {
ws . protocols ([ "graphql-ws" , "graphql-transport-ws" ])
. on_upgrade ( | socket | async move {
// Check which protocol was selected
if let Some ( protocol ) = socket . protocol () {
println! ( "Using protocol: {}" , protocol . to_str () . unwrap ());
}
handle_socket ( socket ) . await ;
})
}
You can also check what protocols the client requested:
async fn ws_handler ( mut ws : WebSocketUpgrade ) -> Response {
for protocol in ws . requested_protocols () {
println! ( "Client requested: {}" , protocol . to_str () . unwrap ());
}
// Manually set the selected protocol
ws . set_selected_protocol ( HeaderValue :: from_static ( "custom-protocol" ));
ws . on_upgrade ( handle_socket )
}
Configuration options
Customize WebSocket behavior:
async fn ws_handler ( ws : WebSocketUpgrade ) -> Response {
ws
. max_message_size ( 10 * 1024 * 1024 ) // 10 MB
. max_frame_size ( 2 * 1024 * 1024 ) // 2 MB
. write_buffer_size ( 256 * 1024 ) // 256 KB
. on_upgrade ( handle_socket )
}
Available configuration options
max_message_size(usize) - Maximum message size (default: 64 MB)
max_frame_size(usize) - Maximum frame size (default: 16 MB)
read_buffer_size(usize) - Read buffer capacity (default: 128 KB)
write_buffer_size(usize) - Write buffer size (default: 128 KB)
max_write_buffer_size(usize) - Maximum write buffer size (default: unlimited)
accept_unmasked_frames(bool) - Accept unmasked frames from clients (default: false)
Error handling
Handle upgrade failures gracefully:
async fn ws_handler ( ws : WebSocketUpgrade ) -> Response {
ws . on_failed_upgrade ( | error | {
eprintln! ( "WebSocket upgrade failed: {error}" );
// Log to monitoring service, etc.
})
. on_upgrade ( handle_socket )
}
Full example
Here’s a complete chat server example:
use axum :: {
body :: Bytes ,
extract :: {
ws :: { Message , WebSocket , WebSocketUpgrade },
ConnectInfo ,
},
response :: IntoResponse ,
routing :: any,
Router ,
};
use futures_util :: { sink :: SinkExt , stream :: StreamExt };
use std :: { net :: SocketAddr , ops :: ControlFlow };
use tower_http :: trace :: TraceLayer ;
#[tokio :: main]
async fn main () {
let app = Router :: new ()
. route ( "/ws" , any ( ws_handler ))
. layer ( TraceLayer :: new_for_http ());
let listener = tokio :: net :: TcpListener :: bind ( "127.0.0.1:3000" )
. await
. unwrap ();
axum :: serve (
listener ,
app . into_make_service_with_connect_info :: < SocketAddr >(),
)
. await ;
}
async fn ws_handler (
ws : WebSocketUpgrade ,
ConnectInfo ( addr ) : ConnectInfo < SocketAddr >,
) -> impl IntoResponse {
println! ( "Client {addr} connected" );
ws . on_upgrade ( move | socket | handle_socket ( socket , addr ))
}
async fn handle_socket ( mut socket : WebSocket , who : SocketAddr ) {
// Send a ping to kick things off
if socket
. send ( Message :: Ping ( Bytes :: from_static ( & [ 1 , 2 , 3 ])))
. await
. is_ok ()
{
println! ( "Pinged {who}..." );
} else {
return ;
}
// Split the socket for concurrent operations
let ( mut sender , mut receiver ) = socket . split ();
// Spawn a task to send periodic messages
let mut send_task = tokio :: spawn ( async move {
for i in 0 .. 10 {
if sender
. send ( Message :: Text ( format! ( "Server message {i}" ) . into ()))
. await
. is_err ()
{
return ;
}
tokio :: time :: sleep ( std :: time :: Duration :: from_secs ( 1 )) . await ;
}
});
// Receive and echo messages
let mut recv_task = tokio :: spawn ( async move {
while let Some ( Ok ( msg )) = receiver . next () . await {
match msg {
Message :: Text ( text ) => println! ( "Received from {who}: {text}" ),
Message :: Close ( _ ) => break ,
_ => {}
}
}
});
// Wait for either task to complete
tokio :: select! {
_ = ( & mut send_task ) => recv_task . abort (),
_ = ( & mut recv_task ) => send_task . abort (),
}
println! ( "WebSocket connection {who} closed" );
}
Always handle errors when sending or receiving messages, as clients can disconnect at any time without notice.