Protohackers in Rust: Budget Chat

Enrico Risa |
|
19 min |
3781 words

This is the fourth article of the series solving protohackers with Rust. The code in this series will be available on GitHub

# The Challenge

In the fourth challenge we have to implement a simple chat server. Each message is a single line terminated by a newline character \n.

When a client connects, the server should send a welcome message asking for the client username. Once the client responds back with a valid username, they have officially joined the chat room.

When a new user joins the server must:

  • Announce the presence of the new user to the other members of the chat.
  • Send the list of all connected users to the joining user.

Upon receiving a chat message from a client, the server must relay it to the other users, prepending it with the name of the user that sent the message.

The server should also notify all users when a joined user is disconnected.

This is an example session with Alice as client:

--> Welcome to budgetchat! What shall I call you?
<-- alice
--> * The room contains: bob, charlie, dave
<-- Hello everyone
--> [bob] hi alice
--> [charlie] hello alice
--> * dave has left the room

The username should contain at least 1 character, and must consist entirely of alphanumeric characters (uppercase, lowercase, and digits).

For this challenge we should at least support 10 simultaneous clients.

# Getting Started

At this point our project structure looks like this:

 tree $PWD -I target
~/Coding/protohackers-rs
├── Cargo.lock
├── Cargo.toml
├── README.md
└── src
    ├── bin
       ├── means_to_an_end.rs
       ├── prime_time.rs
       └── smoke_test.rs
    └── lib.rs

For this challenge we are adding another binary budget_chat.rs in src/bin folder where we'll write our solution:

 touch src/bin/budget_chat.rs

We dealt with line based protocol in the second challenge and we could do the same here, but this time I'd like to try a higher-level solution just for fun and profit.

We will need this additional deps:

cargo add tokio-util --features codec
cargo add futures
cargo add bytes
cargo add derive_more

The main function setup needs to change a bit for this challenge. In the previous challenges each client didn't need to interact or share a state with other connected clients. In this challenge we would probably need to share some kind of state between clients for keeping tracks of the connected clients at least.

We'll create another function run_server_with_state, similar to run_server used in the previous challenges, but with an additional parameter representing our shared state, which then it will be passed to the handle_client function:

pub async fn run_server_with_state<H, S, F>(port: u16, state: S, handler: H) -> anyhow::Result<()>
where
    S: Clone,
    H: Fn(S, TcpStream, SocketAddr) -> F,
    F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
    let listener = TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;

    info!("Starting server at 0.0.0.0:{}", port);
    loop {
        let (socket, address) = listener.accept().await?;

        debug!("Got connection from {}", address);
        let future = handler(state.clone(), socket, address);
        tokio::task::spawn(async move {
            if let Err(err) = future.await {
                error!("Error handling connection {}: {}", address, err);
            }
        });
    }
}

The new parameter is called state and it's a generic S. Since we don't want to deal with lifetimes and references and we want to pass it around easily, we say that the state should be S: Clone.

For keeping the code clean, I've refactored also the original run_server function without breaking the previous challenges by using the above run_server_with_state function.

pub async fn run_server<H, F>(port: u16, handler: H) -> anyhow::Result<()>
where
    H: Fn(TcpStream) -> F,
    F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
    run_server_with_state(port, (), |_, stream, _| handler(stream)).await
}

At this point we are still not sure, what's to put as global state. For now we can just setup the budget_chat.rs binary like this:

use protohackers_rs::run_server_with_state;
use tokio::net::TcpStream;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt::init();
    run_server_with_state(8000, (), handle_client).await?;
    Ok(())
}

async fn handle_client(
    state: (),
    mut stream: TcpStream,
    address: SocketAddr,
) -> anyhow::Result<()> {
    todo()!
}

#[cfg(test)]
mod budget_chat_tests {}

# Budget Chat Implementation

The first thing that we need to do after a client connects is the handshake phase. The server must send a line asking for the client username and then it should wait for the response. If the name is valid the client can join the chat room.

I'd like to apply some type driven development here for modeling the username.

Let's define a type Username:

pub struct Username(String);

This is a new-type pattern. It's useful for implementing external traits on external types and also for enforcing type safety and abstraction.

In our case we want to be sure that, when an instance of Username is created, it contains only a valid username.

Let's create a parse method from String that validates the input string and if the validations are Ok, it returns the Username:

impl Username {
    pub fn parse(input: String) -> anyhow::Result<Username> {
        if input.is_empty() {
            anyhow::bail!("Name should be at least 1 character")
        }
        if input.chars().any(|c| !c.is_alphanumeric()) {
            anyhow::bail!("Name should contains only alphanumeric characters")
        }
        Ok(Username(input))
    }
}

For reading and writing messages (lines) we are taking another strategy for this challenge.

At high level we would like the incoming messages to be of type string, representing a single line, meanwhile for outgoing messages, even though they are still strings (lines), we would like to encode them in a type that better represent all possible outgoing messages.

By reading the specs we can identify six types of messages:

  • Welcome
  • InvalidUsername
  • Chat message
  • User join
  • User leave
  • List users

Let's define an enum type OutgoingMessage:

pub enum OutgoingMessage {
    Welcome,
    Join(Username),
    Leave(Username),
    Chat { from: Username, msg: String },
    InvalidUsername(String),
    Participants(Vec<Username>),
}

Ideally we would like to use this type when working at high level, and then somehow encode it in its string representation before sending it to the client.

# Framing

In the previous challenges for reading and writing messages, we used IO traits {AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt} but they only speak raw bytes. We should find a way to write an intermediate layer that converts those raw bytes in high level types. This process is called framing, which converts byte stream into a unit of data transmitted between peers called frame.

In our case we want:

  • String for input
  • OutgoingMessage for output

Framing can be done manually as documented here, but for this challenge we'll be using tokio-util helpers that.

The entry point for doing framing with tokio-util is Framed. which provides a unified interface for Stream and Sink.

A Stream is similar to an Iterator but for the asynchronous world. Represent a source of values that can be produced over time.

A Sink it's the exact opposite for sending values over time.

To create a Framed instance we have to provide the TcpStream (or in general any AsyncRead + AsyncWrite object), and a codec as parameters:

async fn handle_client(state: (), stream: TcpStream, address: SocketAddr) -> anyhow::Result<()> {
    let framed = Framed::new(stream, codec);
    
    todo!()
}

The codec should be an implementation of Decoder and Encoder for reading and writing raw data.

The Framed is our intermediate framing layer. It will read/write raw bytes from/to the underlying AsyncRead + AsyncWrite, converting those bytes in frames using the codec provided, and it provides high level APIs by implementing Stream and Sink traits.

How does a codec look like?

The tokio-util provides the LineCodec useful when writing line-based protocols. This is half fine for us. Our goal is to have String(line) as input and OutgoingMessage as output.

For that we could easily wrap the LineCodec and write our own implementation of the codec that uses the type OutgoingMessage as output:

pub struct ChatCodec {
    lines: LinesCodec,
}

impl ChatCodec {
    pub fn new() -> Self {
        Self {
            lines: LinesCodec::new(),
        }
    }
}

impl Encoder<OutgoingMessage> for ChatCodec {
    type Error = anyhow::Error;

    fn encode(
        &mut self,
        item: OutgoingMessage,
        dst: &mut bytes::BytesMut,
    ) -> Result<(), Self::Error> {
        self.lines
            .encode(item.to_string(), dst)
            .map_err(anyhow::Error::from)
    }
}
impl Decoder for ChatCodec {
    type Item = String;

    type Error = anyhow::Error;

    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        self.lines.decode(src).map_err(anyhow::Error::from)
    }
}

For the decoder we simply delegate to LinesCodec::decode, while for the Encoder we first transform the OutgoingMessage to a string and then call LinesCodec::encode.

For converting the OutgoingMessage to a string I decided to use derive_more:

#[derive(derive_more::Display, Clone)]
pub enum OutgoingMessage {
    #[display(fmt = "Welcome to budgetchat! What shall I call you?")]
    Welcome,
    
    #[display(fmt = "* {} has entered the room", _0)]
    Join(Username),
    
    #[display(fmt = "* {} has left the room", _0)]
    Leave(Username),
    
    #[display(fmt = "[{}] {}", from, msg)]
    Chat { from: Username, msg: String },
    
    #[display(fmt = "Invalid username {}", _0)]
    InvalidUsername(String),

    #[display(fmt = "* The room contains: {}", "self.participants(_0)")]
    Participants(Vec<Username>),
}

impl OutgoingMessage {

    // helper function users formatting
    fn participants(&self, participants: &[Username>]) -> String {
        participants
            .iter()
            .map(|user| user.to_string())
            .collect::<Vec<String>>()
            .join(", ")
    }
}

Let's plug the ChatCodec in handle_client:

async fn handle_client(state: (), stream: TcpStream, address: SocketAddr) -> anyhow::Result<()> {
    let (input_stream, output_stream) = Framed::new(stream, ChatCodec::new()).split();

    handle_client_internal(state, address, input_stream, output_stream).await
}

async fn handle_client_internal<I, O>(
    state: (),
    address: SocketAddr,
    mut sink: O,
    mut stream: I,
) -> anyhow::Result<()>
where
    I: Stream<Item = anyhow::Result<String>> + Unpin,
    O: Sink<OutgoingMessage, Error = anyhow::Error> + Unpin,
{
    todo!()
}

We call split on Framed for having a separated Sink and Stream. The Stream will emit values of Result<String,_>(lines) and the Sink will allows us to send values of type OutgoingMessage.

Splitting the original Framed will come in handy when testing the handle_client_internal function.

# Handshake

With the framing layer setup done in the form of Sink and Stream, we can easily handle the handshake phase like this:

// send the welcome message
sink.send(OutgoingMessage::Welcome).await?;

// receive the username
let username = stream
    .try_next()
    .await?
    .ok_or_else(|| anyhow::anyhow!("Error while waiting for the username"))?;

// check if the username is valid
let username = match Username::parse(username) {
    Ok(username) => username,
    Err(e) => {
         // send the error and return for closing the connection
        sink.send(OutgoingMessage::InvalidUsername(e.to_string()))
            .await?;
        return Ok(());
    }
};

The stuff that we did for the framing layer setup with Framed and ChatCoded paid off. We were able to abstract away the bytes handling in favor of String and OutgoingMessage.

# Modeling the state

Once the username has been validated, the client can officially join the chat room.

For that we need to model how the server keeps track of the joined clients and how they talk to each other. This state will be available in all the connected clients in the handle_client function.

Let's create a type Room as follow:

#[derive(Clone)]
pub struct Room(Arc<Mutex<HashMap<SocketAddr, User>>>);

pub struct User {
    username: Username,
    sender: mpsc::UnboundedSender<OutgoingMessage>,
}


impl Room {
    pub fn new() -> Room {
        Room(Arc::new(Mutex::new(HashMap::new())))
    }
}

The Room contains all the users connected, identified by their SocketAddr. The type User is composed by two fields:

  • username: The username of the connected client.
  • sender: It's the sender part of unbounded mpsc channel

Whenever we want to send a message to a client we can just use sender.send(msg) method.

The Room is a shared state and we need to dynamically add/remove users. For doing that we need to wrap the HashMap with an Arc<Mutex<>> for concurrent access.

The Room should provide at least one public method for adding a user to the chat room.

A first implementation looks like this:

impl Room {

    // join a new user with addr and username
    pub async fn join(&self, addr: SocketAddr, username: Username) -> UserHandle {
        // Create the communication channel
        let (sender, receiver) = mpsc::unbounded_channel();
        
        // Acquire the lock on the HashMap
        let mut users = self.0.lock().await;

        // insert the new user 
        users.insert(
            addr,
            User {
                username: username.clone(),
                sender,
            },
        );

        // returns a UserHandle 
        UserHandle {
            username,
            receiver,
            room: self.clone(),
            address: addr,
        }
    }
}

pub struct UserHandle {
    username: Username,
    address: SocketAddr,
    receiver: mpsc::UnboundedReceiver<OutgoingMessage>,
    room: Room,
}

In the join method we create a communication channel for the joining user, store it in the HashMap and return an instance of the type UserHandle.

The UserHandle contains this fields:

  • receiver: for receiving message from the other participants.
  • username: Its username.
  • address: Its SocketAddr.
  • room: The joined Room.

In the UserHandle we will encapsulate all the actions that a user can do in a Room. A UserHandle should be able to:

  • send messages to the other users
  • leave the room
impl UserHandle {
    pub async fn send_message(&self, msg: String) {
        self.room
            .broadcast(
                &self.address,
                OutgoingMessage::Chat {
                    from: self.username.clone(),
                    msg,
                },
            )
            .await;
    }

    pub async fn leave(self) {
        self.room.leave(&self.address).await
    }
}

The UserHandle will just delegate the work to the Room by passing its identifier (SocketAddr):

impl Room {

    // method for broadcasting a message to the users excluding the one
    // identified by the addr
    async fn broadcast(&self, addr: &SocketAddr, msg: OutgoingMessage) {
        let mut users = self.0.lock().await;
        self.broadcast_internal(addr, msg, &mut users).await;
    }
    
    // method for handling the leave of a user
    async fn leave(&self, addr: &SocketAddr) {
        let mut users = self.0.lock().await;
        if let Some(leaving) = users.remove(addr) {
            // notify the users
            self.broadcast_internal(addr, OutgoingMessage::Leave(leaving.username), &mut users)
                .await;
        }
    }
    
    
    // internal broadcast implementation that actually delivers the message to each user.
    async fn broadcast_internal(
        &self,
        addr: &SocketAddr,
        msg: OutgoingMessage,
        users: &mut HashMap<SocketAddr, User>,
    ) {
        for (user_addr, user) in users.iter_mut() {
            if addr != user_addr {
                let _ = user.sender.send(msg.clone());
            }
        }
    }
}

The broadcast method sends the same OutgoingMessage to all users, filtering out the input SocketAddr.

The leave method removes the user identified by the input SocketAddr from the HashMap and then broadcast a OutgoingMessage::Leave message.

To complete the Room implementation we are still missing some bits.

When users join a room they should receive a notification with all the room members, and the other users should receive the notification of the newly joined member.

For that we change a little bit the Room::join method:

impl Room {
    pub async fn join(&self, addr: SocketAddr, username: Username) -> UserHandle {
        let mut users = self.0.lock().await;
        let (sender, receiver) = mpsc::unbounded_channel();

        let names = users
            .iter()
            .map(|(_, user)| user.username.clone())
            .collect::<Vec<Username>>();

        // send the users list notification
        let _ = sender.send(OutgoingMessage::Participants(names));

        // broadcast the join message to the other users 
        self.broadcast_internal(&addr, OutgoingMessage::Join(username.clone()), &mut users)
            .await;
            
        users.insert(
            addr,
            User {
                username: username.clone(),
                sender,
            },
        );

        UserHandle {
            username,
            receiver,
            room: self.clone(),
            address: addr,
        }
    }
}

Before adding the user to the chat room we send a notification on its channel with the participants list and then we broadcast to the others the join message.

All the logic for the Room state should be completed and ready for testing now.

I will not describe here the room_test, as I would like to focus on testing the handle_client_internal function once the implementation is done, but the code is available here.

# Wiring all together

So far we have done the handshake phase and the Room API for managing users and broadcasting messages in the server.

Let's dive back in the handle_client_internal function and wire all together.

After the handshake phase we should call the Room::join by passing the SocketAddr of the connected client and the received username. Once a user have joined the room we need to process in a loop the incoming messages. There are two possible source of messages at this point. One coming from the external Stream (Client) and the message in broadcast from other users.

If a message is coming from the external Stream (String), it's a chat message coming from the connected client and we broadcast to the others it by calling UserHandle::send_message.

If a message is coming from the receiver of a UserHandle (OutgoingMessage), it's a broadcast message from another user and we just send it to the Sink of the connected client.

For processing multiple asynchronous sources we'll use tokio::select!.

The final handle_client_internal looks like this:

async fn handle_client_internal<I, O>(
    state: Room,
    address: SocketAddr,
    mut sink: O,
    mut stream: I,
) -> anyhow::Result<()>
where
    I: Stream<Item = anyhow::Result<String>> + Unpin,
    O: Sink<OutgoingMessage, Error = anyhow::Error> + Unpin,
{

    sink.send(OutgoingMessage::Welcome).await?;
    
    let username = stream
        .try_next()
        .await?
        .ok_or_else(|| anyhow::anyhow!("Error while waiting for the username"))?;


    let username = match Username::parse(username) {
        Ok(username) => username,
        Err(e) => {
            sink.send(OutgoingMessage::InvalidUsername(e.to_string()))
                .await?;
            return Ok(());
        }
    };
    
    // join the room
    let mut handle = state.join(address, username).await;

    // process messages until there is an error or the stream is exhausted
    loop {
        tokio::select! {
            // Message received from a user 
            Some(msg) = handle.receiver.recv() => {
                 // Send it to the connected client
                if let Err(e) = sink.send(msg).await {
                    error!("Error sending message {}",e);
                    break;
                }
            }

            // Message received from the connected client
            result = stream.next() => match result {
                Some(Ok(msg)) =>  {
                    // broadcast it
                    handle.send_message(msg).await;
                }
                // break on error 
                Some(Err(e)) => {
                    error!("Error reading messages {}",e);
                    break;
                }
                // break on stream exhausted
                None => break,
            }
        };
    }

    // leave the room
    handle.leave().await;
    Ok(())
}

# Testing the solution

Testing the handle_client_internal requires more setup. First we need to add async-stream as dev dependency for creating ad-hoc streams.

 cargo add --dev async-stream

The handle_client_internal function takes as input:

  • The shared state Room
  • The SocketAddr of the client
  • The Sink for sending messages (client -> server)
  • The Stream for receiving message (server -> client)

We just have to find a way to provide an implementation for Sink and Stream that we can use then for testing. tokio::mpsc::channel seems a good candidate for that.

Let's write a fixture function connect for simulating a client connection:

async fn connect(room: Room, addr: &str) -> UserTest {
    // channel for sending stuff server -> client
    let (sink_tx, sink_rx) = mpsc::channel(100);
    
    // channel for sending stuff client -> server
    let (stream_tx, mut stream_rx) = mpsc::channel(100);

    let address: SocketAddr = addr.parse().unwrap();

    // convert the stream receiver in a stream
    let stream = async_stream::stream! {
        while let Some(message) = stream_rx.recv().await {
            yield message
        }
    };

    let handle = tokio::spawn(async move {
        handle_client_internal(
            room,
            address,
            PollSender::new(sink_tx).sink_map_err(anyhow::Error::from),
            Box::pin(stream),
        )
        .await
    });

    UserTest {
        sink_receiver: sink_rx,
        stream_sender: Some(stream_tx),
        handle,
    }
}
struct UserTest {
    sink_receiver: Receiver<OutgoingMessage>,
    stream_sender: Option<Sender<anyhow::Result<String>>>,
    handle: JoinHandle<anyhow::Result<()>>,
}

impl UserTest {
    async fn send(&mut self, message: &str) {
        self.stream_sender
            .as_ref()
            .unwrap()
            .send(Ok(message.to_string()))
            .await
            .unwrap();
    }

    async fn leave(mut self) {
        let stream = self.stream_sender.take();
        drop(stream);

        self.handle.await.unwrap().unwrap()
    }

    async fn check_message(&mut self, msg: OutgoingMessage) {
        assert_eq!(self.sink_receiver.recv().await.unwrap(), msg);
    }
}

In the connect function we create two channels:

  • sink channel (server -> client)
  • stream channel (client -> server)

and we spawn a background task for handling the single client by calling the handle_client_internal function with those two channels Sink and Stream as input.

We have to do some conversion magic in the middle. The tokio::mpsc::Receiver doesn't implement Stream, but fortunately we can use async_stream::stream! for converting the channel into a Stream. The tokio::mpsc::Sender doesn't implement Sink and we wrap the sender into a PollSender from tokio-util which makes our sender compatible with Sink trait.

The connect function returns a UserTest type, which we can use in our tests for driving a single user in a chat session scenario.

It has the following fields:

  • sink_receiver: for messages (server -> client)
  • stream_sender : for messages (client -> server)

I've created a couple of additional methods on UserTest for making the tests more readable:

  • send: send a message to the server
  • check_message: check if the next message matches msg
  • leave: simulate the disconnection by dropping the stream_sender and waiting for the background handle_client_internal task to complete.

Now we can use this API in a chat test scenario with two participants:

  • alice: "0.0.0.0:10"
  • bob: "0.0.0.0:11"
{server}-->{alice} Welcome to budgetchat! What shall I call you?
{server}<--{alice} alice
{server}-->{alice} * The room contains:
{server}-->{bob}   Welcome to budgetchat! What shall I call you?
{server}-->{bob}   * The room contains: alice
{server}-->{alice} * bob has entered the room
{server}<--{alice} Hi bob!
{server}-->{bob}   Hi bob!
{server}<--{bob}   Hi alice!
{server}-->{alice} Hi alice!
{server}-->{alice} * bob has left the room

The code for testing the above session is:

#[tokio::test]
async fn example_session_test() {
    let room = Room::new();

    let alice_username = Username::parse("alice".to_string()).unwrap();
    let bob_username = Username::parse("bob".to_string()).unwrap();

    // alice connects
    let mut alice = connect(room.clone(), "0.0.0.0:10").await;
    alice.check_message(OutgoingMessage::Welcome).await;

    // alice sends the username and get the participants list
    alice.send(&alice_username.inner_ref()).await;
    alice
        .check_message(OutgoingMessage::Participants(vec![]))
        .await;

    // bob connects
    let mut bob = connect(room.clone(), "0.0.0.0:11").await;
    bob.check_message(OutgoingMessage::Welcome).await;

    // bob sends the username and get the participants list
    bob.send(&bob_username.inner_ref()).await;
    bob.check_message(OutgoingMessage::Participants(vec![alice_username.clone()]))
        .await;


    // alice gets the notification of bob joining the room
    alice
        .check_message(OutgoingMessage::Join(bob_username.clone()))
        .await;

    // alice sends a message
    alice.send("Hi bob!").await;


    // bob gets alice's message
    bob.check_message(OutgoingMessage::Chat {
        msg: "Hi bob!".to_string(),
        from: alice_username.clone(),
    })
    .await;


    // bob sends a message
    bob.send("Hi alice!").await;


    // alice gets bob's message
    alice
        .check_message(OutgoingMessage::Chat {
            msg: "Hi alice!".to_string(),
            from: bob_username.clone(),
        })
        .await;

    // bob leaves the room
    bob.leave().await;
   
    // alice gets the notification of bob leaving the room
    alice
        .check_message(OutgoingMessage::Leave(bob_username))
        .await;
}

Let's run it:

 cargo test example_session_test -- --nocapture
   Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
    Finished test [unoptimized + debuginfo] target(s) in 1.27s
     Running unittests src/bin/budget_chat.rs (target/debug/deps/budget_chat-9feb9b3a82ded77d)

running 1 test
test budget_chat_tests::example_session_test ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 1 filtered out; finished in 0.00s

... and it passed!!

# Wrapping Up

Wow! this post turned out longer than expected, but we did a lot of stuff for this challenge. Even though probably the solution it's over-engineered, and we could have solved it by using AsyncRead & AsyncWrite primitives with their extensions, we explored a bit concepts like Sink and Stream that abstract away all the underlying network layer. We also have learned a bit about framing and how to use Framed, Encoder and Decoder for converting low level stream of bytes into high level frames.

Comments and suggestions are always welcome. Please feel free to send an email or comment on @wolf4ood@hachyderm.io.

Thank you for reading, see you in the next chapter!