Protohackers in Rust: Means to an End

Enrico Risa |
|
9 min |
1728 words

This is the third article of the series solving protohackers with Rust. The series index is available here.

The code in this series will be available on GitHub

# The Challenge

In the third challenge we have to implement a simple binary protocol. Each message is composed by 9 bytes with this format:

Byte:  |  0  |  1     2     3     4  |  5     6     7     8  |
Type:  |char |         int32         |         int32         |

The first byte should be Q or I, which indicates the operation kind (query or insert). The rest 8 bytes are two 32-bit integers in big-endian order. The meaning of this two integer depends on the op code.

For an insert request the server should interprets the first integer as timestamp and the second as price, allowing clients to store their asset at a given timestamp. The insert request doesn't require a response.

For a query request the server should interprets the two integers as a range of timestamp where the first is mintime and the second is maxtime, allowing clients to compute the average price over the given period. The server should then send back the response encoded in 4 bytes int32. If there are no date in the period or mintime > maxtime, 0 should be sent back.

Like the previous challenge, the server needs to handle at least 5 clients simultaneously.

# Getting Started

At this point our project structure looks like this:

 tree $PWD -I target
~/Coding/protohackers-rs
├── Cargo.lock
├── Cargo.toml
├── README.md
└── src
    ├── bin
       ├── prime_time.rs
       └── smoke_test.rs
    └── lib.rs

For this challenge we are adding another binary means_to_an_end.rs in src/bin folder where we'll write our solution:

 touch src/bin/means_to_an_end.rs

For this challenge no additional dependencies are needed and we should be able to solve it with Tokio primitives.

The main function setup looks similar to the previous challenge:

use protohackers_rs::run_server;
use tokio::{
    io::{AsyncRead, AsyncWrite},
    net::TcpStream,
};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt::init();
    run_server(8000, handle_client).await?;
    Ok(())
}

async fn handle_client(mut stream: TcpStream) -> anyhow::Result<()> {
    let (input_stream, output_stream) = stream.split();
    handle_client_internal(input_stream, output_stream).await
}

async fn handle_client_internal(
    mut input_stream: impl AsyncRead + Unpin,
    mut output_stream: impl AsyncWrite + Unpin,
) -> anyhow::Result<()> {
    todo!();
}

#[cfg(test)]
mod means_to_an_end_tests {}

and we are going to focus on the handle_client_internal function where we will write our solution.

# Prime Time Implementation

Since the format of a messages is 9 bytes long, we start by reading in a loop those bytes in a fixed buffer:

let mut buffer = [0; 9];
loop {
    input_stream.read_exact(&mut buffer).await?;
}

The read_exact method reads bytes until the buffer is filled. If the end of file is reached before filling the buffer an error will be returned. We'd want to catch the end of file error, if arises, for a graceful shutdown. If any other error happens we just return it:

let mut buffer = [0; 9];
loop {
    match input_stream.read_exact(&mut buffer).await {
        Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
            return Ok(());
        }
        Err(err) => return Err(err.into()),
        Ok(_) => {}
    }
}

# Message Parsing

With the buffer filled, we are now ready to decode those bytes into an high level structure. Looking at the protocol specs there are two kind of messages:

  • Insert (I)
  • Query (Q)

We could model the message with this type:

pub enum Message {
    Insert { timestamp: i32, price: i32 },
    Query { mintime: i32, maxtime: i32 },
}

and implement a parse function that creates a Message from a fixed buffer of 9 bytes:

impl Message {
    pub fn parse(buffer: &[u8; 9]) -> anyhow::Result<Self> {
        let op = buffer[0];
        let first = i32::from_be_bytes(buffer[1..5].try_into()?);
        let second = i32::from_be_bytes(buffer[5..9].try_into()?);

        match op {
            b'I' => Ok(Message::Insert {
                timestamp: first,
                price: second,
            }),
            b'Q' => Ok(Message::Query {
                mintime: first,
                maxtime: second,
            }),
            _ => Err(anyhow::anyhow!("Unexpected op code {}", op)),
        }
    }
}

The operation code doesn't need to be decoded since it's only one byte. For decoding the two integers we use i32::from_be_bytes from the original buffer in positions 1-4 and 5-8. The buffer[1..5] returns a slice &[u8] representing the view of the original buffer for the range 1..5, but the i32::from_be_bytes works with fixed array [u8;4]. For making the conversion, we have to use the fallible conversion API using buffer[1..5].try_into()?.

Let's plug the message parser:

match Message::parse(&buffer)? {
    Message::Insert { timestamp, price } => {}
    Message::Query { mintime, maxtime } => {}
}

Before proceeding with the message handling, I'd like to write some tests for checking that the message parsing works as expected. The Message::parse method is a perfect candidate for some unit tests.

We should at least test if the parser works with a valid input with both cases Insert and Query, and also the case where the op code is neither I or Q:

#[cfg(test)]
mod means_to_an_end_tests {

    use crate::Message;
    use tokio::io::AsyncWriteExt;

    async fn create_message(op: u8, first: i32, second: i32) -> [u8; 9] {
        let mut buffer = vec![];
        buffer.write_u8(op).await.unwrap();
        buffer.write_i32(first).await.unwrap();
        buffer.write_i32(second).await.unwrap();

        buffer.try_into().unwrap()
    }

    #[tokio::test]
    async fn message_parse_test_insert_ok() {
        let buffer = create_message(b'I', 10, 100).await;

        let message = Message::parse(&buffer).unwrap();

        assert_eq!(
            Message::Insert {
                timestamp: 10,
                price: 100
            },
            message
        )
    }

    #[tokio::test]
    async fn message_parse_test_query_ok() {
        let buffer = create_message(b'Q', 10, 100).await;

        let message = Message::parse(&buffer).unwrap();

        assert_eq!(
            Message::Query {
                mintime: 10,
                maxtime: 100,
            },
            message
        )
    }

    #[tokio::test]
    async fn message_parse_test_fail() {
        let buffer = create_message(b'Z', 10, 100).await;

        let result = Message::parse(&buffer);

        assert!(result.is_err());
    }
}

Here I've created also an helper function create_message for stubbing messages.

Let's run those tests:

 cargo test  message_parse_test
   Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
    Finished test [unoptimized + debuginfo] target(s) in 0.90s
     Running unittests src/bin/means_to_an_end.rs (target/debug/deps/means_to_an_end-498dcc6835502baa)

running 3 tests
test tests::message_parse_test_insert_ok ... ok
test tests::message_parse_test_fail ... ok
test tests::message_parse_test_query_ok ... ok

and all tests are green, Yay!

# Message handling

For keeping track of the client's asset a type like HashMap or a BTreeMap came to my mind in order to associate a price to a given timestamp. For this use case I think it's better if we pick BTreeMap, since it provides us a nice API for finding values between a range [mintime,maxtime].

We can even do better. Instead of using directly HashMap or BTreeMap or whatever, we could encapsulate the storage behavior in a custom type. This will give us more flexibility if we change in the future how we store and retrieve the prices data without affecting our solution. We should at least provide two APIs for the storage, one for inserting a price in a given timestamp and one for computing the mean in a period.

Let's call this type Db:

pub struct Db(BTreeMap<i32, i32>);

impl Db {
    pub fn new() -> Db {
        Db(BTreeMap::new())
    }

    pub fn insert(&mut self, timestamp: i32, price: i32) {
        self.0.insert(timestamp, price);
    }

    pub fn mean(&self, range: RangeInclusive<i32>) -> i32 {
        if range.is_empty() {
            return 0;
        };
        let (count, sum) = self
            .0
            .range(range)
            .fold((0, 0_i64), |(count, sum), (_, amount)| {
                (count + 1, sum + *amount as i64)
            });

        if count > 0 {
            (sum / count) as i32
        } else {
            0
        }
    }
}

The insert is just a delegate to the underlying implementation. The mean methods compute the average price in the period of the input range parameter. If the range is empty (mintime > maxtime) we can just return 0.

To complete the solution, I just have to use the Db type when handling the messages:

async fn handle_client_internal(
    mut input_stream: impl AsyncRead + Unpin,
    mut output_stream: impl AsyncWrite + Unpin,
) -> anyhow::Result<()> {
    let mut buffer = [0; 9];
    let mut prices = Db::new();
    loop {
        match input_stream.read_exact(&mut buffer).await {
            Err(err) if err.kind() == ErrorKind::UnexpectedEof => {
                return Ok(());
            }
            Err(err) => return Err(err.into()),
            Ok(_) => {}
        }

        match Message::parse(&buffer)? {
            Message::Insert { timestamp, price } => {
                prices.insert(timestamp, price);
            }
            Message::Query { mintime, maxtime } => {
                let mean = prices.mean(mintime..=maxtime);
                output_stream.write_i32(mean).await?;
            }
        }
    }
}

When handling the Query message we call the Db::mean method for computing the average in the [mintime,maxtime] range and then write the response back to the client using AsyncWriteExte::write_i32 method.

# Testing the solution

The Message::parse method is already covered with some unit tests. But we are still not sure if the handle_client_internal function behaves correctly.

We should create a test that at least cover the example session in the problem statement:

    Hexadecimal:                 Decoded:
<-- 49 00 00 30 39 00 00 00 65   I 12345 101
<-- 49 00 00 30 3a 00 00 00 66   I 12346 102
<-- 49 00 00 30 3b 00 00 00 64   I 12347 100
<-- 49 00 00 a0 00 00 00 00 05   I 40960 5
<-- 51 00 00 30 00 00 00 40 00   Q 12288 16384
--> 00 00 00 65                  101

In the above example we have four inserts, one query and the server should return 101 as mean value for the given mintime and maxtime:

#[tokio::test]
async fn example_session_test() {
    let messages = vec![
        create_message(b'I', 12345, 101).await,
        create_message(b'I', 123456, 102).await,
        create_message(b'I', 123456, 100).await,
        create_message(b'I', 40960, 5).await,
        create_message(b'Q', 12288, 16384).await,
    ]
    .into_iter()
    .flatten()
    .collect::<Vec<u8>>();

    let mut output = vec![];

    handle_client_internal(messages.as_slice(), &mut output)
        .await
        .unwrap();

    assert_eq!(4, output.len());

    assert_eq!(101, i32::from_be_bytes(output[..4].try_into().unwrap()));
}

This test requires more setup. Since we are simulating a session, more messages are needed. We reuse the function create_message for creating the single message and then we merge them together for composing the final byte array representing the session that we want to test. In the output then we expect to receive a single i32 with value 101.

Let's run it:

 cargo test example_session_test
    Finished test [unoptimized + debuginfo] target(s) in 0.02s
     Running unittests src/bin/means_to_an_end.rs (target/debug/deps/means_to_an_end-022222498b945670)

running 1 test
test means_to_an_end_tests::example_session_test ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 3 filtered out; finished in 0.00s

..and it works 🚀 !

# Wrapping Up

That's it for the third article! This was fun, we got to play a little bit with binary protocols this time, learning how to read fixed-based messages and how to encode/decode i32.

Comments and suggestions are always welcome. Please feel free to send an email or comment on @wolf4ood@hachyderm.io.

Thank you for reading, see you in the next chapter!