summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStuart Stock <stuart@int08h.com>2019-01-26 07:30:14 -0600
committerGitHub <noreply@github.com>2019-01-26 07:30:14 -0600
commit9515a562754b81a3d499f6ab39b1b4810076ff2c (patch)
tree1589241806560531ec1adb63d28837cda906e952
parentbdf087db2f5ebcadf31bd0968b690daa25820489 (diff)
parent796cd33d1a8f9ed2713447f23b24535eef3b34ba (diff)
downloadroughenough-9515a562754b81a3d499f6ab39b1b4810076ff2c.zip
Merge pull request #14 from int08h/1.1.2
Land 1.1.2
-rw-r--r--CHANGELOG.md6
-rw-r--r--Cargo.toml11
-rw-r--r--src/bin/roughenough-server.rs6
-rw-r--r--src/config/mod.rs4
-rw-r--r--src/lib.rs3
-rw-r--r--src/server.rs308
-rw-r--r--src/stats.rs362
7 files changed, 544 insertions, 156 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9a98044..86b8175 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,9 @@
+## Version 1.1.2
+
+* Add client request statistics tracking.
+* Clean-up and simplification of server inner loop.
+* Rust 2018 edition required to compile.
+
## Version 1.1.1
* Provide auxiliary data to the AWS KMS decryption call. The auxiliary data _was_ provided in encrypt, but not decrypt, resulting in unconditional failure when unwrapping the long-term identity. See https://github.com/int08h/roughenough/commit/846128d08bd3fcd72f23b3123b332d0692782e41#diff-7f7c3059af30a5ded26269301caf8531R102
diff --git a/Cargo.toml b/Cargo.toml
index 165d81d..d7e75ca 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "roughenough"
-version = "1.1.1"
+version = "1.1.2"
repository = "https://github.com/int08h/roughenough"
authors = ["Stuart Stock <stuart@int08h.com>", "Aaron Hill <aa1ronham@gmail.com>"]
license = "Apache-2.0"
@@ -32,10 +32,14 @@ clap = "2"
chrono = "0.4"
hex = "0.3"
base64 = "0.9"
+hashbrown = "0.1"
+humansize = "1.0"
-rusoto_core = { version = "0.34", optional = true }
-rusoto_kms = { version = "0.34", optional = true }
+# Used by 'awskms'
+rusoto_core = { version = "0.36", optional = true }
+rusoto_kms = { version = "0.36", optional = true }
+# Used by 'gcpkms'
# google-cloudkms1 intentionally uses an old version of Hyper. See
# https://github.com/Byron/google-apis-rs/issues/173 for more information.
google-cloudkms1 = { version = "1.0.8+20181005", optional = true }
@@ -45,7 +49,6 @@ serde = { version = "^1.0", optional = true }
serde_json = { version = "^1.0", optional = true }
yup-oauth2 = { version = "^1.0", optional = true }
-
[dev-dependencies]
criterion = "0.2"
diff --git a/src/bin/roughenough-server.rs b/src/bin/roughenough-server.rs
index 5f70da0..88c8e83 100644
--- a/src/bin/roughenough-server.rs
+++ b/src/bin/roughenough-server.rs
@@ -32,6 +32,8 @@ use roughenough::config::ServerConfig;
use roughenough::roughenough_version;
use roughenough::server::Server;
+use mio::Events;
+
macro_rules! check_ctrlc {
($keep_running:expr) => {
if !$keep_running.load(Ordering::Acquire) {
@@ -74,9 +76,11 @@ fn polling_loop(config: Box<ServerConfig>) {
ctrlc::set_handler(move || kr.store(false, Ordering::Release))
.expect("failed setting Ctrl-C handler");
+ let mut events = Events::with_capacity(64);
+
loop {
check_ctrlc!(kr_new);
- if server.process_events() {
+ if server.process_events(&mut events) {
return;
}
}
diff --git a/src/config/mod.rs b/src/config/mod.rs
index 65204e6..fb7854f 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -112,7 +112,7 @@ pub trait ServerConfig {
/// * `ENV` will return an [`EnvironmentConfig`](struct.EnvironmentConfig.html)
/// * any other value returns a [`FileConfig`](struct.FileConfig.html)
///
-pub fn make_config(arg: &str) -> Result<Box<ServerConfig>, Error> {
+pub fn make_config(arg: &str) -> Result<Box<dyn ServerConfig>, Error> {
if arg == "ENV" {
match EnvironmentConfig::new() {
Ok(cfg) => Ok(Box::new(cfg)),
@@ -129,7 +129,7 @@ pub fn make_config(arg: &str) -> Result<Box<ServerConfig>, Error> {
///
/// Validate configuration settings. Returns `true` if the config is valid, `false` otherwise.
///
-pub fn is_valid_config(cfg: &Box<ServerConfig>) -> bool {
+pub fn is_valid_config(cfg: &Box<dyn ServerConfig>) -> bool {
let mut is_valid = true;
if cfg.port() == 0 {
diff --git a/src/lib.rs b/src/lib.rs
index 4ad6390..9a21dfa 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -66,6 +66,7 @@ pub mod config;
pub mod key;
pub mod kms;
pub mod merkle;
+pub mod stats;
pub mod server;
pub mod sign;
@@ -74,7 +75,7 @@ pub use crate::message::RtMessage;
pub use crate::tag::Tag;
/// Version of Roughenough
-pub const VERSION: &str = "1.1.1";
+pub const VERSION: &str = "1.1.2";
/// Roughenough version string enriched with any compile-time optional features
pub fn roughenough_version() -> String {
diff --git a/src/server.rs b/src/server.rs
index 01bbef6..6045191 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -23,21 +23,24 @@ use std::process;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
+use std::io::Write;
use time;
use byteorder::{LittleEndian, WriteBytesExt};
+use humansize::{FileSize, file_size_opts as fsopts};
+
use mio::net::{TcpListener, UdpSocket};
use mio::{Events, Poll, PollOpt, Ready, Token};
+use mio::tcp::Shutdown;
use mio_extras::timer::Timer;
use crate::config::ServerConfig;
use crate::key::{LongTermKey, OnlineKey};
use crate::kms;
use crate::merkle::MerkleTree;
-use mio::tcp::Shutdown;
-use std::io::Write;
use crate::{Error, RtMessage, Tag, MIN_REQUEST_LENGTH};
+use crate::stats::{ClientStats, SimpleStats};
macro_rules! check_ctrlc {
($keep_running:expr) => {
@@ -67,20 +70,16 @@ const HTTP_RESPONSE: &str = "HTTP/1.1 200 OK\nContent-Length: 0\nConnection: clo
/// See [the config module](../config/index.html) for more information.
///
pub struct Server {
- config: Box<ServerConfig>,
+ config: Box<dyn ServerConfig>,
online_key: OnlineKey,
cert_bytes: Vec<u8>,
- response_counter: u64,
- num_bad_requests: u64,
-
socket: UdpSocket,
health_listener: Option<TcpListener>,
keep_running: Arc<AtomicBool>,
poll_duration: Option<Duration>,
timer: Timer<()>,
poll: Poll,
- events: Events,
merkle: MerkleTree,
requests: Vec<(Vec<u8>, SocketAddr)>,
buf: [u8; 65_536],
@@ -90,6 +89,8 @@ pub struct Server {
// Used to send requests to ourselves in fuzzing mode
#[cfg(fuzzing)]
fake_client_socket: UdpSocket,
+
+ stats: SimpleStats,
}
impl Server {
@@ -155,8 +156,6 @@ impl Server {
online_key,
cert_bytes,
- response_counter: 0,
- num_bad_requests: 0,
socket,
health_listener,
@@ -164,7 +163,6 @@ impl Server {
poll_duration,
timer,
poll,
- events: Events::with_capacity(32),
merkle,
requests,
buf: [0u8; 65_536],
@@ -173,150 +171,69 @@ impl Server {
#[cfg(fuzzing)]
fake_client_socket: UdpSocket::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(),
+
+ stats: SimpleStats::new(),
}
}
- /// Returns a reference counted pointer the this server's `keep_running` value.
- pub fn get_keep_running(&self) -> Arc<AtomicBool> {
- self.keep_running.clone()
+ /// Returns a reference to the server's long-term public key
+ pub fn get_public_key(&self) -> &str {
+ &self.public_key
}
- // extract the client's nonce from its request
- fn nonce_from_request<'a>(&self, buf: &'a [u8], num_bytes: usize) -> Result<&'a [u8], Error> {
- if num_bytes < MIN_REQUEST_LENGTH as usize {
- return Err(Error::RequestTooShort);
- }
-
- let tag_count = &buf[..4];
- let expected_nonc = &buf[8..12];
- let expected_pad = &buf[12..16];
-
- let tag_count_is_2 = tag_count == [0x02, 0x00, 0x00, 0x00];
- let tag1_is_nonc = expected_nonc == Tag::NONC.wire_value();
- let tag2_is_pad = expected_pad == Tag::PAD.wire_value();
-
- if tag_count_is_2 && tag1_is_nonc && tag2_is_pad {
- Ok(&buf[0x10..0x50])
- } else {
- Err(Error::InvalidRequest)
- }
+ /// Returns a reference to the server's on-line (delegated) key
+ pub fn get_online_key(&self) -> &OnlineKey {
+ &self.online_key
}
- fn make_response(
- &self,
- srep: &RtMessage,
- cert_bytes: &[u8],
- path: &[u8],
- idx: u32,
- ) -> RtMessage {
- let mut index = [0; 4];
- (&mut index as &mut [u8])
- .write_u32::<LittleEndian>(idx)
- .unwrap();
-
- let sig_bytes = srep.get_field(Tag::SIG).unwrap();
- let srep_bytes = srep.get_field(Tag::SREP).unwrap();
+ /// Returns a reference to the `ServerConfig` this server was configured with
+ pub fn get_config(&self) -> &Box<dyn ServerConfig> {
+ &self.config
+ }
- let mut response = RtMessage::new(5);
- response.add_field(Tag::SIG, sig_bytes).unwrap();
- response.add_field(Tag::PATH, path).unwrap();
- response.add_field(Tag::SREP, srep_bytes).unwrap();
- response.add_field(Tag::CERT, cert_bytes).unwrap();
- response.add_field(Tag::INDX, &index).unwrap();
+ /// Returns a reference counted pointer the this server's `keep_running` value.
+ pub fn get_keep_running(&self) -> Arc<AtomicBool> {
+ self.keep_running.clone()
+ }
- response
+ #[cfg(fuzzing)]
+ pub fn send_to_self(&mut self, data: &[u8]) {
+ self.response_counter = 0;
+ self.num_bad_requests = 0;
+ let res = self
+ .fake_client_socket
+ .send_to(data, &self.socket.local_addr().unwrap());
+ info!("Sent to self: {:?}", res);
}
/// The main processing function for incoming connections. This method should be
/// called repeatedly in a loop to process requests. It returns 'true' when the
/// server has shutdown (due to keep_running being set to 'false').
///
- pub fn process_events(&mut self) -> bool {
+ pub fn process_events(&mut self, events: &mut Events) -> bool {
self.poll
- .poll(&mut self.events, self.poll_duration)
+ .poll(events, self.poll_duration)
.expect("poll failed");
- for event in self.events.iter() {
- match event.token() {
+ for msg in events.iter() {
+ match msg.token() {
MESSAGE => {
- let mut done = false;
-
- 'process_batch: loop {
+ loop {
check_ctrlc!(self.keep_running);
- let resp_start = self.response_counter;
-
- for i in 0..self.config.batch_size() {
- match self.socket.recv_from(&mut self.buf) {
- Ok((num_bytes, src_addr)) => {
- match self.nonce_from_request(&self.buf, num_bytes) {
- Ok(nonce) => {
- self.requests.push((Vec::from(nonce), src_addr));
- self.merkle.push_leaf(nonce);
- }
- Err(e) => {
- self.num_bad_requests += 1;
-
- info!(
- "Invalid request: '{:?}' ({} bytes) from {} (#{} in batch, resp #{})",
- e, num_bytes, src_addr, i, resp_start + i as u64
- );
- }
- }
- }
- Err(e) => match e.kind() {
- ErrorKind::WouldBlock => {
- done = true;
- break;
- }
- _ => {
- error!(
- "Error receiving from socket: {:?}: {:?}",
- e.kind(),
- e
- );
- break;
- }
- },
- };
- }
-
- if self.requests.is_empty() {
- break 'process_batch;
- }
-
- let merkle_root = self.merkle.compute_root();
- let srep = self.online_key.make_srep(time::get_time(), &merkle_root);
-
- for (i, &(ref nonce, ref src_addr)) in self.requests.iter().enumerate() {
- let paths = self.merkle.get_paths(i);
-
- let resp =
- self.make_response(&srep, &self.cert_bytes, &paths, i as u32);
- let resp_bytes = resp.encode().unwrap();
-
- let bytes_sent = self
- .socket
- .send_to(&resp_bytes, &src_addr)
- .expect("send_to failed");
+ self.merkle.reset();
+ self.requests.clear();
- self.response_counter += 1;
+ let socket_now_empty = self.collect_requests();
- info!(
- "Responded {} bytes to {} for '{}..' (#{} in batch, resp #{})",
- bytes_sent,
- src_addr,
- hex::encode(&nonce[0..4]),
- i,
- self.response_counter
- );
+ if self.requests.is_empty() {
+ break;
}
- self.merkle.reset();
- self.requests.clear();
+ self.send_responses();
- if done {
- break 'process_batch;
+ if socket_now_empty {
+ break;
}
}
}
@@ -327,11 +244,12 @@ impl Server {
match listener.accept() {
Ok((ref mut stream, src_addr)) => {
info!("health check from {}", src_addr);
+ self.stats.add_health_check(&src_addr.ip());
match stream.write(HTTP_RESPONSE.as_bytes()) {
Ok(_) => (),
Err(e) => warn!("error writing health check {}", e),
- }
+ };
match stream.shutdown(Shutdown::Both) {
Ok(_) => (),
@@ -348,9 +266,21 @@ impl Server {
}
STATUS => {
+ for (addr, counts) in self.stats.iter() {
+ info!(
+ "{:16}: {} valid, {} invalid requests; {} responses ({} sent)",
+ format!("{}", addr), counts.valid_requests, counts.invalid_requests,
+ counts.responses_sent,
+ counts.bytes_sent.file_size(fsopts::BINARY).unwrap()
+ );
+ }
+
info!(
- "responses {}, invalid requests {}",
- self.response_counter, self.num_bad_requests
+ "Totals: {} unique clients; {} valid, {} invalid requests; {} responses ({} sent)",
+ self.stats.total_unique_clients(),
+ self.stats.total_valid_requests(), self.stats.total_invalid_requests(),
+ self.stats.total_responses_sent(),
+ self.stats.total_bytes_sent().file_size(fsopts::BINARY).unwrap()
);
self.timer.set_timeout(self.config.status_interval(), ());
@@ -362,28 +292,110 @@ impl Server {
false
}
- /// Returns a reference to the server's long-term public key
- pub fn get_public_key(&self) -> &str {
- &self.public_key
+ // Read and process client requests from socket until empty or 'batch_size' number of
+ // requests have been read.
+ fn collect_requests(&mut self) -> bool {
+ for i in 0..self.config.batch_size() {
+ match self.socket.recv_from(&mut self.buf) {
+ Ok((num_bytes, src_addr)) => {
+ match self.nonce_from_request(&self.buf, num_bytes) {
+ Ok(nonce) => {
+ self.stats.add_valid_request(&src_addr.ip());
+ self.requests.push((Vec::from(nonce), src_addr));
+ self.merkle.push_leaf(nonce);
+ }
+ Err(e) => {
+ self.stats.add_invalid_request(&src_addr.ip());
+
+ info!(
+ "Invalid request: '{:?}' ({} bytes) from {} (#{} in batch, resp #{})",
+ e, num_bytes, src_addr, i,
+ self.stats.total_responses_sent() + u64::from(i)
+ );
+ }
+ }
+ }
+ Err(e) => match e.kind() {
+ ErrorKind::WouldBlock => {
+ return true;
+ }
+ _ => {
+ error!("Error receiving from socket: {:?}: {:?}", e.kind(), e);
+ return false;
+ }
+ },
+ };
+ }
+
+ false
}
- /// Returns a reference to the server's on-line (delegated) key
- pub fn get_online_key(&self) -> &OnlineKey {
- &self.online_key
+ // extract the client's nonce from its request
+ fn nonce_from_request<'a>(&self, buf: &'a [u8], num_bytes: usize) -> Result<&'a [u8], Error> {
+ if num_bytes < MIN_REQUEST_LENGTH as usize {
+ return Err(Error::RequestTooShort);
+ }
+
+ let tag_count = &buf[..4];
+ let expected_nonc = &buf[8..12];
+ let expected_pad = &buf[12..16];
+
+ let tag_count_is_2 = tag_count == [0x02, 0x00, 0x00, 0x00];
+ let tag1_is_nonc = expected_nonc == Tag::NONC.wire_value();
+ let tag2_is_pad = expected_pad == Tag::PAD.wire_value();
+
+ if tag_count_is_2 && tag1_is_nonc && tag2_is_pad {
+ Ok(&buf[0x10..0x50])
+ } else {
+ Err(Error::InvalidRequest)
+ }
}
- /// Returns a reference to the `ServerConfig` this server was configured with
- pub fn get_config(&self) -> &Box<ServerConfig> {
- &self.config
+ fn send_responses(&mut self) -> () {
+ let merkle_root = self.merkle.compute_root();
+
+ // The SREP tag is identical for each response
+ let srep = self.online_key.make_srep(time::get_time(), &merkle_root);
+
+ for (i, &(ref nonce, ref src_addr)) in self.requests.iter().enumerate() {
+ let paths = self.merkle.get_paths(i);
+ let resp = self.make_response(&srep, &self.cert_bytes, &paths, i as u32);
+ let resp_bytes = resp.encode().unwrap();
+
+ let bytes_sent = self
+ .socket
+ .send_to(&resp_bytes, &src_addr)
+ .expect("send_to failed");
+
+ self.stats.add_response(&src_addr.ip(), bytes_sent);
+
+ info!(
+ "Responded {} bytes to {} for '{}..' (#{} in batch, resp #{})",
+ bytes_sent,
+ src_addr,
+ hex::encode(&nonce[0..4]),
+ i,
+ self.stats.total_responses_sent()
+ );
+ }
}
- #[cfg(fuzzing)]
- pub fn send_to_self(&mut self, data: &[u8]) {
- self.response_counter = 0;
- self.num_bad_requests = 0;
- let res = self
- .fake_client_socket
- .send_to(data, &self.socket.local_addr().unwrap());
- info!("Sent to self: {:?}", res);
+ fn make_response(&self, srep: &RtMessage, cert_bytes: &[u8], path: &[u8], idx: u32) -> RtMessage {
+ let mut index = [0; 4];
+ (&mut index as &mut [u8])
+ .write_u32::<LittleEndian>(idx)
+ .unwrap();
+
+ let sig_bytes = srep.get_field(Tag::SIG).unwrap();
+ let srep_bytes = srep.get_field(Tag::SREP).unwrap();
+
+ let mut response = RtMessage::new(5);
+ response.add_field(Tag::SIG, sig_bytes).unwrap();
+ response.add_field(Tag::PATH, path).unwrap();
+ response.add_field(Tag::SREP, srep_bytes).unwrap();
+ response.add_field(Tag::CERT, cert_bytes).unwrap();
+ response.add_field(Tag::INDX, &index).unwrap();
+
+ response
}
}
diff --git a/src/stats.rs b/src/stats.rs
new file mode 100644
index 0000000..d296e40
--- /dev/null
+++ b/src/stats.rs
@@ -0,0 +1,362 @@
+// Copyright 2017-2019 int08h LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//!
+//! Facilities for tracking client requests to the server
+//!
+
+use hashbrown::HashMap;
+use hashbrown::hash_map::Iter;
+
+use std::net::IpAddr;
+
+///
+/// Implementations of this trait record client activity
+///
+pub trait ClientStats {
+ fn add_valid_request(&mut self, addr: &IpAddr);
+
+ fn add_invalid_request(&mut self, addr: &IpAddr);
+
+ fn add_health_check(&mut self, addr: &IpAddr);
+
+ fn add_response(&mut self, addr: &IpAddr, bytes_sent: usize);
+
+ fn total_valid_requests(&self) -> u64;
+
+ fn total_invalid_requests(&self) -> u64;
+
+ fn total_health_checks(&self) -> u64;
+
+ fn total_responses_sent(&self) -> u64;
+
+ fn total_bytes_sent(&self) -> usize;
+
+ fn total_unique_clients(&self) -> u64;
+
+ fn get_stats(&self, addr: &IpAddr) -> Option<&StatEntry>;
+
+ fn iter(&self) -> Iter<IpAddr, StatEntry>;
+
+ fn clear(&mut self);
+}
+
+///
+/// Specific metrics tracked per each client
+///
+#[derive(Debug, Clone, Copy)]
+pub struct StatEntry {
+ pub valid_requests: u64,
+ pub invalid_requests: u64,
+ pub health_checks: u64,
+ pub responses_sent: u64,
+ pub bytes_sent: usize,
+}
+
+impl StatEntry {
+ fn new() -> Self {
+ StatEntry {
+ valid_requests: 0,
+ invalid_requests: 0,
+ health_checks: 0,
+ responses_sent: 0,
+ bytes_sent: 0,
+ }
+ }
+}
+
+///
+/// Implementation of `ClientStats` backed by a hashmap.
+///
+/// Maintains a maximum of `MAX_CLIENTS` unique entries to bound memory use. Excess
+/// entries beyond `MAX_CLIENTS` are ignored and `num_overflows` is incremented.
+///
+pub struct SimpleStats {
+ clients: HashMap<IpAddr, StatEntry>,
+ num_overflows: u64,
+ max_clients: usize,
+}
+
+impl SimpleStats {
+
+ /// Maximum number of stats entries to maintain to prevent
+ /// unbounded memory growth.
+ pub const MAX_CLIENTS: usize = 1_000_000;
+
+ pub fn new() -> Self {
+ SimpleStats {
+ clients: HashMap::with_capacity(128),
+ num_overflows: 0,
+ max_clients: SimpleStats::MAX_CLIENTS,
+ }
+ }
+
+ // visible for testing
+ #[cfg(test)]
+ fn with_limits(limit: usize) -> Self {
+ SimpleStats {
+ clients: HashMap::with_capacity(128),
+ num_overflows: 0,
+ max_clients: limit,
+ }
+ }
+
+ #[inline]
+ fn too_many_entries(&mut self) -> bool {
+ let too_big = self.clients.len() >= self.max_clients;
+
+ if too_big {
+ self.num_overflows += 1;
+ }
+
+ return too_big;
+ }
+
+ #[allow(dead_code)]
+ pub fn num_overflows(&self) -> u64 {
+ self.num_overflows
+ }
+}
+
+impl ClientStats for SimpleStats {
+ fn add_valid_request(&mut self, addr: &IpAddr) {
+ if self.too_many_entries() {
+ return;
+ }
+ self.clients
+ .entry(*addr)
+ .or_insert_with(StatEntry::new)
+ .valid_requests += 1;
+ }
+
+ fn add_invalid_request(&mut self, addr: &IpAddr) {
+ if self.too_many_entries() {
+ return;
+ }
+ self.clients
+ .entry(*addr)
+ .or_insert_with(StatEntry::new)
+ .invalid_requests += 1;
+ }
+
+ fn add_health_check(&mut self, addr: &IpAddr) {
+ if self.too_many_entries() {
+ return;
+ }
+ self.clients
+ .entry(*addr)
+ .or_insert_with(StatEntry::new)
+ .health_checks += 1;
+ }
+
+ fn add_response(&mut self, addr: &IpAddr, bytes_sent: usize) {
+ if self.too_many_entries() {
+ return;
+ }
+ let entry = self.clients
+ .entry(*addr)
+ .or_insert_with(StatEntry::new);
+
+ entry.responses_sent += 1;
+ entry.bytes_sent += bytes_sent;
+ }
+
+ fn total_valid_requests(&self) -> u64 {
+ self.clients
+ .values()
+ .map(|&v| v.valid_requests)
+ .sum()
+ }
+
+ fn total_invalid_requests(&self) -> u64 {
+ self.clients
+ .values()
+ .map(|&v| v.invalid_requests)
+ .sum()
+ }
+
+ fn total_health_checks(&self) -> u64 {
+ self.clients
+ .values()
+ .map(|&v| v.health_checks)
+ .sum()
+ }
+
+ fn total_responses_sent(&self) -> u64 {
+ self.clients
+ .values()
+ .map(|&v| v.responses_sent)
+ .sum()
+ }
+
+ fn total_bytes_sent(&self) -> usize {
+ self.clients
+ .values()
+ .map(|&v| v.bytes_sent)
+ .sum()
+ }
+
+ fn total_unique_clients(&self) -> u64 {
+ self.clients.len() as u64
+ }
+
+ fn get_stats(&self, addr: &IpAddr) -> Option<&StatEntry> {
+ self.clients.get(addr)
+ }
+
+ fn iter(&self) -> Iter<IpAddr, StatEntry> {
+ self.clients.iter()
+ }
+
+ fn clear(&mut self) {
+ self.clients.clear();
+ self.num_overflows = 0;
+ }
+}
+
+///
+/// A no-op implementation that does not track anything and has no runtime cost
+///
+#[allow(dead_code)]
+pub struct NoOpStats {
+ empty_map: HashMap<IpAddr, StatEntry>
+}
+
+impl NoOpStats {
+
+ #[allow(dead_code)]
+ pub fn new() -> Self {
+ NoOpStats {
+ empty_map: HashMap::new()
+ }
+ }
+}
+
+impl ClientStats for NoOpStats {
+ fn add_valid_request(&mut self, _addr: &IpAddr) {}
+
+ fn add_invalid_request(&mut self, _addr: &IpAddr) {}
+
+ fn add_health_check(&mut self, _addr: &IpAddr) {}
+
+ fn add_response(&mut self, _addr: &IpAddr, _bytes_sent: usize) {}
+
+ fn total_valid_requests(&self) -> u64 {
+ 0
+ }
+
+ fn total_invalid_requests(&self) -> u64 {
+ 0
+ }
+
+ fn total_health_checks(&self) -> u64 {
+ 0
+ }
+
+ fn total_responses_sent(&self) -> u64 {
+ 0
+ }
+
+ fn total_bytes_sent(&self) -> usize {
+ 0
+ }
+
+ fn total_unique_clients(&self) -> u64 {
+ 0
+ }
+
+ fn get_stats(&self, _addr: &IpAddr) -> Option<&StatEntry> {
+ None
+ }
+
+ fn iter(&self) -> Iter<IpAddr, StatEntry> {
+ self.empty_map.iter()
+ }
+
+ fn clear(&mut self) {}
+}
+
+#[cfg(test)]
+mod test {
+ use crate::stats::{ClientStats, SimpleStats};
+ use std::net::{IpAddr, Ipv4Addr};
+
+ #[test]
+ fn simple_stats_starts_empty() {
+ let stats = SimpleStats::new();
+
+ assert_eq!(stats.total_valid_requests(), 0);
+ assert_eq!(stats.total_invalid_requests(), 0);
+ assert_eq!(stats.total_health_checks(), 0);
+ assert_eq!(stats.total_responses_sent(), 0);
+ assert_eq!(stats.total_bytes_sent(), 0);
+ assert_eq!(stats.total_unique_clients(), 0);
+ assert_eq!(stats.num_overflows(), 0);
+ }
+
+ #[test]
+ fn client_requests_are_tracked() {
+ let mut stats = SimpleStats::new();
+
+ let ip1 = "127.0.0.1".parse().unwrap();
+ let ip2 = "127.0.0.2".parse().unwrap();
+ let ip3 = "127.0.0.3".parse().unwrap();
+
+ stats.add_valid_request(&ip1);
+ stats.add_valid_request(&ip2);
+ stats.add_valid_request(&ip3);
+ assert_eq!(stats.total_valid_requests(), 3);
+
+ stats.add_invalid_request(&ip2);
+ assert_eq!(stats.total_invalid_requests(), 1);
+
+ stats.add_response(&ip2, 8192);
+ assert_eq!(stats.total_bytes_sent(), 8192);
+
+ assert_eq!(stats.total_unique_clients(), 3);
+ }
+
+ #[test]
+ fn per_client_stats() {
+ let mut stats = SimpleStats::new();
+ let ip = "127.0.0.3".parse().unwrap();
+
+ stats.add_valid_request(&ip);
+ stats.add_response(&ip, 2048);
+ stats.add_response(&ip, 1024);
+
+ let entry = stats.get_stats(&ip).unwrap();
+ assert_eq!(entry.valid_requests, 1);
+ assert_eq!(entry.invalid_requests, 0);
+ assert_eq!(entry.responses_sent, 2);
+ assert_eq!(entry.bytes_sent, 3072);
+ }
+
+ #[test]
+ fn overflow_max_entries() {
+ let mut stats = SimpleStats::with_limits(100);
+
+ for i in 0..201 {
+ let ipv4 = Ipv4Addr::from(i as u32);
+ let addr = IpAddr::from(ipv4);
+
+ stats.add_valid_request(&addr);
+ };
+
+ assert_eq!(stats.total_unique_clients(), 100);
+ assert_eq!(stats.num_overflows(), 101);
+ }
+}
+
+