Protohackers in Rust: Prime Time

Enrico Risa |
|
11 min |
2065 words

This is the second article of the series solving protohackers with Rust. In the previous article we implemented a simple Echo Server, a kind of hello world for network programming. This was useful to us for getting familiar with async rust using Tokio and the protohackers test suite.

The code in this series will be available on GitHub.

# The Challenge

In the second problem we have to write a JSON-based request-response protocol, where each request is a single line terminated by a newline character \n. Once a client is connected, we should consume each line in order, parse the JSON and send a response back to the client.

The JSON payload contains two required fields:

  • method
  • number

For a request to be valid, the method field should always contains the string isPrime, while the number should always be any JSON number. If the request is valid, we should send back a valid response. A valid response is a JSON with two required fields, method always containing the string isPrime and prime of type boolean, which is the result of the calculation if the number was prime or not.

Example:

{"method":"isPrime","number":123} // Request

{"method":"isPrime","prime":false} // Response

If a request is not valid we should send back an invalid response and close the connection.

Example:

{"method":"isPrime","number": "number"} // Request

malformed // Response and close the connection

Like the previous challenge, we need to handle at least 5 clients simultaneously.

# Getting Started

At this point the project structure looks like this:

 tree $PWD -I target
~/Coding/protohackers-rs
├── Cargo.lock
├── Cargo.toml
├── README.md
└── src
    ├── bin
       └── smoke_test.rs
    └── lib.rs

For this challenge we are going to add another binary prime_time.rs in src/bin folder where we'll implement our solution:

 touch src/bin/prime_time.rs

For handling JSON, we'll need:

  • serde (serialization/deserialization framework)
  • serde-json (JSON serialization/deserialization)
 cargo add serde --features derive
 cargo add serde-json

By default serde does not include the ability to generate serde::{Deserialize,Serialize} for our types using the macro #[derive(..)], that's why we have to include the derive feature when adding serde.

Now let's bootstrap the main function in prime_prime.rs and init the logging system:

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt::init();
    Ok(())
}

What we should do now, is write a server that accepts TCP connections like the one in the first Protohackers challenge. Since we'll probably need this kind of server in all/most of the challenges, it's a good idea to extract the acceptor code in a common function.

in lib.rs we could create a function run_server that looks like this:

use std::future::Future;
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, error, info};

pub async fn run_server<H, F>(port: u16, handler: H) -> anyhow::Result<()>
where
    H: Fn(TcpStream) -> F,
    F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
    let listener = TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;

    info!("Starting server at 0.0.0.0:{}", port);
    loop {
        let (socket, address) = listener.accept().await?;

        debug!("Got connection from {}", address);
        let future = handler(socket);
        tokio::task::spawn(async move {
            if let Err(err) = future.await {
                error!("Error handling connection {}: {}", address, err);
            }
        });
    }
}

The run_server function takes in input:

  • port: the binding port
  • handler: a function that will handle each connection

The handler parameter is an H(Handler) generic with bound:

where
    H: Fn(TcpStream) -> F,
    F: Future<Output = anyhow::Result<()>> + Send + 'static,

meaning that the handler should be a function that takes a TcpStream as input and returns a Future. The returned Future must be compliant with the one required by the tokio::task::spawn, therefore we need add the bounds:

  • Send: the runtime is allowed to move the task between threads
  • 'static: the task should not reference any data owned outside the task

In the acceptor loop, once a client is connected, we call the handler and then we await the returned Future inside the spawned task to handle the TcpStream concurrently.

Let's plug the above function in our main:

use protohackers_rs::run_server;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt::init();
    run_server(8000, handle_client).await?;
    Ok(())
}

async fn handle_client(stream: TcpStream) -> anyhow::Result<()> {
    Ok(())
}

# Prime Time Implementation

In the handle_client function first we need to start reading lines from the socket. An option could be reading bytes from the socket and checking if there is a newline there. But that would require to manually handling the buffer and we would like to avoid that for this challenge. Instead we'll be using a Tokio IO helper that can do that for us.

We could use one of the methods available in tokio::io::AsyncBufReadExt:

  • read_line read all bytes until a newline is reached
  • lines returns a stream over the lines of the underling input source

I opted for the lines function, because I think it gives us an higher lever of abstraction compared to read_line.

The AsyncBufReadExt is available only for types that supports buffering, which is not the case for TcpStream. To enhance the stream with buffering we have to wrap it in a BufReader. We can now write the reading while loop line by line:

async fn handle_client(stream: TcpStream) -> anyhow::Result<()> {
    let input_stream = BufReader::new(stream);
    let mut lines = input_stream.lines();
    while let Some(line) = lines.next_line().await? {
        debug!("Got a line {}", line);
    }
    Ok(())
}

For each line we parse the JSON using serde-json and serde in a custom type using the #[derive(..)] macro. If the parsing fails, it means that the request is malformed.

We can represent this challenge JSON format in a type like this:

#[derive(Deserialize, Debug)]
pub struct Request {
    method: Method,
    number: f64,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub enum Method {
    IsPrime,
}

By adding #[derive(Deserialize)] serde will automatically generate an implementation of Deserialize trait for our types.

The requirement of the method field being a string with value isPrime has been encoded in the enum Method, in this way we don't have to manually check that the input string is equals to isPrime.

Each line now can be parsed with the serde_json::from_str function:

match serde_json::from_str::<Request>(&line) {
    Ok(_) => {
        // valid request check if is prime
    }
    Err(_) => {
        // malformed request
    }
};

if the request is a valid one, we calculate if the number is prime and build the JSON response for the client. For that we could create a type Response, same as we did for Request:

#[derive(Serialize, Debug)]
pub struct Response {
    method: Method,
    prime: bool,
}

impl Response {
    pub fn new(prime: bool) -> Self {
        Self {
            method: Method::IsPrime,
            prime,
        }
    }
}

Note: We have to add the #[derive(Serialize,..)] macro also in the Method enum.

Now we can build the response and prepare the buffer:

let response = Response::new(is_prime(req.number));
let bytes = serde_json::to_vec(&response)?;
bytes.push(b'\n');

where is_prime is a function that calculates if the input number is prime.

Note: We have to append \n character since it's a line based protocol.

For writing the buffer we are using here AsyncWriteExt::write_all, but first we have to split the TcpStream, because the lines method takes the ownership of the input_stream:

async fn handle_client(mut stream: TcpStream) -> anyhow::Result<()> {
    let (input_stream, mut output_stream) = stream.split();
    let input_stream = BufReader::new(input_stream);
    let mut lines = input_stream.lines();
    while let Some(line) = lines.next_line().await? {
        debug!("Got a line {}", line);
        match serde_json::from_str::<Request>(&line) {
            Ok(req) => {
                let response = Response::new(is_prime(req.number));
                let mut bytes = serde_json::to_vec(&response)?;
                bytes.push(b'\n');
                output_stream.write_all(&bytes).await?;
            }
            Err(e) => {
                error!("Malformed request {}", e);
                output_stream.write_all(b"malformed\n").await?;
                break;
            }
        };
    }
    Ok(())
}

When a request is malformed we can just write back a malformed response and break the while loop for closing the TCP connection.

# Testing the solution

We can manually test our solution locally using socat:

 RUST_LOG=debug cargo run --bin prime_time
   Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
    Finished dev [unoptimized + debuginfo] target(s) in 1.30s
     Running `target/debug/prime_time`
2023-01-19T11:49:51.482543Z  INFO protohackers_rs: Starting server at 0.0.0.:8000

 echo "test" | socat - tcp:localhost:8000
malformed

 echo '{"method": "isPrime","number": 123}' | socat - tcp:localhost:8000
{"method":"isPrime","prime":false}

 echo '{"method": "isPrime","number": 29}' | socat - tcp:localhost:8000
{"method":"isPrime","prime":true}

For this challenge we'd want also to write some tests for checking that our implementation works as expected before running the protohackers test suite.

The handle_client function is not easy to test without spawning the server, because it takes a TcpSocket (a real connection client/server) as input. To solve this we could create another function handle_client_internal which works on traits rather than TcpStream.

async fn handle_client(mut stream: TcpStream) -> anyhow::Result<()> {
    let (input_stream, output_stream) = stream.split();
    handle_client_internal(input_stream, output_stream).await
}
async fn handle_client_internal(
    input_stream: impl AsyncRead,
    output_stream: impl AsyncWrite,
) -> anyhow::Result<()> {
    let input_stream = BufReader::new(input_stream);
    ...
}

With this refactor the handle_client_internal function can take as input concrete channels (TpcStream) as well as memory buffers for testing.

Note that since don't know if the IO object AsyncWrite is a buffered one, we need to call the flush to ensure that eventual buffered content is written. In a TcpStream it will be a noop.

Here's an example on how we could write a test that checks a malformed request:

#[tokio::test]
async fn prime_time_test_malformed() {
    let input = "{}\n";
    let mut output: Vec<u8> = vec![];

    handle_client_internal(input.as_bytes(), &mut output)
        .await
        .expect("Failed to handle");

    assert_eq!(
        String::from("malformed\n"),
        String::from_utf8(output).unwrap()
    );
}

Note that #[tokio::test] is needed for boostrapping the Tokio runtime and for async functions in tests.

Let's try to run the test:

 cargo test prime_time_test_malformed
  Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
error[E0277]: `impl AsyncWrite` cannot be unpinned
  --> src/bin/prime_time.rs:35:31
   |
35  |                 output_stream.write_all(&bytes).await?;
   |                               ^^^^^^^^^ within `tokio::io::util::buf_writer::_::__Origin<'_, impl AsyncWrite>`, the trait `Unpin` is not implemented for `impl AsyncWrite`
   |
   = note: consider using `Box::pin`
   = note: required because it appears within the type `tokio::io::util::buf_writer::_::__Origin<'_, impl AsyncWrite>`
   = note: required because of the requirements on the impl of `Unpin` for `tokio::io::BufWriter<impl AsyncWrite>`
note: required by a bound in `tokio::io::AsyncWriteExt::write_all`
  --> ~/.cargo/registry/src/github.com-1ecc6299db9ec823/tokio-1.24.1/src/io/util/async_write_ext.rs:366:19
   |
366 |             Self: Unpin,
   |                   ^^^^^ required by this bound in `tokio::io::AsyncWriteExt::write_all`
help: consider further restricting this bound
   |
23  |     output_stream: impl AsyncWrite + std::marker::Unpin,
   |                                    ++++++++++++++++++++

Ops! the compiler is complaining that impl AsyncWrite cannot be unpinned, same for impl AsyncRead. Fortunately the Rust compiler give us an hint on how to fix this by suggesting to stick a + Unpin in both parameters in handler_client_internal. Thanks Rust compiler!!

Wait but what does Unpin mean?

For now we can just trust the compiler and we only need to know that:

  • AsyncWriteExt::write_all
  • Lines::next_line

require Unpin types.

If you want to read more about Pin/Unpin, I suggest these links:

The final implementation after the refactoring is:

async fn handle_client_internal(
    input_stream: impl AsyncRead + Unpin,
    mut output_stream: impl AsyncWrite + Unpin,
) -> anyhow::Result<()> {
    let input_stream = BufReader::new(input_stream);
    let mut lines = input_stream.lines();
    while let Some(line) = lines.next_line().await? {
        debug!("Got a line {}", line);
        match serde_json::from_str::<Request>(&line) {
            Ok(req) => {
                let response = Response::new(is_prime(req.number));
                let mut bytes = serde_json::to_vec(&response)?;
                bytes.push(b'\n');
                output_stream.write_all(&bytes).await?;
                output_stream.flush().await?;
            }
            Err(e) => {
                error!("Malformed request {}", e);
                output_stream.write_all(b"malformed\n").await?;
                output_stream.flush().await?;
                break;
            }
        };
    }
    Ok(())
}

And now tests run just fine:

 cargo test prime_time_test_malformed
   Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
    Finished test [unoptimized + debuginfo] target(s) in 1.08s
     Running unittests src/bin/prime_time.rs (target/debug/deps/prime_time-f3e95c65c73516d5)

running 1 test
test tests::prime_time_test_malformed ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

# Wrapping Up

In this second part of protohackers challenges, we did a little bit of refactoring on the acceptor loop by extracting it in a common function reusable in most of the challenges.

We have learned how to handle line-based protocols and how to write functions that deal with IO in a way that could be easily unit-tested.

Comments and suggestions are always welcome. Please feel free to send an email or comment on @wolf4ood@hachyderm.io.

Thank you for reading, see you in the next chapter!