summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--example.cfg1
-rw-r--r--src/bin/client.rs89
-rw-r--r--src/bin/server.rs181
-rw-r--r--src/lib.rs4
-rw-r--r--src/merkle.rs174
6 files changed, 352 insertions, 98 deletions
diff --git a/Cargo.toml b/Cargo.toml
index e2af555..999efa4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,6 +13,7 @@ travis-ci = { repository = "int08h/roughenough", branch = "master" }
[dependencies]
mio = "0.6"
+mio-extras = "2.0"
byteorder = "1"
ring = "0.12"
untrusted = "0.5"
diff --git a/example.cfg b/example.cfg
index c271481..d27f012 100644
--- a/example.cfg
+++ b/example.cfg
@@ -1,3 +1,4 @@
port: 8686
interface: 127.0.0.1
seed: a32049da0ffde0ded92ce10a0230d35fe615ec8461c14986baa63fe3b3bac3db
+batch_size: 64
diff --git a/src/bin/client.rs b/src/bin/client.rs
index 1c1da03..d9b37fc 100644
--- a/src/bin/client.rs
+++ b/src/bin/client.rs
@@ -9,7 +9,6 @@ extern crate hex;
use ring::rand;
use ring::rand::SecureRandom;
-use ring::digest;
use byteorder::{LittleEndian, ReadBytesExt};
@@ -20,8 +19,9 @@ use std::iter::Iterator;
use std::collections::HashMap;
use std::net::{UdpSocket, ToSocketAddrs};
-use roughenough::{RtMessage, Tag, VERSION, TREE_NODE_TWEAK, TREE_LEAF_TWEAK, CERTIFICATE_CONTEXT, SIGNED_RESPONSE_CONTEXT};
+use roughenough::{RtMessage, Tag, VERSION, CERTIFICATE_CONTEXT, SIGNED_RESPONSE_CONTEXT};
use roughenough::sign::Verifier;
+use roughenough::merkle::root_from_paths;
use clap::{Arg, App};
fn create_nonce() -> [u8; 64] {
@@ -56,6 +56,12 @@ struct ResponseHandler {
nonce: [u8; 64]
}
+struct ParsedResponse {
+ verified: bool,
+ midpoint: u64,
+ radius: u32
+}
+
impl ResponseHandler {
pub fn new(pub_key: Option<Vec<u8>>, response: RtMessage, nonce: [u8; 64]) -> ResponseHandler {
let msg = response.into_hash_map();
@@ -74,18 +80,24 @@ impl ResponseHandler {
}
}
- pub fn extract_time(&self) -> (u64, u32) {
+ pub fn extract_time(&self) -> ParsedResponse {
let midpoint = self.srep[&Tag::MIDP].as_slice().read_u64::<LittleEndian>().unwrap();
let radius = self.srep[&Tag::RADI].as_slice().read_u32::<LittleEndian>().unwrap();
+ let mut verified = false;
if self.pub_key.is_some() {
self.validate_dele();
self.validate_srep();
self.validate_merkle();
self.validate_midpoint(midpoint);
+ verified = true;
}
- (midpoint, radius)
+ ParsedResponse {
+ verified,
+ midpoint,
+ radius
+ }
}
fn validate_dele(&self) {
@@ -106,32 +118,12 @@ impl ResponseHandler {
fn validate_merkle(&self) {
let srep = RtMessage::from_bytes(&self.msg[&Tag::SREP]).unwrap().into_hash_map();
- let mut index = self.msg[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
+ let index = self.msg[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
let paths = &self.msg[&Tag::PATH];
- let mut hash = sha_512(TREE_LEAF_TWEAK, &self.nonce);
-
- assert_eq!(paths.len() % 64, 0);
-
- for path in paths.chunks(64) {
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(TREE_NODE_TWEAK);
-
- if index & 1 == 0 {
- // Left
- ctx.update(&hash);
- ctx.update(path);
- } else {
- // Right
- ctx.update(path);
- ctx.update(&hash);
- }
- hash = Vec::from(ctx.finish().as_ref());
+ let hash = root_from_paths(index as usize, &self.nonce, paths);
- index >>= 1;
- }
-
- assert_eq!(hash, srep[&Tag::ROOT], "Nonce not in merkle tree!");
+ assert_eq!(Vec::from(hash), srep[&Tag::ROOT], "Nonce not in merkle tree!");
}
@@ -150,13 +142,6 @@ impl ResponseHandler {
}
}
-fn sha_512(prefix: &[u8], data: &[u8]) -> Vec<u8> {
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(prefix);
- ctx.update(data);
- Vec::from(ctx.finish().as_ref())
-}
-
fn main() {
let matches = App::new("roughenough client")
.version(VERSION)
@@ -187,6 +172,11 @@ fn main() {
.help("The number of requests to make to the server (each from a different source port). This is mainly useful for testing batch response handling")
.default_value("1")
)
+ .arg(Arg::with_name("stress")
+ .short("s")
+ .long("stress")
+ .help("Stress-tests the server by sending the same request as fast as possible. Please only use this on your own server")
+ )
.get_matches();
let host = matches.value_of("host").unwrap();
@@ -194,10 +184,30 @@ fn main() {
let num_requests = value_t_or_exit!(matches.value_of("num-requests"), u16) as usize;
let pub_key = matches.value_of("public-key").map(|pkey| hex::decode(pkey).expect("Error parsing public key!"));
let time_format = matches.value_of("time-format").unwrap();
+ let stress = matches.is_present("stress");
println!("Requesting time from: {:?}:{:?}", host, port);
- let addrs: Vec<_> = (host, port).to_socket_addrs().unwrap().collect();
+ let addr = (host, port).to_socket_addrs().unwrap().next().unwrap();
+
+ if stress {
+
+ if !addr.ip().is_loopback() {
+ println!("ERROR: Cannot use non-loopback address {} for stress testing", addr.ip());
+ return;
+ }
+
+ println!("Stress-testing!");
+
+ let nonce = create_nonce();
+ let socket = UdpSocket::bind("0.0.0.0:0").expect("Couldn't open UDP socket");
+ let request = make_request(&nonce);
+
+ loop {
+ socket.send_to(&request, addr).unwrap();
+ }
+ }
+
let mut requests = Vec::with_capacity(num_requests);
@@ -210,18 +220,21 @@ fn main() {
}
for &mut (_, ref request, ref mut socket) in requests.iter_mut() {
- socket.send_to(request, addrs.as_slice()).unwrap();
+ socket.send_to(request, addr).unwrap();
}
for (nonce, _, mut socket) in requests {
let resp = receive_response(&mut socket);
- let (midpoint, radius) = ResponseHandler::new(pub_key.clone(), resp, nonce).extract_time();
+ let ParsedResponse {verified, midpoint, radius} = ResponseHandler::new(pub_key.clone(), resp.clone(), nonce).extract_time();
+
+ let map = resp.into_hash_map();
+ let index = map[&Tag::INDX].as_slice().read_u32::<LittleEndian>().unwrap();
let seconds = midpoint / 10_u64.pow(6);
let spec = Utc.timestamp(seconds as i64, ((midpoint - (seconds * 10_u64.pow(6))) * 10_u64.pow(3)) as u32);
let out = spec.format(time_format).to_string();
- println!("Recieved time from server: midpoint={:?}, radius={:?}", out, radius);
+ println!("Recieved time from server: midpoint={:?}, radius={:?} (merkle_index={}, verified={})", out, radius, index, verified);
}
}
diff --git a/src/bin/server.rs b/src/bin/server.rs
index 31e9392..86d8f57 100644
--- a/src/bin/server.rs
+++ b/src/bin/server.rs
@@ -22,6 +22,7 @@
//! interface: 127.0.0.1
//! port: 8686
//! seed: f61075c988feb9cb700a4a6a3291bfbc9cab11b9c9eca8c802468eb38a43d7d3
+//! batch_size: 64
//! ```
//!
//! Where:
@@ -31,6 +32,9 @@
//! * **seed** - A 32-byte hexadecimal value used as the seed to generate the
//! server's long-term key pair. **This is a secret value**, treat it
//! with care.
+//! * **batch_size** - The number of requests to process in one batch. All nonces
+//! in a batch are used to build a Merkle tree, the root of which
+//! is signed.
//!
//! # Running the Server
//!
@@ -38,8 +42,6 @@
//! $ cargo run --release --bin server /path/to/config.file
//! ```
-#![allow(deprecated)] // for mio::Timer
-
extern crate byteorder;
extern crate ring;
extern crate roughenough;
@@ -51,29 +53,31 @@ extern crate yaml_rust;
extern crate log;
extern crate simple_logger;
extern crate mio;
+extern crate mio_extras;
extern crate hex;
use std::env;
use std::process;
use std::fs::File;
-use std::io::Read;
-use std::io;
+use std::io::{Read, ErrorKind};
use std::time::Duration;
use std::net::SocketAddr;
use std::sync::Arc;
-use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::thread;
use mio::{Poll, Token, Ready, PollOpt, Events};
use mio::net::UdpSocket;
-use mio::timer::Timer;
+use mio_extras::timer::Timer;
use byteorder::{LittleEndian, WriteBytesExt};
use roughenough::{RtMessage, Tag, Error};
-use roughenough::{VERSION, CERTIFICATE_CONTEXT, MIN_REQUEST_LENGTH, SIGNED_RESPONSE_CONTEXT, TREE_LEAF_TWEAK};
+use roughenough::{VERSION, CERTIFICATE_CONTEXT, MIN_REQUEST_LENGTH, SIGNED_RESPONSE_CONTEXT};
use roughenough::sign::Signer;
+use roughenough::merkle::*;
-use ring::{digest, rand};
+use ring::rand;
use ring::rand::SecureRandom;
use yaml_rust::YamlLoader;
@@ -81,6 +85,8 @@ use yaml_rust::YamlLoader;
const MESSAGE: Token = Token(0);
const STATUS: Token = Token(1);
+pub static NUM_RESPONSES: AtomicUsize = AtomicUsize::new(0);
+
fn create_ephemeral_key() -> Signer {
let rng = rand::SystemRandom::new();
let mut seed = [0u8; 32];
@@ -128,7 +134,7 @@ fn make_key_and_cert(seed: &[u8]) -> (Signer, Vec<u8>) {
(ephemeral_key, cert_bytes)
}
-fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) -> RtMessage {
+fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], root: &[u8], path: &[u8], idx: u32) -> RtMessage {
// create SREP
// sign SREP
// create response:
@@ -138,14 +144,15 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
// - CERT (pre-created)
// - INDX (always 0)
- let path = [0u8; 0];
- let zeros = [0u8; 4];
+ let mut index = [0; 4];
- let mut radi: Vec<u8> = Vec::with_capacity(4);
- let mut midp: Vec<u8> = Vec::with_capacity(8);
+ let mut radi = [0; 4];
+ let mut midp = [0; 8];
+
+ (&mut index as &mut [u8]).write_u32::<LittleEndian>(idx).unwrap();
// one second (in microseconds)
- radi.write_u32::<LittleEndian>(1_000_000).unwrap();
+ (&mut radi as &mut [u8]).write_u32::<LittleEndian>(1_000_000).unwrap();
// current epoch time in microseconds
let now = {
@@ -155,20 +162,14 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
secs + nsecs
};
- midp.write_u64::<LittleEndian>(now).unwrap();
+ (&mut midp as &mut [u8]).write_u64::<LittleEndian>(now).unwrap();
// Signed response SREP
let srep_bytes = {
- // hash request nonce
- let mut ctx = digest::Context::new(&digest::SHA512);
- ctx.update(TREE_LEAF_TWEAK);
- ctx.update(nonce);
- let digest = ctx.finish();
-
let mut srep_msg = RtMessage::new(3);
srep_msg.add_field(Tag::RADI, &radi).unwrap();
srep_msg.add_field(Tag::MIDP, &midp).unwrap();
- srep_msg.add_field(Tag::ROOT, digest.as_ref()).unwrap();
+ srep_msg.add_field(Tag::ROOT, root).unwrap();
srep_msg.encode().unwrap()
};
@@ -185,7 +186,7 @@ fn make_response(ephemeral_key: &mut Signer, cert_bytes: &[u8], nonce: &[u8]) ->
response.add_field(Tag::PATH, &path).unwrap();
response.add_field(Tag::SREP, &srep_bytes).unwrap();
response.add_field(Tag::CERT, cert_bytes).unwrap();
- response.add_field(Tag::INDX, &zeros).unwrap();
+ response.add_field(Tag::INDX, &index).unwrap();
response
}
@@ -211,7 +212,7 @@ fn nonce_from_request(buf: &[u8], num_bytes: usize) -> Result<&[u8], Error> {
}
}
-fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
+fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>, u8) {
let mut infile = File::open(config_file)
.expect("failed to open config file");
@@ -229,12 +230,14 @@ fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
let mut port: u16 = 0;
let mut iface: String = "unknown".to_string();
let mut seed: String = "".to_string();
+ let mut batch_size: u8 = 1;
for (key, value) in cfg[0].as_hash().unwrap() {
match key.as_str().unwrap() {
"port" => port = value.as_i64().unwrap() as u16,
"interface" => iface = value.as_str().unwrap().to_string(),
"seed" => seed = value.as_str().unwrap().to_string(),
+ "batch_size" => batch_size = value.as_i64().unwrap() as u8,
_ => warn!("ignoring unknown config key '{}'", key.as_str().unwrap())
}
}
@@ -246,10 +249,10 @@ fn load_config(config_file: &str) -> (SocketAddr, Vec<u8>) {
let binseed = hex::decode(seed)
.expect("seed value invalid; 'seed' should be 32 byte hex value");
- (sock_addr, binseed)
+ (sock_addr, binseed, batch_size)
}
-fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &[u8]) {
+fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &[u8], batch_size: u8) {
let keep_running = Arc::new(AtomicBool::new(true));
let kr = keep_running.clone();
@@ -257,62 +260,100 @@ fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &
.expect("failed setting Ctrl-C handler");
let socket = UdpSocket::bind(addr).expect("failed to bind to socket");
- let status_duration = Duration::from_secs(6_000);
+ let status_duration = Duration::from_secs(6);
let poll_duration = Some(Duration::from_millis(100));
let mut timer: Timer<()> = Timer::default();
- timer.set_timeout(status_duration, ()).expect("unable to set_timeout");
+ timer.set_timeout(status_duration, ());
+
let mut buf = [0u8; 65_536];
let mut events = Events::with_capacity(32);
- let mut num_responses = 0u64;
let mut num_bad_requests = 0u64;
let poll = Poll::new().unwrap();
poll.register(&socket, MESSAGE, Ready::readable(), PollOpt::edge()).unwrap();
poll.register(&timer, STATUS, Ready::readable(), PollOpt::edge()).unwrap();
- loop {
- if !keep_running.load(Ordering::Acquire) {
- info!("Ctrl-C caught, exiting...");
- break;
+ let mut merkle = MerkleTree::new();
+ let mut requests = Vec::with_capacity(batch_size as usize);
+
+ macro_rules! check_ctrlc {
+ () => {
+ if !keep_running.load(Ordering::Acquire) {
+ warn!("Ctrl-C caught, exiting...");
+ return;
+ }
}
+ }
+
+ loop {
+ check_ctrlc!();
poll.poll(&mut events, poll_duration).expect("poll failed");
for event in events.iter() {
+
match event.token() {
MESSAGE => {
- loop {
- match socket.recv_from(&mut buf) {
- Ok((num_bytes, src_addr)) => {
- if let Ok(nonce) = nonce_from_request(&buf, num_bytes) {
- let resp = make_response(&mut ephemeral_key, cert_bytes, nonce);
- let resp_bytes = resp.encode().unwrap();
-
- let bytes_sent = socket.send_to(&resp_bytes, &src_addr).expect("send_to failed");
-
- num_responses += 1;
- info!("Responded {} bytes to {} for '{}..' (resp #{})", bytes_sent, src_addr, hex::encode(&nonce[0..4]), num_responses);
- } else {
- num_bad_requests += 1;
- info!("Invalid request ({} bytes) from {} (resp #{})", num_bytes, src_addr, num_responses);
+
+ let mut done = false;
+
+ 'process_batch: loop {
+ check_ctrlc!();
+
+ merkle.reset();
+ requests.clear();
+
+ let resp_start = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ for i in 0..batch_size {
+ match socket.recv_from(&mut buf) {
+ Ok((num_bytes, src_addr)) => {
+ if let Ok(nonce) = nonce_from_request(&buf, num_bytes) {
+ requests.push((Vec::from(nonce), src_addr));
+ merkle.push_leaf(nonce);
+ } else {
+ num_bad_requests += 1;
+ info!("Invalid request ({} bytes) from {} (#{} in batch, resp #{})", num_bytes, src_addr, i, resp_start + i as usize);
+ }
+ },
+ Err(e) => match e.kind() {
+ ErrorKind::WouldBlock => {
+ done = true;
+ break;
+ },
+ _ => panic!("recv_from failed with {:?}", e)
}
- }
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- break
- }
- Err(ref e) => {
- error!("Error {:?}: {:?}", e.kind(), e);
- break
- }
+ };
+ }
+
+ if requests.is_empty() {
+ break 'process_batch
+ }
+
+ let root = merkle.compute_root();
+ for (i, &(ref nonce, ref src_addr)) in requests.iter().enumerate() {
+ let paths: Vec<_> = merkle.get_paths(i).into_iter().flat_map(|x| x).collect();
+
+ let resp = make_response(&mut ephemeral_key, cert_bytes, &root, &paths, i as u32);
+ let resp_bytes = resp.encode().unwrap();
+
+ let bytes_sent = socket.send_to(&resp_bytes, &src_addr).expect("send_to failed");
+ let num_responses = NUM_RESPONSES.fetch_add(1, Ordering::SeqCst);
+
+ info!("Responded {} bytes to {} for '{}..' (#{} in batch, resp #{})", bytes_sent, src_addr, hex::encode(&nonce[0..4]), i, num_responses);
+ }
+ if done {
+ break 'process_batch
}
}
+
}
STATUS => {
- info!("responses {}, invalid requests {}", num_responses, num_bad_requests);
- timer.set_timeout(status_duration, ()).expect("unable to set_timeout");
+ info!("responses {}, invalid requests {}", NUM_RESPONSES.load(Ordering::SeqCst), num_bad_requests);
+ timer.set_timeout(status_duration, ());
}
_ => unreachable!()
@@ -321,8 +362,9 @@ fn polling_loop(addr: &SocketAddr, mut ephemeral_key: &mut Signer, cert_bytes: &
}
}
-fn main() {
+pub fn main() {
use log::Level;
+
simple_logger::init_with_level(Level::Info).unwrap();
info!("Roughenough server v{} starting", VERSION);
@@ -333,12 +375,31 @@ fn main() {
process::exit(1);
}
- let (addr, key_seed) = load_config(&args.nth(1).unwrap());
+ let (addr, key_seed, batch_size) = load_config(&args.nth(1).unwrap());
let (mut ephemeral_key, cert_bytes) = make_key_and_cert(&key_seed);
info!("Server listening on {}", addr);
- polling_loop(&addr, &mut ephemeral_key, &cert_bytes);
+ if env::var("BENCH").is_ok() {
+ log::set_max_level(log::LevelFilter::Warn);
+
+ thread::spawn(|| {
+ loop {
+ let old = time::get_time().sec;
+ let old_reqs = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ thread::sleep(Duration::from_secs(1));
+
+ let new = time::get_time().sec;
+ let new_reqs = NUM_RESPONSES.load(Ordering::SeqCst);
+
+ warn!("Processing at {:?} reqs/sec", (new_reqs - old_reqs) / (new - old) as usize);
+ }
+ });
+ }
+
+
+ polling_loop(&addr, &mut ephemeral_key, &cert_bytes, batch_size);
info!("Done.");
process::exit(0);
diff --git a/src/lib.rs b/src/lib.rs
index 7259c26..72cdb69 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -58,6 +58,7 @@ mod tag;
mod message;
pub mod sign;
+pub mod merkle;
pub use error::Error;
pub use tag::Tag;
@@ -83,6 +84,9 @@ pub const NONCE_LENGTH: u32 = 64;
/// Size (in bytes) of an Ed25519 signature
pub const SIGNATURE_LENGTH: u32 = 64;
+/// Size (in bytes) of a SHA-512 hash
+pub const HASH_LENGTH: u32 = 64;
+
/// Size (in bytes) of server's timestamp value
pub const TIMESTAMP_LENGTH: u32 = 8;
diff --git a/src/merkle.rs b/src/merkle.rs
new file mode 100644
index 0000000..623770f
--- /dev/null
+++ b/src/merkle.rs
@@ -0,0 +1,174 @@
+// Copyright 2018 int08h LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+extern crate ring;
+
+use super::{TREE_LEAF_TWEAK, TREE_NODE_TWEAK, HASH_LENGTH};
+use self::ring::digest;
+
+type Data = Vec<u8>;
+type Hash = Data;
+
+pub struct MerkleTree {
+ levels: Vec<Vec<Data>>,
+}
+
+impl MerkleTree {
+ pub fn new() -> MerkleTree {
+ MerkleTree {
+ levels: vec![vec![]]
+ }
+ }
+
+ pub fn push_leaf(&mut self, data: &[u8]) {
+ let hash = self.hash_leaf(data);
+ self.levels[0].push(hash);
+ }
+
+ pub fn get_paths(&self, mut index: usize) -> Vec<Hash> {
+ let mut paths = Vec::new();
+ let mut level = 0;
+
+ while !self.levels[level].is_empty() {
+ let sibling = if index % 2 == 0 {
+ index + 1
+ } else {
+ index - 1
+ };
+
+ paths.push(self.levels[level][sibling].clone());
+ level += 1;
+ index /= 2;
+ }
+ paths
+ }
+
+ pub fn compute_root(&mut self) -> Hash {
+ assert!(self.levels[0].len() > 0, "Must have at least one leaf to hash!");
+
+ let mut level = 0;
+ let mut node_count = self.levels[0].len();
+ while node_count > 1 {
+ level += 1;
+
+ if self.levels.len() < level + 1 {
+ self.levels.push(vec![]);
+ }
+
+ if node_count % 2 != 0 {
+ self.levels[level - 1].push(vec![0; HASH_LENGTH as usize]);
+ node_count += 1;
+ }
+
+ node_count /= 2;
+
+ for i in 0..node_count {
+ let hash = self.hash_nodes(&self.levels[level - 1][i*2], &self.levels[level - 1][(i*2)+1]);
+ self.levels[level].push(hash);
+ }
+ }
+ assert_eq!(self.levels[level].len(), 1);
+ self.levels[level].pop().unwrap()
+ }
+
+ pub fn reset(&mut self) {
+ for mut level in &mut self.levels {
+ level.clear();
+ }
+ }
+
+ fn hash_leaf(&self, leaf: &[u8]) -> Data {
+ self.hash(&[TREE_LEAF_TWEAK, leaf])
+ }
+
+ fn hash_nodes(&self, first: &[u8], second: &[u8]) -> Data {
+ self.hash(&[TREE_NODE_TWEAK, first, second])
+ }
+
+ fn hash(&self, to_hash: &[&[u8]]) -> Data {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ for data in to_hash {
+ ctx.update(data);
+ }
+ Data::from(ctx.finish().as_ref())
+ }
+}
+
+pub fn root_from_paths(mut index: usize, data: &[u8], paths: &[u8]) -> Hash {
+ let mut hash = {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ ctx.update(TREE_LEAF_TWEAK);
+ ctx.update(data);
+ Hash::from(ctx.finish().as_ref())
+ };
+
+ assert_eq!(paths.len() % 64, 0);
+
+ for path in paths.chunks(64) {
+ let mut ctx = digest::Context::new(&digest::SHA512);
+ ctx.update(TREE_NODE_TWEAK);
+
+ if index & 1 == 0 {
+ // Left
+ ctx.update(&hash);
+ ctx.update(path);
+ } else {
+ // Right
+ ctx.update(path);
+ ctx.update(&hash);
+ }
+ hash = Hash::from(ctx.finish().as_ref());
+
+ index >>= 1;
+ }
+ hash
+}
+
+#[cfg(test)]
+mod test {
+
+ use merkle::*;
+
+ fn test_paths_with_num(num: usize) {
+ let mut merkle = MerkleTree::new();
+
+ for i in 0..num {
+ merkle.push_leaf(&[i as u8]);
+ }
+
+ let root = merkle.compute_root();
+
+ for i in 0..num {
+ println!("Testing {:?} {:?}", num, i);
+ let paths: Vec<u8> = merkle.get_paths(i).into_iter().flat_map(|x| x).collect();
+ let computed_root = root_from_paths(i, &[i as u8], &paths);
+
+ assert_eq!(root, computed_root);
+ }
+ }
+
+ #[test]
+ fn power_of_two() {
+ test_paths_with_num(2);
+ test_paths_with_num(4);
+ test_paths_with_num(8);
+ test_paths_with_num(16);
+ }
+
+ #[test]
+ fn not_power_of_two() {
+ test_paths_with_num(1);
+ test_paths_with_num(20);
+ }
+}