summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAaron Hill <aa1ronham@gmail.com>2018-03-01 20:00:05 -0500
committerAaron Hill <aa1ronham@gmail.com>2018-03-11 16:13:47 -0400
commit3471e04b9b28ded3d12ef31a2477a06ec71ae97e (patch)
treea6bd1191e6ee228d89e4e93915d0a1dd1bdd9410
parent8900bc7fab8eeeff2d3801c1bf951c8e57cffa33 (diff)
downloadroughenough-3471e04b9b28ded3d12ef31a2477a06ec71ae97e.zip
Add support for batch-signing requests
As documented in the Roughtime spec, servers can batch together requests, only signing the root of a computed Merkle tree, in order to increase efficiency. Following the example of the reference Roughtime implementation, the default batch size is set to 64. However, this value can be changed in the config. Two pieces of benchmark infrastructure are added - a simple "benchmark mode" on the server, and a "stress test mode" on the client. These features can be used to help pick an optimal batch size for the server. In "benchmark mode", the server does not log any requests. Instead, it prints out the current request processing speed every second. This helps to keep the output manageable when using the client's "stress test" mode. In "stress test mode", the client sends the same message to the server in a loop. To prevent accidental flooding of the users's local network, or a remote server, only loopback addresses are supported in this mode.
-rw-r--r--Cargo.toml1
-rw-r--r--example.cfg1
-rw-r--r--src/bin/client.rs89
-rw-r--r--src/bin/server.rs181
-rw-r--r--src/lib.rs4
-rw-r--r--src/merkle.rs174
6 files changed, 352 insertions, 98 deletions
diff --git a/Cargo.toml b/Cargo.toml
index e2af555..999efa4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,6 +13,7 @@ travis-ci = { repository = "int08h/roughenough", branch = "master" }
[dependencies]
mio = "0.6"
+mio-extras = "2.0"
byteorder = "1"
ring = "0.12"
untrusted = "0.5"
diff --git a/example.cfg b/example.cfg
index c271481..d27f012 100644
--- a/example.cfg
+++ b/example.cfg
@@ -1,3 +1,4 @@
port: 8686
interface: 127.0.0.1
seed: a32049da0ffde0ded92ce10a0230d35fe615ec8461c14986baa63fe3b3bac3db
+batch_size: 64
diff --git a/src/bin/client.rs b/src/bin/client.rs
index 1c1da03..d9b37fc 100644
--- a/src/bin/client.rs
+++ b/src/bin/client.rs
@@ -9,7 +9,6 @@ extern crate hex;
use ring::rand;
use ring::rand::SecureRandom;
-use ring::digest;
use byteorder::{LittleEndian, ReadBytesExt};
@@ -20,8 +19,9 @@ use std::iter::Iterator;
use std::collections::HashMap;
use std::net::{UdpSocket, ToSocketAddrs};
-use roughenough::{RtMessage, Tag, VERSION, TREE_NODE_TWEAK, TREE_LEAF_TWEAK, CERTIFICATE_CONTEXT, SIGNED_RESPONSE_CONTEXT};
+use roughenough::{RtMessage, Tag, VERSION, CERTIFICATE_CONTEXT, SIGNED_RESPONSE_CONTEXT};
use roughenough::sign::Verifier;
+use roughenough::merkle::root_from_paths;
use clap::{Arg, App};
fn create_nonce() -> [u8; 64] {
@@ -56,6 +56,12 @@ struct ResponseHandler {
nonce: [u8; 64]
}
+struct ParsedResponse {
+ verified: bool,
+ midpoint: u64,
+ radius: u32
+}
+
impl ResponseHandler {
pub fn new(pub_key: Option<Vec<u8>>, response: RtMessage, nonce: [u8; 64]) -> ResponseHandler {
let msg = response.into_hash_map();
@@ -74,18 +80,24 @@ impl ResponseHandler {
}
}
- pub fn extract_time(&self) -> (u64, u32) {
+ pub fn extract_time(&self) -> ParsedResponse {
let midpoint = self.srep[&Tag::MIDP].as_slice().read_u64::<LittleEndian>().unwrap();
let radius = self.srep[&Tag::RADI].as_slice().read_u32::<LittleEndian>().unwrap();
+ let mut verified = false;
if self.pub_key.is_some() {
self.validate_dele();
self.validate_srep();
self.validate_merkle();
self.validate_midpoint(midpoint);
+ verified = true;
}
- (midpoint, radius)
+ ParsedResponse {
+ verified,
+ midpoint,
+ radius
+ }
}
fn validate_dele(&self) {
@@ -106,32 +118,12 @@ impl ResponseHandler {
fn validate_merkle(&self) {
let srep = RtMessage::from_bytes(&self.msg[&Tag::SREP]).unwrap().into_hash_map();
- let mut index = self.msg[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
+ let index = self.msg[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
let paths = &self.msg[&Tag::PATH];
- let mut hash = sha_512(TREE_LEAF_TWEAK, &self.nonce);
-
- assert_eq!(paths.len() % 64, 0);
-
- for path in paths.chunks(64) {
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(TREE_NODE_TWEAK);
-
- if index & 1 == 0 {
- // Left
- ctx.update(&hash);
- ctx.update(path);
- } else {
- // Right
- ctx.update(path);
- ctx.update(&hash);
- }
- hash = Vec::from(ctx.finish().as_ref());
+ let hash = root_from_paths(index as usize, &self.nonce, paths);
- index >>= 1;
- }
-
- assert_eq!(hash, srep[&Tag::ROOT], "Nonce not in merkle tree!");
+ assert_eq!(Vec::from(hash), srep[&Tag::ROOT], "Nonce not in merkle tree!");
}
@@ -150,13 +142,6 @@ impl ResponseHandler {
}
}
-fn sha_512(prefix: &[u8], data: &[u8]) -> Vec<u8> {
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(prefix);
- ctx.update(data);
- Vec::from(ctx.finish().as_ref())
-}
-
fn main() {
let matches = App::new("roughenough client")
.version(VERSION)
@@ -187,6 +172,11 @@ fn main() {
.help("The number of requests to make to the server (each from a different source port). This is mainly useful for testing batch response handling")
.default_value("1")
)
+ .arg(Arg::with_name("stress")
+ .short("s")
+ .long("stress")
+ .help("Stress-tests the server by sending the same request as fast as possible. Please only use this on your own server")
+ )
.get_matches();
let host = matches.value_of("host").unwrap();
@@ -194,10 +184,30 @@ fn main() {
let num_requests = value_t_or_exit!(matches.value_of("num-requests"), u16) as usize;
let pub_key = matches.value_of("public-key").map(|pkey| hex::decode(pkey).expect("Error parsing public key!"));
let time_format = matches.value_of("time-format").unwrap();
+ let stress = matches.is_present("stress");
println!("Requesting time from: {:?}:{:?}", host, port);
- let addrs: Vec<_> = (host, port).to_socket_addrs().unwrap().collect();
+ let addr = (host, port).to_socket_addrs().unwrap().next().unwrap();
+
+ if stress {
+
+ if !addr.ip().is_loopback() {
+ println!("ERROR: Cannot use non-loopback address {} for stress testing", addr.ip());
+ return;
+ }
+
+ println!("Stress-testing!");
+
+ let nonce = create_nonce();
+ let socket = UdpSocket::bind("0.0.0.0:0").expect("Couldn't open UDP socket");
+ let request = make_request(&nonce);
+
+ loop {
+ socket.send_to(&request, addr).unwrap();
+ }
+ }
+
let mut requests = Vec::with_capacity(num_requests);
@@ -210,18 +220,21 @@ fn main() {
}
for &mut (_, ref request, ref mut socket) in requests.iter_mut() {
- socket.send_to(request, addrs.as_slice()).unwrap();
+ socket.send_to(request, addr).unwrap();
}
for (nonce, _, mut socket) in requests {
let resp = receive_response(&mut socket);
- let (midpoint, radius) = ResponseHandler::new(pub_key.clone(), resp, nonce).extract_time();
+ let ParsedResponse {verified, midpoint, radius} = ResponseHandler::new(pub_key.clone(), resp.clone(), nonce).extract_time();
+
+ let map = resp.into_hash_map();
+ let index = map[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
let seconds = midpoint / 10_u64.pow(6);
let spec = Utc.timestamp(seconds as i64, ((midpoint - (seconds * 10_u64.pow(6))) * 10_u64.pow(3)) as u32);
let out = spec.format(time_format).to_string();
- println!("Recieved time from server: midpoint={:?}, radius={:?}", out, radius);
+ println!("Recieved time from server: midpoint={:?}, radius={:?} (merkle_index={}, verified={})", out, radius, index, verified);
}
}
diff --git a/src/bin/server.rs b/src/bin/server.rs
index 31e9392..86d8f57 100644
--- a/src/bin/server.rs
+++ b/src/bin/server.rs
@@ -22,6 +22,7 @@
//! interface: 127.0.0.1
//! port: 8686
//! seed: f61075c988feb9cb700a4a6a3291bfbc9cab11b9c9eca8c802468eb38a43d7d3
+//! batch_size: 64
//! ```
//!
//! Where:
@@ -31,6 +32,9 @@
//! * **seed** - A 32-byte hexadecimal value used as the seed to generate the
//! server's long-term key pair. **This is a secret value**, treat it
//! with care.
+//! * **batch_size** - The number of requests to process in one batch. All nonces
+//! in a batch are used to build a Merkle tree, the root of which
+//! is signed.
//!
//! # Running the Server
//!
@@ -38,8 +42,6 @@
//! $ cargo run --release --bin server /path/to/config.file
//! ```
-#![allow(deprecated)] // for mio::Timer
-
extern crate byteorder;
extern crate ring;
extern crate roughenough;
@@ -51,29 +53,31 @@ extern crate yaml_rust;
extern crate log;
extern crate simple_logger;
extern crate mio;
+extern crate mio_extras;
extern crate hex;
use std::env;
use std::process;
use std::fs::File;
-use std::io::Read;
-use std::io;
+use std::io::{Read, ErrorKind};
use std::time::Duration;
use std::net::SocketAddr;
use std::sync::Arc;
-use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::thread;
use mio::{Poll, Token, Ready, PollOpt, Events};
use mio::net::UdpSocket;
-use mio::timer::Timer;
+use mio_extras::timer::Timer;
use byteorder::{LittleEndian, WriteBytesExt};
use roughenough::{RtMessage, Tag, Error};
-use roughenough::{VERSION, CERTIFICATE_CONTEXT, MIN_REQUEST_LENGTH, SIGNED_RESPONSE_CONTEXT, TREE_LEAF_TWEAK};
+use roughenough::{VERSION, CERTIFICATE_CONTEXT, MIN_REQUEST_LENGTH, SIGNED_RESPONSE_CONTEXT};
use roughenough::sign::Signer;
+use roughenough::merkle::*;
-use ring::{digest, rand};
+use ring::rand;
use ring::rand::SecureRandom;
use yaml_rust::YamlLoader;
@@ -81,6 +85,8 @@ use yaml_rust::YamlLoader;
const MESSAGE: Token = Token(0);
const STATUS: Token = Token(1);
+pub static NUM_RESPONSES: AtomicUsize = AtomicUsize::new(0);
+
fn create_ephemeral_key() -> Signer {
let rng = rand::SystemRandom::new();
let mut seed = [0u8; 32];
@@ -128,7 +134,7 @@ fn make_key_and_cert(seed: &[u8]) -> (Signer, Vec<u8>) {
(ephemeral_key, cert_bytes)
}
-fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) -> RtMessage {
+fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], root: &[u8], path: &[u8], idx: u32) -> RtMessage {
// create SREP
// sign SREP
// create response:
@@ -138,14 +144,15 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
// - CERT (pre-created)
// - INDX (always 0)
- let path = [0u8; 0];
- let zeros = [0u8; 4];
+ let mut index = [0; 4];
- let mut radi: Vec<u8> = Vec::with_capacity(4);
- let mut midp: Vec<u8> = Vec::with_capacity(8);
+ let mut radi = [0; 4];
+ let mut midp = [0; 8];
+
+ (&mut index as &mut [u8]).write_u32::<LittleEndian>(idx).unwrap();
// one second (in microseconds)
- radi.write_u32::<LittleEndian>(1_000_000).unwrap();
+ (&mut radi as &mut [u8]).write_u32::<LittleEndian>(1_000_000).unwrap();
// current epoch time in microseconds
let now = {
@@ -155,20 +162,14 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
secs + nsecs
};
- midp.write_u64::<LittleEndian>(now).unwrap();
+ (&mut midp as &mut [u8]).write_u64::<LittleEndian>(now).unwrap();
// Signed response SREP
let srep_bytes = {
- // hash request nonce
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(TREE_LEAF_TWEAK);
- ctx.update(nonce);
- let digest = ctx.finish();
-
let mut srep_msg = RtMessage::new(3);
srep_msg.add_field(Tag::RADI, &radi).unwrap();
srep_msg.add_field(Tag::MIDP, &midp).unwrap();
- srep_msg.add_field(Tag::ROOT, digest.as_ref()).unwrap();
+ srep_msg.add_field(Tag::ROOT, root).unwrap();
srep_msg.encode().unwrap()
};
@@ -185,7 +186,7 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
response.add_field(Tag::PATH, &path).unwrap();
response.add_field(Tag::SREP, &srep_bytes).unwrap();
response.add_field(Tag::CERT, cert_bytes).unwrap();
- response.add_field(Tag::INDX, &zeros).unwrap();
+ response.add_field(Tag::INDX, &index).unwrap();
response
}
@@ -211,7 +212,7 @@ fn nonce_from_request(buf: &[u8], num_bytes: usize) -> Result<&[u8], Error> {
}
}
-fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
+fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>, u8) {
let mut infile = File::open(config_file)
.expect("failed to open config file");
@@ -229,12 +230,14 @@ fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
let mut port: u16 = 0;
let mut iface: String = "unknown".to_string();
let mut seed: String = "".to_string();
+ let mut batch_size: u8 = 1;
for (key, value) in cfg[0].as_hash().unwrap() {
match key.as_str().unwrap() {
"port" => port = value.as_i64().unwrap() as u16,
"interface" => iface = value.as_str().unwrap().to_string(),
"seed" => seed = value.as_str().unwrap().to_string(),
+ "batch_size" => batch_size = value.as_i64().unwrap() as u8,
_ => warn!("ignoring unknown config key '{}'", key.as_str().unwrap())
}
}
@@ -246,10 +249,10 @@ fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
let binseed = hex::decode(seed)
.expect("seed value invalid; 'seed' should be 32 byte hex value");
- (sock_addr, binseed)
+ (sock_addr, binseed, batch_size)
}
-fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &[u8]) {
+fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &[u8], batch_size: u8) {
let keep_running = Arc::new(AtomicBool::new(true));
let kr = keep_running.clone();
@@ -257,62 +260,100 @@ fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &
.expect("failed setting Ctrl-C handler");
let socket = UdpSocket::bind(addr).expect("failed to bind to socket");
- let status_duration = Duration::from_secs(6_000);
+ let status_duration = Duration::from_secs(6);
let poll_duration = Some(Duration::from_millis(100));
let mut timer: Timer<()> = Timer::default();
- timer.set_timeout(status_duration, ()).expect("unable to set_timeout");
+ timer.set_timeout(status_duration, ());
+
let mut buf = [0u8; 65_536];
let mut events = Events::with_capacity(32);
- let mut num_responses = 0u64;
let mut num_bad_requests = 0u64;
let poll = Poll::new().unwrap();
poll.register(&socket, MESSAGE, Ready::readable(), PollOpt::edge()).unwrap();
poll.register(&timer, STATUS, Ready::readable(), PollOpt::edge()).unwrap();
- loop {
- if !keep_running.load(Ordering::Acquire) {
- info!("Ctrl-C caught, exiting...");
- break;
+ let mut merkle = MerkleTree::new();
+ let mut requests = Vec::with_capacity(batch_size as usize);
+
+ macro_rules! check_ctrlc {
+ () => {
+ if !keep_running.load(Ordering::Acquire) {
+ warn!("Ctrl-C caught, exiting...");
+ return;
+ }
}
+ }
+
+ loop {
+ check_ctrlc!();
poll.poll(&mut events, poll_duration).expect("poll failed");
for event in events.iter() {
+
match event.token() {
MESSAGE => {
- loop {
- match socket.recv_from(&mut buf) {
- Ok((num_bytes, src_addr)) => {
- if let Ok(nonce) = nonce_from_request(&buf, num_bytes) {
- let resp = make_response(&mut ephemeral_key, cert_bytes, nonce);
- let resp_bytes = resp.encode().unwrap();
-
- let bytes_sent = socket.send_to(&resp_bytes, &src_addr).expect("send_to failed");
-
- num_responses += 1;
- info!("Responded {} bytes to {} for '{}..' (resp #{})", bytes_sent, src_addr, hex::encode(&nonce[0..4]), num_responses);
- } else {
- num_bad_requests += 1;
- info!("Invalid request ({} bytes) from {} (resp #{})", num_bytes, src_addr, num_responses);
+
+ let mut done = false;
+
+ 'process_batch: loop {
+ check_ctrlc!();
+
+ merkle.reset();
+ requests.clear();
+
+ let resp_start = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ for i in 0..batch_size {
+ match socket.recv_from(&mut buf) {
+ Ok((num_bytes, src_addr)) => {
+ if let Ok(nonce) = nonce_from_request(&buf, num_bytes) {
+ requests.push((Vec::from(nonce), src_addr));
+ merkle.push_leaf(nonce);
+ } else {
+ num_bad_requests += 1;
+ info!("Invalid request ({} bytes) from {} (#{} in batch, resp #{})", num_bytes, src_addr, i, resp_start + i as usize);
+ }
+ },
+ Err(e) => match e.kind() {
+ ErrorKind::WouldBlock => {
+ done = true;
+ break;
+ },
+ _ => panic!("recv_from failed with {:?}", e)
}
- }
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- break
- }
- Err(ref e) => {
- error!("Error {:?}: {:?}", e.kind(), e);
- break
- }
+ };
+ }
+
+ if requests.is_empty() {
+ break 'process_batch
+ }
+
+ let root = merkle.compute_root();
+ for (i, &(ref nonce, ref src_addr)) in requests.iter().enumerate() {
+ let paths: Vec<_> = merkle.get_paths(i).into_iter().flat_map(|x| x).collect();
+
+ let resp = make_response(&mut ephemeral_key, cert_bytes, &root, &paths, i as u32);
+ let resp_bytes = resp.encode().unwrap();
+
+ let bytes_sent = socket.send_to(&resp_bytes, &src_addr).expect("send_to failed");
+ let num_responses = NUM_RESPONSES.fetch_add(1, Ordering::SeqCst);
+
+ info!("Responded {} bytes to {} for '{}..' (#{} in batch, resp #{})", bytes_sent, src_addr, hex::encode(&nonce[0..4]), i, num_responses);
+ }
+ if done {
+ break 'process_batch
}
}
+
}
STATUS => {
- info!("responses {}, invalid requests {}", num_responses, num_bad_requests);
- timer.set_timeout(status_duration, ()).expect("unable to set_timeout");
+ info!("responses {}, invalid requests {}", NUM_RESPONSES.load(Ordering::SeqCst), num_bad_requests);
+ timer.set_timeout(status_duration, ());
}
_ => unreachable!()
@@ -321,8 +362,9 @@ fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &
}
}
-fn main() {
+pub fn main() {
use log::Level;
+
simple_logger::init_with_level(Level::Info).unwrap();
info!("Roughenough server v{} starting", VERSION);
@@ -333,12 +375,31 @@ fn main() {
process::exit(1);
}
- let (addr, key_seed) = load_config(&args.nth(1).unwrap());
+ let (addr, key_seed, batch_size) = load_config(&args.nth(1).unwrap());
let (mut ephemeral_key, cert_bytes) = make_key_and_cert(&key_seed);
info!("Server listening on {}", addr);
- polling_loop(&addr, &mut ephemeral_key, &cert_bytes);
+ if env::var("BENCH").is_ok() {
+ log::set_max_level(log::LevelFilter::Warn);
+
+ thread::spawn(|| {
+ loop {
+ let old = time::get_time().sec;
+ let old_reqs = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ thread::sleep(Duration::from_secs(1));
+
+ let new = time::get_time().sec;
+ let new_reqs = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ warn!("Processing at {:?} reqs/sec", (new_reqs - old_reqs) / (new - old) as usize);
+ }
+ });
+ }
+
+
+ polling_loop(&addr, &mut ephemeral_key, &cert_bytes, batch_size);
info!("Done.");
process::exit(0);
diff --git a/src/lib.rs b/src/lib.rs
index 7259c26..72cdb69 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -58,6 +58,7 @@ mod tag;
mod message;
pub mod sign;
+pub mod merkle;
pub use error::Error;
pub use tag::Tag;
@@ -83,6 +84,9 @@ pub const NONCE_LENGTH: u32 = 64;
/// Size (in bytes) of an Ed25519 signature
pub const SIGNATURE_LENGTH: u32 = 64;
+/// Size (in bytes) of a SHA-512 hash
+pub const HASH_LENGTH: u32 = 64;
+
/// Size (in bytes) of server's timestamp value
pub const TIMESTAMP_LENGTH: u32 = 8;
diff --git a/src/merkle.rs b/src/merkle.rs
new file mode 100644
index 0000000..623770f
--- /dev/null
+++ b/src/merkle.rs
@@ -0,0 +1,174 @@
+// Copyright 2018 int08h LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+extern crate ring;
+
+use super::{TREE_LEAF_TWEAK, TREE_NODE_TWEAK, HASH_LENGTH};
+use self::ring::digest;
+
+type Data = Vec<u8>;
+type Hash = Data;
+
+pub struct MerkleTree {
+ levels: Vec<Vec<Data>>,
+}
+
+impl MerkleTree {
+ pub fn new() -> MerkleTree {
+ MerkleTree {
+ levels: vec![vec![]]
+ }
+ }
+
+ pub fn push_leaf(&mut self, data: &[u8]) {
+ let hash = self.hash_leaf(data);
+ self.levels[0].push(hash);
+ }
+
+ pub fn get_paths(&self, mut index: usize) -> Vec<Hash> {
+ let mut paths = Vec::new();
+ let mut level = 0;
+
+ while !self.levels[level].is_empty() {
+ let sibling = if index % 2 == 0 {
+ index + 1
+ } else {
+ index - 1
+ };
+
+ paths.push(self.levels[level][sibling].clone());
+ level += 1;
+ index /= 2;
+ }
+ paths
+ }
+
+ pub fn compute_root(&mut self) -> Hash {
+ assert!(self.levels[0].len() > 0, "Must have at least one leaf to hash!");
+
+ let mut level = 0;
+ let mut node_count = self.levels[0].len();
+ while node_count > 1 {
+ level += 1;
+
+ if self.levels.len() < level + 1 {
+ self.levels.push(vec![]);
+ }
+
+ if node_count % 2 != 0 {
+ self.levels[level - 1].push(vec![0; HASH_LENGTH as usize]);
+ node_count += 1;
+ }
+
+ node_count /= 2;
+
+ for i in 0..node_count {
+ let hash = self.hash_nodes(&self.levels[level - 1][i*2], &self.levels[level - 1][(i*2)+1]);
+ self.levels[level].push(hash);
+ }
+ }
+ assert_eq!(self.levels[level].len(), 1);
+ self.levels[level].pop().unwrap()
+ }
+
+ pub fn reset(&mut self) {
+ for mut level in &mut self.levels {
+ level.clear();
+ }
+ }
+
+ fn hash_leaf(&self, leaf: &[u8]) -> Data {
+ self.hash(&[TREE_LEAF_TWEAK, leaf])
+ }
+
+ fn hash_nodes(&self, first: &[u8], second: &[u8]) -> Data {
+ self.hash(&[TREE_NODE_TWEAK, first, second])
+ }
+
+ fn hash(&self, to_hash: &[&[u8]]) -> Data {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ for data in to_hash {
+ ctx.update(data);
+ }
+ Data::from(ctx.finish().as_ref())
+ }
+}
+
+pub fn root_from_paths(mut index: usize, data: &[u8], paths: &[u8]) -> Hash {
+ let mut hash = {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ ctx.update(TREE_LEAF_TWEAK);
+ ctx.update(data);
+ Hash::from(ctx.finish().as_ref())
+ };
+
+ assert_eq!(paths.len() % 64, 0);
+
+ for path in paths.chunks(64) {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ ctx.update(TREE_NODE_TWEAK);
+
+ if index & 1 == 0 {
+ // Left
+ ctx.update(&hash);
+ ctx.update(path);
+ } else {
+ // Right
+ ctx.update(path);
+ ctx.update(&hash);
+ }
+ hash = Hash::from(ctx.finish().as_ref());
+
+ index >>= 1;
+ }
+ hash
+}
+
+#[cfg(test)]
+mod test {
+
+ use merkle::*;
+
+ fn test_paths_with_num(num: usize) {
+ let mut merkle = MerkleTree::new();
+
+ for i in 0..num {
+ merkle.push_leaf(&[i as u8]);
+ }
+
+ let root = merkle.compute_root();
+
+ for i in 0..num {
+ println!("Testing {:?} {:?}", num, i);
+ let paths: Vec<u8> = merkle.get_paths(i).into_iter().flat_map(|x| x).collect();
+ let computed_root = root_from_paths(i, &[i as u8], &paths);
+
+ assert_eq!(root, computed_root);
+ }
+ }
+
+ #[test]
+ fn power_of_two() {
+ test_paths_with_num(2);
+ test_paths_with_num(4);
+ test_paths_with_num(8);
+ test_paths_with_num(16);
+ }
+
+ #[test]
+ fn not_power_of_two() {
+ test_paths_with_num(1);
+ test_paths_with_num(20);
+ }
+}