Protohackers in Rust: Prime Time
This is the second article of the series solving protohackers with Rust. In the previous article
we implemented a simple Echo Server, a kind of hello world
for network programming. This was useful to us for getting familiar with async rust using Tokio
and the protohackers test suite.
The code in this series will be available on GitHub.
# The Challenge
In the second problem we have to write a JSON-based request-response protocol, where each request is a single line terminated by a newline character \n
.
Once a client is connected, we should consume each line in order, parse the JSON and send a response back to the client.
The JSON payload contains two required fields:
method
number
For a request to be valid, the method
field should always contains the string isPrime, while the number
should always be any JSON number.
If the request is valid, we should send back a valid response. A valid response is a JSON with two required fields,
method
always containing the string isPrime
and prime
of type boolean, which is the result of the calculation if the number
was prime or not.
Example:
{"method":"isPrime","number":123} // Request
{"method":"isPrime","prime":false} // Response
If a request is not valid we should send back an invalid response and close the connection.
Example:
{"method":"isPrime","number": "number"} // Request
malformed // Response and close the connection
Like the previous challenge, we need to handle at least 5 clients simultaneously.
# Getting Started
At this point the project structure looks like this:
❯ tree $PWD -I target
~/Coding/protohackers-rs
├── Cargo.lock
├── Cargo.toml
├── README.md
└── src
├── bin
│ └── smoke_test.rs
└── lib.rs
For this challenge we are going to add another binary prime_time.rs
in src/bin
folder where we'll implement our solution:
❯ touch src/bin/prime_time.rs
For handling JSON, we'll need:
serde
(serialization/deserialization framework)serde-json
(JSON serialization/deserialization)
❯ cargo add serde --features derive
❯ cargo add serde-json
By default serde
does not include the ability to generate serde::{Deserialize,Serialize}
for our types using the macro #[derive(..)]
, that's why we have to include the derive
feature when adding serde
.
Now let's bootstrap the main function in prime_prime.rs
and init the logging system:
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
Ok(())
}
What we should do now, is write a server that accepts TCP connections like the one in the first Protohackers challenge. Since we'll probably need this kind of server in all/most of the challenges, it's a good idea to extract the acceptor code in a common function.
in lib.rs
we could create a function run_server
that looks like this:
use std::future::Future;
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, error, info};
pub async fn run_server<H, F>(port: u16, handler: H) -> anyhow::Result<()>
where
H: Fn(TcpStream) -> F,
F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
let listener = TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;
info!("Starting server at 0.0.0.0:{}", port);
loop {
let (socket, address) = listener.accept().await?;
debug!("Got connection from {}", address);
let future = handler(socket);
tokio::task::spawn(async move {
if let Err(err) = future.await {
error!("Error handling connection {}: {}", address, err);
}
});
}
}
The run_server
function takes in input:
- port: the binding port
- handler: a function that will handle each connection
The handler parameter is an H
(Handler) generic with bound:
where
H: Fn(TcpStream) -> F,
F: Future<Output = anyhow::Result<()>> + Send + 'static,
meaning that the handler should be a function that takes a TcpStream
as input and returns a Future
.
The returned Future
must be compliant with the one required by the tokio::task::spawn
, therefore we need add the bounds:
Send
: the runtime is allowed to move the task between threads'static
: the task should not reference any data owned outside the task
In the acceptor loop, once a client is connected, we call the handler
and then we await
the returned Future
inside the spawned task to handle the TcpStream
concurrently.
Let's plug the above function in our main:
use protohackers_rs::run_server;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
run_server(8000, handle_client).await?;
Ok(())
}
async fn handle_client(stream: TcpStream) -> anyhow::Result<()> {
Ok(())
}
# Prime Time Implementation
In the handle_client
function first we need to start reading lines from the socket. An option could be reading bytes from the socket and checking if there is a newline there. But that would require to manually handling the buffer and we would like to avoid that for this challenge. Instead we'll be using a Tokio IO helper that can do that for us.
We could use one of the methods available in tokio::io::AsyncBufReadExt
:
read_line
read all bytes until a newline is reachedlines
returns a stream over the lines of the underling input source
I opted for the lines
function, because I think it gives us an higher lever of abstraction compared to read_line
.
The AsyncBufReadExt
is available only for types that supports buffering, which is not the case for TcpStream
. To enhance the stream with buffering we have to wrap it in a BufReader
. We can now write the reading while loop line by line:
async fn handle_client(stream: TcpStream) -> anyhow::Result<()> {
let input_stream = BufReader::new(stream);
let mut lines = input_stream.lines();
while let Some(line) = lines.next_line().await? {
debug!("Got a line {}", line);
}
Ok(())
}
For each line we parse the JSON using serde-json and serde in a custom type using the #[derive(..)]
macro. If the parsing fails, it means that the request is malformed.
We can represent this challenge JSON format in a type like this:
#[derive(Deserialize, Debug)]
pub struct Request {
method: Method,
number: f64,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub enum Method {
IsPrime,
}
By adding #[derive(Deserialize)]
serde
will automatically generate an implementation of Deserialize
trait for our types.
The requirement of the method
field being a string with value isPrime has been encoded in the enum Method
, in this way we don't have to manually check that the input string is equals to isPrime.
Each line now can be parsed with the serde_json::from_str
function:
match serde_json::from_str::<Request>(&line) {
Ok(_) => {
// valid request check if is prime
}
Err(_) => {
// malformed request
}
};
if the request is a valid one, we calculate if the number is prime and build the JSON response for the client. For that we could create a type Response
, same as we did for Request
:
#[derive(Serialize, Debug)]
pub struct Response {
method: Method,
prime: bool,
}
impl Response {
pub fn new(prime: bool) -> Self {
Self {
method: Method::IsPrime,
prime,
}
}
}
Note: We have to add the #[derive(Serialize,..)]
macro also in the Method
enum.
Now we can build the response and prepare the buffer:
let response = Response::new(is_prime(req.number));
let bytes = serde_json::to_vec(&response)?;
bytes.push(b'\n');
where is_prime
is a function that calculates if the input number is prime.
Note: We have to append \n
character since it's a line based protocol.
For writing the buffer we are using here AsyncWriteExt::write_all
, but first we have to split the TcpStream
, because the lines
method takes the ownership of the input_stream
:
async fn handle_client(mut stream: TcpStream) -> anyhow::Result<()> {
let (input_stream, mut output_stream) = stream.split();
let input_stream = BufReader::new(input_stream);
let mut lines = input_stream.lines();
while let Some(line) = lines.next_line().await? {
debug!("Got a line {}", line);
match serde_json::from_str::<Request>(&line) {
Ok(req) => {
let response = Response::new(is_prime(req.number));
let mut bytes = serde_json::to_vec(&response)?;
bytes.push(b'\n');
output_stream.write_all(&bytes).await?;
}
Err(e) => {
error!("Malformed request {}", e);
output_stream.write_all(b"malformed\n").await?;
break;
}
};
}
Ok(())
}
When a request is malformed we can just write back a malformed response and break the while loop for closing the TCP connection.
# Testing the solution
We can manually test our solution locally using socat
:
❯ RUST_LOG=debug cargo run --bin prime_time
Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
Finished dev [unoptimized + debuginfo] target(s) in 1.30s
Running `target/debug/prime_time`
2023-01-19T11:49:51.482543Z INFO protohackers_rs: Starting server at 0.0.0.:8000
❯ echo "test" | socat - tcp:localhost:8000
malformed
❯ echo '{"method": "isPrime","number": 123}' | socat - tcp:localhost:8000
{"method":"isPrime","prime":false}
❯ echo '{"method": "isPrime","number": 29}' | socat - tcp:localhost:8000
{"method":"isPrime","prime":true}
For this challenge we'd want also to write some tests for checking that our implementation works as expected before running the protohackers test suite.
The handle_client
function is not easy to test without spawning the server, because it takes a TcpSocket
(a real connection client/server) as input. To solve this we could create another function handle_client_internal
which works on traits rather than TcpStream
.
async fn handle_client(mut stream: TcpStream) -> anyhow::Result<()> {
let (input_stream, output_stream) = stream.split();
handle_client_internal(input_stream, output_stream).await
}
async fn handle_client_internal(
input_stream: impl AsyncRead,
output_stream: impl AsyncWrite,
) -> anyhow::Result<()> {
let input_stream = BufReader::new(input_stream);
...
}
With this refactor the handle_client_internal
function can take as input concrete channels (TpcStream
) as well as memory buffers for testing.
Note that since don't know if the IO object AsyncWrite
is a buffered one, we need to call the flush
to ensure that eventual buffered content is written. In a TcpStream
it will be a noop.
Here's an example on how we could write a test that checks a malformed request:
#[tokio::test]
async fn prime_time_test_malformed() {
let input = "{}\n";
let mut output: Vec<u8> = vec![];
handle_client_internal(input.as_bytes(), &mut output)
.await
.expect("Failed to handle");
assert_eq!(
String::from("malformed\n"),
String::from_utf8(output).unwrap()
);
}
Note that #[tokio::test]
is needed for boostrapping the Tokio runtime and for async functions in tests.
Let's try to run the test:
❯ cargo test prime_time_test_malformed
Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
error[E0277]: `impl AsyncWrite` cannot be unpinned
--> src/bin/prime_time.rs:35:31
|
35 | output_stream.write_all(&bytes).await?;
| ^^^^^^^^^ within `tokio::io::util::buf_writer::_::__Origin<'_, impl AsyncWrite>`, the trait `Unpin` is not implemented for `impl AsyncWrite`
|
= note: consider using `Box::pin`
= note: required because it appears within the type `tokio::io::util::buf_writer::_::__Origin<'_, impl AsyncWrite>`
= note: required because of the requirements on the impl of `Unpin` for `tokio::io::BufWriter<impl AsyncWrite>`
note: required by a bound in `tokio::io::AsyncWriteExt::write_all`
--> ~/.cargo/registry/src/github.com-1ecc6299db9ec823/tokio-1.24.1/src/io/util/async_write_ext.rs:366:19
|
366 | Self: Unpin,
| ^^^^^ required by this bound in `tokio::io::AsyncWriteExt::write_all`
help: consider further restricting this bound
|
23 | output_stream: impl AsyncWrite + std::marker::Unpin,
| ++++++++++++++++++++
Ops! the compiler is complaining that impl AsyncWrite
cannot be unpinned, same for impl AsyncRead
. Fortunately the Rust compiler give us an hint on how to fix this by suggesting to stick a + Unpin
in both parameters in handler_client_internal
. Thanks Rust compiler!!
Wait but what does Unpin
mean?
For now we can just trust the compiler and we only need to know that:
AsyncWriteExt::write_all
Lines::next_line
require Unpin
types.
If you want to read more about Pin
/Unpin
, I suggest these links:
The final implementation after the refactoring is:
async fn handle_client_internal(
input_stream: impl AsyncRead + Unpin,
mut output_stream: impl AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let input_stream = BufReader::new(input_stream);
let mut lines = input_stream.lines();
while let Some(line) = lines.next_line().await? {
debug!("Got a line {}", line);
match serde_json::from_str::<Request>(&line) {
Ok(req) => {
let response = Response::new(is_prime(req.number));
let mut bytes = serde_json::to_vec(&response)?;
bytes.push(b'\n');
output_stream.write_all(&bytes).await?;
output_stream.flush().await?;
}
Err(e) => {
error!("Malformed request {}", e);
output_stream.write_all(b"malformed\n").await?;
output_stream.flush().await?;
break;
}
};
}
Ok(())
}
And now tests run just fine:
❯ cargo test prime_time_test_malformed
Compiling protohackers-rs v0.1.0 (~/Coding/protohackers-rs)
Finished test [unoptimized + debuginfo] target(s) in 1.08s
Running unittests src/bin/prime_time.rs (target/debug/deps/prime_time-f3e95c65c73516d5)
running 1 test
test tests::prime_time_test_malformed ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
# Wrapping Up
In this second part of protohackers challenges, we did a little bit of refactoring on the acceptor loop by extracting it in a common function reusable in most of the challenges.
We have learned how to handle line-based protocols and how to write functions that deal with IO in a way that could be easily unit-tested.
Comments and suggestions are always welcome. Please feel free to send an email or comment on @wolf4ood@hachyderm.io.
Thank you for reading, see you in the next chapter!