Protohackers in Rust: Budget Chat
This is the fourth article of the series solving protohackers with Rust. The code in this series will be available on GitHub
# The Challenge
In the fourth challenge we have to implement a simple chat server. Each message is a single line terminated by a newline character \n
.
When a client connects, the server should send a welcome message asking for the client username. Once the client responds back with a valid username, they have officially joined the chat room.
When a new user joins the server must:
- Announce the presence of the new user to the other members of the chat.
- Send the list of all connected users to the joining user.
Upon receiving a chat message from a client, the server must relay it to the other users, prepending it with the name of the user that sent the message.
The server should also notify all users when a joined user is disconnected.
This is an example session with Alice as client:
--> Welcome to budgetchat! What shall I call you?
<-- alice
--> * The room contains: bob, charlie, dave
<-- Hello everyone
--> [bob] hi alice
--> [charlie] hello alice
--> * dave has left the room
The username should contain at least 1 character, and must consist entirely of alphanumeric characters (uppercase, lowercase, and digits).
For this challenge we should at least support 10 simultaneous clients.
# Getting Started
At this point our project structure looks like this:
❯ tree $PWD -I target
~/Coding/protohackers-rs
├── Cargo.lock
├── Cargo.toml
├── README.md
└── src
├── bin
│ ├── means_to_an_end.rs
│ ├── prime_time.rs
│ └── smoke_test.rs
└── lib.rs
For this challenge we are adding another binary budget_chat.rs
in src/bin
folder where we'll write our solution:
❯ touch src/bin/budget_chat.rs
We dealt with line based protocol in the second challenge and we could do the same here, but this time I'd like to try a higher-level solution just for fun and profit.
We will need this additional deps:
- tokio-util Utilities for
Tokio
- futures abstractions for asynchronous IO
- bytes Byte buffers APIs
- derive_more Additional
#[derive(..)]
cargo add tokio-util --features codec
cargo add futures
cargo add bytes
cargo add derive_more
The main function setup needs to change a bit for this challenge. In the previous challenges each client didn't need to interact or share a state with other connected clients. In this challenge we would probably need to share some kind of state between clients for keeping tracks of the connected clients at least.
We'll create another function run_server_with_state
, similar to run_server
used in the previous challenges, but with an additional parameter representing our shared state, which then it will be passed to the handle_client
function:
pub async fn run_server_with_state<H, S, F>(port: u16, state: S, handler: H) -> anyhow::Result<()>
where
S: Clone,
H: Fn(S, TcpStream, SocketAddr) -> F,
F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
let listener = TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;
info!("Starting server at 0.0.0.0:{}", port);
loop {
let (socket, address) = listener.accept().await?;
debug!("Got connection from {}", address);
let future = handler(state.clone(), socket, address);
tokio::task::spawn(async move {
if let Err(err) = future.await {
error!("Error handling connection {}: {}", address, err);
}
});
}
}
The new parameter is called state
and it's a generic S
. Since we don't want to deal with lifetimes and references and we want to pass it around easily, we say that the state should be S: Clone
.
For keeping the code clean, I've refactored also the original run_server
function without breaking the previous challenges by using the above run_server_with_state
function.
pub async fn run_server<H, F>(port: u16, handler: H) -> anyhow::Result<()>
where
H: Fn(TcpStream) -> F,
F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
run_server_with_state(port, (), |_, stream, _| handler(stream)).await
}
At this point we are still not sure, what's to put as global state. For now we can just setup the budget_chat.rs binary like this:
use protohackers_rs::run_server_with_state;
use tokio::net::TcpStream;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
run_server_with_state(8000, (), handle_client).await?;
Ok(())
}
async fn handle_client(
state: (),
mut stream: TcpStream,
address: SocketAddr,
) -> anyhow::Result<()> {
todo()!
}
#[cfg(test)]
mod budget_chat_tests {}
# Budget Chat Implementation
The first thing that we need to do after a client connects is the handshake phase. The server must send a line asking for the client username and then it should wait for the response. If the name is valid the client can join the chat room.
I'd like to apply some type driven development here for modeling the username.
Let's define a type Username
:
pub struct Username(String);
This is a new-type pattern. It's useful for implementing external traits on external types and also for enforcing type safety and abstraction.
In our case we want to be sure that, when an instance of Username
is created, it contains only a valid username.
Let's create a parse
method from String
that validates the input string and if the validations are Ok, it returns the Username
:
impl Username {
pub fn parse(input: String) -> anyhow::Result<Username> {
if input.is_empty() {
anyhow::bail!("Name should be at least 1 character")
}
if input.chars().any(|c| !c.is_alphanumeric()) {
anyhow::bail!("Name should contains only alphanumeric characters")
}
Ok(Username(input))
}
}
For reading and writing messages (lines) we are taking another strategy for this challenge.
At high level we would like the incoming messages to be of type string, representing a single line, meanwhile for outgoing messages, even though they are still strings (lines), we would like to encode them in a type that better represent all possible outgoing messages.
By reading the specs we can identify six types of messages:
- Welcome
- InvalidUsername
- Chat message
- User join
- User leave
- List users
Let's define an enum type OutgoingMessage
:
pub enum OutgoingMessage {
Welcome,
Join(Username),
Leave(Username),
Chat { from: Username, msg: String },
InvalidUsername(String),
Participants(Vec<Username>),
}
Ideally we would like to use this type when working at high level, and then somehow encode it in its string representation before sending it to the client.
# Framing
In the previous challenges for reading and writing messages, we used IO traits {AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}
but they only speak raw bytes. We should find a way to write an intermediate layer that converts those raw bytes in high level types. This process is called framing, which converts byte stream into a unit of data transmitted between peers called frame.
In our case we want:
- String for input
- OutgoingMessage for output
Framing can be done manually as documented here, but for this challenge we'll be using tokio-util
helpers that.
The entry point for doing framing with tokio-util
is Framed. which provides a unified interface for Stream and Sink.
A Stream
is similar to an Iterator
but for the asynchronous world. Represent a source of values that can be produced over time.
A Sink
it's the exact opposite for sending values over time.
To create a Framed
instance we have to provide the TcpStream
(or in general any AsyncRead + AsyncWrite
object), and a codec as parameters:
async fn handle_client(state: (), stream: TcpStream, address: SocketAddr) -> anyhow::Result<()> {
let framed = Framed::new(stream, codec);
todo!()
}
The codec
should be an implementation of Decoder and Encoder for reading and writing raw data.
The Framed
is our intermediate framing layer. It will read/write raw bytes from/to the underlying AsyncRead + AsyncWrite
, converting those bytes in frames using the codec
provided, and it provides high level APIs by implementing Stream
and Sink
traits.
How does a codec look like?
The tokio-util
provides the LineCodec useful when writing line-based protocols. This is half fine for us. Our goal is to have String
(line) as input and OutgoingMessage
as output.
For that we could easily wrap the LineCodec
and write our own implementation of the codec that uses the type OutgoingMessage
as output:
pub struct ChatCodec {
lines: LinesCodec,
}
impl ChatCodec {
pub fn new() -> Self {
Self {
lines: LinesCodec::new(),
}
}
}
impl Encoder<OutgoingMessage> for ChatCodec {
type Error = anyhow::Error;
fn encode(
&mut self,
item: OutgoingMessage,
dst: &mut bytes::BytesMut,
) -> Result<(), Self::Error> {
self.lines
.encode(item.to_string(), dst)
.map_err(anyhow::Error::from)
}
}
impl Decoder for ChatCodec {
type Item = String;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.lines.decode(src).map_err(anyhow::Error::from)
}
}
For the decoder we simply delegate to LinesCodec::decode
, while for the Encoder
we first transform the OutgoingMessage
to a string and then call LinesCodec::encode
.
For converting the OutgoingMessage
to a string I decided to use derive_more:
#[derive(derive_more::Display, Clone)]
pub enum OutgoingMessage {
#[display(fmt = "Welcome to budgetchat! What shall I call you?")]
Welcome,
#[display(fmt = "* {} has entered the room", _0)]
Join(Username),
#[display(fmt = "* {} has left the room", _0)]
Leave(Username),
#[display(fmt = "[{}] {}", from, msg)]
Chat { from: Username, msg: String },
#[display(fmt = "Invalid username {}", _0)]
InvalidUsername(String),
#[display(fmt = "* The room contains: {}", "self.participants(_0)")]
Participants(Vec<Username>),
}
impl OutgoingMessage {
// helper function users formatting
fn participants(&self, participants: &[Username>]) -> String {
participants
.iter()
.map(|user| user.to_string())
.collect::<Vec<String>>()
.join(", ")
}
}
Let's plug the ChatCodec
in handle_client
:
async fn handle_client(state: (), stream: TcpStream, address: SocketAddr) -> anyhow::Result<()> {
let (input_stream, output_stream) = Framed::new(stream, ChatCodec::new()).split();
handle_client_internal(state, address, input_stream, output_stream).await
}
async fn handle_client_internal<I, O>(
state: (),
address: SocketAddr,
mut sink: O,
mut stream: I,
) -> anyhow::Result<()>
where
I: Stream<Item = anyhow::Result<String>> + Unpin,
O: Sink<OutgoingMessage, Error = anyhow::Error> + Unpin,
{
todo!()
}
We call split on Framed
for having a separated Sink
and Stream
. The Stream
will emit values of Result<String,_>
(lines) and the Sink
will allows us to send values of type OutgoingMessage
.
Splitting the original Framed
will come in handy when testing the handle_client_internal
function.
# Handshake
With the framing layer setup done in the form of Sink
and Stream
, we can easily handle the handshake phase like this:
// send the welcome message
sink.send(OutgoingMessage::Welcome).await?;
// receive the username
let username = stream
.try_next()
.await?
.ok_or_else(|| anyhow::anyhow!("Error while waiting for the username"))?;
// check if the username is valid
let username = match Username::parse(username) {
Ok(username) => username,
Err(e) => {
// send the error and return for closing the connection
sink.send(OutgoingMessage::InvalidUsername(e.to_string()))
.await?;
return Ok(());
}
};
The stuff that we did for the framing layer setup with Framed
and ChatCoded
paid off. We were able to abstract away the bytes handling in favor of String
and OutgoingMessage
.
# Modeling the state
Once the username has been validated, the client can officially join the chat room.
For that we need to model how the server keeps track of the joined clients and how they talk to each other. This state will be available in all the connected clients in the handle_client
function.
Let's create a type Room
as follow:
#[derive(Clone)]
pub struct Room(Arc<Mutex<HashMap<SocketAddr, User>>>);
pub struct User {
username: Username,
sender: mpsc::UnboundedSender<OutgoingMessage>,
}
impl Room {
pub fn new() -> Room {
Room(Arc::new(Mutex::new(HashMap::new())))
}
}
The Room
contains all the users connected, identified by their SocketAddr
. The type User
is composed by two fields:
username
: The username of the connected client.sender
: It's the sender part of unbounded mpsc channel
Whenever we want to send a message to a client we can just use sender.send(msg)
method.
The Room
is a shared state and we need to dynamically add/remove users. For doing that we need to wrap the HashMap
with an Arc<Mutex<>>
for concurrent access.
The Room
should provide at least one public method for adding a user to the chat room.
A first implementation looks like this:
impl Room {
// join a new user with addr and username
pub async fn join(&self, addr: SocketAddr, username: Username) -> UserHandle {
// Create the communication channel
let (sender, receiver) = mpsc::unbounded_channel();
// Acquire the lock on the HashMap
let mut users = self.0.lock().await;
// insert the new user
users.insert(
addr,
User {
username: username.clone(),
sender,
},
);
// returns a UserHandle
UserHandle {
username,
receiver,
room: self.clone(),
address: addr,
}
}
}
pub struct UserHandle {
username: Username,
address: SocketAddr,
receiver: mpsc::UnboundedReceiver<OutgoingMessage>,
room: Room,
}
In the join
method we create a communication channel for the joining user, store it in the HashMap
and return an instance of the type UserHandle
.
The UserHandle
contains this fields:
- receiver: for receiving message from the other participants.
- username: Its username.
- address: Its
SocketAddr
. - room: The joined
Room
.
In the UserHandle
we will encapsulate all the actions that a user can do in a Room
. A UserHandle
should be able to:
- send messages to the other users
- leave the room
impl UserHandle {
pub async fn send_message(&self, msg: String) {
self.room
.broadcast(
&self.address,
OutgoingMessage::Chat {
from: self.username.clone(),
msg,
},
)
.await;
}
pub async fn leave(self) {
self.room.leave(&self.address).await
}
}
The UserHandle
will just delegate the work to the Room
by passing its identifier (SocketAddr
):
impl Room {
// method for broadcasting a message to the users excluding the one
// identified by the addr
async fn broadcast(&self, addr: &SocketAddr, msg: OutgoingMessage) {
let mut users = self.0.lock().await;
self.broadcast_internal(addr, msg, &mut users).await;
}
// method for handling the leave of a user
async fn leave(&self, addr: &SocketAddr) {
let mut users = self.0.lock().await;
if let Some(leaving) = users.remove(addr) {
// notify the users
self.broadcast_internal(addr, OutgoingMessage::Leave(leaving.username), &mut users)
.await;
}
}
// internal broadcast implementation that actually delivers the message to each user.
async fn broadcast_internal(
&self,
addr: &SocketAddr,
msg: OutgoingMessage,
users: &mut HashMap<SocketAddr, User>,
) {
for (user_addr, user) in users.iter_mut() {
if addr != user_addr {
let _ = user.sender.send(msg.clone());
}
}
}
}
The broadcast
method sends the same OutgoingMessage
to all users, filtering out the input SocketAddr
.
The leave
method removes the user identified by the input SocketAddr
from the HashMap
and then broadcast a OutgoingMessage::Leave
message.
To complete the Room
implementation we are still missing some bits.
When users join a room they should receive a notification with all the room members, and the other users should receive the notification of the newly joined member.
For that we change a little bit the Room::join
method:
impl Room {
pub async fn join(&self, addr: SocketAddr, username: Username) -> UserHandle {
let mut users = self.0.lock().await;
let (sender, receiver) = mpsc::unbounded_channel();
let names = users
.iter()
.map(|(_, user)| user.username.clone())
.collect::<Vec<Username>>();
// send the users list notification
let _ = sender.send(OutgoingMessage::Participants(names));
// broadcast the join message to the other users
self.broadcast_internal(&addr, OutgoingMessage::Join(username.clone()), &mut users)
.await;
users.insert(
addr,
User {
username: username.clone(),
sender,
},
);
UserHandle {
username,
receiver,
room: self.clone(),
address: addr,
}
}
}
Before adding the user to the chat room we send a notification on its channel with the participants list and then we broadcast to the others the join message.
All the logic for the Room
state should be completed and ready for testing now.
I will not describe here the room_test
, as I would like to focus on testing the handle_client_internal
function once the implementation is done, but the code is available here.
# Wiring all together
So far we have done the handshake phase and the Room
API for managing users and broadcasting messages in the server.
Let's dive back in the handle_client_internal
function and wire all together.
After the handshake phase we should call the Room::join
by passing the SocketAddr
of the connected client and the received username. Once a user have joined the room we need to process in a loop the incoming messages. There are two possible source of messages at this point. One coming from the external Stream
(Client) and the message in broadcast from other users.
If a message is coming from the external Stream
(String), it's a chat message coming from the connected client and we broadcast to the others it by calling UserHandle::send_message
.
If a message is coming from the receiver
of a UserHandle
(OutgoingMessage), it's a broadcast message from another user and we just send it to the Sink
of the connected client.
For processing multiple asynchronous sources we'll use tokio::select!.
The final handle_client_internal
looks like this:
async fn handle_client_internal<I, O>(
state: Room,
address: SocketAddr,
mut sink: O,
mut stream: I,
) -> anyhow::Result<()>
where
I: Stream<Item = anyhow::Result<String>> + Unpin,
O: Sink<OutgoingMessage, Error = anyhow::Error> + Unpin,
{
sink.send(OutgoingMessage::Welcome).await?;
let username = stream
.try_next()
.await?
.ok_or_else(|| anyhow::anyhow!("Error while waiting for the username"))?;
let username = match Username::parse(username) {
Ok(username) => username,
Err(e) => {
sink.send(OutgoingMessage::InvalidUsername(e.to_string()))
.await?;
return Ok(());
}
};
// join the room
let mut handle = state.join(address, username).await;
// process messages until there is an error or the stream is exhausted
loop {
tokio::select! {
// Message received from a user
Some(msg) = handle.receiver.recv() => {
// Send it to the connected client
if let Err(e) = sink.send(msg).await {
error!("Error sending message {}",e);
break;
}
}
// Message received from the connected client
result = stream.next() => match result {
Some(Ok(msg)) => {
// broadcast it
handle.send_message(msg).await;
}
// break on error
Some(Err(e)) => {
error!("Error reading messages {}",e);
break;
}
// break on stream exhausted
None => break,
}
};
}
// leave the room
handle.leave().await;
Ok(())
}
# Testing the solution
Testing the handle_client_internal
requires more setup. First we need to add async-stream as dev dependency for creating ad-hoc streams.
❯ cargo add --dev async-stream
The handle_client_internal
function takes as input:
- The shared state
Room
- The
SocketAddr
of the client - The
Sink
for sending messages (client -> server) - The
Stream
for receiving message (server -> client)
We just have to find a way to provide an implementation for Sink
and Stream
that we can use then for testing. tokio::mpsc::channel
seems a good candidate for that.
Let's write a fixture function connect
for simulating a client connection:
async fn connect(room: Room, addr: &str) -> UserTest {
// channel for sending stuff server -> client
let (sink_tx, sink_rx) = mpsc::channel(100);
// channel for sending stuff client -> server
let (stream_tx, mut stream_rx) = mpsc::channel(100);
let address: SocketAddr = addr.parse().unwrap();
// convert the stream receiver in a stream
let stream = async_stream::stream! {
while let Some(message) = stream_rx.recv().await {
yield message
}
};
let handle = tokio::spawn(async move {
handle_client_internal(
room,
address,
PollSender::new(sink_tx).sink_map_err(anyhow::Error::from),
Box::pin(stream),
)
.await
});
UserTest {
sink_receiver: sink_rx,
stream_sender: Some(stream_tx),
handle,
}
}
struct UserTest {
sink_receiver: Receiver<OutgoingMessage>,
stream_sender: Option<Sender<anyhow::Result<String>>>,
handle: JoinHandle<anyhow::Result<()>>,
}
impl UserTest {
async fn send(&mut self, message: &str) {
self.stream_sender
.as_ref()
.unwrap()
.send(Ok(message.to_string()))
.await
.unwrap();
}
async fn leave(mut self) {
let stream = self.stream_sender.take();
drop(stream);
self.handle.await.unwrap().unwrap()
}
async fn check_message(&mut self, msg: OutgoingMessage) {
assert_eq!(self.sink_receiver.recv().await.unwrap(), msg);
}
}
In the connect
function we create two channels:
- sink channel (server -> client)
- stream channel (client -> server)
and we spawn a background task for handling the single client by calling the handle_client_internal
function with those two channels Sink
and Stream
as input.
We have to do some conversion magic in the middle. The tokio::mpsc::Receiver
doesn't implement Stream
, but fortunately we can use async_stream::stream!
for converting the channel into a Stream
.
The tokio::mpsc::Sender
doesn't implement Sink
and we wrap the sender into a PollSender
from tokio-util
which makes our sender compatible with Sink
trait.
The connect
function returns a UserTest
type, which we can use in our tests for driving a single user in a chat session scenario.
It has the following fields:
- sink_receiver: for messages (server -> client)
- stream_sender : for messages (client -> server)
I've created a couple of additional methods on UserTest
for making the tests more readable:
- send: send a message to the server
- check_message: check if the next message matches
msg
- leave: simulate the disconnection by dropping the
stream_sender
and waiting for the backgroundhandle_client_internal
task to complete.
Now we can use this API in a chat test scenario with two participants:
- alice: "0.0.0.0:10"
- bob: "0.0.0.0:11"
{server}-->{alice} Welcome to budgetchat! What shall I call you?
{server}<--{alice} alice
{server}-->{alice} * The room contains:
{server}-->{bob} Welcome to budgetchat! What shall I call you?
{server}-->{bob} * The room contains: alice
{server}-->{alice} * bob has entered the room
{server}<--{alice} Hi bob!
{server}-->{bob} Hi bob!
{server}<--{bob} Hi alice!
{server}-->{alice} Hi alice!
{server}-->{alice} * bob has left the room
The code for testing the above session is:
#[tokio::test]
async fn example_session_test() {
let room = Room::new();
let alice_username = Username::parse("alice".to_string()).unwrap();
let bob_username = Username::parse("bob".to_string()).unwrap();
// alice connects
let mut alice = connect(room.clone(), "0.0.0.0:10").await;
alice.check_message(OutgoingMessage::Welcome).await;
// alice sends the username and get the participants list
alice.send(&alice_username.inner_ref()).await;
alice
.check_message(OutgoingMessage::Participants(vec![]))
.await;
// bob connects
let mut bob = connect(room.clone(), "0.0.0.0:11").await;
bob.check_message(OutgoingMessage::Welcome).await;
// bob sends the username and get the participants list
bob.send(&bob_username.inner_ref()).await;
bob.check_message(OutgoingMessage::Participants(vec![alice_username.clone()]))
.await;
// alice gets the notification of bob joining the room
alice
.check_message(OutgoingMessage::Join(bob_username.clone()))
.await;
// alice sends a message
alice.send("Hi bob!").await;
// bob gets alice's message
bob.check_message(OutgoingMessage::Chat {
msg: "Hi bob!".to_string(),
from: alice_username.clone(),
})
.await;
// bob sends a message
bob.send("Hi alice!").await;
// alice gets bob's message
alice
.check_message(OutgoingMessage::Chat {
msg: "Hi alice!".to_string(),
from: bob_username.clone(),
})
.await;
// bob leaves the room
bob.leave().await;
// alice gets the notification of bob leaving the room
alice
.check_message(OutgoingMessage::Leave(bob_username))
.await;
}
Let's run it:
❯ cargo test example_session_test -- --nocapture
Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
Finished test [unoptimized + debuginfo] target(s) in 1.27s
Running unittests src/bin/budget_chat.rs (target/debug/deps/budget_chat-9feb9b3a82ded77d)
running 1 test
test budget_chat_tests::example_session_test ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 1 filtered out; finished in 0.00s
... and it passed!!
# Wrapping Up
Wow! this post turned out longer than expected, but we did a lot of stuff for this challenge. Even though probably the solution it's over-engineered, and we could have solved it by using AsyncRead
& AsyncWrite
primitives with their extensions, we explored a bit concepts like Sink
and Stream
that abstract away all the underlying network layer. We also have learned a bit about framing and how to use Framed
, Encoder
and Decoder
for converting low level stream of bytes into high level frames.
Comments and suggestions are always welcome. Please feel free to send an email or comment on @wolf4ood@hachyderm.io.
Thank you for reading, see you in the next chapter!