ntp-proto-1.4.0/.cargo_vcs_info.json0000644000000001470000000000100130020ustar { "git": { "sha1": "ef33a8c17713a52546b54e1034679d3ebdb038af" }, "path_in_vcs": "ntp-proto" }ntp-proto-1.4.0/COPYRIGHT000064400000000000000000000005441046102023000130660ustar 00000000000000Copyright (c) 2022-2024 Trifecta Tech Foundation, Tweede Golf, and Contributors Except as otherwise noted (below and/or in individual files), ntpd-rs is licensed under the Apache License, Version 2.0 or or the MIT license or , at your option. ntp-proto-1.4.0/Cargo.toml0000644000000052300000000000100107760ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.70" name = "ntp-proto" version = "1.4.0" build = false publish = true autolib = false autobins = false autoexamples = false autotests = false autobenches = false description = "ntpd-rs packet parsing and algorithms" homepage = "https://github.com/pendulum-project/ntpd-rs" readme = "README.md" license = "Apache-2.0 OR MIT" repository = "https://github.com/pendulum-project/ntpd-rs" [lib] name = "ntp_proto" path = "src/lib.rs" [dependencies.aead] version = "0.5.0" [dependencies.aes-siv] version = "0.7.0" [dependencies.arbitrary] version = "1.0" optional = true [dependencies.md-5] version = "0.10.0" [dependencies.rand] version = "0.8.0" [dependencies.rustls-native-certs6] version = "0.6" optional = true package = "rustls-native-certs" [dependencies.rustls-native-certs7] version = "0.7" optional = true package = "rustls-native-certs" [dependencies.rustls-pemfile1] version = "1.0" optional = true package = "rustls-pemfile" [dependencies.rustls-pemfile2] version = "2.0" optional = true package = "rustls-pemfile" [dependencies.rustls-pki-types] version = "1.2" optional = true [dependencies.rustls21] version = "0.21.0" optional = true package = "rustls" [dependencies.rustls22] version = "0.22.0" features = [ "ring", "logging", "tls12", ] optional = true default-features = false package = "rustls" [dependencies.rustls23] version = "0.23.0" features = [ "ring", "logging", "std", "tls12", ] optional = true default-features = false package = "rustls" [dependencies.serde] version = "1.0.145" features = ["derive"] [dependencies.tracing] version = "0.1.37" [dependencies.zeroize] version = "1.7" [dev-dependencies.serde_test] version = "1.0.176" [features] __internal-api = [] __internal-fuzz = [ "arbitrary", "__internal-api", ] __internal-test = ["__internal-api"] default = ["rustls23"] ntpv5 = [] nts-pool = ["rustls23"] rustls21 = [ "dep:rustls21", "dep:rustls-pemfile1", "dep:rustls-native-certs6", ] rustls22 = [ "dep:rustls22", "dep:rustls-pemfile2", "dep:rustls-native-certs7", "dep:rustls-pki-types", ] rustls23 = [ "dep:rustls23", "dep:rustls-pemfile2", "dep:rustls-native-certs7", ] ntp-proto-1.4.0/Cargo.toml.orig000064400000000000000000000031331046102023000144570ustar 00000000000000[package] name = "ntp-proto" description = "ntpd-rs packet parsing and algorithms" readme = "README.md" version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true homepage.workspace = true publish.workspace = true rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] default = ["rustls23"] __internal-fuzz = ["arbitrary", "__internal-api"] __internal-test = ["__internal-api"] __internal-api = [] ntpv5 = [] nts-pool = [ "rustls23" ] rustls23 = [ "dep:rustls23", "dep:rustls-pemfile2", "dep:rustls-native-certs7" ] rustls22 = [ "dep:rustls22", "dep:rustls-pemfile2", "dep:rustls-native-certs7", "dep:rustls-pki-types" ] rustls21 = [ "dep:rustls21", "dep:rustls-pemfile1", "dep:rustls-native-certs6" ] [dependencies] # Note: md5 is needed to calculate ReferenceIDs for IPv6 addresses per RFC5905 md-5.workspace = true rand.workspace = true tracing.workspace = true serde.workspace = true arbitrary = { workspace = true, optional = true } rustls23 = { workspace = true, optional = true } rustls22 = { workspace = true, optional = true } rustls21 = { workspace = true, optional = true } rustls-pki-types = { workspace = true, optional = true } rustls-pemfile2 = { workspace = true, optional = true } rustls-pemfile1 = { workspace = true, optional = true } rustls-native-certs6 = { workspace = true, optional = true } rustls-native-certs7 = { workspace = true, optional = true } aead.workspace = true aes-siv.workspace = true zeroize.workspace = true [dev-dependencies] serde_test.workspace = true ntp-proto-1.4.0/LICENSE-APACHE000064400000000000000000000227731046102023000135270ustar 00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS ntp-proto-1.4.0/LICENSE-MIT000064400000000000000000000021201046102023000132170ustar 00000000000000Copyright (c) 2022-2024 Trifecta Tech Foundation, Tweede Golf, and Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ntp-proto-1.4.0/README.md000064400000000000000000000005741046102023000130550ustar 00000000000000# ntp-proto This crate contains packet parsing and algorithm code for ntpd-rs and is not intended as a public interface at this time. It follows the same version as the main ntpd-rs crate, but that version is not intended to give any stability guarantee. Use at your own risk. Please visit the [ntpd-rs](https://github.com/pendulum-project/ntpd-rs) project for more information. ntp-proto-1.4.0/src/algorithm/kalman/combiner.rs000064400000000000000000000231621046102023000177600ustar 00000000000000use crate::{packet::NtpLeapIndicator, time_types::NtpDuration}; use super::{config::AlgorithmConfig, source::KalmanState, SourceSnapshot}; pub(super) struct Combine { pub estimate: KalmanState, pub sources: Vec, pub delay: NtpDuration, pub leap_indicator: Option, } fn vote_leap(selection: &[SourceSnapshot]) -> Option { let mut votes_59 = 0; let mut votes_61 = 0; let mut votes_none = 0; for snapshot in selection { match snapshot.leap_indicator { NtpLeapIndicator::NoWarning => votes_none += 1, NtpLeapIndicator::Leap61 => votes_61 += 1, NtpLeapIndicator::Leap59 => votes_59 += 1, NtpLeapIndicator::Unknown => { panic!("Unsynchronized source selected for synchronization!") } } } if votes_none * 2 > selection.len() { Some(NtpLeapIndicator::NoWarning) } else if votes_59 * 2 > selection.len() { Some(NtpLeapIndicator::Leap59) } else if votes_61 * 2 > selection.len() { Some(NtpLeapIndicator::Leap61) } else { None } } pub(super) fn combine( selection: &[SourceSnapshot], algo_config: &AlgorithmConfig, ) -> Option> { selection.first().map(|first| { let mut estimate = first.state; if !algo_config.ignore_server_dispersion { estimate = estimate.add_server_dispersion(first.source_uncertainty.to_seconds()) } let mut used_sources = vec![(first.index, estimate.uncertainty.determinant())]; for snapshot in selection.iter().skip(1) { let source_estimate = if algo_config.ignore_server_dispersion { snapshot.state } else { snapshot .state .add_server_dispersion(snapshot.source_uncertainty.to_seconds()) }; used_sources.push((snapshot.index, source_estimate.uncertainty.determinant())); estimate = estimate.merge(&source_estimate); } used_sources.sort_by(|a, b| a.1.total_cmp(&b.1)); Combine { estimate, sources: used_sources.iter().map(|v| v.0).collect(), delay: selection .iter() .map(|v| NtpDuration::from_seconds(v.delay) + v.source_delay) .min() .unwrap_or(NtpDuration::from_seconds(first.delay) + first.source_delay), leap_indicator: vote_leap(selection), } }) } #[cfg(test)] mod tests { use crate::{ algorithm::kalman::{ matrix::{Matrix, Vector}, source::KalmanState, }, time_types::NtpTimestamp, }; use super::*; fn snapshot_for_state( state: Vector<2>, uncertainty: Matrix<2, 2>, source_uncertainty: f64, ) -> SourceSnapshot { SourceSnapshot { index: 0, state: KalmanState { state, uncertainty, time: NtpTimestamp::from_fixed_int(0), }, wander: 0.0, delay: 0.0, source_uncertainty: NtpDuration::from_seconds(source_uncertainty), source_delay: NtpDuration::from_seconds(0.01), leap_indicator: NtpLeapIndicator::NoWarning, last_update: NtpTimestamp::from_fixed_int(0), } } #[test] fn test_none() { let selected: Vec> = vec![]; let algconfig = AlgorithmConfig::default(); assert!(combine(&selected, &algconfig).is_none()); } #[test] fn test_single() { let selected = vec![snapshot_for_state( Vector::new_vector([0.0, 0.0]), Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), 1e-3, )]; let algconfig = AlgorithmConfig { ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert!((result.estimate.offset_variance() - 2e-6).abs() < 1e-12); let algconfig = AlgorithmConfig { ignore_server_dispersion: true, ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert!((result.estimate.offset_variance() - 1e-6).abs() < 1e-12); } #[test] fn test_multiple() { let selected = vec![ snapshot_for_state( Vector::new_vector([0.0, 0.0]), Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), 1e-3, ), snapshot_for_state( Vector::new_vector([1e-3, 0.0]), Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), 1e-3, ), ]; let algconfig = AlgorithmConfig { ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert!((result.estimate.offset() - 5e-4).abs() < 1e-8); assert!(result.estimate.frequency().abs() < 1e-8); assert!((result.estimate.offset_variance() - 1e-6).abs() < 1e-12); assert!((result.estimate.frequency_variance() - 5e-13).abs() < 1e-16); let algconfig = AlgorithmConfig { ignore_server_dispersion: true, ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert!((result.estimate.offset() - 5e-4).abs() < 1e-8); assert!(result.estimate.frequency().abs() < 1e-8); assert!((result.estimate.offset_variance() - 5e-7).abs() < 1e-12); assert!((result.estimate.frequency_variance() - 5e-13).abs() < 1e-16); } #[test] fn test_sort_order() { let mut selected = vec![ snapshot_for_state( Vector::new_vector([0.0, 0.0]), Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), 1e-3, ), snapshot_for_state( Vector::new_vector([1e-3, 0.0]), Matrix::new([[2e-6, 0.0], [0.0, 2e-12]]), 1e-3, ), ]; selected[0].index = 0; selected[1].index = 1; let algconfig = AlgorithmConfig { ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.sources, vec![0, 1]); let mut selected = vec![ snapshot_for_state( Vector::new_vector([1e-3, 0.0]), Matrix::new([[2e-6, 0.0], [0.0, 2e-12]]), 1e-3, ), snapshot_for_state( Vector::new_vector([0.0, 0.0]), Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), 1e-3, ), ]; selected[0].index = 0; selected[1].index = 1; let algconfig = AlgorithmConfig { ..Default::default() }; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.sources, vec![1, 0]); } fn snapshot_for_leap(leap: NtpLeapIndicator) -> SourceSnapshot { SourceSnapshot { index: 0, state: KalmanState { state: Vector::new_vector([0.0, 0.0]), uncertainty: Matrix::new([[1e-6, 0.0], [0.0, 1e-12]]), time: NtpTimestamp::from_fixed_int(0), }, wander: 0.0, delay: 0.0, source_uncertainty: NtpDuration::from_seconds(0.0), source_delay: NtpDuration::from_seconds(0.0), leap_indicator: leap, last_update: NtpTimestamp::from_fixed_int(0), } } #[test] fn test_leap_vote() { let algconfig = AlgorithmConfig::default(); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::NoWarning), snapshot_for_leap(NtpLeapIndicator::NoWarning), snapshot_for_leap(NtpLeapIndicator::NoWarning), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, Some(NtpLeapIndicator::NoWarning)); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::Leap59), snapshot_for_leap(NtpLeapIndicator::Leap59), snapshot_for_leap(NtpLeapIndicator::Leap59), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, Some(NtpLeapIndicator::Leap59)); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::Leap61), snapshot_for_leap(NtpLeapIndicator::Leap61), snapshot_for_leap(NtpLeapIndicator::Leap61), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, Some(NtpLeapIndicator::Leap61)); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::Leap61), snapshot_for_leap(NtpLeapIndicator::Leap59), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, None); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::NoWarning), snapshot_for_leap(NtpLeapIndicator::Leap61), snapshot_for_leap(NtpLeapIndicator::Leap61), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, Some(NtpLeapIndicator::Leap61)); let selected = vec![ snapshot_for_leap(NtpLeapIndicator::NoWarning), snapshot_for_leap(NtpLeapIndicator::Leap59), snapshot_for_leap(NtpLeapIndicator::Leap61), ]; let result = combine(&selected, &algconfig).unwrap(); assert_eq!(result.leap_indicator, None); } } ntp-proto-1.4.0/src/algorithm/kalman/config.rs000064400000000000000000000205161046102023000174270ustar 00000000000000use serde::Deserialize; use crate::time_types::NtpDuration; #[derive(Debug, Copy, Clone, Deserialize)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub struct AlgorithmConfig { /// Probability bound below which we start moving towards decreasing /// our precision estimate. (probability, 0-1) #[serde(default = "default_precision_low_probability")] pub precision_low_probability: f64, /// Probability bound above which we start moving towards increasing /// our precision estimate. (probability, 0-1) #[serde(default = "default_precision_high_probability")] pub precision_high_probability: f64, /// Amount of hysteresis in changing the precision estimate. (count, 1+) #[serde(default = "default_precision_hysteresis")] pub precision_hysteresis: i32, /// Lower bound on the amount of effect our precision estimate /// has on the total noise estimate before we allow decreasing /// of the precision estimate. (weight, 0-1) #[serde(default = "default_precision_minimum_weight")] pub precision_minimum_weight: f64, /// Amount which a measurement contributes to the state, below /// which we start increasing the poll interval. (weight, 0-1) #[serde(default = "default_poll_interval_low_weight")] pub poll_interval_low_weight: f64, /// Amount which a measurement contributes to the state, above /// which we start decreasing the poll_interval interval. (weight, 0-1) #[serde(default = "default_poll_interval_high_weight")] pub poll_interval_high_weight: f64, /// Amount of hysteresis in changing the poll interval (count, 1+) #[serde(default = "default_poll_interval_hysteresis")] pub poll_interval_hysteresis: i32, /// Probability threshold for when a measurement is considered a /// significant enough outlier that we decide something weird is /// going on and we need to do more measurements. (probability, 0-1) #[serde(default = "default_poll_interval_step_threshold")] pub poll_interval_step_threshold: f64, /// Threshold (in number of standard deviations) above which /// measurements with a significantly larger network delay /// are rejected. (standard deviations, 0+) #[serde(default = "default_delay_outlier_threshold")] pub delay_outlier_threshold: f64, /// Initial estimate of the clock wander of the combination /// of our local clock and that of the source. (s/s^2) #[serde(default = "default_initial_wander")] pub initial_wander: f64, /// Initial uncertainty of the frequency difference between /// our clock and that of the source. (s/s) #[serde(default = "default_initial_frequency_uncertainty")] pub initial_frequency_uncertainty: f64, /// Maximum source uncertainty before we start disregarding it /// Note that this is combined uncertainty due to noise and /// possible asymmetry error (see also weights below). (seconds) #[serde(default = "default_maximum_source_uncertainty")] pub maximum_source_uncertainty: f64, /// Weight of statistical uncertainty when constructing /// overlap ranges. (standard deviations, 0+) #[serde(default = "default_range_statistical_weight")] pub range_statistical_weight: f64, /// Weight of delay uncertainty when constructing overlap /// ranges. (weight, 0-1) #[serde(default = "default_range_delay_weight")] pub range_delay_weight: f64, /// How far from 0 (in multiples of the uncertainty) should /// the offset be before we correct. (standard deviations, 0+) #[serde(default = "default_steer_offset_threshold")] pub steer_offset_threshold: f64, /// How many standard deviations do we leave after offset /// correction? (standard deviations, 0+) #[serde(default = "default_steer_offset_leftover")] pub steer_offset_leftover: f64, /// How far from 0 (in multiples of the uncertainty) should /// the frequency estimate be before we correct. (standard deviations, 0+) #[serde(default = "default_steer_frequency_threshold")] pub steer_frequency_threshold: f64, /// How many standard deviations do we leave after frequency /// correction? (standard deviations, 0+) #[serde(default = "default_steer_frequency_leftover")] pub steer_frequency_leftover: f64, /// From what offset should we step the clock instead of /// trying to adjust gradually? (seconds, 0+) #[serde(default = "default_step_threshold")] pub step_threshold: f64, /// What is the maximum frequency offset during a slew (s/s) #[serde(default = "default_slew_maximum_frequency_offset")] pub slew_maximum_frequency_offset: f64, /// What is the minimum duration of a slew (s) #[serde(default = "default_slew_minimum_duration")] pub slew_minimum_duration: f64, /// Absolute maximum frequency correction (s/s) #[serde(default = "default_maximum_frequency_steer")] pub maximum_frequency_steer: f64, /// Ignore a servers advertised dispersion when synchronizing. /// Can improve synchronization quality with servers reporting /// overly conservative root dispersion. #[serde(default)] pub ignore_server_dispersion: bool, /// Threshold for detecting external clock meddling #[serde(default = "default_meddling_threshold")] pub meddling_threshold: NtpDuration, } impl Default for AlgorithmConfig { fn default() -> Self { Self { precision_low_probability: default_precision_low_probability(), precision_high_probability: default_precision_high_probability(), precision_hysteresis: default_precision_hysteresis(), precision_minimum_weight: default_precision_minimum_weight(), poll_interval_low_weight: default_poll_interval_low_weight(), poll_interval_high_weight: default_poll_interval_high_weight(), poll_interval_hysteresis: default_poll_interval_hysteresis(), poll_interval_step_threshold: default_poll_interval_step_threshold(), delay_outlier_threshold: default_delay_outlier_threshold(), initial_wander: default_initial_wander(), initial_frequency_uncertainty: default_initial_frequency_uncertainty(), maximum_source_uncertainty: default_maximum_source_uncertainty(), range_statistical_weight: default_range_statistical_weight(), range_delay_weight: default_range_delay_weight(), steer_offset_threshold: default_steer_offset_threshold(), steer_offset_leftover: default_steer_offset_leftover(), steer_frequency_threshold: default_steer_frequency_threshold(), steer_frequency_leftover: default_steer_frequency_leftover(), step_threshold: default_step_threshold(), slew_maximum_frequency_offset: default_slew_maximum_frequency_offset(), slew_minimum_duration: default_slew_minimum_duration(), maximum_frequency_steer: default_maximum_frequency_steer(), ignore_server_dispersion: false, meddling_threshold: default_meddling_threshold(), } } } fn default_precision_low_probability() -> f64 { 1. / 3. } fn default_precision_high_probability() -> f64 { 2. / 3. } fn default_precision_hysteresis() -> i32 { 16 } fn default_precision_minimum_weight() -> f64 { 0.1 } fn default_poll_interval_low_weight() -> f64 { 0.4 } fn default_poll_interval_high_weight() -> f64 { 0.6 } fn default_poll_interval_hysteresis() -> i32 { 16 } fn default_poll_interval_step_threshold() -> f64 { 1e-6 } fn default_delay_outlier_threshold() -> f64 { 5. } fn default_initial_wander() -> f64 { 1e-8 } fn default_initial_frequency_uncertainty() -> f64 { 100e-6 } fn default_maximum_source_uncertainty() -> f64 { 0.250 } fn default_range_statistical_weight() -> f64 { 2. } fn default_range_delay_weight() -> f64 { 0.25 } fn default_steer_offset_threshold() -> f64 { 2.0 } fn default_steer_offset_leftover() -> f64 { 1.0 } fn default_steer_frequency_threshold() -> f64 { 0.0 } fn default_steer_frequency_leftover() -> f64 { 0.0 } fn default_step_threshold() -> f64 { 0.010 } fn default_slew_maximum_frequency_offset() -> f64 { 200e-6 } fn default_maximum_frequency_steer() -> f64 { 495e-6 } fn default_slew_minimum_duration() -> f64 { 8.0 } fn default_meddling_threshold() -> NtpDuration { NtpDuration::from_seconds(5.) } ntp-proto-1.4.0/src/algorithm/kalman/matrix.rs000064400000000000000000000137251046102023000174720ustar 00000000000000use std::ops::{Add, Mul, Sub}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct Matrix { data: [[f64; M]; N], } pub type Vector = Matrix; impl Matrix { pub fn new(data: [[f64; M]; N]) -> Self { Matrix { data } } pub fn transpose(self) -> Matrix { Matrix { data: std::array::from_fn(|i| std::array::from_fn(|j| self.data[j][i])), } } pub fn entry(&self, i: usize, j: usize) -> f64 { assert!(i < N && j < M); self.data[i][j] } } impl Vector { pub fn new_vector(data: [f64; N]) -> Self { Self { data: std::array::from_fn(|i| std::array::from_fn(|_| data[i])), } } pub fn ventry(&self, i: usize) -> f64 { self.data[i][0] } pub fn inner(&self, rhs: Vector) -> f64 { (0..N).map(|i| self.data[i][0] * rhs.data[i][0]).sum() } } impl Matrix { pub fn symmetrize(self) -> Self { Matrix { data: std::array::from_fn(|i| { std::array::from_fn(|j| (self.data[i][j] + self.data[j][i]) / 2.) }), } } pub fn unit() -> Self { Matrix { data: std::array::from_fn(|i| std::array::from_fn(|j| if i == j { 1.0 } else { 0.0 })), } } } impl Matrix<1, 1> { pub fn inverse(self) -> Self { Matrix { data: [[1. / self.data[0][0]]], } } pub fn determinant(self) -> f64 { self.data[0][0] } } impl Matrix<2, 2> { pub fn inverse(self) -> Self { let d = 1. / (self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0]); Matrix { data: [ [d * self.data[1][1], -d * self.data[0][1]], [-d * self.data[1][0], d * self.data[0][0]], ], } } pub fn determinant(self) -> f64 { self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0] } } impl std::fmt::Display for Matrix { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for i in 0..N { for j in 0..M { if j != 0 { f.write_str(" ")?; } f.write_fmt(format_args!("{:>14.10}", self.data[i][j]))?; } if i != N - 1 { f.write_str("\n")?; } } Ok(()) } } impl Mul> for Matrix { type Output = Matrix; fn mul(self, rhs: Matrix) -> Self::Output { Matrix { data: std::array::from_fn(|i| { std::array::from_fn(|j| (0..K).map(|k| self.data[i][k] * rhs.data[k][j]).sum()) }), } } } impl Mul> for f64 { type Output = Matrix; fn mul(self, rhs: Matrix) -> Self::Output { Matrix { data: std::array::from_fn(|i| std::array::from_fn(|j| self * rhs.data[i][j])), } } } impl Add> for Matrix { type Output = Matrix; fn add(self, rhs: Matrix) -> Self::Output { Matrix { data: std::array::from_fn(|i| { std::array::from_fn(|j| self.data[i][j] + rhs.data[i][j]) }), } } } impl Sub> for Matrix { type Output = Matrix; fn sub(self, rhs: Matrix) -> Self::Output { Matrix { data: std::array::from_fn(|i| { std::array::from_fn(|j| self.data[i][j] - rhs.data[i][j]) }), } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_matrix_mul() { let a = Matrix::new([[1., 2.], [3., 4.]]); let b = Matrix::new([[5., 6.], [7., 8.]]); let c = Matrix::new([[19., 22.], [43., 50.]]); assert_eq!(c, a * b); } #[test] fn test_matrix_vector_mul() { let a = Matrix::new([[1., 2.], [3., 4.]]); let b = Vector::new_vector([5., 6.]); let c = Vector::new_vector([17., 39.]); assert_eq!(c, a * b); } #[test] fn test_matrix_inverse() { let a = Matrix::new([[1., 1.], [1., 2.]]); let b = a.inverse(); assert_eq!(a * b, Matrix::unit()); } #[test] fn test_matrix_transpose() { let a = Matrix::new([[1., 1.], [0., 1.]]); let b = Matrix::new([[1., 0.], [1., 1.]]); assert_eq!(a.transpose(), b); assert_eq!(b.transpose(), a); } #[test] fn test_matrix_add() { let a = Matrix::new([[1., 0.], [0., 1.]]); let b = Matrix::new([[0., -1.], [-1., 0.]]); let c = Matrix::new([[1., -1.], [-1., 1.]]); assert_eq!(a + b, c); } #[test] fn test_matrix_sub() { let a = Matrix::new([[1., 0.], [0., 1.]]); let b = Matrix::new([[0., 1.], [1., 0.]]); let c = Matrix::new([[1., -1.], [-1., 1.]]); assert_eq!(a - b, c); } #[test] fn test_vector_add() { let a = Vector::new_vector([1., 0.]); let b = Vector::new_vector([0., -1.]); let c = Vector::new_vector([1., -1.]); assert_eq!(a + b, c); } #[test] fn test_vector_sub() { let a = Vector::new_vector([1., 0.]); let b = Vector::new_vector([0., 1.]); let c = Vector::new_vector([1., -1.]); assert_eq!(a - b, c); } #[test] fn test_matrix_rendering() { let a = Matrix::new([[1.0, 2.0], [3.0, 4.0]]); assert_eq!( format!("{a}"), " 1.0000000000 2.0000000000\n 3.0000000000 4.0000000000" ); } #[test] fn test_vector_rendering() { let a = Vector::new_vector([5.0, 6.0]); assert_eq!(format!("{a}"), " 5.0000000000\n 6.0000000000"); } } ntp-proto-1.4.0/src/algorithm/kalman/mod.rs000064400000000000000000000670211046102023000167430ustar 00000000000000use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration}; pub(crate) use source::AveragingBuffer; use source::OneWayKalmanSourceController; use tracing::{debug, error, info}; use crate::{ clock::NtpClock, config::{SourceDefaultsConfig, SynchronizationConfig}, packet::NtpLeapIndicator, system::TimeSnapshot, time_types::{NtpDuration, NtpTimestamp}, }; use self::{combiner::combine, config::AlgorithmConfig, source::KalmanState}; use super::{ObservableSourceTimedata, StateUpdate, TimeSyncController}; mod combiner; pub(super) mod config; mod matrix; mod select; mod source; pub use source::{KalmanSourceController, TwoWayKalmanSourceController}; fn sqr(x: f64) -> f64 { x * x } #[derive(Debug, Clone, Copy)] struct SourceSnapshot { index: Index, state: KalmanState, wander: f64, delay: f64, source_uncertainty: NtpDuration, source_delay: NtpDuration, leap_indicator: NtpLeapIndicator, last_update: NtpTimestamp, } impl SourceSnapshot { fn offset(&self) -> f64 { self.state.offset() } fn offset_uncertainty(&self) -> f64 { self.state.offset_variance().sqrt() } fn observe(&self) -> ObservableSourceTimedata { ObservableSourceTimedata { offset: NtpDuration::from_seconds(self.offset()), uncertainty: NtpDuration::from_seconds(self.offset_uncertainty()), delay: NtpDuration::from_seconds(self.delay), remote_delay: self.source_delay, remote_uncertainty: self.source_uncertainty, last_update: self.last_update, } } } #[derive(Debug, Clone)] pub struct KalmanControllerMessage { inner: KalmanControllerMessageInner, } #[derive(Debug, Clone)] enum KalmanControllerMessageInner { Step { steer: f64 }, FreqChange { steer: f64, time: NtpTimestamp }, } #[derive(Debug, Clone, Copy)] pub struct KalmanSourceMessage { inner: SourceSnapshot, } #[derive(Debug, Clone)] pub struct KalmanClockController { sources: HashMap>, bool)>, clock: C, synchronization_config: SynchronizationConfig, source_defaults_config: SourceDefaultsConfig, algo_config: AlgorithmConfig, freq_offset: f64, timedata: TimeSnapshot, desired_freq: f64, in_startup: bool, } impl KalmanClockController { fn update_clock( &mut self, time: NtpTimestamp, ) -> StateUpdate { // ensure all filters represent the same (current) time if self .sources .iter() .filter_map(|(_, (state, _))| state.map(|v| v.state.time)) .any(|sourcetime| time - sourcetime < NtpDuration::ZERO) { return StateUpdate { source_message: None, used_sources: None, time_snapshot: Some(self.timedata), next_update: None, }; } for (_, (state, _)) in self.sources.iter_mut() { if let Some(ref mut snapshot) = state { snapshot.state = snapshot.state.progress_time(time, snapshot.wander) } } let selection = select::select( &self.synchronization_config, &self.algo_config, self.sources .iter() .filter_map( |(_, (state, usable))| { if *usable { state.as_ref() } else { None } }, ) .cloned() .collect(), ); if let Some(combined) = combine(&selection, &self.algo_config) { info!( "Offset: {}+-{}ms, frequency: {}+-{}ppm", combined.estimate.offset() * 1e3, combined.estimate.offset_variance().sqrt() * 1e3, combined.estimate.frequency() * 1e6, combined.estimate.frequency_variance().sqrt() * 1e6 ); if self.in_startup { self.clock .disable_ntp_algorithm() .expect("Cannot update clock"); } let freq_delta = combined.estimate.frequency() - self.desired_freq; let freq_uncertainty = combined.estimate.frequency_variance().sqrt(); let offset_delta = combined.estimate.offset(); let offset_uncertainty = combined.estimate.offset_variance().sqrt(); let next_update = if self.desired_freq == 0.0 && offset_delta.abs() > offset_uncertainty * self.algo_config.steer_offset_threshold { // Note: because of threshold effects, offset_delta is likely an extreme estimate // at this point. Hence we only correct it partially in order to avoid // overcorrecting. // The same does not apply to freq_delta, so if we start slewing // it can be fully corrected without qualms. self.steer_offset( offset_delta - offset_uncertainty * self.algo_config.steer_offset_leftover * offset_delta.signum(), freq_delta, ) } else if freq_delta.abs() > freq_uncertainty * self.algo_config.steer_frequency_threshold { // Note: because of threshold effects, freq_delta is likely an extreme estimate // at this point. Hence we only correct it partially in order to avoid // overcorrecting. self.steer_frequency( freq_delta - freq_uncertainty * self.algo_config.steer_frequency_leftover * freq_delta.signum(), ) } else { StateUpdate::default() }; self.timedata.root_delay = combined.delay; self.timedata.root_dispersion = NtpDuration::from_seconds(combined.estimate.offset_variance().sqrt()); self.clock .error_estimate_update(self.timedata.root_dispersion, self.timedata.root_delay) .expect("Cannot update clock"); if let Some(leap) = combined.leap_indicator { self.clock.status_update(leap).expect("Cannot update clock"); self.timedata.leap_indicator = leap; } // After a successful measurement we are out of startup. self.in_startup = false; StateUpdate { used_sources: Some(combined.sources), time_snapshot: Some(self.timedata), ..next_update } } else { info!("No consensus on current time"); StateUpdate { time_snapshot: Some(self.timedata), ..StateUpdate::default() } } } fn check_offset_steer(&mut self, change: f64) { let change = NtpDuration::from_seconds(change); if self.in_startup { if !self .synchronization_config .startup_step_panic_threshold .is_within(change) { error!("Unusually large clock step suggested, please manually verify system clock and reference clock state and restart if appropriate. If the clock is significantly wrong, you can use `ntp-ctl force-sync` to correct it."); #[cfg(not(test))] std::process::exit(crate::exitcode::SOFTWARE); #[cfg(test)] panic!("Threshold exceeded"); } } else { self.timedata.accumulated_steps += change.abs(); if !self .synchronization_config .single_step_panic_threshold .is_within(change) || self .synchronization_config .accumulated_step_panic_threshold .map(|v| self.timedata.accumulated_steps > v) .unwrap_or(false) { error!("Unusually large clock step suggested, please manually verify system clock and reference clock state and restart if appropriate. If the clock is significantly wrong, you can use `ntp-ctl force-sync` to correct it."); #[cfg(not(test))] std::process::exit(crate::exitcode::SOFTWARE); #[cfg(test)] panic!("Threshold exceeded"); } } } fn steer_offset( &mut self, change: f64, freq_delta: f64, ) -> StateUpdate { if change.abs() > self.algo_config.step_threshold { // jump self.check_offset_steer(change); self.clock .step_clock(NtpDuration::from_seconds(change)) .expect("Cannot adjust clock"); for (state, _) in self.sources.values_mut() { if let Some(ref mut state) = state { state.state = state.state.process_offset_steering(change); } } info!("Jumped offset by {}ms", change * 1e3); StateUpdate { source_message: Some(KalmanControllerMessage { inner: KalmanControllerMessageInner::Step { steer: change }, }), ..StateUpdate::default() } } else { // start slew let freq = self .algo_config .slew_maximum_frequency_offset .min(change.abs() / self.algo_config.slew_minimum_duration); let duration = Duration::from_secs_f64(change.abs() / freq); debug!( "Slewing by {}ms over {}s", change * 1e3, duration.as_secs_f64(), ); let update = self.change_desired_frequency(-freq * change.signum(), freq_delta); StateUpdate { next_update: Some(duration), ..update } } } fn change_desired_frequency( &mut self, new_freq: f64, freq_delta: f64, ) -> StateUpdate { let change = self.desired_freq - new_freq + freq_delta; self.desired_freq = new_freq; self.steer_frequency(change) } fn steer_frequency(&mut self, change: f64) -> StateUpdate { let new_freq_offset = ((1.0 + self.freq_offset) * (1.0 + change) - 1.0).clamp( -self.algo_config.maximum_frequency_steer, self.algo_config.maximum_frequency_steer, ); let actual_change = (1.0 + new_freq_offset) / (1.0 + self.freq_offset) - 1.0; self.freq_offset = new_freq_offset; let freq_update = self .clock .set_frequency(self.freq_offset) .expect("Cannot adjust clock"); for (state, _) in self.sources.values_mut() { if let Some(ref mut state) = state { state.state = state .state .process_frequency_steering(freq_update, actual_change, state.wander) } } debug!( "Changed frequency, current steer {}ppm, desired freq {}ppm", self.freq_offset * 1e6, self.desired_freq * 1e6, ); StateUpdate { source_message: Some(KalmanControllerMessage { inner: KalmanControllerMessageInner::FreqChange { steer: actual_change, time: freq_update, }, }), ..StateUpdate::default() } } } impl TimeSyncController for KalmanClockController { type Clock = C; type SourceId = SourceId; type AlgorithmConfig = AlgorithmConfig; type ControllerMessage = KalmanControllerMessage; type SourceMessage = KalmanSourceMessage; type NtpSourceController = TwoWayKalmanSourceController; type OneWaySourceController = OneWayKalmanSourceController; fn new( clock: C, synchronization_config: SynchronizationConfig, source_defaults_config: SourceDefaultsConfig, algo_config: Self::AlgorithmConfig, ) -> Result { // Setup clock let freq_offset = clock.get_frequency()?; Ok(KalmanClockController { sources: HashMap::new(), clock, synchronization_config, source_defaults_config, algo_config, freq_offset, desired_freq: 0.0, timedata: TimeSnapshot::default(), in_startup: true, }) } fn take_control(&mut self) -> Result<(), ::Error> { self.clock.disable_ntp_algorithm()?; self.clock.status_update(NtpLeapIndicator::Unknown)?; Ok(()) } fn add_source(&mut self, id: SourceId) -> Self::NtpSourceController { self.sources.insert(id, (None, false)); KalmanSourceController::new( id, self.algo_config, self.source_defaults_config, AveragingBuffer::default(), ) } fn add_one_way_source( &mut self, id: SourceId, measurement_noise_estimate: f64, ) -> Self::OneWaySourceController { self.sources.insert(id, (None, false)); KalmanSourceController::new( id, self.algo_config, self.source_defaults_config, measurement_noise_estimate, ) } fn remove_source(&mut self, id: SourceId) { self.sources.remove(&id); } fn source_update(&mut self, id: SourceId, usable: bool) { if let Some(state) = self.sources.get_mut(&id) { state.1 = usable; } } fn time_update(&mut self) -> StateUpdate { // End slew self.change_desired_frequency(0.0, 0.0) } fn source_message( &mut self, id: SourceId, message: Self::SourceMessage, ) -> StateUpdate { if let Some(source) = self.sources.get_mut(&id) { let time = message.inner.last_update; source.0 = Some(message.inner); self.update_clock(time) } else { error!("Internal error: Update from non-existing source"); StateUpdate::default() } } } #[cfg(test)] mod tests { use std::cell::RefCell; use matrix::{Matrix, Vector}; use crate::config::StepThreshold; use crate::source::Measurement; use crate::time_types::NtpInstant; use crate::SourceController; use super::*; #[derive(Debug, Clone)] struct TestClock { has_steered: RefCell, current_time: NtpTimestamp, } impl NtpClock for TestClock { type Error = std::io::Error; fn now(&self) -> Result { Ok(self.current_time) } fn set_frequency(&self, _freq: f64) -> Result { *self.has_steered.borrow_mut() = true; Ok(self.current_time) } fn get_frequency(&self) -> Result { Ok(0.0) } fn step_clock(&self, _offset: NtpDuration) -> Result { *self.has_steered.borrow_mut() = true; Ok(self.current_time) } fn disable_ntp_algorithm(&self) -> Result<(), Self::Error> { Ok(()) } fn error_estimate_update( &self, _est_error: NtpDuration, _maximum_error: NtpDuration, ) -> Result<(), Self::Error> { Ok(()) } fn status_update(&self, _leap_status: NtpLeapIndicator) -> Result<(), Self::Error> { Ok(()) } } #[test] fn test_startup_flag_unsets() { let synchronization_config = SynchronizationConfig { minimum_agreeing_sources: 1, ..SynchronizationConfig::default() }; let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); let mut cur_instant = NtpInstant::now(); // ignore startup steer of frequency. *algo.clock.has_steered.borrow_mut() = false; let mut source = algo.add_source(0); algo.source_update(0, true); assert!(algo.in_startup); let mut noise = 1e-9; while !*algo.clock.has_steered.borrow() { cur_instant = cur_instant + std::time::Duration::from_secs(1); algo.clock.current_time += NtpDuration::from_seconds(1.0); noise += 1e-9; let message = source.handle_measurement(Measurement { delay: NtpDuration::from_seconds(0.001 + noise), offset: NtpDuration::from_seconds(1700.0 + noise), localtime: algo.clock.current_time, monotime: cur_instant, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }); if let Some(message) = message { let actions = algo.source_message(0, message); if let Some(source_message) = actions.source_message { source.handle_message(source_message); } } } assert!(!algo.in_startup); assert_eq!(algo.timedata.leap_indicator, NtpLeapIndicator::NoWarning); assert_ne!(algo.timedata.root_delay, NtpDuration::ZERO); assert_ne!(algo.timedata.root_dispersion, NtpDuration::ZERO); } #[test] fn slews_dont_accumulate() { let synchronization_config = SynchronizationConfig { minimum_agreeing_sources: 1, single_step_panic_threshold: StepThreshold { forward: None, backward: None, }, ..SynchronizationConfig::default() }; let algo_config = AlgorithmConfig { step_threshold: 1800.0, ..Default::default() }; let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::<_, u32>::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); algo.in_startup = false; algo.steer_offset(1000.0, 0.0); assert_eq!(algo.timedata.accumulated_steps, NtpDuration::ZERO); } #[test] #[should_panic] fn jumps_add_absolutely() { let synchronization_config = SynchronizationConfig { minimum_agreeing_sources: 1, single_step_panic_threshold: StepThreshold { forward: None, backward: None, }, accumulated_step_panic_threshold: Some(NtpDuration::from_seconds(1800.0)), ..SynchronizationConfig::default() }; let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::<_, u32>::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); algo.in_startup = false; algo.steer_offset(1000.0, 0.0); algo.steer_offset(-1000.0, 0.0); } #[test] fn test_jumps_update_state() { let synchronization_config = SynchronizationConfig::default(); let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::<_, u32>::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); algo.sources.insert( 0, ( Some(SourceSnapshot { index: 0, state: KalmanState { state: Vector::new_vector([0.0, 0.0]), uncertainty: Matrix::new([[1e-18, 0.0], [0.0, 1e-18]]), time: NtpTimestamp::from_fixed_int(0), }, wander: 0.0, delay: 0.0, source_uncertainty: NtpDuration::ZERO, source_delay: NtpDuration::ZERO, leap_indicator: NtpLeapIndicator::NoWarning, last_update: NtpTimestamp::from_fixed_int(0), }), true, ), ); algo.steer_offset(100.0, 0.0); assert_eq!( algo.sources.get(&0).unwrap().0.unwrap().state.offset(), -100.0 ); assert_eq!( algo.sources.get(&0).unwrap().0.unwrap().state.time, NtpTimestamp::from_seconds_nanos_since_ntp_era(100, 0) ); } #[test] fn test_freqsteer_update_state() { let synchronization_config = SynchronizationConfig::default(); let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::<_, u32>::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); algo.sources.insert( 0, ( Some(SourceSnapshot { index: 0, state: KalmanState { state: Vector::new_vector([0.0, 0.0]), uncertainty: Matrix::new([[1e-18, 0.0], [0.0, 1e-18]]), time: NtpTimestamp::from_fixed_int(0), }, wander: 0.0, delay: 0.0, source_uncertainty: NtpDuration::ZERO, source_delay: NtpDuration::ZERO, leap_indicator: NtpLeapIndicator::NoWarning, last_update: NtpTimestamp::from_fixed_int(0), }), true, ), ); algo.steer_frequency(1e-6); assert!(algo.sources.get(&0).unwrap().0.unwrap().state.frequency() - -1e-6 < 1e-12); } #[test] #[should_panic] fn test_large_offset_eventually_panics() { let synchronization_config = SynchronizationConfig { minimum_agreeing_sources: 1, ..SynchronizationConfig::default() }; let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); let mut cur_instant = NtpInstant::now(); // ignore startup steer of frequency. *algo.clock.has_steered.borrow_mut() = false; let mut source = algo.add_source(0); algo.source_update(0, true); let mut noise = 1e-9; loop { cur_instant = cur_instant + std::time::Duration::from_secs(1); algo.clock.current_time += NtpDuration::from_seconds(1800.0); noise += 1e-9; let message = source.handle_measurement(Measurement { delay: NtpDuration::from_seconds(0.001 + noise), offset: NtpDuration::from_seconds(1700.0 + noise), localtime: algo.clock.current_time, monotime: cur_instant, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }); if let Some(message) = message { let actions = algo.source_message(0, message); if let Some(source_message) = actions.source_message { source.handle_message(source_message); } } } } #[test] #[should_panic] fn test_backward_step_panics_before_steer() { let synchronization_config = SynchronizationConfig { minimum_agreeing_sources: 1, startup_step_panic_threshold: StepThreshold { forward: None, backward: Some(NtpDuration::from_seconds(1800.)), }, ..SynchronizationConfig::default() }; let algo_config = AlgorithmConfig::default(); let source_defaults_config = SourceDefaultsConfig::default(); let mut algo = KalmanClockController::new( TestClock { has_steered: RefCell::new(false), current_time: NtpTimestamp::from_fixed_int(0), }, synchronization_config, source_defaults_config, algo_config, ) .unwrap(); let mut cur_instant = NtpInstant::now(); // ignore startup steer of frequency. *algo.clock.has_steered.borrow_mut() = false; let mut source = algo.add_source(0); algo.source_update(0, true); let mut noise = 1e-9; while !*algo.clock.has_steered.borrow() { cur_instant = cur_instant + std::time::Duration::from_secs(1); algo.clock.current_time += NtpDuration::from_seconds(1.0); noise *= -1.0; let message = source.handle_measurement(Measurement { delay: NtpDuration::from_seconds(0.001 + noise), offset: NtpDuration::from_seconds(-3600.0 + noise), localtime: algo.clock.current_time, monotime: cur_instant, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }); if let Some(message) = message { let actions = algo.source_message(0, message); if let Some(source_message) = actions.source_message { source.handle_message(source_message); } } } } } ntp-proto-1.4.0/src/algorithm/kalman/select.rs000064400000000000000000000206611046102023000174420ustar 00000000000000use crate::config::SynchronizationConfig; use super::{config::AlgorithmConfig, SourceSnapshot}; #[derive(Debug)] enum BoundType { Start, End, } // Select a maximum overlapping set of candidates. Note that here we define // overlapping to mean that any part of their confidence intervals overlaps, instead // of the NTP convention that all centers need to be within each others confidence // intervals. // The advantage of doing this is that the algorithm becomes a lot simpler, and it // is also statistically more sound. Any difference (larger set of accepted sources) // can be compensated for if desired by setting tighter bounds on the weights // determining the confidence interval. pub(super) fn select( synchronization_config: &SynchronizationConfig, algo_config: &AlgorithmConfig, candidates: Vec>, ) -> Vec> { let mut bounds: Vec<(f64, BoundType)> = Vec::with_capacity(2 * candidates.len()); for snapshot in candidates.iter() { let radius = snapshot.offset_uncertainty() * algo_config.range_statistical_weight + snapshot.delay * algo_config.range_delay_weight; if radius > algo_config.maximum_source_uncertainty || !snapshot.leap_indicator.is_synchronized() { continue; } bounds.push((snapshot.offset() - radius, BoundType::Start)); bounds.push((snapshot.offset() + radius, BoundType::End)); } bounds.sort_by(|a, b| a.0.total_cmp(&b.0)); let mut max: usize = 0; let mut maxt: f64 = 0.0; let mut cur: usize = 0; for (time, boundtype) in bounds.iter() { match boundtype { BoundType::Start => cur += 1, BoundType::End => cur -= 1, } if cur > max { max = cur; maxt = *time; } } if max >= synchronization_config.minimum_agreeing_sources && max * 4 > bounds.len() { candidates .iter() .filter(|snapshot| { let radius = snapshot.offset_uncertainty() * algo_config.range_statistical_weight + snapshot.delay * algo_config.range_delay_weight; radius <= algo_config.maximum_source_uncertainty && snapshot.offset() - radius <= maxt && snapshot.offset() + radius >= maxt && snapshot.leap_indicator.is_synchronized() }) .cloned() .collect() } else { vec![] } } #[cfg(test)] mod tests { use crate::{ algorithm::kalman::source::KalmanState, packet::NtpLeapIndicator, time_types::{NtpDuration, NtpTimestamp}, }; use super::super::{ matrix::{Matrix, Vector}, sqr, }; use super::*; fn snapshot_for_range(center: f64, uncertainty: f64, delay: f64) -> SourceSnapshot { SourceSnapshot { index: 0, state: KalmanState { state: Vector::new_vector([center, 0.0]), uncertainty: Matrix::new([[sqr(uncertainty), 0.0], [0.0, 10e-12]]), time: NtpTimestamp::from_fixed_int(0), }, wander: 0.0, delay, source_uncertainty: NtpDuration::from_seconds(0.01), source_delay: NtpDuration::from_seconds(0.01), leap_indicator: NtpLeapIndicator::NoWarning, last_update: NtpTimestamp::from_fixed_int(0), } } #[test] fn test_weighing() { // Test that there only is sufficient overlap in the below set when // both statistical and delay based errors are considered. let candidates = vec![ snapshot_for_range(0.0, 0.01, 0.09), snapshot_for_range(0.0, 0.09, 0.01), snapshot_for_range(0.05, 0.01, 0.09), snapshot_for_range(0.05, 0.09, 0.01), ]; let sysconfig = SynchronizationConfig { minimum_agreeing_sources: 4, ..Default::default() }; let algconfig = AlgorithmConfig { maximum_source_uncertainty: 1.0, range_statistical_weight: 1.0, range_delay_weight: 0.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 0); let algconfig = AlgorithmConfig { maximum_source_uncertainty: 1.0, range_statistical_weight: 0.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 0); let algconfig = AlgorithmConfig { maximum_source_uncertainty: 1.0, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates); assert_eq!(result.len(), 4); } #[test] fn test_rejection() { // Test sources get properly rejected as rejection bound gets tightened. let candidates = vec![ snapshot_for_range(0.0, 1.0, 1.0), snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.0, 0.01, 0.01), ]; let sysconfig = SynchronizationConfig { minimum_agreeing_sources: 1, ..Default::default() }; let algconfig = AlgorithmConfig { maximum_source_uncertainty: 3.0, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 3); let algconfig = AlgorithmConfig { maximum_source_uncertainty: 0.3, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 2); let algconfig = AlgorithmConfig { maximum_source_uncertainty: 0.03, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 1); let algconfig = AlgorithmConfig { maximum_source_uncertainty: 0.003, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates); assert_eq!(result.len(), 0); } #[test] fn test_min_survivors() { // Test that minimum number of survivors is correctly tested for. let candidates = vec![ snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.5, 0.1, 0.1), snapshot_for_range(0.5, 0.1, 0.1), ]; let algconfig = AlgorithmConfig { maximum_source_uncertainty: 3.0, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let sysconfig = SynchronizationConfig { minimum_agreeing_sources: 3, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates.clone()); assert_eq!(result.len(), 3); let sysconfig = SynchronizationConfig { minimum_agreeing_sources: 4, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates); assert_eq!(result.len(), 0); } #[test] fn test_tie() { // Test that in the case of a tie no group is chosen. let candidates = vec![ snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.0, 0.1, 0.1), snapshot_for_range(0.5, 0.1, 0.1), snapshot_for_range(0.5, 0.1, 0.1), ]; let algconfig = AlgorithmConfig { maximum_source_uncertainty: 3.0, range_statistical_weight: 1.0, range_delay_weight: 1.0, ..Default::default() }; let sysconfig = SynchronizationConfig { minimum_agreeing_sources: 1, ..Default::default() }; let result = select(&sysconfig, &algconfig, candidates); assert_eq!(result.len(), 0); } } ntp-proto-1.4.0/src/algorithm/kalman/source.rs000064400000000000000000002147361046102023000174730ustar 00000000000000/// This module implements a kalman filter to filter the measurements /// provided by the sources. /// /// The filter tracks the time difference between the local and remote /// timescales. For ease of implementation, it actually is programmed /// mostly as if the local timescale is absolute truth, and the remote /// timescale is the one that is estimated. The filter state is kept at /// a local timestamp t, and progressed in time as needed for processing /// measurements and producing estimation outputs. /// /// This approach is chosen so that it is possible to line up the filters /// from multiple sources (this has no real meaning when using remote /// timescales for that), and makes sure that we control the timescale /// used to express the filter in. /// /// The state is a vector (D, w) where /// - D is the offset between the remote and local timescales /// - w is (in seconds per second) the frequency difference. /// /// For process noise, we assume this is fully resultant from frequency /// drift between the local and remote timescale, and that this frequency /// drift is assumed to be the result from a (limit of) a random walk /// process (wiener process). Under this assumption, a time change from t1 /// to t2 has a state propagation matrix /// 1 (t2-t1) /// 0 0 /// and a noise matrix given by /// v*(t2-t1)^3/3 v*(t2-t1)^2/2 /// v*(t2-t1)^2/2 v*(t2-t1) /// where v is a constant describing how much the frequency drifts per /// unit of time. /// /// This modules input consists of measurements containing: /// - the time of the measurement t_m /// - the measured offset d /// - the measured transmission delay r /// /// On these, we assume that /// - there is no impact from frequency differences on r /// - individual measurements are independent /// /// This information on its own is not enough to feed the kalman filter. /// For this, a further piece of information is needed: a measurement /// related to the frequency difference. Although mathematically not /// entirely sound, we construct the frequency measurement also using /// the previous measurement (which we will denote with t_p and D_p). /// It turns out this works well in practice /// /// The observation is then the vector (D, D-D_p), and the observation /// matrix is given by /// 1 0 /// 0 t_m-t_p /// /// To estimate the measurement noise, the variance s of the transmission /// delays r is used. Writing r as r1 - r2, where r1 is the time /// difference on the client-to-server leg and r2 the time difference on /// the server to client leg, we have Var(D) = Var(1/2 (r1 + r2)) = 1/4 /// Var(r1 - r2) = 1/4 Var(r). Furthermore Var(D+Dp) = Var(D) + Var(Dp) /// = 1/2 Var(r) and Covar(D, D+Dp) = Covar(D, D) + Covar(D, Dp) = Var(D) /// s/4 s/4 /// s/4 s/2 /// /// This setup leaves two major issues: /// - How often do we want measurements (what is the desired polling interval) /// - What is v /// /// The polling interval is changed dynamically such that /// approximately each measurement is about halved before contributing to /// the state (see below). /// /// The value for v is determined by observing how well the distribution /// of measurement errors matches up with what we would statistically expect. /// If they are often too small, v is quartered, and if they are often too /// large, v is quadrupled (note, this corresponds with doubling/halving /// the more intuitive standard deviation). use tracing::{debug, trace}; use crate::{ algorithm::{KalmanControllerMessage, KalmanSourceMessage, SourceController}, config::SourceDefaultsConfig, source::Measurement, time_types::{NtpDuration, NtpTimestamp, PollInterval, PollIntervalLimits}, ObservableSourceTimedata, }; use core::fmt::Debug; use super::{ config::AlgorithmConfig, matrix::{Matrix, Vector}, sqr, SourceSnapshot, }; #[derive(Debug, Clone, Copy)] pub(super) struct KalmanState { pub state: Vector<2>, pub uncertainty: Matrix<2, 2>, // current time of the filter state pub time: NtpTimestamp, } pub(super) struct MeasurementStats { // Probability that the measurement was as close or closer to the prediction from the filter pub observe_probability: f64, // How much the measurement affected the filter state pub weight: f64, } impl KalmanState { #[must_use] pub fn progress_time(&self, time: NtpTimestamp, wander: f64) -> KalmanState { debug_assert!(!time.is_before(self.time)); if time.is_before(self.time) { return *self; } // Time step parameters let delta_t = (time - self.time).to_seconds(); let update = Matrix::new([[1.0, delta_t], [0.0, 1.0]]); let process_noise = Matrix::new([ [ wander * delta_t * delta_t * delta_t / 3., wander * delta_t * delta_t / 2., ], [wander * delta_t * delta_t / 2., wander * delta_t], ]); // Kalman filter update KalmanState { state: update * self.state, uncertainty: update * self.uncertainty * update.transpose() + process_noise, time, } } #[must_use] pub fn absorb_measurement( &self, measurement: Matrix<1, 2>, value: Vector<1>, noise: Matrix<1, 1>, ) -> (KalmanState, MeasurementStats) { let difference = value - measurement * self.state; let difference_covariance = measurement * self.uncertainty * measurement.transpose() + noise; let update_strength = self.uncertainty * measurement.transpose() * difference_covariance.inverse(); // Statistics let observe_probability = chi_1(difference.inner(difference_covariance.inverse() * difference)); // Calculate an indicator of how much of the measurement was incorporated // into the state. 1.0 - is needed here as this should become lower as // measurement noise's contribution to difference uncertainty increases. let weight = 1.0 - noise.determinant() / difference_covariance.determinant(); ( KalmanState { state: self.state + update_strength * difference, uncertainty: ((Matrix::unit() - update_strength * measurement) * self.uncertainty) .symmetrize(), time: self.time, }, MeasurementStats { observe_probability, weight, }, ) } #[must_use] pub fn merge(&self, other: &KalmanState) -> KalmanState { debug_assert_eq!(self.time, other.time); let mixer = (self.uncertainty + other.uncertainty).inverse(); KalmanState { state: self.state + self.uncertainty * mixer * (other.state - self.state), uncertainty: self.uncertainty * mixer * other.uncertainty, time: self.time, } } #[must_use] pub fn add_server_dispersion(&self, dispersion: f64) -> KalmanState { KalmanState { state: self.state, uncertainty: self.uncertainty + Matrix::new([[sqr(dispersion), 0.0], [0.0, 0.0]]), time: self.time, } } #[must_use] pub fn offset(&self) -> f64 { self.state.ventry(0) } #[must_use] pub fn offset_variance(&self) -> f64 { self.uncertainty.entry(0, 0) } #[must_use] pub fn frequency(&self) -> f64 { self.state.ventry(1) } #[must_use] pub fn frequency_variance(&self) -> f64 { self.uncertainty.entry(1, 1) } #[must_use] pub fn process_offset_steering(&self, steer: f64) -> KalmanState { KalmanState { state: self.state - Vector::new_vector([steer, 0.0]), uncertainty: self.uncertainty, time: self.time + NtpDuration::from_seconds(steer), } } #[must_use] pub fn process_frequency_steering( &self, time: NtpTimestamp, steer: f64, wander: f64, ) -> KalmanState { let mut result = self.progress_time(time, wander); result.state = result.state - Vector::new_vector([0.0, steer]); result } } #[derive(Debug, Default, Clone)] pub struct AveragingBuffer { data: [f64; 8], next_idx: usize, } // Large frequency uncertainty as early time essentially gives no reasonable info on frequency. const INITIALIZATION_FREQ_UNCERTAINTY: f64 = 100.0; /// Approximation of 1 - the chi-squared cdf with 1 degree of freedom /// source: https://en.wikipedia.org/wiki/Error_function fn chi_1(chi: f64) -> f64 { const P: f64 = 0.3275911; const A1: f64 = 0.254829592; const A2: f64 = -0.284496736; const A3: f64 = 1.421413741; const A4: f64 = -1.453152027; const A5: f64 = 1.061405429; let x = (chi / 2.).sqrt(); let t = 1. / (1. + P * x); (A1 * t + A2 * t * t + A3 * t * t * t + A4 * t * t * t * t + A5 * t * t * t * t * t) * (-(x * x)).exp() } impl AveragingBuffer { fn mean(&self) -> f64 { self.data.iter().sum::() / (self.data.len() as f64) } fn variance(&self) -> f64 { let mean = self.mean(); self.data.iter().map(|v| sqr(v - mean)).sum::() / ((self.data.len() - 1) as f64) } fn update(&mut self, rtt: f64) { self.data[self.next_idx] = rtt; self.next_idx = (self.next_idx + 1) % self.data.len(); } } pub trait MeasurementNoiseEstimator { type MeasurementDelay; fn update(&mut self, delay: Self::MeasurementDelay); fn get_noise_estimate(&self) -> f64; fn is_outlier(&self, delay: Self::MeasurementDelay, threshold: f64) -> bool; fn preprocess(&self, delay: Self::MeasurementDelay) -> Self::MeasurementDelay; fn reset(&mut self) -> Self; // for SourceSnapshot fn get_max_roundtrip(&self, samples: &i32) -> Option; fn get_delay_mean(&self) -> f64; } impl MeasurementNoiseEstimator for AveragingBuffer { type MeasurementDelay = NtpDuration; fn update(&mut self, delay: Self::MeasurementDelay) { self.update(delay.to_seconds()) } fn get_noise_estimate(&self) -> f64 { self.variance() / 4. } fn is_outlier(&self, delay: Self::MeasurementDelay, threshold: f64) -> bool { (delay.to_seconds() - self.mean()) > threshold * self.variance().sqrt() } fn preprocess(&self, delay: Self::MeasurementDelay) -> Self::MeasurementDelay { delay.max(MIN_DELAY) } fn reset(&mut self) -> Self { AveragingBuffer::default() } fn get_max_roundtrip(&self, samples: &i32) -> Option { self.data[..*samples as usize] .iter() .copied() .fold(None, |v1, v2| { if v2.is_nan() { v1 } else if let Some(v1) = v1 { Some(v2.max(v1)) } else { Some(v2) } }) } fn get_delay_mean(&self) -> f64 { self.mean() } } impl MeasurementNoiseEstimator for f64 { type MeasurementDelay = (); fn update(&mut self, _delay: Self::MeasurementDelay) {} fn get_noise_estimate(&self) -> f64 { *self } fn is_outlier(&self, _delay: Self::MeasurementDelay, _threshold: f64) -> bool { false } fn preprocess(&self, _delay: Self::MeasurementDelay) -> Self::MeasurementDelay {} fn reset(&mut self) -> Self { *self } fn get_max_roundtrip(&self, _samples: &i32) -> Option { Some(1.) } fn get_delay_mean(&self) -> f64 { 0. } } #[derive(Debug, Clone)] struct InitialSourceFilter< D: Debug + Copy + Clone, N: MeasurementNoiseEstimator + Clone, > { noise_estimator: N, init_offset: AveragingBuffer, last_measurement: Option>, samples: i32, } impl + Clone> InitialSourceFilter { fn update(&mut self, measurement: Measurement) { self.noise_estimator.update(measurement.delay); self.init_offset.update(measurement.offset.to_seconds()); self.samples += 1; self.last_measurement = Some(measurement); debug!(samples = self.samples, "Initial source update"); } fn process_offset_steering(&mut self, steer: f64) { for sample in self.init_offset.data.iter_mut() { *sample -= steer; } } } #[derive(Debug, Clone)] struct SourceFilter> { state: KalmanState, clock_wander: f64, noise_estimator: N, precision_score: i32, poll_score: i32, desired_poll_interval: PollInterval, last_measurement: Measurement, prev_was_outlier: bool, // Last time a packet was processed last_iter: NtpTimestamp, } impl> SourceFilter { /// Move the filter forward to reflect the situation at a new, later timestamp fn progress_filtertime(&mut self, time: NtpTimestamp) { self.state = self.state.progress_time(time, self.clock_wander); trace!(?time, "Filter progressed"); } /// Absorb knowledge from a measurement fn absorb_measurement(&mut self, measurement: Measurement) -> (f64, f64, f64) { // Measurement parameters let m_delta_t = (measurement.localtime - self.last_measurement.localtime).to_seconds(); // Kalman filter update let measurement_vec = Vector::new_vector([measurement.offset.to_seconds()]); let measurement_transform = Matrix::new([[1., 0.]]); let measurement_noise = Matrix::new([[self.noise_estimator.get_noise_estimate()]]); let (new_state, stats) = self.state.absorb_measurement( measurement_transform, measurement_vec, measurement_noise, ); self.state = new_state; self.last_measurement = measurement; trace!( stats.observe_probability, stats.weight, "Measurement absorbed" ); (stats.observe_probability, stats.weight, m_delta_t) } /// Ensure we poll often enough to keep the filter well-fed with information, but /// not so much that each individual poll message gives us very little new information. fn update_desired_poll( &mut self, source_defaults_config: &SourceDefaultsConfig, algo_config: &AlgorithmConfig, p: f64, weight: f64, measurement_period: f64, ) { // We don't want to speed up when we already want more than we get, and vice versa. let reference_measurement_period = self.desired_poll_interval.as_duration().to_seconds(); if weight < algo_config.poll_interval_low_weight && measurement_period / reference_measurement_period > 0.75 { self.poll_score -= 1; } else if weight > algo_config.poll_interval_high_weight && measurement_period / reference_measurement_period < 1.4 { self.poll_score += 1; } else { self.poll_score -= self.poll_score.signum(); } trace!(poll_score = self.poll_score, ?weight, "Poll desire update"); if p <= algo_config.poll_interval_step_threshold { self.desired_poll_interval = source_defaults_config.poll_interval_limits.min; self.poll_score = 0; } else if self.poll_score <= -algo_config.poll_interval_hysteresis { self.desired_poll_interval = self .desired_poll_interval .inc(source_defaults_config.poll_interval_limits); self.poll_score = 0; debug!(interval = ?self.desired_poll_interval, "Increased poll interval"); } else if self.poll_score >= algo_config.poll_interval_hysteresis { self.desired_poll_interval = self .desired_poll_interval .dec(source_defaults_config.poll_interval_limits); self.poll_score = 0; debug!(interval = ?self.desired_poll_interval, "Decreased poll interval"); } } // Our estimate for the clock stability might be completely wrong. The code here // correlates the estimation for errors to what we actually observe, so we can // update our estimate should it turn out to be significantly off. fn update_wander_estimate(&mut self, algo_config: &AlgorithmConfig, p: f64, weight: f64) { // Note that chi is exponentially distributed with mean 2 // Also, we do not steer towards a smaller precision estimate when measurement noise dominates. if 1. - p < algo_config.precision_low_probability && weight > algo_config.precision_minimum_weight { self.precision_score -= 1; } else if 1. - p > algo_config.precision_high_probability { self.precision_score += 1; } else { self.precision_score -= self.precision_score.signum(); } trace!( precision_score = self.precision_score, p, "Wander estimate update" ); if self.precision_score <= -algo_config.precision_hysteresis { self.clock_wander /= 4.0; self.precision_score = 0; debug!( wander = self.clock_wander.sqrt(), "Decreased wander estimate" ); } else if self.precision_score >= algo_config.precision_hysteresis { self.clock_wander *= 4.0; self.precision_score = 0; debug!( wander = self.clock_wander.sqrt(), "Increased wander estimate" ); } } /// Update our estimates based on a new measurement. fn update( &mut self, source_defaults_config: &SourceDefaultsConfig, algo_config: &AlgorithmConfig, measurement: Measurement, ) -> bool { // Always update the root_delay, root_dispersion, leap second status and stratum, as they always represent the most accurate state. self.last_measurement.root_delay = measurement.root_delay; self.last_measurement.root_dispersion = measurement.root_dispersion; self.last_measurement.stratum = measurement.stratum; self.last_measurement.leap = measurement.leap; if measurement.localtime.is_before(self.state.time) { // Ignore the past return false; } // This was a valid measurement, so no matter what this represents our current iteration time // for the purposes of synchronizing self.last_iter = measurement.localtime; // Filter out one-time outliers (based on delay!) if !self.prev_was_outlier && self .noise_estimator .is_outlier(measurement.delay, algo_config.delay_outlier_threshold) { self.prev_was_outlier = true; return false; } // Environment update self.progress_filtertime(measurement.localtime); self.noise_estimator.update(measurement.delay); let (p, weight, measurement_period) = self.absorb_measurement(measurement); self.update_wander_estimate(algo_config, p, weight); self.update_desired_poll( source_defaults_config, algo_config, p, weight, measurement_period, ); debug!( "source offset {}±{}ms, freq {}±{}ppm", self.state.offset() * 1000., (self.state.offset_variance() + sqr(self.last_measurement.root_dispersion.to_seconds())) .sqrt() * 1000., self.state.frequency() * 1e6, self.state.frequency_variance().sqrt() * 1e6 ); true } fn process_offset_steering(&mut self, steer: f64) { self.state = self.state.process_offset_steering(steer); self.last_measurement.offset -= NtpDuration::from_seconds(steer); self.last_measurement.localtime += NtpDuration::from_seconds(steer); } fn process_frequency_steering(&mut self, time: NtpTimestamp, steer: f64) { self.state = self .state .process_frequency_steering(time, steer, self.clock_wander); self.last_measurement.offset += NtpDuration::from_seconds( steer * (time - self.last_measurement.localtime).to_seconds(), ); } } #[derive(Debug, Clone)] #[allow(clippy::large_enum_variant)] enum SourceStateInner< D: Debug + Copy + Clone, N: MeasurementNoiseEstimator + Clone, > { Initial(InitialSourceFilter), Stable(SourceFilter), } #[derive(Debug, Clone)] pub(super) struct SourceState< D: Debug + Copy + Clone, N: MeasurementNoiseEstimator + Clone, >(SourceStateInner); const MIN_DELAY: NtpDuration = NtpDuration::from_exponent(-18); impl + Clone> SourceState { pub(super) fn new(noise_estimator: N) -> Self { SourceState(SourceStateInner::Initial(InitialSourceFilter { noise_estimator, init_offset: AveragingBuffer::default(), last_measurement: None, samples: 0, })) } // Returns whether the clock may need adjusting. pub fn update_self_using_measurement( &mut self, source_defaults_config: &SourceDefaultsConfig, algo_config: &AlgorithmConfig, mut measurement: Measurement, ) -> bool { // preprocessing let noise_estimator = match self { SourceState(SourceStateInner::Initial(filter)) => &filter.noise_estimator, SourceState(SourceStateInner::Stable(filter)) => &filter.noise_estimator, }; measurement.delay = noise_estimator.preprocess(measurement.delay); self.update_self_using_raw_measurement(source_defaults_config, algo_config, measurement) } fn update_self_using_raw_measurement( &mut self, source_defaults_config: &SourceDefaultsConfig, algo_config: &AlgorithmConfig, measurement: Measurement, ) -> bool { match &mut self.0 { SourceStateInner::Initial(filter) => { filter.update(measurement); if filter.samples == 8 { *self = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([filter.init_offset.mean(), 0.]), uncertainty: Matrix::new([ [filter.init_offset.variance(), 0.], [0., sqr(algo_config.initial_frequency_uncertainty)], ]), time: measurement.localtime, }, clock_wander: sqr(algo_config.initial_wander), noise_estimator: filter.noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: source_defaults_config.initial_poll_interval, last_measurement: measurement, prev_was_outlier: false, last_iter: measurement.localtime, })); debug!("Initial source measurements complete"); } true } SourceStateInner::Stable(filter) => { // We check that the difference between the localtime and monotonic // times of the measurement is in line with what would be expected // from recent steering. This check needs to be done here since we // need to revert back to the initial state. let localtime_difference = measurement.localtime - filter.last_measurement.localtime; let monotime_difference = measurement .monotime .abs_diff(filter.last_measurement.monotime); if localtime_difference.abs_diff(monotime_difference) > algo_config.meddling_threshold { let msg = "Detected clock meddling. Has another process updated the clock?"; tracing::warn!(msg); *self = SourceState(SourceStateInner::Initial(InitialSourceFilter { noise_estimator: filter.noise_estimator.reset(), init_offset: AveragingBuffer::default(), last_measurement: None, samples: 0, })); false } else { filter.update(source_defaults_config, algo_config, measurement) } } } } fn snapshot( &self, index: Index, config: &AlgorithmConfig, ) -> Option> { match &self.0 { SourceStateInner::Initial(InitialSourceFilter { noise_estimator, init_offset, last_measurement: Some(last_measurement), samples, }) if *samples > 0 => { let max_roundtrip = noise_estimator.get_max_roundtrip(samples)?; Some(SourceSnapshot { index, source_uncertainty: last_measurement.root_dispersion, source_delay: last_measurement.root_delay, leap_indicator: last_measurement.leap, last_update: last_measurement.localtime, delay: max_roundtrip, state: KalmanState { state: Vector::new_vector([ init_offset.data[..*samples as usize] .iter() .copied() .sum::() / (*samples as f64), 0.0, ]), uncertainty: Matrix::new([ [max_roundtrip, 0.0], [0.0, INITIALIZATION_FREQ_UNCERTAINTY], ]), time: last_measurement.localtime, }, wander: config.initial_wander, }) } SourceStateInner::Stable(filter) => Some(SourceSnapshot { index, state: filter.state, wander: filter.clock_wander, delay: filter.noise_estimator.get_delay_mean(), source_uncertainty: filter.last_measurement.root_dispersion, source_delay: filter.last_measurement.root_delay, leap_indicator: filter.last_measurement.leap, last_update: filter.last_iter, }), _ => None, } } pub fn get_desired_poll(&self, limits: &PollIntervalLimits) -> PollInterval { match &self.0 { SourceStateInner::Initial(_) => limits.min, SourceStateInner::Stable(filter) => filter.desired_poll_interval, } } pub fn process_offset_steering(&mut self, steer: f64) { match &mut self.0 { SourceStateInner::Initial(filter) => filter.process_offset_steering(steer), SourceStateInner::Stable(filter) => filter.process_offset_steering(steer), } } pub fn process_frequency_steering(&mut self, time: NtpTimestamp, steer: f64) { match &mut self.0 { SourceStateInner::Initial(_) => {} SourceStateInner::Stable(filter) => filter.process_frequency_steering(time, steer), } } } #[derive(Debug)] pub struct KalmanSourceController< SourceId, D: Debug + Copy + Clone, N: MeasurementNoiseEstimator + Clone, > { index: SourceId, state: SourceState, algo_config: AlgorithmConfig, source_defaults_config: SourceDefaultsConfig, } pub type TwoWayKalmanSourceController = KalmanSourceController; pub type OneWayKalmanSourceController = KalmanSourceController; impl< SourceId: Copy, D: Debug + Copy + Clone, N: MeasurementNoiseEstimator + Clone, > KalmanSourceController { pub(super) fn new( index: SourceId, algo_config: AlgorithmConfig, source_defaults_config: SourceDefaultsConfig, noise_estimator: N, ) -> Self { KalmanSourceController { index, state: SourceState::new(noise_estimator), algo_config, source_defaults_config, } } } impl< SourceId: std::fmt::Debug + Copy + Send + 'static, D: Debug + Copy + Clone + Send + 'static, N: MeasurementNoiseEstimator + Clone + Send + 'static, > SourceController for KalmanSourceController { type ControllerMessage = KalmanControllerMessage; type SourceMessage = KalmanSourceMessage; type MeasurementDelay = D; fn handle_message(&mut self, message: Self::ControllerMessage) { match message.inner { super::KalmanControllerMessageInner::Step { steer } => { self.state.process_offset_steering(steer); } super::KalmanControllerMessageInner::FreqChange { steer, time } => { self.state.process_frequency_steering(time, steer) } } } fn handle_measurement( &mut self, measurement: Measurement, ) -> Option { if self.state.update_self_using_measurement( &self.source_defaults_config, &self.algo_config, measurement, ) { self.state .snapshot(self.index, &self.algo_config) .map(|snapshot| KalmanSourceMessage { inner: snapshot }) } else { None } } fn desired_poll_interval(&self) -> PollInterval { self.state .get_desired_poll(&self.source_defaults_config.poll_interval_limits) } fn observe(&self) -> super::super::ObservableSourceTimedata { self.state .snapshot(&self.index, &self.algo_config) .map(|snapshot| snapshot.observe()) .unwrap_or(ObservableSourceTimedata { offset: NtpDuration::ZERO, uncertainty: NtpDuration::MAX, delay: NtpDuration::MAX, remote_delay: NtpDuration::MAX, remote_uncertainty: NtpDuration::MAX, last_update: NtpTimestamp::default(), }) } } #[cfg(test)] mod tests { use crate::{packet::NtpLeapIndicator, time_types::NtpInstant}; use super::*; #[test] fn test_meddling_detection() { let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(2800), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!(matches!(source, SourceState(SourceStateInner::Initial(_)))); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_offset_steering(-1800.0); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(2800), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!(matches!(source, SourceState(SourceStateInner::Stable(_)))); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_offset_steering(1800.0); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(20e-3), localtime: base + NtpDuration::from_seconds(2800.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!(matches!(source, SourceState(SourceStateInner::Stable(_)))); } fn test_offset_steering_and_measurements< D: Debug + Clone + Copy, N: MeasurementNoiseEstimator + Clone, >( noise_estimator: N, delay: D, ) { let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay, offset: NtpDuration::from_seconds(20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_offset_steering(20e-3); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() .abs() < 1e-7 ); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 0.0, noise_estimator: noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay, offset: NtpDuration::from_seconds(20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_offset_steering(20e-3); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() .abs() < 1e-7 ); source.update_self_using_raw_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(20e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( dbg!((source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() - 20e-3) .abs()) < 1e-7 ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency() - 20e-6) .abs() < 1e-7 ); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([-20e-3, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 0.0, noise_estimator: noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay, offset: NtpDuration::from_seconds(-20e-3), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_offset_steering(-20e-3); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() .abs() < 1e-7 ); source.update_self_using_raw_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(-20e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( dbg!((source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() - -20e-3) .abs()) < 1e-7 ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency() - -20e-6) .abs() < 1e-7 ); } #[test] fn test_offset_steering_and_measurements_normal() { test_offset_steering_and_measurements( AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, NtpDuration::from_seconds(0.0), ); } #[test] fn test_offset_steering_and_measurements_constant_noise_estimate() { test_offset_steering_and_measurements(1e-9, ()); } #[test] fn test_freq_steering() { let noise_estimator = AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }; let delay = NtpDuration::from_seconds(0.0); let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceFilter { state: KalmanState { state: Vector::new_vector([0.0, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay, offset: NtpDuration::from_seconds(0.0), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, }; source.process_frequency_steering(base + NtpDuration::from_seconds(5.0), 200e-6); assert!((source.state.frequency() - -200e-6).abs() < 1e-10); assert!(source.state.offset().abs() < 1e-8); assert!((source.last_measurement.offset.to_seconds() - 1e-3).abs() < 1e-8); source.process_frequency_steering(base + NtpDuration::from_seconds(10.0), -200e-6); assert!(source.state.frequency().abs() < 1e-10); assert!((source.state.offset() - -1e-3).abs() < 1e-8); assert!((source.last_measurement.offset.to_seconds() - -1e-3).abs() < 1e-8); let mut source = SourceState(SourceStateInner::Stable(SourceFilter { state: KalmanState { state: Vector::new_vector([0.0, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: noise_estimator.clone(), precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay, offset: NtpDuration::from_seconds(0.0), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, })); source.process_frequency_steering(base + NtpDuration::from_seconds(5.0), 200e-6); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency() - -200e-6) .abs() < 1e-10 ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() .abs() < 1e-8 ); source.process_frequency_steering(base + NtpDuration::from_seconds(10.0), -200e-6); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency() .abs() < 1e-10 ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() - -1e-3) .abs() < 1e-8 ); } fn test_init< D: Debug + Clone + Copy, N: MeasurementNoiseEstimator + Clone, >( noise_estimator: N, delay: D, ) { let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceState::new(noise_estimator); assert!(source .snapshot(0_usize, &AlgorithmConfig::default()) .is_none()); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(0e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(1e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(2e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(3e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(4e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(5e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(6e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay, offset: NtpDuration::from_seconds(7e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() - 3.5e-3) .abs() < 1e-7 ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset_variance() - 1e-6) > 0. ); } #[test] fn test_init_normal() { test_init( AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, NtpDuration::from_seconds(0.0), ); } #[test] fn test_init_constant_noise_estimate() { test_init(1e-3, ()); } #[test] fn test_steer_during_init() { let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceState::new(AveragingBuffer::default()); assert!(source .snapshot(0_usize, &AlgorithmConfig::default()) .is_none()); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(4e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(5e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(6e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(7e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); source.process_offset_steering(4e-3); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(4e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(5e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(6e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .frequency_variance() > 1.0 ); source.update_self_using_measurement( &SourceDefaultsConfig::default(), &AlgorithmConfig::default(), Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(7e-3), localtime: base + NtpDuration::from_seconds(1000.0), monotime: basei + std::time::Duration::from_secs(1000), stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset() - 3.5e-3) .abs() < 1e-7 ); assert!( (source .snapshot(0_usize, &AlgorithmConfig::default()) .unwrap() .state .offset_variance() - 1e-6) > 0. ); } #[test] fn test_poll_duration_variation() { let config = SourceDefaultsConfig::default(); let algo_config = AlgorithmConfig { poll_interval_hysteresis: 2, ..Default::default() }; let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceFilter { state: KalmanState { state: Vector::new_vector([0.0, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(0.0), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, }; let baseinterval = source.desired_poll_interval.as_duration().to_seconds(); let pollup = source .desired_poll_interval .inc(PollIntervalLimits::default()); source.update_desired_poll(&config, &algo_config, 1.0, 1.0, baseinterval * 2.); assert_eq!(source.poll_score, 0); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval * 2.); assert_eq!(source.poll_score, -1); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval * 2.); assert_eq!(source.poll_score, 0); assert_eq!(source.desired_poll_interval, pollup); source.update_desired_poll(&config, &algo_config, 1.0, 1.0, baseinterval * 3.); assert_eq!(source.poll_score, 0); assert_eq!(source.desired_poll_interval, pollup); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval); assert_eq!(source.poll_score, 0); assert_eq!(source.desired_poll_interval, pollup); source.update_desired_poll(&config, &algo_config, 0.0, 0.0, baseinterval * 3.); assert_eq!(source.poll_score, 0); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval * 2.); assert_eq!(source.poll_score, -1); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval * 2.); assert_eq!(source.poll_score, 0); assert_eq!(source.desired_poll_interval, pollup); source.update_desired_poll(&config, &algo_config, 1.0, 1.0, baseinterval); assert_eq!(source.poll_score, 1); assert_eq!(source.desired_poll_interval, pollup); source.update_desired_poll(&config, &algo_config, 1.0, 1.0, baseinterval); assert_eq!(source.poll_score, 0); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 0.0, baseinterval); assert_eq!(source.poll_score, -1); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll( &config, &algo_config, 1.0, (algo_config.poll_interval_high_weight + algo_config.poll_interval_low_weight) / 2., baseinterval, ); assert_eq!(source.poll_score, 0); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll(&config, &algo_config, 1.0, 1.0, baseinterval); assert_eq!(source.poll_score, 1); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); source.update_desired_poll( &config, &algo_config, 1.0, (algo_config.poll_interval_high_weight + algo_config.poll_interval_low_weight) / 2., baseinterval, ); assert_eq!(source.poll_score, 0); assert_eq!( source.desired_poll_interval, PollIntervalLimits::default().min ); } #[test] fn test_wander_estimation() { let algo_config = AlgorithmConfig { precision_hysteresis: 2, ..Default::default() }; let base = NtpTimestamp::from_fixed_int(0); let basei = NtpInstant::now(); let mut source = SourceFilter { state: KalmanState { state: Vector::new_vector([0.0, 0.]), uncertainty: Matrix::new([[1e-6, 0.], [0., 1e-8]]), time: base, }, clock_wander: 1e-8, noise_estimator: AveragingBuffer { data: [0.0, 0.0, 0.0, 0.0, 0.875e-6, 0.875e-6, 0.875e-6, 0.875e-6], next_idx: 0, }, precision_score: 0, poll_score: 0, desired_poll_interval: PollIntervalLimits::default().min, last_measurement: Measurement { delay: NtpDuration::from_seconds(0.0), offset: NtpDuration::from_seconds(0.0), localtime: base, monotime: basei, stratum: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), leap: NtpLeapIndicator::NoWarning, precision: 0, }, prev_was_outlier: false, last_iter: base, }; source.update_wander_estimate(&algo_config, 1.0, 0.0); assert_eq!(source.precision_score, 0); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate(&algo_config, 1.0, 1.0); assert_eq!(source.precision_score, -1); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate(&algo_config, 1.0, 1.0); assert_eq!(source.precision_score, 0); assert!(dbg!((source.clock_wander - 0.25e-8).abs()) < 1e-12); source.update_wander_estimate(&algo_config, 0.0, 0.0); assert_eq!(source.precision_score, 1); assert!(dbg!((source.clock_wander - 0.25e-8).abs()) < 1e-12); source.update_wander_estimate(&algo_config, 0.0, 1.0); assert_eq!(source.precision_score, 0); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate(&algo_config, 0.0, 0.0); assert_eq!(source.precision_score, 1); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate( &algo_config, (algo_config.precision_high_probability + algo_config.precision_low_probability) / 2.0, 0.0, ); assert_eq!(source.precision_score, 0); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate(&algo_config, 1.0, 1.0); assert_eq!(source.precision_score, -1); assert!((source.clock_wander - 1e-8).abs() < 1e-12); source.update_wander_estimate( &algo_config, (algo_config.precision_high_probability + algo_config.precision_low_probability) / 2.0, 0.0, ); assert_eq!(source.precision_score, 0); assert!((source.clock_wander - 1e-8).abs() < 1e-12); } } ntp-proto-1.4.0/src/algorithm/mod.rs000064400000000000000000000106631046102023000155000ustar 00000000000000use std::{fmt::Debug, time::Duration}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::{ clock::NtpClock, config::{SourceDefaultsConfig, SynchronizationConfig}, source::Measurement, system::TimeSnapshot, time_types::{NtpDuration, NtpTimestamp}, PollInterval, }; #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ObservableSourceTimedata { pub offset: NtpDuration, pub uncertainty: NtpDuration, pub delay: NtpDuration, pub remote_delay: NtpDuration, pub remote_uncertainty: NtpDuration, pub last_update: NtpTimestamp, } #[derive(Debug, Clone)] pub struct StateUpdate { // Message for all sources, if any pub source_message: Option, // Update to the time snapshot, if any pub time_snapshot: Option, // Update to the used sources, if any pub used_sources: Option>, // Requested timestamp for next non-measurement update pub next_update: Option, } // Note: this default implementation is necessary since the // derive only works if SourceId is Default (which it isn't // necessarily) impl Default for StateUpdate { fn default() -> Self { Self { source_message: None, time_snapshot: None, used_sources: None, next_update: None, } } } pub trait TimeSyncController: Sized + Send + 'static { type Clock: NtpClock; type SourceId; type AlgorithmConfig: Debug + Copy + DeserializeOwned + Send; type ControllerMessage: Debug + Clone + Send + 'static; type SourceMessage: Debug + Clone + Send + 'static; type NtpSourceController: SourceController< ControllerMessage = Self::ControllerMessage, SourceMessage = Self::SourceMessage, MeasurementDelay = NtpDuration, >; type OneWaySourceController: SourceController< ControllerMessage = Self::ControllerMessage, SourceMessage = Self::SourceMessage, MeasurementDelay = (), >; /// Create a new clock controller controlling the given clock fn new( clock: Self::Clock, synchronization_config: SynchronizationConfig, source_defaults_config: SourceDefaultsConfig, algorithm_config: Self::AlgorithmConfig, ) -> Result::Error>; /// Take control of the clock (should not be done in new!) fn take_control(&mut self) -> Result<(), ::Error>; /// Create a new source with given identity fn add_source(&mut self, id: Self::SourceId) -> Self::NtpSourceController; /// Create a new one way source with given identity (used e.g. with GPS sock sources) fn add_one_way_source( &mut self, id: Self::SourceId, measurement_noise_estimate: f64, ) -> Self::OneWaySourceController; /// Notify the controller that a previous source has gone fn remove_source(&mut self, id: Self::SourceId); /// Notify the controller that the status of a source (whether /// or not it is usable for synchronization) has changed. fn source_update(&mut self, id: Self::SourceId, usable: bool); /// Notify the controller of a new measurement from a source. /// The list of SourceIds is used for loop detection, with the /// first SourceId given considered the primary source used. fn source_message( &mut self, id: Self::SourceId, message: Self::SourceMessage, ) -> StateUpdate; /// Non-message driven update (queued via next_update) fn time_update(&mut self) -> StateUpdate; } pub trait SourceController: Sized + Send + 'static { type ControllerMessage: Debug + Clone + Send + 'static; type SourceMessage: Debug + Clone + Send + 'static; type MeasurementDelay: Debug + Copy + Clone; fn handle_message(&mut self, message: Self::ControllerMessage); fn handle_measurement( &mut self, measurement: Measurement, ) -> Option; fn desired_poll_interval(&self) -> PollInterval; fn observe(&self) -> ObservableSourceTimedata; } mod kalman; pub use kalman::{ config::AlgorithmConfig, KalmanClockController, KalmanControllerMessage, KalmanSourceController, KalmanSourceMessage, TwoWayKalmanSourceController, }; ntp-proto-1.4.0/src/clock.rs000064400000000000000000000032751046102023000140270ustar 00000000000000use crate::{ packet::NtpLeapIndicator, time_types::{NtpDuration, NtpTimestamp}, }; /// Interface for a clock settable by the ntp implementation. /// This needs to be a trait as a single system can have multiple clocks /// which need different implementation for steering and/or now. pub trait NtpClock: Clone + Send + 'static { type Error: std::error::Error + Send + Sync; // Get current time fn now(&self) -> Result; // Change the frequency of the clock, returning the time // at which the change was applied. fn set_frequency(&self, freq: f64) -> Result; // Get the frequency of the clock fn get_frequency(&self) -> Result; // Change the current time of the clock by offset. Returns // the time at which the change was applied. fn step_clock(&self, offset: NtpDuration) -> Result; // A clock can have a built in NTP clock discipline algorithm // that does more processing on the offsets it receives. This // functions disables that discipline. fn disable_ntp_algorithm(&self) -> Result<(), Self::Error>; // Provide the system with our current best estimates for // the statistical error of the clock (est_error), and // the maximum deviation due to frequency error and // distance to the root clock. fn error_estimate_update( &self, est_error: NtpDuration, max_error: NtpDuration, ) -> Result<(), Self::Error>; // Change the indicators for upcoming leap seconds and // the clocks synchronization status. fn status_update(&self, leap_status: NtpLeapIndicator) -> Result<(), Self::Error>; } ntp-proto-1.4.0/src/config.rs000064400000000000000000000234761046102023000142060ustar 00000000000000use std::fmt; use serde::{ de::{self, MapAccess, Unexpected, Visitor}, Deserialize, Deserializer, }; use crate::time_types::{NtpDuration, PollInterval, PollIntervalLimits}; fn deserialize_option_accumulated_step_panic_threshold<'de, D>( deserializer: D, ) -> Result, D::Error> where D: Deserializer<'de>, { let duration: NtpDuration = Deserialize::deserialize(deserializer)?; Ok(if duration == NtpDuration::ZERO { None } else { Some(duration) }) } #[derive(Debug, Default, Copy, Clone)] pub struct StepThreshold { pub forward: Option, pub backward: Option, } impl StepThreshold { pub fn is_within(&self, duration: NtpDuration) -> bool { self.forward.map(|v| duration < v).unwrap_or(true) && self.backward.map(|v| duration > -v).unwrap_or(true) } } #[derive(Debug, Copy, Clone)] struct ThresholdPart(Option); impl<'de> Deserialize<'de> for ThresholdPart { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct ThresholdPartVisitor; impl Visitor<'_> for ThresholdPartVisitor { type Value = ThresholdPart; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("float or \"inf\"") } fn visit_f64(self, v: f64) -> Result where E: de::Error, { Ok(ThresholdPart(Some(NtpDuration::from_seconds(v)))) } fn visit_i64(self, v: i64) -> Result where E: de::Error, { self.visit_f64(v as f64) } fn visit_u64(self, v: u64) -> Result where E: de::Error, { self.visit_f64(v as f64) } fn visit_str(self, v: &str) -> Result where E: de::Error, { if v != "inf" { return Err(de::Error::invalid_value( de::Unexpected::Str(v), &"float or \"inf\"", )); } Ok(ThresholdPart(None)) } } deserializer.deserialize_any(ThresholdPartVisitor) } } // We have a custom deserializer for StepThreshold because we // want to deserialize it from either a number or map impl<'de> Deserialize<'de> for StepThreshold { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct StepThresholdVisitor; impl<'de> Visitor<'de> for StepThresholdVisitor { type Value = StepThreshold; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("float, map or \"inf\"") } fn visit_f64(self, v: f64) -> Result where E: de::Error, { if v.is_nan() || v.is_infinite() || v < 0.0 { return Err(serde::de::Error::invalid_value( Unexpected::Float(v), &"a positive number", )); } let duration = NtpDuration::from_seconds(v); Ok(StepThreshold { forward: Some(duration), backward: Some(duration), }) } fn visit_i64(self, v: i64) -> Result where E: de::Error, { self.visit_f64(v as f64) } fn visit_u64(self, v: u64) -> Result where E: de::Error, { self.visit_f64(v as f64) } fn visit_str(self, v: &str) -> Result where E: de::Error, { if v != "inf" { return Err(de::Error::invalid_value( de::Unexpected::Str(v), &"float, map or \"inf\"", )); } Ok(StepThreshold { forward: None, backward: None, }) } fn visit_map>(self, mut map: M) -> Result { let mut forward = None; let mut backward = None; while let Some(key) = map.next_key::()? { match key.as_str() { "forward" => { if forward.is_some() { return Err(de::Error::duplicate_field("forward")); } let raw: ThresholdPart = map.next_value()?; forward = Some(raw.0); } "backward" => { if backward.is_some() { return Err(de::Error::duplicate_field("backward")); } let raw: ThresholdPart = map.next_value()?; backward = Some(raw.0); } _ => { return Err(de::Error::unknown_field( key.as_str(), &["forward", "backward"], )); } } } Ok(StepThreshold { forward: forward.flatten(), backward: backward.flatten(), }) } } deserializer.deserialize_any(StepThresholdVisitor) } } #[derive(Deserialize, Debug, Clone, Copy)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub struct SourceDefaultsConfig { /// Minima and maxima for the poll interval of clients #[serde(default)] pub poll_interval_limits: PollIntervalLimits, /// Initial poll interval of the system #[serde(default = "default_initial_poll_interval")] pub initial_poll_interval: PollInterval, } impl Default for SourceDefaultsConfig { fn default() -> Self { Self { poll_interval_limits: Default::default(), initial_poll_interval: default_initial_poll_interval(), } } } fn default_initial_poll_interval() -> PollInterval { PollIntervalLimits::default().min } #[derive(Deserialize, Debug, Clone, Copy)] #[serde(rename_all = "kebab-case", deny_unknown_fields)] pub struct SynchronizationConfig { /// Minimum number of survivors needed to be able to discipline the system clock. /// More survivors (so more servers from which to get the time) means a more accurate time. /// /// The spec notes (CMIN was renamed to MIN_INTERSECTION_SURVIVORS in our implementation): /// /// > CMIN defines the minimum number of servers consistent with the correctness requirements. /// > Suspicious operators would set CMIN to ensure multiple redundant servers are available for the /// > algorithms to mitigate properly. However, for historic reasons the default value for CMIN is one. #[serde(default = "default_minimum_agreeing_sources")] pub minimum_agreeing_sources: usize, /// The maximum amount the system clock is allowed to change in a single go /// before we conclude something is seriously wrong. This is used to limit /// the changes to the clock to reasonable amounts, and stop issues with /// remote servers from causing us to drift too far. /// /// Note that this is not used during startup. To limit system clock changes /// during startup, use startup_panic_threshold #[serde(default = "default_single_step_panic_threshold")] pub single_step_panic_threshold: StepThreshold, /// The maximum amount the system clock is allowed to change during startup. /// This can be used to limit the impact of bad servers if the system clock /// is known to be reasonable on startup #[serde(default = "default_startup_step_panic_threshold")] pub startup_step_panic_threshold: StepThreshold, /// The maximum amount distributed amongst all steps except at startup the /// daemon is allowed to step the system clock. #[serde( deserialize_with = "deserialize_option_accumulated_step_panic_threshold", default )] pub accumulated_step_panic_threshold: Option, /// Stratum of the local clock, when not synchronized through ntp. This /// can be used in servers to indicate that there are external mechanisms /// synchronizing the clock #[serde(default = "default_local_stratum")] pub local_stratum: u8, } impl Default for SynchronizationConfig { fn default() -> Self { Self { minimum_agreeing_sources: default_minimum_agreeing_sources(), single_step_panic_threshold: default_single_step_panic_threshold(), startup_step_panic_threshold: default_startup_step_panic_threshold(), accumulated_step_panic_threshold: None, local_stratum: default_local_stratum(), } } } fn default_minimum_agreeing_sources() -> usize { 3 } fn default_single_step_panic_threshold() -> StepThreshold { let raw = NtpDuration::from_seconds(1000.); StepThreshold { forward: Some(raw), backward: Some(raw), } } fn default_startup_step_panic_threshold() -> StepThreshold { // No forward limit, backwards max. 1 day StepThreshold { forward: None, backward: Some(NtpDuration::from_seconds(86400.)), } } fn default_local_stratum() -> u8 { 16 } ntp-proto-1.4.0/src/cookiestash.rs000064400000000000000000000055451046102023000152520ustar 00000000000000//! Datastructure for managing cookies. It keeps the following //! invariants: //! - Each cookie is yielded at most once //! - The oldest cookie is always yielded first //! //! Note that as a consequence, this type is not Clone! pub const MAX_COOKIES: usize = 8; #[derive(Default, PartialEq, Eq)] pub(crate) struct CookieStash { cookies: [Vec; MAX_COOKIES], read: usize, valid: usize, } impl std::fmt::Debug for CookieStash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CookieStash") .field("cookies", &self.cookies.len()) .field("read", &self.read) .field("valid", &self.valid) .finish() } } impl CookieStash { /// Store a new cookie pub fn store(&mut self, cookie: Vec) { let wpos = (self.read + self.valid) % self.cookies.len(); self.cookies[wpos] = cookie; if self.valid < self.cookies.len() { self.valid += 1; } else { debug_assert!(self.valid == self.cookies.len()); // No place for extra cookies, but it is still // newer so just keep the newest cookies. self.read = (self.read + 1) % self.cookies.len(); } } /// Get oldest cookie pub fn get(&mut self) -> Option> { if self.valid == 0 { None } else { // takes the cookie, puts `vec![]` in its place let result = std::mem::take(&mut self.cookies[self.read]); self.read = (self.read + 1) % self.cookies.len(); self.valid -= 1; Some(result) } } /// Number of cookies missing from the stash pub fn gap(&self) -> u8 { // This never overflows or underflows since cookies.len will // fit in a u8 and 0 <= self.valid <= self.cookies.len() (self.cookies.len() - self.valid) as u8 } pub fn len(&self) -> usize { self.valid } pub fn is_empty(&self) -> bool { self.valid == 0 } } #[cfg(test)] mod tests { use super::*; #[test] fn test_empty_read() { let mut stash = CookieStash::default(); assert_eq!(stash.get(), None); } #[test] fn test_overfill() { let mut stash = CookieStash::default(); for i in 0..10_u8 { stash.store(vec![i]); } assert_eq!(stash.get(), Some(vec![2])); assert_eq!(stash.get(), Some(vec![3])); } #[test] fn test_normal_op() { let mut stash = CookieStash::default(); for i in 0..8_u8 { stash.store(vec![i]); assert_eq!(stash.gap(), 7 - i); } for i in 8_u8..32_u8 { assert_eq!(stash.get(), Some(vec![i - 8])); assert_eq!(stash.gap(), 1); stash.store(vec![i]); assert_eq!(stash.gap(), 0); } } } ntp-proto-1.4.0/src/identifiers.rs000064400000000000000000000054671046102023000152460ustar 00000000000000use std::net::IpAddr; use md5::{Digest, Md5}; use serde::{Deserialize, Serialize}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ReferenceId(u32); impl ReferenceId { // Note: Names chosen to match the identifiers given in rfc5905 pub const KISS_DENY: ReferenceId = ReferenceId(u32::from_be_bytes(*b"DENY")); pub const KISS_RATE: ReferenceId = ReferenceId(u32::from_be_bytes(*b"RATE")); pub const KISS_RSTR: ReferenceId = ReferenceId(u32::from_be_bytes(*b"RSTR")); pub const NONE: ReferenceId = ReferenceId(u32::from_be_bytes(*b"XNON")); pub const SOCK: ReferenceId = ReferenceId(u32::from_be_bytes(*b"SOCK")); // Network Time Security (NTS) negative-acknowledgment (NAK), from rfc8915 pub const KISS_NTSN: ReferenceId = ReferenceId(u32::from_be_bytes(*b"NTSN")); pub fn from_ip(addr: IpAddr) -> ReferenceId { match addr { IpAddr::V4(addr) => ReferenceId(u32::from_be_bytes(addr.octets())), IpAddr::V6(addr) => ReferenceId(u32::from_be_bytes( Md5::digest(addr.octets())[0..4].try_into().unwrap(), )), } } pub(crate) const fn from_int(value: u32) -> ReferenceId { ReferenceId(value) } pub(crate) fn is_deny(&self) -> bool { *self == Self::KISS_DENY } pub(crate) fn is_rate(&self) -> bool { *self == Self::KISS_RATE } pub(crate) fn is_rstr(&self) -> bool { *self == Self::KISS_RSTR } pub(crate) fn is_ntsn(&self) -> bool { *self == Self::KISS_NTSN } pub(crate) fn to_bytes(self) -> [u8; 4] { self.0.to_be_bytes() } pub(crate) fn from_bytes(bits: [u8; 4]) -> ReferenceId { ReferenceId(u32::from_be_bytes(bits)) } } #[cfg(test)] mod tests { use super::*; #[test] fn referenceid_serialization_roundtrip() { let a = [12, 34, 56, 78]; let b = ReferenceId::from_bytes(a); let c = b.to_bytes(); let d = ReferenceId::from_bytes(c); assert_eq!(a, c); assert_eq!(b, d); } #[test] fn referenceid_kiss_codes() { let a = [b'R', b'A', b'T', b'E']; let b = ReferenceId::from_bytes(a); assert!(b.is_rate()); let a = [b'R', b'S', b'T', b'R']; let b = ReferenceId::from_bytes(a); assert!(b.is_rstr()); let a = [b'D', b'E', b'N', b'Y']; let b = ReferenceId::from_bytes(a); assert!(b.is_deny()); } #[test] fn referenceid_from_ipv4() { let ip: IpAddr = "12.34.56.78".parse().unwrap(); let rep = [12, 34, 56, 78]; let a = ReferenceId::from_ip(ip); let b = ReferenceId::from_bytes(rep); assert_eq!(a, b); // TODO: Generate and add a testcase for ipv6 addresses once // we have access to an ipv6 network. } } ntp-proto-1.4.0/src/io.rs000064400000000000000000000015141046102023000133350ustar 00000000000000/// Write trait for structs that implement std::io::Write without doing blocking io pub trait NonBlockingWrite: std::io::Write {} impl NonBlockingWrite for std::io::Cursor where std::io::Cursor: std::io::Write {} impl NonBlockingWrite for Vec {} impl NonBlockingWrite for &mut [u8] {} impl NonBlockingWrite for std::collections::VecDeque {} impl NonBlockingWrite for Box where W: NonBlockingWrite {} impl NonBlockingWrite for &mut W where W: NonBlockingWrite {} pub trait NonBlockingRead: std::io::Read {} impl NonBlockingRead for std::io::Cursor where std::io::Cursor: std::io::Read {} impl NonBlockingRead for &[u8] {} impl NonBlockingRead for std::collections::VecDeque {} impl NonBlockingRead for Box where R: NonBlockingRead {} impl NonBlockingRead for &mut R where R: NonBlockingRead {} ntp-proto-1.4.0/src/ipfilter.rs000064400000000000000000000257401046102023000145530ustar 00000000000000use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::server::IpSubnet; /// One part of a `BitTree` #[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] struct TreeNode { // Where in the array the child nodes of this // node are located. A child node is only // generated if the symbol cannot be used to // make a final decision at this level child_offset: u32, inset: u16, outset: u16, } #[derive(Debug, Clone, PartialEq, Eq)] /// `BitTree` is a Trie on 128 bit integers encoding /// which integers are part of the set. /// /// It matches the integer a 4-bit segment at a time /// recording at each level whether for a given symbol /// all integers with the prefix extended with that /// symbol are either in or outside of the set. struct BitTree { nodes: Vec, } const fn top_nibble(v: u128) -> u8 { ((v >> 124) & 0xF) as u8 } /// retain only the top `128 - len` bits const fn apply_mask(val: u128, len: u8) -> u128 { match u128::MAX.checked_shl((128 - len) as u32) { Some(mask) => val & mask, None => 0, } } impl BitTree { /// Lookup whether a given value is in the set encoded in this `BitTree` /// Complexity is O(log(l)), where l is the length of the longest /// prefix in the set. fn lookup(&self, mut val: u128) -> bool { let mut node = &self.nodes[0]; loop { // extract the current symbol as bit and see if we know the answer immediately. // (example: symbol 1 maps to 0x2, symbol 5 maps to 0x10) let cur = 1 << top_nibble(val); if node.inset & cur != 0 { return true; } if node.outset & cur != 0 { return false; } // no decision, shift to next symbol val <<= 4; // To calculate the child index we need to know how many symbols smaller // than our symbol are not decided here. We do this by generating the bitmap // of symbols neither in in or out, then masking out all symbols >=cur // and finally counting how many are left. let next_idx = node.child_offset + (!(node.inset | node.outset) & (cur - 1)).count_ones(); node = &self.nodes[next_idx as usize]; } } /// Create a `BitTree` from the given prefixes. Complexity is O(n*log(l)), /// where n is the number of prefixes, and l the length of the longest /// prefix. fn create(data: &mut [(u128, u8)]) -> Self { // Ensure values only have 1s in significant positions for (val, len) in data.iter_mut() { *val = apply_mask(*val, *len); } // Ensure values are sorted by value and then by length data.sort(); let mut result = BitTree { nodes: vec![TreeNode::default()], }; result.fill_node(data, 0); result } /// Create the substructure for a node, recursively. /// Max recursion depth is maximum value of data[i].1/4 /// for any i fn fill_node(&mut self, mut data: &mut [(u128, u8)], node_index: usize) { // distribute the data into 16 4-bit buckets let mut counts = [0; 16]; for (val, _) in data.iter() { counts[top_nibble(*val) as usize] += 1; } // Actually split into the relevant subsegments, relies on the input being sorted. let mut subsegments: [&mut [(u128, u8)]; 16] = Default::default(); for (i, start) in counts.iter().enumerate() { (subsegments[i], data) = data.split_at_mut(*start); } // Fill in node let child_offset = self.nodes.len(); let node = &mut self.nodes[node_index]; node.child_offset = child_offset as u32; for (i, segment) in subsegments.iter().enumerate() { match segment.first().copied() { // Probably empty, unless covered earlier, but we fix that later None => node.outset |= 1 << i, // Definitely covered, mark all that is needed // Note that due to sorting order, len here // is guaranteed to be largest amongst all // parts of the segment Some((_, len)) if len <= 4 => { // mark ALL parts of node covered by the segment as in the set. for j in 0..(1 << (4 - len)) { node.inset |= 1 << (i + j as usize); } } // May be covered by a the union of all its parts, we need to check // for that. Otherwise it is undecided Some(_) => { let offset = (i as u128) << 124; let mut last = 0; for part in segment.iter() { if part.0 - offset <= last { last = u128::max(last, part.0 - offset + (1_u128 << (128 - part.1))); } } if last >= (1 << 124) { // All parts together cover the segment, so mark as in node.inset |= 1 << i; } } } } // the outset should not contain anything that is included in the inset // (this can happen due to overcoverage) node.outset &= !node.inset; // bitmap of subsegments for which we have a decision let known_bitmap = node.inset | node.outset; // allocate additional empty nodes let unknown_count = known_bitmap.count_zeros() as usize; self.nodes .extend(std::iter::repeat(TreeNode::default()).take(unknown_count)); // Create children for segments undecided at this level. let mut child_offset = child_offset; for (i, segment) in subsegments.iter_mut().enumerate() { if known_bitmap & (1 << i) != 0 { continue; // no child needed } // we've taken care of the top nibble, // so shift everything over and do a recursive call for (val, len) in segment.iter_mut() { *val <<= 4; *len -= 4; } self.fill_node(segment, child_offset); child_offset += 1; } } } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct IpFilter { ipv4_filter: BitTree, ipv6_filter: BitTree, } impl IpFilter { /// Create a filter from a list of subnets /// Complexity: O(n) with n length of list pub fn new(subnets: &[IpSubnet]) -> Self { let mut ipv4list = Vec::new(); let mut ipv6list = Vec::new(); for subnet in subnets { match subnet.addr { IpAddr::V4(addr) => ipv4list.push(( (u32::from_be_bytes(addr.octets()) as u128) << 96, subnet.mask, )), IpAddr::V6(addr) => { ipv6list.push((u128::from_be_bytes(addr.octets()), subnet.mask)); } } } IpFilter { ipv4_filter: BitTree::create(ipv4list.as_mut_slice()), ipv6_filter: BitTree::create(ipv6list.as_mut_slice()), } } /// Check whether a given ip address is contained in the filter. /// Complexity: O(1) pub fn is_in(&self, addr: &IpAddr) -> bool { match addr { IpAddr::V4(addr) => self.is_in4(addr), IpAddr::V6(addr) => self.is_in6(addr), } } fn is_in4(&self, addr: &Ipv4Addr) -> bool { self.ipv4_filter .lookup((u32::from_be_bytes(addr.octets()) as u128) << 96) } fn is_in6(&self, addr: &Ipv6Addr) -> bool { self.ipv6_filter.lookup(u128::from_be_bytes(addr.octets())) } } #[cfg(feature = "__internal-fuzz")] pub mod fuzz { use super::*; fn contains(subnet: &IpSubnet, addr: &IpAddr) -> bool { match (subnet.addr, addr) { (IpAddr::V4(net), IpAddr::V4(addr)) => { let net = u32::from_be_bytes(net.octets()); let addr = u32::from_be_bytes(addr.octets()); let mask = 0xFFFFFFFF_u32 .checked_shl((32 - subnet.mask) as u32) .unwrap_or(0); (net & mask) == (addr & mask) } (IpAddr::V6(net), IpAddr::V6(addr)) => { let net = u128::from_be_bytes(net.octets()); let addr = u128::from_be_bytes(addr.octets()); let mask = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF_u128 .checked_shl((128 - subnet.mask) as u32) .unwrap_or(0); (net & mask) == (addr & mask) } _ => false, } } fn any_contains(subnets: &[IpSubnet], addr: &IpAddr) -> bool { for net in subnets { if contains(net, addr) { return true; } } false } pub fn fuzz_ipfilter(nets: &[IpSubnet], addr: &[IpAddr]) { let filter = IpFilter::new(nets); for addr in addr { assert_eq!(filter.is_in(addr), any_contains(nets, addr)); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_bittree() { let mut data = [ (0x10 << 120, 4), (0x20 << 120, 3), (0x43 << 120, 8), (0x82 << 120, 7), ]; let tree = BitTree::create(&mut data); assert!(tree.lookup(0x11 << 120)); assert!(!tree.lookup(0x40 << 120)); assert!(tree.lookup(0x30 << 120)); assert!(tree.lookup(0x43 << 120)); assert!(!tree.lookup(0xC4 << 120)); assert!(tree.lookup(0x82 << 120)); assert!(tree.lookup(0x83 << 120)); assert!(!tree.lookup(0x81 << 120)); } #[test] fn test_filter() { let filter = IpFilter::new(&[ "127.0.0.0/24".parse().unwrap(), "::FFFF:0000:0000/96".parse().unwrap(), ]); assert!(filter.is_in(&"127.0.0.1".parse().unwrap())); assert!(!filter.is_in(&"192.168.1.1".parse().unwrap())); assert!(filter.is_in(&"::FFFF:ABCD:0123".parse().unwrap())); assert!(!filter.is_in(&"::FEEF:ABCD:0123".parse().unwrap())); } #[test] fn test_subnet_edgecases() { let filter = IpFilter::new(&["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]); assert!(filter.is_in(&"0.0.0.0".parse().unwrap())); assert!(filter.is_in(&"255.255.255.255".parse().unwrap())); assert!(filter.is_in(&"::".parse().unwrap())); assert!(filter.is_in(&"FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF".parse().unwrap())); let filter = IpFilter::new(&[ "1.2.3.4/32".parse().unwrap(), "10:32:54:76:98:BA:DC:FE/128".parse().unwrap(), ]); assert!(filter.is_in(&"1.2.3.4".parse().unwrap())); assert!(!filter.is_in(&"1.2.3.5".parse().unwrap())); assert!(filter.is_in(&"10:32:54:76:98:BA:DC:FE".parse().unwrap())); assert!(!filter.is_in(&"10:32:54:76:98:BA:DC:FF".parse().unwrap())); } } ntp-proto-1.4.0/src/keyset.rs000064400000000000000000000373361046102023000142450ustar 00000000000000use std::{ io::{Read, Write}, sync::Arc, }; use aead::{generic_array::GenericArray, KeyInit}; use crate::{ nts_record::AeadAlgorithm, packet::{ AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError, EncryptResult, ExtensionField, }, }; pub struct DecodedServerCookie { pub(crate) algorithm: AeadAlgorithm, pub s2c: Box, pub c2s: Box, } impl DecodedServerCookie { fn plaintext(&self) -> Vec { let mut plaintext = Vec::new(); let algorithm_bytes = (self.algorithm as u16).to_be_bytes(); plaintext.extend_from_slice(&algorithm_bytes); plaintext.extend_from_slice(self.s2c.key_bytes()); plaintext.extend_from_slice(self.c2s.key_bytes()); plaintext } } impl std::fmt::Debug for DecodedServerCookie { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("DecodedServerCookie") .field("algorithm", &self.algorithm) .finish() } } #[derive(Debug)] pub struct KeySetProvider { current: Arc, history: usize, } impl KeySetProvider { /// Create a new keysetprovider that keeps history old /// keys around (so in total, history+1 keys are valid /// at any time) pub fn new(history: usize) -> Self { KeySetProvider { current: Arc::new(KeySet { keys: vec![AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key( rand::thread_rng(), ))], id_offset: 0, primary: 0, }), history, } } #[cfg(feature = "__internal-fuzz")] pub fn dangerous_new_deterministic(history: usize) -> Self { KeySetProvider { current: Arc::new(KeySet { keys: vec![AesSivCmac512::new( std::array::from_fn(|i| (i as u8)).into(), )], id_offset: 0, primary: 0, }), history, } } /// Rotate a new key in as primary, forgetting an old one if needed pub fn rotate(&mut self) { let next_key = AesSivCmac512::new(aes_siv::Aes256SivAead::generate_key(rand::thread_rng())); let mut keys = Vec::with_capacity((self.history + 1).min(self.current.keys.len() + 1)); for key in self.current.keys [self.current.keys.len().saturating_sub(self.history)..self.current.keys.len()] .iter() { // This is the rare case where we do really want to make a copy. keys.push(AesSivCmac512::new(GenericArray::clone_from_slice( key.key_bytes(), ))); } keys.push(next_key); self.current = Arc::new(KeySet { id_offset: self .current .id_offset .wrapping_add(self.current.keys.len().saturating_sub(self.history) as u32), primary: keys.len() as u32 - 1, keys, }); } pub fn load( reader: &mut impl Read, history: usize, ) -> std::io::Result<(Self, std::time::SystemTime)> { let mut buf = [0; 64]; reader.read_exact(&mut buf[0..20])?; let time = std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(u64::from_be_bytes(buf[0..8].try_into().unwrap())); let id_offset = u32::from_be_bytes(buf[8..12].try_into().unwrap()); let primary = u32::from_be_bytes(buf[12..16].try_into().unwrap()); let len = u32::from_be_bytes(buf[16..20].try_into().unwrap()); if primary > len { return Err(std::io::ErrorKind::Other.into()); } let mut keys = vec![]; for _ in 0..len { reader.read_exact(&mut buf[0..64])?; keys.push(AesSivCmac512::new(buf.into())); } Ok(( KeySetProvider { current: Arc::new(KeySet { keys, id_offset, primary, }), history, }, time, )) } pub fn store(&self, writer: &mut impl Write) -> std::io::Result<()> { let time = std::time::SystemTime::now() .duration_since(std::time::SystemTime::UNIX_EPOCH) .expect("Could not get current time"); writer.write_all(&time.as_secs().to_be_bytes())?; writer.write_all(&self.current.id_offset.to_be_bytes())?; writer.write_all(&self.current.primary.to_be_bytes())?; writer.write_all(&(self.current.keys.len() as u32).to_be_bytes())?; for key in self.current.keys.iter() { writer.write_all(key.key_bytes())?; } Ok(()) } /// Get the current KeySet pub fn get(&self) -> Arc { self.current.clone() } } pub struct KeySet { keys: Vec, id_offset: u32, primary: u32, } impl KeySet { #[cfg(feature = "__internal-fuzz")] pub fn encode_cookie_pub(&self, cookie: &DecodedServerCookie) -> Vec { self.encode_cookie(cookie) } pub(crate) fn encode_cookie(&self, cookie: &DecodedServerCookie) -> Vec { let mut output = cookie.plaintext(); let plaintext_length = output.as_slice().len(); // Add space for header (4 + 2 bytes), additional ciphertext // data from the cmac (16 bytes) and nonce (16 bytes). output.resize(output.len() + 2 + 4 + 16 + 16, 0); // And move plaintext to make space for header output.copy_within(0..plaintext_length, 6); let EncryptResult { nonce_length, ciphertext_length, } = self.keys[self.primary as usize] .encrypt(&mut output[6..], plaintext_length, &[]) .expect("Failed to encrypt cookie"); debug_assert_eq!(nonce_length, 16); debug_assert_eq!(plaintext_length + 16, ciphertext_length); output[0..4].copy_from_slice(&(self.primary.wrapping_add(self.id_offset)).to_be_bytes()); output[4..6].copy_from_slice(&(ciphertext_length as u16).to_be_bytes()); debug_assert_eq!(output.len(), 6 + nonce_length + ciphertext_length); output } #[cfg(feature = "__internal-fuzz")] pub fn decode_cookie_pub(&self, cookie: &[u8]) -> Result { self.decode_cookie(cookie) } pub(crate) fn decode_cookie(&self, cookie: &[u8]) -> Result { // we need at least an id, cipher text length and nonce for this message to be valid if cookie.len() < 4 + 2 + 16 { return Err(DecryptError); } let id = u32::from_be_bytes(cookie[0..4].try_into().unwrap()); let id = id.wrapping_sub(self.id_offset) as usize; let key = self.keys.get(id).ok_or(DecryptError)?; let cipher_text_length = u16::from_be_bytes([cookie[4], cookie[5]]) as usize; let nonce = &cookie[6..22]; let ciphertext = cookie[22..].get(..cipher_text_length).ok_or(DecryptError)?; let plaintext = key.decrypt(nonce, ciphertext, &[])?; let [b0, b1, ref key_bytes @ ..] = plaintext[..] else { return Err(DecryptError); }; let algorithm = AeadAlgorithm::try_deserialize(u16::from_be_bytes([b0, b1])).ok_or(DecryptError)?; Ok(match algorithm { AeadAlgorithm::AeadAesSivCmac256 => { const KEY_WIDTH: usize = 32; if key_bytes.len() != 2 * KEY_WIDTH { return Err(DecryptError); } let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH); DecodedServerCookie { algorithm, s2c: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(s2c))), c2s: Box::new(AesSivCmac256::new(GenericArray::clone_from_slice(c2s))), } } AeadAlgorithm::AeadAesSivCmac512 => { const KEY_WIDTH: usize = 64; if key_bytes.len() != 2 * KEY_WIDTH { return Err(DecryptError); } let (s2c, c2s) = key_bytes.split_at(KEY_WIDTH); DecodedServerCookie { algorithm, s2c: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(s2c))), c2s: Box::new(AesSivCmac512::new(GenericArray::clone_from_slice(c2s))), } } }) } #[cfg(test)] pub(crate) fn new() -> Self { Self { keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())], id_offset: 1, primary: 0, } } } impl CipherProvider for KeySet { fn get(&self, context: &[ExtensionField<'_>]) -> Option> { let mut decoded = None; for ef in context { if let ExtensionField::NtsCookie(cookie) = ef { if decoded.is_some() { // more than one cookie, abort return None; } decoded = Some(self.decode_cookie(cookie).ok()?); } } decoded.map(CipherHolder::DecodedServerCookie) } } impl std::fmt::Debug for KeySet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("KeySet") .field("keys", &self.keys.len()) .field("id_offset", &self.id_offset) .field("primary", &self.primary) .finish() } } #[cfg(any(test, feature = "__internal-fuzz"))] pub fn test_cookie() -> DecodedServerCookie { DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), } } #[cfg(test)] mod tests { use std::io::Cursor; use super::*; #[test] fn roundtrip_aes_siv_cmac_256() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let keyset = KeySet { keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())], id_offset: 1, primary: 0, }; let encoded = keyset.encode_cookie(&decoded); let round = keyset.decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); } #[test] fn test_encode_after_rotate() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let mut provider = KeySetProvider::new(1); provider.rotate(); let keyset = provider.get(); let encoded = keyset.encode_cookie(&decoded); let round = keyset.decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); } #[test] fn can_decode_cookie_with_padding() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac512, s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())), c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())), }; let keyset = KeySet { keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())], id_offset: 1, primary: 0, }; let mut encoded = keyset.encode_cookie(&decoded); encoded.extend([0, 0]); let round = keyset.decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); } #[test] fn roundtrip_aes_siv_cmac_512() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac512, s2c: Box::new(AesSivCmac512::new((0..64_u8).collect())), c2s: Box::new(AesSivCmac512::new((64..128_u8).collect())), }; let keyset = KeySet { keys: vec![AesSivCmac512::new(std::iter::repeat(0).take(64).collect())], id_offset: 1, primary: 0, }; let encoded = keyset.encode_cookie(&decoded); let round = keyset.decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); } #[test] fn test_save_restore() { let mut provider = KeySetProvider::new(8); provider.rotate(); provider.rotate(); let mut output = Cursor::new(vec![]); provider.store(&mut output).unwrap(); let mut input = Cursor::new(output.into_inner()); let (copy, time) = KeySetProvider::load(&mut input, 8).unwrap(); assert!( std::time::SystemTime::now() .duration_since(time) .unwrap() .as_secs() < 2 ); assert_eq!(provider.get().primary, copy.get().primary); assert_eq!(provider.get().id_offset, copy.get().id_offset); for i in 0..provider.get().keys.len() { assert_eq!( provider.get().keys[i].key_bytes(), copy.get().keys[i].key_bytes() ); } } #[test] fn old_cookie_still_valid() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let mut provider = KeySetProvider::new(1); let encoded = provider.get().encode_cookie(&decoded); let round = provider.get().decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); provider.rotate(); let round = provider.get().decode_cookie(&encoded).unwrap(); assert_eq!(decoded.algorithm, round.algorithm); assert_eq!(decoded.s2c.key_bytes(), round.s2c.key_bytes()); assert_eq!(decoded.c2s.key_bytes(), round.c2s.key_bytes()); provider.rotate(); assert!(provider.get().decode_cookie(&encoded).is_err()); } #[test] fn invalid_cookie_length() { // this cookie data lies about its length, pretending to be longer than it actually is. let input = b"\x23\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x04\x00\x24\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x04\x00\x18\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x04\x00\x28\x00\x10\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; let provider = KeySetProvider::new(1); let output = provider.get().decode_cookie(input); assert!(output.is_err()); } } ntp-proto-1.4.0/src/lib.rs000064400000000000000000000075041046102023000135010ustar 00000000000000//! This crate contains packet parsing and algorithm code for ntpd-rs and is not //! intended as a public interface at this time. It follows the same version as the //! main ntpd-rs crate, but that version is not intended to give any stability //! guarantee. Use at your own risk. //! //! Please visit the [ntpd-rs](https://github.com/pendulum-project/ntpd-rs) project //! for more information. #![forbid(unsafe_code)] #![cfg_attr(not(feature = "__internal-api"), allow(unused))] mod algorithm; mod clock; mod config; mod cookiestash; mod identifiers; mod io; mod ipfilter; mod keyset; mod nts_record; mod packet; mod server; mod source; mod system; mod time_types; #[cfg(feature = "nts-pool")] mod nts_pool_ke; pub mod tls_utils; pub(crate) mod exitcode { /// An internal software error has been detected. This /// should be limited to non-operating system related /// errors as possible. #[cfg(not(test))] pub const SOFTWARE: i32 = 70; } mod exports { pub use super::algorithm::{ AlgorithmConfig, KalmanClockController, KalmanControllerMessage, KalmanSourceController, KalmanSourceMessage, ObservableSourceTimedata, SourceController, StateUpdate, TimeSyncController, TwoWayKalmanSourceController, }; pub use super::clock::NtpClock; pub use super::config::{SourceDefaultsConfig, StepThreshold, SynchronizationConfig}; pub use super::identifiers::ReferenceId; #[cfg(feature = "__internal-fuzz")] pub use super::ipfilter::fuzz::fuzz_ipfilter; pub use super::keyset::{DecodedServerCookie, KeySet, KeySetProvider}; #[cfg(feature = "__internal-fuzz")] pub use super::keyset::test_cookie; #[cfg(feature = "__internal-fuzz")] pub use super::packet::ExtensionField; pub use super::packet::{ Cipher, CipherProvider, EncryptResult, ExtensionHeaderVersion, NoCipher, NtpAssociationMode, NtpLeapIndicator, NtpPacket, PacketParsingError, }; pub use super::server::{ FilterAction, FilterList, IpSubnet, Server, ServerAction, ServerConfig, ServerReason, ServerResponse, ServerStatHandler, SubnetParseError, }; #[cfg(feature = "__internal-test")] pub use super::source::source_snapshot; pub use super::source::{ AcceptSynchronizationError, Measurement, NtpSource, NtpSourceAction, NtpSourceActionIterator, NtpSourceSnapshot, NtpSourceUpdate, ObservableSourceState, OneWaySource, OneWaySourceSnapshot, OneWaySourceUpdate, ProtocolVersion, Reach, SourceNtsData, }; pub use super::system::{ System, SystemAction, SystemActionIterator, SystemSnapshot, SystemSourceUpdate, TimeSnapshot, }; #[cfg(feature = "__internal-fuzz")] pub use super::time_types::fuzz_duration_from_seconds; pub use super::time_types::{ FrequencyTolerance, NtpDuration, NtpInstant, NtpTimestamp, PollInterval, PollIntervalLimits, }; #[cfg(feature = "__internal-fuzz")] pub use super::nts_record::fuzz_key_exchange_result_decoder; #[cfg(feature = "__internal-fuzz")] pub use super::nts_record::fuzz_key_exchange_server_decoder; pub use super::nts_record::{ KeyExchangeClient, KeyExchangeError, KeyExchangeResult, KeyExchangeServer, NtpVersion, NtsRecord, NtsRecordDecoder, WriteError, }; pub use super::cookiestash::MAX_COOKIES; #[cfg(feature = "ntpv5")] pub mod v5 { pub use crate::packet::v5::server_reference_id::{BloomFilter, ServerId}; } #[cfg(feature = "nts-pool")] pub use super::nts_record::AeadAlgorithm; #[cfg(feature = "nts-pool")] pub use super::nts_pool_ke::{ ClientToPoolData, ClientToPoolDecoder, PoolToServerData, PoolToServerDecoder, SupportedAlgorithmsDecoder, }; } #[cfg(feature = "__internal-api")] pub use exports::*; #[cfg(not(feature = "__internal-api"))] pub(crate) use exports::*; ntp-proto-1.4.0/src/nts_pool_ke.rs000064400000000000000000000251541046102023000152500ustar 00000000000000use std::ops::ControlFlow; use crate::{ nts_record::{AeadAlgorithm, NtsKeys, ProtocolId}, KeyExchangeError, NtsRecord, NtsRecordDecoder, }; /// Pool KE decoding records reserved from an NTS KE #[derive(Debug, Default)] pub struct SupportedAlgorithmsDecoder { decoder: NtsRecordDecoder, supported_algorithms: Vec<(u16, u16)>, } impl SupportedAlgorithmsDecoder { pub fn step_with_slice( mut self, bytes: &[u8], ) -> ControlFlow, KeyExchangeError>, Self> { self.decoder.extend(bytes.iter().copied()); loop { match self.decoder.step() { Err(e) => return ControlFlow::Break(Err(e.into())), Ok(Some(record)) => self = self.step_with_record(record)?, Ok(None) => return ControlFlow::Continue(self), } } } #[inline(always)] fn step_with_record( self, record: NtsRecord, ) -> ControlFlow, KeyExchangeError>, Self> { use ControlFlow::{Break, Continue}; use NtsRecord::*; let mut state = self; match record { EndOfMessage => Break(Ok(state.supported_algorithms)), Error { errorcode } => Break(Err(KeyExchangeError::from_error_code(errorcode))), Warning { warningcode } => { tracing::warn!(warningcode, "Received key exchange warning code"); Continue(state) } #[cfg(feature = "nts-pool")] SupportedAlgorithmList { supported_algorithms, } => { state.supported_algorithms = supported_algorithms; Continue(state) } _ => Continue(state), } } } /// Pool KE decoding records from the client #[derive(Debug, Default)] pub struct ClientToPoolDecoder { decoder: NtsRecordDecoder, /// AEAD algorithm that the client is able to use and that we support /// it may be that the server and client supported algorithms have no /// intersection! algorithm: AeadAlgorithm, /// Protocol (NTP version) that is supported by both client and server protocol: ProtocolId, records: Vec, denied_servers: Vec, #[cfg(feature = "ntpv5")] allow_v5: bool, } #[derive(Debug)] pub struct ClientToPoolData { pub algorithm: AeadAlgorithm, pub protocol: ProtocolId, pub records: Vec, pub denied_servers: Vec, } impl ClientToPoolData { pub fn extract_nts_keys( &self, stream: &rustls23::ConnectionCommon, ) -> Result { self.algorithm .extract_nts_keys(self.protocol, stream) .map_err(KeyExchangeError::Tls) } } impl ClientToPoolDecoder { pub fn step_with_slice( mut self, bytes: &[u8], ) -> ControlFlow, Self> { self.decoder.extend(bytes.iter().copied()); loop { match self.decoder.step() { Err(e) => return ControlFlow::Break(Err(e.into())), Ok(Some(record)) => self = self.step_with_record(record)?, Ok(None) => return ControlFlow::Continue(self), } } } #[inline(always)] fn step_with_record( self, record: NtsRecord, ) -> ControlFlow, Self> { use self::AeadAlgorithm as Algorithm; use ControlFlow::{Break, Continue}; use KeyExchangeError::*; use NtsRecord::*; let mut state = self; match record { EndOfMessage => { // NOTE EndOfMessage not pushed onto the vector let result = ClientToPoolData { algorithm: state.algorithm, protocol: state.protocol, records: state.records, denied_servers: state.denied_servers, }; Break(Ok(result)) } Error { errorcode } => { // Break(Err(KeyExchangeError::from_error_code(errorcode))) } Warning { warningcode } => { tracing::debug!(warningcode, "Received key exchange warning code"); state.records.push(record); Continue(state) } #[cfg(feature = "ntpv5")] DraftId { data } => { if data == crate::packet::v5::DRAFT_VERSION.as_bytes() { state.allow_v5 = true; } Continue(state) } NextProtocol { protocol_ids } => { #[cfg(feature = "ntpv5")] let selected = if state.allow_v5 { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize_v5) } else { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize) }; #[cfg(not(feature = "ntpv5"))] let selected = protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize); match selected { None => Break(Err(NoValidProtocol)), Some(protocol) => { state.protocol = protocol; Continue(state) } } } AeadAlgorithm { algorithm_ids, .. } => { let selected = algorithm_ids .iter() .copied() .find_map(Algorithm::try_deserialize); match selected { None => Break(Err(NoValidAlgorithm)), Some(algorithm) => { state.algorithm = algorithm; Continue(state) } } } #[cfg(feature = "nts-pool")] NtpServerDeny { denied } => { state.denied_servers.push(denied); Continue(state) } other => { // just forward other records blindly state.records.push(other); Continue(state) } } } } /// Pool KE decoding records from the NTS KE #[derive(Debug, Default)] pub struct PoolToServerDecoder { decoder: NtsRecordDecoder, /// AEAD algorithm that the client is able to use and that we support /// it may be that the server and client supported algorithms have no /// intersection! algorithm: AeadAlgorithm, /// Protocol (NTP version) that is supported by both client and server protocol: ProtocolId, records: Vec, #[cfg(feature = "ntpv5")] allow_v5: bool, } #[derive(Debug)] pub struct PoolToServerData { pub algorithm: AeadAlgorithm, pub protocol: ProtocolId, pub records: Vec, } impl PoolToServerDecoder { pub fn step_with_slice( mut self, bytes: &[u8], ) -> ControlFlow, Self> { self.decoder.extend(bytes.iter().copied()); loop { match self.decoder.step() { Err(e) => return ControlFlow::Break(Err(e.into())), Ok(Some(record)) => self = self.step_with_record(record)?, Ok(None) => return ControlFlow::Continue(self), } } } #[inline(always)] fn step_with_record( self, record: NtsRecord, ) -> ControlFlow, Self> { use self::AeadAlgorithm as Algorithm; use ControlFlow::{Break, Continue}; use KeyExchangeError::*; use NtsRecord::*; let mut state = self; match &record { EndOfMessage => { state.records.push(EndOfMessage); let result = PoolToServerData { algorithm: state.algorithm, protocol: state.protocol, records: state.records, }; Break(Ok(result)) } Error { errorcode } => { // Break(Err(KeyExchangeError::from_error_code(*errorcode))) } Warning { warningcode } => { tracing::debug!(warningcode, "Received key exchange warning code"); state.records.push(record); Continue(state) } #[cfg(feature = "ntpv5")] DraftId { data } => { if data == crate::packet::v5::DRAFT_VERSION.as_bytes() { state.allow_v5 = true; } Continue(state) } NextProtocol { protocol_ids } => { #[cfg(feature = "ntpv5")] let selected = if state.allow_v5 { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize_v5) } else { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize) }; #[cfg(not(feature = "ntpv5"))] let selected = protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize); state.records.push(record); match selected { None => Break(Err(NoValidProtocol)), Some(protocol) => { state.protocol = protocol; Continue(state) } } } AeadAlgorithm { algorithm_ids, .. } => { let selected = algorithm_ids .iter() .copied() .find_map(Algorithm::try_deserialize); state.records.push(record); match selected { None => Break(Err(NoValidAlgorithm)), Some(algorithm) => { state.algorithm = algorithm; Continue(state) } } } _other => { // just forward other records blindly state.records.push(record); Continue(state) } } } } ntp-proto-1.4.0/src/nts_record.rs000064400000000000000000003344171046102023000151030ustar 00000000000000use std::{ fmt::Display, io::{Read, Write}, ops::ControlFlow, sync::Arc, }; use crate::tls_utils::{self, ServerName}; use crate::{ cookiestash::CookieStash, io::{NonBlockingRead, NonBlockingWrite}, keyset::{DecodedServerCookie, KeySet}, packet::{AesSivCmac256, AesSivCmac512, Cipher}, source::{ProtocolVersion, SourceNtsData}, }; #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum NtpVersion { V4, #[cfg(feature = "ntpv5")] V5, } #[derive(Debug)] pub enum WriteError { Invalid, TooLong, } impl std::fmt::Display for WriteError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Invalid => f.write_str("Invalid NTS-KE record"), Self::TooLong => f.write_str("NTS-KE record too long"), } } } impl std::error::Error for WriteError {} impl NtsRecord { fn record_type(&self) -> u16 { match self { NtsRecord::EndOfMessage => 0, NtsRecord::NextProtocol { .. } => 1, NtsRecord::Error { .. } => 2, NtsRecord::Warning { .. } => 3, NtsRecord::AeadAlgorithm { .. } => 4, NtsRecord::NewCookie { .. } => 5, NtsRecord::Server { .. } => 6, NtsRecord::Port { .. } => 7, #[cfg(feature = "nts-pool")] NtsRecord::KeepAlive { .. } => 0x4000, #[cfg(feature = "nts-pool")] NtsRecord::SupportedAlgorithmList { .. } => 0x4001, #[cfg(feature = "nts-pool")] NtsRecord::FixedKeyRequest { .. } => 0x4002, #[cfg(feature = "nts-pool")] NtsRecord::NtpServerDeny { .. } => 0x4003, #[cfg(feature = "ntpv5")] NtsRecord::DraftId { .. } => 0x4008, NtsRecord::Unknown { record_type, .. } => record_type & !0x8000, } } fn is_critical(&self) -> bool { match self { NtsRecord::EndOfMessage => true, NtsRecord::NextProtocol { .. } => true, NtsRecord::Error { .. } => true, NtsRecord::Warning { .. } => true, NtsRecord::AeadAlgorithm { critical, .. } => *critical, NtsRecord::NewCookie { .. } => false, NtsRecord::Server { critical, .. } => *critical, NtsRecord::Port { critical, .. } => *critical, #[cfg(feature = "nts-pool")] NtsRecord::KeepAlive { .. } => false, #[cfg(feature = "nts-pool")] NtsRecord::SupportedAlgorithmList { .. } => true, #[cfg(feature = "nts-pool")] NtsRecord::FixedKeyRequest { .. } => true, #[cfg(feature = "nts-pool")] NtsRecord::NtpServerDeny { .. } => false, #[cfg(feature = "ntpv5")] NtsRecord::DraftId { .. } => false, NtsRecord::Unknown { critical, .. } => *critical, } } fn validate(&self) -> Result<(), WriteError> { match self { NtsRecord::Unknown { record_type, data, .. } => { if *record_type & 0x8000 != 0 { return Err(WriteError::Invalid); } if data.len() > u16::MAX as usize { return Err(WriteError::TooLong); } } NtsRecord::NextProtocol { protocol_ids } => { if protocol_ids.len() >= (u16::MAX as usize) / 2 { return Err(WriteError::TooLong); } } NtsRecord::AeadAlgorithm { algorithm_ids, .. } => { if algorithm_ids.len() >= (u16::MAX as usize) / 2 { return Err(WriteError::TooLong); } } NtsRecord::NewCookie { cookie_data } => { if cookie_data.len() > u16::MAX as usize { return Err(WriteError::TooLong); } } NtsRecord::Server { name, .. } => { if name.as_bytes().len() >= (u16::MAX as usize) { return Err(WriteError::TooLong); } } _ => {} } Ok(()) } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum NtsRecord { EndOfMessage, NextProtocol { protocol_ids: Vec, }, Error { errorcode: u16, }, Warning { warningcode: u16, }, AeadAlgorithm { critical: bool, algorithm_ids: Vec, }, NewCookie { cookie_data: Vec, }, Server { critical: bool, name: String, }, Port { critical: bool, port: u16, }, Unknown { record_type: u16, critical: bool, data: Vec, }, #[cfg(feature = "ntpv5")] DraftId { data: Vec, }, #[cfg(feature = "nts-pool")] KeepAlive, #[cfg(feature = "nts-pool")] SupportedAlgorithmList { supported_algorithms: Vec<(u16, u16)>, }, #[cfg(feature = "nts-pool")] FixedKeyRequest { c2s: Vec, s2c: Vec, }, #[cfg(feature = "nts-pool")] NtpServerDeny { denied: String, }, } fn read_u16_be(reader: &mut impl NonBlockingRead) -> std::io::Result { let mut bytes = [0, 0]; reader.read_exact(&mut bytes)?; Ok(u16::from_be_bytes(bytes)) } fn read_u16s_be(reader: &mut impl NonBlockingRead, length: usize) -> std::io::Result> { (0..length).map(|_| read_u16_be(reader)).collect() } #[cfg(feature = "nts-pool")] fn read_u16_tuples_be( reader: &mut impl NonBlockingRead, length: usize, ) -> std::io::Result> { (0..length) .map(|_| Ok((read_u16_be(reader)?, read_u16_be(reader)?))) .collect() } fn read_bytes_exact(reader: &mut impl NonBlockingRead, length: usize) -> std::io::Result> { let mut output = vec![0; length]; reader.read_exact(&mut output)?; Ok(output) } impl NtsRecord { pub const UNRECOGNIZED_CRITICAL_RECORD: u16 = 0; pub const BAD_REQUEST: u16 = 1; pub const INTERNAL_SERVER_ERROR: u16 = 2; #[allow(unused_variables)] pub fn client_key_exchange_records( ntp_version: Option, denied_servers: impl IntoIterator, ) -> Box<[NtsRecord]> { let mut base = vec![ #[cfg(feature = "ntpv5")] NtsRecord::DraftId { data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(), }, #[cfg(feature = "ntpv5")] match ntp_version { None => NtsRecord::NextProtocol { protocol_ids: vec![0x8001, 0], }, Some(NtpVersion::V4) => NtsRecord::NextProtocol { protocol_ids: vec![0], }, Some(NtpVersion::V5) => NtsRecord::NextProtocol { protocol_ids: vec![0x8001], }, }, #[cfg(not(feature = "ntpv5"))] NtsRecord::NextProtocol { protocol_ids: vec![ #[cfg(feature = "ntpv5")] 0x8001, 0, ], }, NtsRecord::AeadAlgorithm { critical: false, algorithm_ids: AeadAlgorithm::IN_ORDER_OF_PREFERENCE .iter() .map(|algorithm| *algorithm as u16) .collect(), }, ]; #[cfg(feature = "nts-pool")] base.extend( denied_servers .into_iter() .map(|server| NtsRecord::NtpServerDeny { denied: server }), ); base.push(NtsRecord::EndOfMessage); base.into_boxed_slice() } #[cfg(feature = "nts-pool")] pub fn client_key_exchange_records_fixed( c2s: Vec, s2c: Vec, ) -> [NtsRecord; if cfg!(feature = "ntpv5") { 5 } else { 4 }] { [ #[cfg(feature = "ntpv5")] NtsRecord::DraftId { data: crate::packet::v5::DRAFT_VERSION.as_bytes().into(), }, NtsRecord::NextProtocol { protocol_ids: vec![ #[cfg(feature = "ntpv5")] 0x8001, 0, ], }, NtsRecord::AeadAlgorithm { critical: false, algorithm_ids: AeadAlgorithm::IN_ORDER_OF_PREFERENCE .iter() .map(|algorithm| *algorithm as u16) .collect(), }, #[cfg(feature = "nts-pool")] NtsRecord::FixedKeyRequest { c2s, s2c }, NtsRecord::EndOfMessage, ] } fn server_key_exchange_records( protocol: ProtocolId, algorithm: AeadAlgorithm, keyset: &KeySet, keys: NtsKeys, ntp_port: Option, ntp_server: Option, #[cfg(feature = "nts-pool")] send_supported_algorithms: bool, ) -> Box<[NtsRecord]> { let cookie = DecodedServerCookie { algorithm, s2c: keys.s2c, c2s: keys.c2s, }; let next_cookie = || -> NtsRecord { NtsRecord::NewCookie { cookie_data: keyset.encode_cookie(&cookie), } }; let mut response = Vec::new(); //Probably, a NTS request should not send this record while attempting //to negotiate a "standard key exchange" at the same time. The current spec //does not outright say this, however, so we will add it whenever requested. #[cfg(feature = "nts-pool")] if send_supported_algorithms { response.push(NtsRecord::SupportedAlgorithmList { supported_algorithms: crate::nts_record::AeadAlgorithm::IN_ORDER_OF_PREFERENCE .iter() .map(|&algo| (algo as u16, algo.key_size())) .collect(), }) } if let Some(ntp_port) = ntp_port { response.push(NtsRecord::Port { critical: ntp_port != 123, port: ntp_port, }); } if let Some(ntp_server) = ntp_server { response.push(NtsRecord::Server { critical: true, name: ntp_server, }); } response.extend(vec![ NtsRecord::NextProtocol { protocol_ids: vec![protocol as u16], }, NtsRecord::AeadAlgorithm { critical: false, algorithm_ids: vec![algorithm as u16], }, next_cookie(), next_cookie(), next_cookie(), next_cookie(), next_cookie(), next_cookie(), next_cookie(), next_cookie(), NtsRecord::EndOfMessage, ]); response.into_boxed_slice() } pub fn read(reader: &mut impl NonBlockingRead) -> std::io::Result { let raw_record_type = read_u16_be(reader)?; let critical = raw_record_type & 0x8000 != 0; let record_type = raw_record_type & !0x8000; let record_len = read_u16_be(reader)? as usize; Ok(match record_type { 0 if record_len == 0 && critical => NtsRecord::EndOfMessage, 1 if record_len % 2 == 0 && critical => { let n_protocols = record_len / 2; let protocol_ids = read_u16s_be(reader, n_protocols)?; NtsRecord::NextProtocol { protocol_ids } } 2 if record_len == 2 && critical => NtsRecord::Error { errorcode: read_u16_be(reader)?, }, 3 if record_len == 2 && critical => NtsRecord::Warning { warningcode: read_u16_be(reader)?, }, 4 if record_len % 2 == 0 => { let n_algorithms = record_len / 2; let algorithm_ids = read_u16s_be(reader, n_algorithms)?; NtsRecord::AeadAlgorithm { critical, algorithm_ids, } } 5 if !critical => { let cookie_data = read_bytes_exact(reader, record_len)?; NtsRecord::NewCookie { cookie_data } } 6 => { // NOTE: the string data should be ascii (not utf8) but we don't enforce that here let str_data = read_bytes_exact(reader, record_len)?; match String::from_utf8(str_data) { Ok(name) => NtsRecord::Server { critical, name }, Err(e) => NtsRecord::Unknown { record_type, critical, data: e.into_bytes(), }, } } 7 if record_len == 2 => NtsRecord::Port { critical, port: read_u16_be(reader)?, }, #[cfg(feature = "nts-pool")] 0x4000 if !critical => NtsRecord::KeepAlive, #[cfg(feature = "nts-pool")] 0x4001 if record_len % 4 == 0 && critical => { let n_algorithms = record_len / 4; // 4 bytes per element let supported_algorithms = read_u16_tuples_be(reader, n_algorithms)?; NtsRecord::SupportedAlgorithmList { supported_algorithms, } } #[cfg(feature = "nts-pool")] 0x4002 if record_len % 2 == 0 && critical => { let mut c2s = vec![0; record_len / 2]; let mut s2c = vec![0; record_len / 2]; reader.read_exact(&mut c2s)?; reader.read_exact(&mut s2c)?; NtsRecord::FixedKeyRequest { c2s, s2c } } #[cfg(feature = "nts-pool")] 0x4003 => { // NOTE: the string data should be ascii (not utf8) but we don't enforce that here let str_data = read_bytes_exact(reader, record_len)?; match String::from_utf8(str_data) { Ok(denied) => NtsRecord::NtpServerDeny { denied }, Err(e) => NtsRecord::Unknown { record_type, critical, data: e.into_bytes(), }, } } #[cfg(feature = "ntpv5")] 0x4008 => NtsRecord::DraftId { data: read_bytes_exact(reader, record_len)?, }, _ => NtsRecord::Unknown { record_type, critical, data: read_bytes_exact(reader, record_len)?, }, }) } pub fn write(&self, mut writer: impl NonBlockingWrite) -> std::io::Result<()> { // error out early when the record is invalid if let Err(e) = self.validate() { return Err(std::io::Error::new(std::io::ErrorKind::Other, e)); } // all messages start with the record type let record_type = self.record_type() | ((self.is_critical() as u16) << 15); writer.write_all(&record_type.to_be_bytes())?; let size_of_u16 = std::mem::size_of::() as u16; match self { NtsRecord::EndOfMessage => { writer.write_all(&0_u16.to_be_bytes())?; } NtsRecord::Unknown { data, .. } => { writer.write_all(&(data.len() as u16).to_be_bytes())?; writer.write_all(data)?; } NtsRecord::NextProtocol { protocol_ids } => { let length = size_of_u16 * protocol_ids.len() as u16; writer.write_all(&length.to_be_bytes())?; for id in protocol_ids { writer.write_all(&id.to_be_bytes())?; } } NtsRecord::Error { errorcode } => { writer.write_all(&size_of_u16.to_be_bytes())?; writer.write_all(&errorcode.to_be_bytes())?; } NtsRecord::Warning { warningcode } => { writer.write_all(&size_of_u16.to_be_bytes())?; writer.write_all(&warningcode.to_be_bytes())?; } NtsRecord::AeadAlgorithm { algorithm_ids, .. } => { let length = size_of_u16 * algorithm_ids.len() as u16; writer.write_all(&length.to_be_bytes())?; for id in algorithm_ids { writer.write_all(&id.to_be_bytes())?; } } NtsRecord::NewCookie { cookie_data } => { let length = cookie_data.len() as u16; writer.write_all(&length.to_be_bytes())?; writer.write_all(cookie_data)?; } NtsRecord::Server { name, .. } => { // NOTE: the server name should be ascii #[cfg(not(feature = "__internal-fuzz"))] debug_assert!(name.is_ascii()); let length = name.len() as u16; writer.write_all(&length.to_be_bytes())?; writer.write_all(name.as_bytes())?; } NtsRecord::Port { port, .. } => { writer.write_all(&size_of_u16.to_be_bytes())?; writer.write_all(&port.to_be_bytes())?; } #[cfg(feature = "nts-pool")] NtsRecord::KeepAlive => { // nothing to encode; there is no payload let length = 0u16; writer.write_all(&length.to_be_bytes())?; } #[cfg(feature = "nts-pool")] NtsRecord::SupportedAlgorithmList { supported_algorithms, } => { let length = size_of_u16 * 2 * supported_algorithms.len() as u16; writer.write_all(&length.to_be_bytes())?; for (algorithm_id, key_length) in supported_algorithms { writer.write_all(&algorithm_id.to_be_bytes())?; writer.write_all(&key_length.to_be_bytes())?; } } #[cfg(feature = "nts-pool")] NtsRecord::FixedKeyRequest { c2s, s2c } => { debug_assert_eq!(c2s.len(), s2c.len()); let length = (c2s.len() + s2c.len()) as u16; writer.write_all(&length.to_be_bytes())?; writer.write_all(c2s)?; writer.write_all(s2c)?; } #[cfg(feature = "nts-pool")] NtsRecord::NtpServerDeny { denied: name } => { // NOTE: the server name should be ascii #[cfg(not(feature = "__internal-fuzz"))] debug_assert!(name.is_ascii()); let length = name.len() as u16; writer.write_all(&length.to_be_bytes())?; writer.write_all(name.as_bytes())?; } #[cfg(feature = "ntpv5")] NtsRecord::DraftId { data } => { writer.write_all(&(data.len() as u16).to_be_bytes())?; writer.write_all(data)?; } } Ok(()) } pub fn decoder() -> NtsRecordDecoder { NtsRecordDecoder { bytes: vec![] } } } #[cfg(feature = "__internal-fuzz")] impl<'a> arbitrary::Arbitrary<'a> for NtsRecord { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { let record = u16::arbitrary(u)?; let critical = record & 0x8000 != 0; let record_type = record & !0x8000; use NtsRecord::*; Ok(match record_type { 0 => EndOfMessage, 1 => NextProtocol { protocol_ids: u.arbitrary()?, }, 2 => Error { errorcode: u.arbitrary()?, }, 3 => Warning { warningcode: u.arbitrary()?, }, 4 => AeadAlgorithm { critical, algorithm_ids: u.arbitrary()?, }, 5 => NewCookie { cookie_data: u.arbitrary()?, }, 6 => Server { critical, name: u.arbitrary()?, }, 7 => Port { critical, port: u.arbitrary()?, }, _ => NtsRecord::Unknown { record_type, critical, data: u.arbitrary()?, }, }) } } #[derive(Debug, Clone, Default)] pub struct NtsRecordDecoder { bytes: Vec, } impl Extend for NtsRecordDecoder { fn extend>(&mut self, iter: T) { self.bytes.extend(iter); } } impl NtsRecordDecoder { /// the size of the KE packet header: /// /// - 2 bytes for the record type + critical flag /// - 2 bytes for the record length const HEADER_BYTES: usize = 4; /// Try to decode the next record. Returns None when there are not enough bytes pub fn step(&mut self) -> std::io::Result> { if self.bytes.len() < Self::HEADER_BYTES { return Ok(None); } let record_len = u16::from_be_bytes([self.bytes[2], self.bytes[3]]); let message_len = Self::HEADER_BYTES + record_len as usize; if self.bytes.len() >= message_len { let record = NtsRecord::read(&mut self.bytes.as_slice())?; // remove the first `message_len` bytes from the buffer self.bytes.copy_within(message_len.., 0); self.bytes.truncate(self.bytes.len() - message_len); Ok(Some(record)) } else { Ok(None) } } pub fn new() -> Self { Self::default() } } #[derive(Debug)] pub enum KeyExchangeError { UnrecognizedCriticalRecord, BadRequest, InternalServerError, UnknownErrorCode(u16), BadResponse, NoValidProtocol, NoValidAlgorithm, InvalidFixedKeyLength, NoCookies, Io(std::io::Error), Tls(tls_utils::Error), Certificate(tls_utils::Error), DnsName(tls_utils::InvalidDnsNameError), IncompleteResponse, } impl Display for KeyExchangeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::UnrecognizedCriticalRecord => { write!(f, "Unrecognized record is marked as critical") } Self::BadRequest => write!(f, "Remote: Bad request"), Self::InternalServerError => write!(f, "Remote: Internal server error"), Self::UnknownErrorCode(e) => write!(f, "Remote: Error with unknown code {e}"), Self::BadResponse => write!(f, "The server response is invalid"), Self::NoValidProtocol => write!( f, "No continuation protocol supported by both us and server" ), Self::NoValidAlgorithm => { write!(f, "No encryption algorithm supported by both us and server") } Self::InvalidFixedKeyLength => write!( f, "The length of a fixed key does not match the algorithm used" ), Self::NoCookies => write!(f, "Missing cookies"), Self::Io(e) => write!(f, "{e}"), Self::Tls(e) => write!(f, "{e}"), Self::Certificate(e) => write!(f, "{e}"), Self::DnsName(e) => write!(f, "{e}"), Self::IncompleteResponse => write!(f, "Incomplete response"), } } } impl From for KeyExchangeError { fn from(value: std::io::Error) -> Self { Self::Io(value) } } impl From for KeyExchangeError { fn from(value: crate::tls_utils::Error) -> Self { Self::Tls(value) } } impl From for KeyExchangeError { fn from(value: tls_utils::InvalidDnsNameError) -> Self { Self::DnsName(value) } } impl std::error::Error for KeyExchangeError {} impl KeyExchangeError { pub(crate) fn from_error_code(error_code: u16) -> Self { match error_code { 0 => Self::UnrecognizedCriticalRecord, 1 => Self::BadRequest, 2 => Self::InternalServerError, _ => Self::UnknownErrorCode(error_code), } } pub fn to_error_code(&self) -> u16 { use KeyExchangeError::*; match self { UnrecognizedCriticalRecord => NtsRecord::UNRECOGNIZED_CRITICAL_RECORD, BadRequest => NtsRecord::BAD_REQUEST, InternalServerError | Io(_) => NtsRecord::INTERNAL_SERVER_ERROR, UnknownErrorCode(_) | BadResponse | NoValidProtocol | NoValidAlgorithm | InvalidFixedKeyLength | NoCookies | Tls(_) | Certificate(_) | DnsName(_) | IncompleteResponse => NtsRecord::BAD_REQUEST, } } } /// From https://www.rfc-editor.org/rfc/rfc8915.html#name-network-time-security-next- #[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] #[repr(u16)] pub enum ProtocolId { #[default] NtpV4 = 0, #[cfg(feature = "ntpv5")] NtpV5 = 0x8001, } impl ProtocolId { const IN_ORDER_OF_PREFERENCE: &'static [Self] = &[ #[cfg(feature = "ntpv5")] Self::NtpV5, Self::NtpV4, ]; pub const fn try_deserialize(number: u16) -> Option { match number { 0 => Some(Self::NtpV4), _ => None, } } #[cfg(feature = "ntpv5")] pub const fn try_deserialize_v5(number: u16) -> Option { match number { 0 => Some(Self::NtpV4), 0x8001 => Some(Self::NtpV5), _ => None, } } } /// From https://www.iana.org/assignments/aead-parameters/aead-parameters.xhtml #[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] #[repr(u16)] pub enum AeadAlgorithm { #[default] AeadAesSivCmac256 = 15, AeadAesSivCmac512 = 17, } impl AeadAlgorithm { // per https://www.rfc-editor.org/rfc/rfc8915.html#section-5.1 pub const fn c2s_context(self, protocol: ProtocolId) -> [u8; 5] { // The final octet SHALL be 0x00 for the C2S key [ (protocol as u16 >> 8) as u8, protocol as u8, (self as u16 >> 8) as u8, self as u8, 0, ] } // per https://www.rfc-editor.org/rfc/rfc8915.html#section-5.1 pub const fn s2c_context(self, protocol: ProtocolId) -> [u8; 5] { // The final octet SHALL be 0x01 for the S2C key [ (protocol as u16 >> 8) as u8, protocol as u8, (self as u16 >> 8) as u8, self as u8, 1, ] } pub const fn try_deserialize(number: u16) -> Option { match number { 15 => Some(AeadAlgorithm::AeadAesSivCmac256), 17 => Some(AeadAlgorithm::AeadAesSivCmac512), _ => None, } } const IN_ORDER_OF_PREFERENCE: &'static [Self] = &[Self::AeadAesSivCmac512, Self::AeadAesSivCmac256]; pub(crate) fn extract_nts_keys( &self, protocol: ProtocolId, tls_connection: &tls_utils::ConnectionCommon, ) -> Result { match self { AeadAlgorithm::AeadAesSivCmac256 => { let c2s = extract_nts_key(tls_connection, self.c2s_context(protocol))?; let s2c = extract_nts_key(tls_connection, self.s2c_context(protocol))?; let c2s = Box::new(AesSivCmac256::new(c2s)); let s2c = Box::new(AesSivCmac256::new(s2c)); Ok(NtsKeys { c2s, s2c }) } AeadAlgorithm::AeadAesSivCmac512 => { let c2s = extract_nts_key(tls_connection, self.c2s_context(protocol))?; let s2c = extract_nts_key(tls_connection, self.s2c_context(protocol))?; let c2s = Box::new(AesSivCmac512::new(c2s)); let s2c = Box::new(AesSivCmac512::new(s2c)); Ok(NtsKeys { c2s, s2c }) } } } #[cfg(feature = "nts-pool")] fn try_into_nts_keys(&self, RequestedKeys { c2s, s2c }: &RequestedKeys) -> Option { match self { AeadAlgorithm::AeadAesSivCmac256 => { let c2s = Box::new(AesSivCmac256::from_key_bytes(c2s).ok()?); let s2c = Box::new(AesSivCmac256::from_key_bytes(s2c).ok()?); Some(NtsKeys { c2s, s2c }) } AeadAlgorithm::AeadAesSivCmac512 => { let c2s = Box::new(AesSivCmac512::from_key_bytes(c2s).ok()?); let s2c = Box::new(AesSivCmac512::from_key_bytes(s2c).ok()?); Some(NtsKeys { c2s, s2c }) } } } #[cfg(feature = "nts-pool")] fn key_size(&self) -> u16 { match self { AeadAlgorithm::AeadAesSivCmac256 => AesSivCmac256::key_size() as u16, AeadAlgorithm::AeadAesSivCmac512 => AesSivCmac512::key_size() as u16, } } } pub struct NtsKeys { c2s: Box, s2c: Box, } impl NtsKeys { #[cfg(feature = "nts-pool")] pub fn as_fixed_key_request(&self) -> NtsRecord { NtsRecord::FixedKeyRequest { c2s: self.c2s.key_bytes().to_vec(), s2c: self.s2c.key_bytes().to_vec(), } } } impl std::fmt::Debug for NtsKeys { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("NtsKeys") .field("c2s", &"") .field("s2c", &"") .finish() } } fn extract_nts_key, ConnectionData>( tls_connection: &tls_utils::ConnectionCommon, context: [u8; 5], ) -> Result { let mut key = T::default(); tls_connection.export_keying_material( &mut key, b"EXPORTER-network-time-security", Some(context.as_slice()), )?; Ok(key) } #[derive(Debug, PartialEq, Eq)] pub struct PartialKeyExchangeData { remote: Option, port: Option, protocol: ProtocolId, algorithm: AeadAlgorithm, cookies: CookieStash, #[cfg(feature = "nts-pool")] supported_algorithms: Option>, } #[derive(Debug, Default)] pub struct KeyExchangeResultDecoder { decoder: NtsRecordDecoder, remote: Option, port: Option, algorithm: Option, protocol: Option, cookies: CookieStash, #[cfg(feature = "nts-pool")] keep_alive: bool, #[cfg(feature = "nts-pool")] supported_algorithms: Option>, } impl KeyExchangeResultDecoder { pub fn step_with_slice( mut self, bytes: &[u8], ) -> ControlFlow, Self> { self.decoder.extend(bytes.iter().copied()); loop { match self.decoder.step() { Err(e) => return ControlFlow::Break(Err(e.into())), Ok(Some(record)) => self = self.step_with_record(record)?, Ok(None) => return ControlFlow::Continue(self), } } } #[inline(always)] fn step_with_record( self, record: NtsRecord, ) -> ControlFlow, Self> { use self::AeadAlgorithm as Algorithm; use ControlFlow::{Break, Continue}; use KeyExchangeError::*; use NtsRecord::*; let mut state = self; match record { EndOfMessage => { let Some(protocol) = state.protocol else { return ControlFlow::Break(Err(KeyExchangeError::NoValidProtocol)); }; // the spec notes // // > If the NTS Next Protocol Negotiation record offers Protocol ID 0 (for NTPv4), // > then this record MUST be included exactly once. Other protocols MAY require it as well. // // but we only support Protocol ID 0 (and assume ntpv5 behaves like ntpv4 in this regard) let Some(algorithm) = state.algorithm else { return ControlFlow::Break(Err(KeyExchangeError::NoValidAlgorithm)); }; if state.cookies.is_empty() { Break(Err(KeyExchangeError::NoCookies)) } else { Break(Ok(PartialKeyExchangeData { remote: state.remote, port: state.port, protocol, algorithm, cookies: state.cookies, #[cfg(feature = "nts-pool")] supported_algorithms: state.supported_algorithms, })) } } #[cfg(feature = "ntpv5")] DraftId { .. } => { tracing::debug!("Unexpected draft id"); Continue(state) } NewCookie { cookie_data } => { state.cookies.store(cookie_data); Continue(state) } Server { name, .. } => { state.remote = Some(name); Continue(state) } Port { port, .. } => { state.port = Some(port); Continue(state) } Error { errorcode } => { // Break(Err(KeyExchangeError::from_error_code(errorcode))) } Warning { warningcode } => { tracing::warn!(warningcode, "Received key exchange warning code"); Continue(state) } NextProtocol { protocol_ids } => { let selected = ProtocolId::IN_ORDER_OF_PREFERENCE .iter() .find_map(|proto| protocol_ids.contains(&(*proto as u16)).then_some(*proto)); match selected { None => Break(Err(NoValidProtocol)), Some(protocol) => { // The NTS Next Protocol Negotiation record [..] MUST occur exactly once in every NTS-KE request and response. match state.protocol { None => { state.protocol = Some(protocol); Continue(state) } Some(_) => Break(Err(KeyExchangeError::BadResponse)), } } } } AeadAlgorithm { algorithm_ids, .. } => { // it MUST include at most one let algorithm_id = match algorithm_ids[..] { [] => return Break(Err(NoValidAlgorithm)), [algorithm_id] => algorithm_id, _ => return Break(Err(BadResponse)), }; let selected = Algorithm::IN_ORDER_OF_PREFERENCE .iter() .find(|algo| (algorithm_id == (**algo as u16))); match selected { None => Break(Err(NoValidAlgorithm)), Some(algorithm) => { // for the protocol ids we support, the AeadAlgorithm record must be present match state.algorithm { None => { state.algorithm = Some(*algorithm); Continue(state) } Some(_) => Break(Err(KeyExchangeError::BadResponse)), } } } } Unknown { critical, .. } => { if critical { Break(Err(KeyExchangeError::UnrecognizedCriticalRecord)) } else { Continue(state) } } #[cfg(feature = "nts-pool")] KeepAlive => { state.keep_alive = true; Continue(state) } #[cfg(feature = "nts-pool")] SupportedAlgorithmList { supported_algorithms, } => { use self::AeadAlgorithm; state.supported_algorithms = Some( supported_algorithms .into_iter() .filter_map(|(aead_protocol_id, key_length)| { let aead_algorithm = AeadAlgorithm::try_deserialize(aead_protocol_id)?; Some((aead_algorithm, key_length)) }) .collect::>() .into_boxed_slice(), ); Continue(state) } #[cfg(feature = "nts-pool")] FixedKeyRequest { .. } => { // a client should never receive a FixedKeyRequest tracing::warn!("Unexpected fixed key request"); Continue(state) } #[cfg(feature = "nts-pool")] NtpServerDeny { .. } => { // a client should never receive a NtpServerDeny tracing::warn!("Unexpected ntp server deny"); Continue(state) } } } fn new() -> Self { Self::default() } } #[derive(Debug)] pub struct KeyExchangeResult { pub remote: String, pub port: u16, pub nts: Box, pub protocol_version: ProtocolVersion, #[cfg(feature = "nts-pool")] pub algorithms_reported_by_server: Option>, } pub struct KeyExchangeClient { tls_connection: tls_utils::ClientConnection, decoder: KeyExchangeResultDecoder, server_name: String, } impl KeyExchangeClient { const NTP_DEFAULT_PORT: u16 = 123; pub fn wants_read(&self) -> bool { self.tls_connection.wants_read() } pub fn read_socket(&mut self, rd: &mut dyn Read) -> std::io::Result { self.tls_connection.read_tls(rd) } pub fn wants_write(&self) -> bool { self.tls_connection.wants_write() } pub fn write_socket(&mut self, wr: &mut dyn Write) -> std::io::Result { self.tls_connection.write_tls(wr) } pub fn progress(mut self) -> ControlFlow, Self> { // Move any received data from tls to decoder let mut buf = [0; 128]; loop { if let Err(e) = self.tls_connection.process_new_packets() { return ControlFlow::Break(Err(e.into())); } match self.tls_connection.reader().read(&mut buf) { Ok(0) => return ControlFlow::Break(Err(KeyExchangeError::IncompleteResponse)), Ok(n) => { self.decoder = match self.decoder.step_with_slice(&buf[..n]) { ControlFlow::Continue(decoder) => decoder, ControlFlow::Break(Ok(result)) => { let algorithm = result.algorithm; let protocol = result.protocol; tracing::debug!(?algorithm, "selected AEAD algorithm"); let keys = match algorithm .extract_nts_keys(protocol, &self.tls_connection) { Ok(keys) => keys, Err(e) => return ControlFlow::Break(Err(KeyExchangeError::Tls(e))), }; let nts = Box::new(SourceNtsData { cookies: result.cookies, c2s: keys.c2s, s2c: keys.s2c, }); return ControlFlow::Break(Ok(KeyExchangeResult { remote: result.remote.unwrap_or(self.server_name), protocol_version: match protocol { ProtocolId::NtpV4 => ProtocolVersion::V4, #[cfg(feature = "ntpv5")] ProtocolId::NtpV5 => ProtocolVersion::V5, }, port: result.port.unwrap_or(Self::NTP_DEFAULT_PORT), nts, #[cfg(feature = "nts-pool")] algorithms_reported_by_server: result.supported_algorithms, })); } ControlFlow::Break(Err(error)) => return ControlFlow::Break(Err(error)), } } Err(e) => match e.kind() { std::io::ErrorKind::WouldBlock => return ControlFlow::Continue(self), _ => return ControlFlow::Break(Err(e.into())), }, } } } // should only be used in tests! fn new_without_tls_write( server_name: String, mut tls_config: tls_utils::ClientConfig, ) -> Result { // Ensure we send only ntske/1 as alpn tls_config.alpn_protocols.clear(); tls_config.alpn_protocols.push(b"ntske/1".to_vec()); // TLS only works when the server name is a DNS name; an IP address does not work let tls_connection = tls_utils::ClientConnection::new( Arc::new(tls_config), ServerName::try_from(&server_name as &str)?.to_owned(), )?; Ok(KeyExchangeClient { tls_connection, decoder: KeyExchangeResultDecoder::new(), server_name, }) } pub fn new( server_name: String, tls_config: tls_utils::ClientConfig, ntp_version: Option, denied_servers: impl IntoIterator, ) -> Result { let mut client = Self::new_without_tls_write(server_name, tls_config)?; // Make the request immediately (note, this will only go out to the wire via the write functions above) // We use an intermediary buffer to ensure that all records are sent at once. // This should not be needed, but works around issues in some NTS-ke server implementations let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records(ntp_version, denied_servers).iter() { record.write(&mut buffer)?; } client.tls_connection.writer().write_all(&buffer)?; Ok(client) } } #[derive(Debug, Default)] struct KeyExchangeServerDecoder { decoder: NtsRecordDecoder, /// AEAD algorithm that the client is able to use and that we support /// it may be that the server and client supported algorithms have no /// intersection! algorithm: Option, /// Protocol (NTP version) that is supported by both client and server protocol: Option, #[cfg(feature = "ntpv5")] allow_v5: bool, #[cfg(feature = "nts-pool")] keep_alive: Option, #[cfg(feature = "nts-pool")] requested_supported_algorithms: bool, #[cfg(feature = "nts-pool")] fixed_key_request: Option, } #[cfg(feature = "nts-pool")] #[derive(Debug, PartialEq, Eq)] struct RequestedKeys { c2s: Vec, s2c: Vec, } #[derive(Debug, PartialEq, Eq)] struct ServerKeyExchangeData { algorithm: AeadAlgorithm, protocol: ProtocolId, /// By default, perform key extraction to acquire the c2s and s2c keys; otherwise, use the fixed keys. #[cfg(feature = "nts-pool")] fixed_keys: Option, #[cfg(feature = "nts-pool")] requested_supported_algorithms: bool, } impl KeyExchangeServerDecoder { pub fn step_with_slice( mut self, bytes: &[u8], ) -> ControlFlow, Self> { self.decoder.extend(bytes.iter().copied()); loop { match self.decoder.step() { Err(e) => return ControlFlow::Break(Err(e.into())), Ok(Some(record)) => self = self.step_with_record(record)?, Ok(None) => return ControlFlow::Continue(self), } } } fn validate(self) -> Result { let Some(protocol) = self.protocol else { // The NTS Next Protocol Negotiation record [..] MUST occur exactly once in every NTS-KE request and response. return Err(KeyExchangeError::NoValidProtocol); }; let Some(algorithm) = self.algorithm else { // for the protocol ids we support, the AeadAlgorithm record must be present return Err(KeyExchangeError::NoValidAlgorithm); }; let result = ServerKeyExchangeData { algorithm, protocol, #[cfg(feature = "nts-pool")] fixed_keys: self.fixed_key_request, #[cfg(feature = "nts-pool")] requested_supported_algorithms: self.requested_supported_algorithms, }; Ok(result) } #[cfg(feature = "nts-pool")] fn done(self) -> Result { if self.requested_supported_algorithms { let protocol = self.protocol.unwrap_or_default(); let algorithm = self.algorithm.unwrap_or_default(); let result = ServerKeyExchangeData { algorithm, protocol, #[cfg(feature = "nts-pool")] fixed_keys: self.fixed_key_request, #[cfg(feature = "nts-pool")] requested_supported_algorithms: self.requested_supported_algorithms, }; Ok(result) } else { self.validate() } } #[cfg(not(feature = "nts-pool"))] fn done(self) -> Result { self.validate() } #[inline(always)] fn step_with_record( self, record: NtsRecord, ) -> ControlFlow, Self> { use self::AeadAlgorithm as Algorithm; use ControlFlow::{Break, Continue}; use KeyExchangeError::*; use NtsRecord::*; let mut state = self; match record { EndOfMessage => { // perform a final validation step: did we receive everything that we should? Break(state.done()) } #[cfg(feature = "ntpv5")] DraftId { data } => { if data == crate::packet::v5::DRAFT_VERSION.as_bytes() { state.allow_v5 = true; } Continue(state) } NewCookie { .. } => { // > Clients MUST NOT send records of this type // // TODO should we actively error when a client does? Continue(state) } Server { name: _, .. } => { // > When this record is sent by the client, it indicates that the client wishes to associate with the specified NTP // > server. The NTS-KE server MAY incorporate this request when deciding which NTPv4 Server Negotiation // > records to respond with, but honoring the client's preference is OPTIONAL. The client MUST NOT send more // > than one record of this type. // // we ignore the client's preference Continue(state) } Port { port: _, .. } => { // > When this record is sent by the client in conjunction with a NTPv4 Server Negotiation record, it indicates that // > the client wishes to associate with the NTP server at the specified port. The NTS-KE server MAY incorporate this // > request when deciding what NTPv4 Server Negotiation and NTPv4 Port Negotiation records to respond with, // > but honoring the client's preference is OPTIONAL // // we ignore the client's preference Continue(state) } Error { errorcode } => { // Break(Err(KeyExchangeError::from_error_code(errorcode))) } Warning { warningcode } => { tracing::debug!(warningcode, "Received key exchange warning code"); Continue(state) } NextProtocol { protocol_ids } => { #[cfg(feature = "ntpv5")] let selected = if state.allow_v5 { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize_v5) } else { protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize) }; #[cfg(not(feature = "ntpv5"))] let selected = protocol_ids .iter() .copied() .find_map(ProtocolId::try_deserialize); match selected { None => Break(Err(NoValidProtocol)), Some(protocol) => { // The NTS Next Protocol Negotiation record [..] MUST occur exactly once in every NTS-KE request and response. match state.protocol { None => { state.protocol = Some(protocol); Continue(state) } Some(_) => Break(Err(KeyExchangeError::BadRequest)), } } } } AeadAlgorithm { algorithm_ids, .. } => { let selected = algorithm_ids .iter() .copied() .find_map(Algorithm::try_deserialize); match selected { None => Break(Err(NoValidAlgorithm)), Some(algorithm) => { // for the protocol ids we support, the AeadAlgorithm record must be present match state.algorithm { None => { state.algorithm = Some(algorithm); Continue(state) } Some(_) => Break(Err(KeyExchangeError::BadRequest)), } } } } #[cfg(feature = "nts-pool")] KeepAlive => { state.keep_alive = Some(true); Continue(state) } #[cfg(feature = "nts-pool")] SupportedAlgorithmList { supported_algorithms: _supported_algorithms, } => { #[cfg(not(feature = "__internal-fuzz"))] debug_assert_eq!(_supported_algorithms, &[]); state.requested_supported_algorithms = true; Continue(state) } #[cfg(feature = "nts-pool")] FixedKeyRequest { c2s, s2c } => { state.fixed_key_request = Some(RequestedKeys { c2s, s2c }); Continue(state) } #[cfg(feature = "nts-pool")] NtpServerDeny { denied: _ } => { // we are not a NTS pool server, so we ignore this record Continue(state) } Unknown { critical, .. } => { if critical { Break(Err(KeyExchangeError::UnrecognizedCriticalRecord)) } else { Continue(state) } } } } fn new() -> Self { Self::default() } } #[derive(Debug)] pub struct KeyExchangeServer { tls_connection: tls_utils::ServerConnection, state: State, keyset: Arc, ntp_port: Option, ntp_server: Option, #[cfg(feature = "nts-pool")] pool_certificates: Arc<[tls_utils::Certificate]>, } #[derive(Debug)] enum State { Active { decoder: KeyExchangeServerDecoder }, PendingError { error: KeyExchangeError }, Done, } impl KeyExchangeServer { pub fn wants_read(&self) -> bool { self.tls_connection.wants_read() } pub fn read_socket(&mut self, rd: &mut dyn Read) -> std::io::Result { self.tls_connection.read_tls(rd) } pub fn wants_write(&self) -> bool { self.tls_connection.wants_write() } pub fn write_socket(&mut self, wr: &mut dyn Write) -> std::io::Result { self.tls_connection.write_tls(wr) } fn send_records( tls_connection: &mut tls_utils::ServerConnection, records: &[NtsRecord], ) -> std::io::Result<()> { let mut buffer = Vec::with_capacity(1024); for record in records.iter() { record.write(&mut buffer)?; } tls_connection.writer().write_all(&buffer)?; tls_connection.send_close_notify(); Ok(()) } fn send_error_record( tls_connection: &mut tls_utils::ServerConnection, error: &KeyExchangeError, ) { let error_records = [ NtsRecord::Error { errorcode: error.to_error_code(), }, NtsRecord::NextProtocol { protocol_ids: vec![ProtocolId::NtpV4 as u16], }, NtsRecord::EndOfMessage, ]; if let Err(io) = Self::send_records(tls_connection, &error_records) { tracing::debug!(key_exchange_error = ?error, io_error = ?io, "sending error record failed"); } } pub fn progress( mut self, ) -> ControlFlow, Self> { // Move any received data from tls to decoder if let Err(e) = self.tls_connection.process_new_packets() { return ControlFlow::Break(Err(e.into())); } let mut buf = [0; 512]; loop { match self.tls_connection.reader().read(&mut buf) { Ok(0) => { // the connection was closed cleanly by the client // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read if self.wants_write() { return ControlFlow::Continue(self); } else { return ControlFlow::Break(self.end_of_file()); } } Ok(n) => { match self.state { State::Active { decoder } => match decoder.step_with_slice(&buf[..n]) { ControlFlow::Continue(decoder) => { // more bytes are needed self.state = State::Active { decoder }; } ControlFlow::Break(Ok(data)) => { // all records have been decoded; send a response // continues for a clean shutdown of the connection by the client self.state = State::Done; return self.decoder_done(data); } ControlFlow::Break(Err(error)) => { Self::send_error_record(&mut self.tls_connection, &error); self.state = State::PendingError { error }; return ControlFlow::Continue(self); } }, State::PendingError { .. } | State::Done => { // client is sending more bytes, but we don't expect any more // these extra bytes are ignored return ControlFlow::Continue(self); } } } Err(e) => match e.kind() { std::io::ErrorKind::WouldBlock => { // basically an await; give other tasks a chance return ControlFlow::Continue(self); } std::io::ErrorKind::UnexpectedEof => { // the connection was closed uncleanly by the client // see https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read if self.wants_write() { return ControlFlow::Continue(self); } else { return ControlFlow::Break(self.end_of_file()); } } _ => { let error = KeyExchangeError::Io(e); Self::send_error_record(&mut self.tls_connection, &error); self.state = State::PendingError { error }; return ControlFlow::Continue(self); } }, } } } fn end_of_file(self) -> Result { match self.state { State::Active { .. } => { // there are no more client bytes, but decoding was not finished yet Err(KeyExchangeError::IncompleteResponse) } State::PendingError { error } => { // We can now return the error Err(error) } State::Done => { // we're all done Ok(self.tls_connection) } } } #[cfg(feature = "nts-pool")] pub fn privileged_connection(&self) -> bool { self.tls_connection .peer_certificates() .and_then(|cert_chain| cert_chain.first()) .map(|cert| self.pool_certificates.contains(cert)) .unwrap_or(false) } #[cfg(feature = "nts-pool")] fn extract_nts_keys(&self, data: &ServerKeyExchangeData) -> Result { if let Some(keys) = &data.fixed_keys { if self.privileged_connection() { tracing::debug!("using fixed keys for AEAD algorithm"); data.algorithm .try_into_nts_keys(keys) .ok_or(KeyExchangeError::InvalidFixedKeyLength) } else { tracing::debug!("refused fixed key request due to improper authorization"); Err(KeyExchangeError::UnrecognizedCriticalRecord) } } else { self.extract_nts_keys_tls(data) } } #[cfg(not(feature = "nts-pool"))] fn extract_nts_keys(&self, data: &ServerKeyExchangeData) -> Result { self.extract_nts_keys_tls(data) } fn extract_nts_keys_tls( &self, data: &ServerKeyExchangeData, ) -> Result { tracing::debug!("using AEAD keys extracted from TLS connection"); data.algorithm .extract_nts_keys(data.protocol, &self.tls_connection) .map_err(KeyExchangeError::Tls) } fn decoder_done( mut self, data: ServerKeyExchangeData, ) -> ControlFlow, Self> { let algorithm = data.algorithm; let protocol = data.protocol; //TODO: see comment in fn server_key_exchange_records() #[cfg(feature = "nts-pool")] let send_algorithm_list = data.requested_supported_algorithms; tracing::debug!(?protocol, ?algorithm, "selected AEAD algorithm"); match self.extract_nts_keys(&data) { Ok(keys) => { let records = NtsRecord::server_key_exchange_records( protocol, algorithm, &self.keyset, keys, self.ntp_port, self.ntp_server.clone(), #[cfg(feature = "nts-pool")] send_algorithm_list, ); match Self::send_records(&mut self.tls_connection, &records) { Err(e) => ControlFlow::Break(Err(KeyExchangeError::Io(e))), Ok(()) => ControlFlow::Continue(self), } } Err(key_extract_error) => { Self::send_error_record(&mut self.tls_connection, &key_extract_error); self.state = State::PendingError { error: key_extract_error, }; ControlFlow::Continue(self) } } } pub fn new( tls_config: Arc, keyset: Arc, ntp_port: Option, ntp_server: Option, pool_certificates: Arc<[tls_utils::Certificate]>, ) -> Result { // Ensure we send only ntske/1 as alpn debug_assert_eq!(tls_config.alpn_protocols, &[b"ntske/1".to_vec()]); // TLS only works when the server name is a DNS name; an IP address does not work let tls_connection = tls_utils::ServerConnection::new(tls_config)?; #[cfg(not(feature = "nts-pool"))] let _ = pool_certificates; Ok(Self { tls_connection, state: State::Active { decoder: KeyExchangeServerDecoder::new(), }, keyset, ntp_port, ntp_server, #[cfg(feature = "nts-pool")] pool_certificates, }) } } #[cfg(feature = "__internal-fuzz")] pub fn fuzz_key_exchange_server_decoder(data: &[u8]) { // this fuzz harness is inspired by the server_decoder_finds_algorithm() test let mut decoder = KeyExchangeServerDecoder::new(); let decode_output = || { // chunk size 24 is taken from the original test function, this may // benefit from additional changes for chunk in data.chunks(24) { decoder = match decoder.step_with_slice(chunk) { ControlFlow::Continue(d) => d, ControlFlow::Break(done) => return done, }; } Err(KeyExchangeError::IncompleteResponse) }; let _result = decode_output(); } #[cfg(feature = "__internal-fuzz")] pub fn fuzz_key_exchange_result_decoder(data: &[u8]) { let decoder = KeyExchangeResultDecoder::new(); let _res = match decoder.step_with_slice(data) { ControlFlow::Continue(decoder) => decoder, ControlFlow::Break(_result) => return, }; } #[cfg(test)] mod test { use std::io::Cursor; use crate::keyset::KeySetProvider; use super::*; #[test] fn test_algorithm_decoding() { for i in 0..=u16::MAX { if let Some(alg) = AeadAlgorithm::try_deserialize(i) { assert_eq!(alg as u16, i); } } } #[test] fn test_protocol_decoding() { for i in 0..=u16::MAX { if let Some(proto) = ProtocolId::try_deserialize(i) { assert_eq!(proto as u16, i); } } } #[cfg(not(feature = "ntpv5"))] #[test] fn test_client_key_exchange_records() { let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records(None, []).iter() { record.write(&mut buffer).unwrap(); } assert_eq!( buffer, &[128, 1, 0, 2, 0, 0, 0, 4, 0, 4, 0, 17, 0, 15, 128, 0, 0, 0] ); } #[cfg(not(feature = "ntpv5"))] #[test] fn test_decode_client_key_exchange_records() { let bytes = [128, 1, 0, 2, 0, 0, 0, 4, 0, 4, 0, 17, 0, 15, 128, 0, 0, 0]; let mut decoder = NtsRecord::decoder(); decoder.extend(bytes); assert_eq!( [ decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), ], NtsRecord::client_key_exchange_records(None, vec![]).as_ref() ); assert!(decoder.step().unwrap().is_none()); } #[test] fn encode_decode_server_invalid_utf8() { let buffer = vec![ 0, 6, // type 0, 4, // length 0xF8, 0x80, 0x80, 0x80, // content (invalid utf8 sequence) ]; let record = NtsRecord::Unknown { record_type: 6, critical: false, data: vec![0xF8, 0x80, 0x80, 0x80], }; let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] fn encode_decode_error_record() { let mut buffer = Vec::new(); let record = NtsRecord::Error { errorcode: 42 }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] fn encode_decode_warning_record() { let mut buffer = Vec::new(); let record = NtsRecord::Warning { warningcode: 42 }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] fn encode_decode_unknown_record() { let mut buffer = Vec::new(); let record = NtsRecord::Unknown { record_type: 8, critical: true, data: vec![1, 2, 3], }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] #[cfg(feature = "nts-pool")] fn encode_decode_keep_alive_record() { let mut buffer = Vec::new(); let record = NtsRecord::KeepAlive; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] #[cfg(feature = "nts-pool")] fn encode_decode_supported_protocol_list_record() { let mut buffer = Vec::new(); let record = NtsRecord::SupportedAlgorithmList { supported_algorithms: vec![ (AeadAlgorithm::AeadAesSivCmac256 as u16, 256 / 8), (AeadAlgorithm::AeadAesSivCmac512 as u16, 512 / 8), ], }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] #[cfg(feature = "nts-pool")] fn encode_decode_fixed_key_request_record() { let mut buffer = Vec::new(); let c2s: Vec<_> = (0..).take(8).collect(); let s2c: Vec<_> = (0..).skip(8).take(8).collect(); let record = NtsRecord::FixedKeyRequest { c2s, s2c }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] #[cfg(feature = "nts-pool")] fn encode_decode_server_deny_record() { let mut buffer = Vec::new(); let record = NtsRecord::NtpServerDeny { denied: String::from("a string"), }; record.write(&mut buffer).unwrap(); let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } #[test] #[cfg(feature = "nts-pool")] fn encode_decode_server_deny_invalid_utf8() { let [a, b] = 0x4003u16.to_be_bytes(); let buffer = vec![ a, b, // type 0, 4, // length 0xF8, 0x80, 0x80, 0x80, // content (invalid utf8 sequence) ]; let record = NtsRecord::Unknown { record_type: 0x4003, critical: false, data: vec![0xF8, 0x80, 0x80, 0x80], }; let decoded = NtsRecord::read(&mut Cursor::new(buffer)).unwrap(); assert_eq!(record, decoded); } fn client_decode_records( records: &[NtsRecord], ) -> Result { let mut decoder = KeyExchangeResultDecoder::new(); let mut buffer = Vec::with_capacity(1024); for record in records { buffer.clear(); record.write(&mut buffer).unwrap(); decoder = match decoder.step_with_slice(&buffer) { ControlFlow::Continue(decoder) => decoder, ControlFlow::Break(result) => return result, } } Err(KeyExchangeError::IncompleteResponse) } #[test] fn client_decoder_immediate_next_protocol_end_of_message() { assert!(matches!( client_decode_records(&[ NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::EndOfMessage ]), Err(KeyExchangeError::NoCookies) )); } #[test] fn client_decoder_immediate_end_of_message() { assert!(matches!( client_decode_records(&[NtsRecord::EndOfMessage]), Err(KeyExchangeError::NoValidProtocol) )); } #[test] fn client_decoder_missing_aead_algorithm_record() { assert!(matches!( client_decode_records(&[ NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn client_decoder_empty_aead_algorithm_list() { assert!(matches!( client_decode_records(&[ NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![], }, NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage, ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn client_decoder_invalid_aead_algorithm_id() { assert!(matches!( client_decode_records(&[ NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![42], }, NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage, ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn client_decoder_no_valid_protocol() { let records = [ NtsRecord::NextProtocol { protocol_ids: vec![1234], }, NtsRecord::EndOfMessage, ]; let error = client_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::NoValidProtocol)) } #[test] fn client_decoder_double_next_protocol() { let records = vec![ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::EndOfMessage, ]; let error = client_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::BadResponse)); } #[test] fn client_decoder_double_aead_algorithm() { let records = vec![ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15, 16], }, NtsRecord::EndOfMessage, ]; let error = client_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::BadResponse)); } #[test] fn client_decoder_twice_aead_algorithm() { let records = vec![ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::EndOfMessage, ]; let error = client_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::BadResponse)); } #[test] fn host_port_updates() { let name = String::from("ntp.time.nl"); let port = 4567; let records = [ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::Server { critical: true, name: name.clone(), }, NtsRecord::Port { critical: true, port, }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 68, 107, 175, 34, 77, 63, 18, 34, 122, 22, 95, 242, 175, 224, 29, 173, 58, 187, 47, 11, 245, 247, 119, 89, 5, 8, 221, 162, 106, 66, 30, 65, 218, 13, 108, 238, 12, 29, 200, 9, 92, 218, 38, 20, 238, 251, 68, 35, 44, 129, 189, 132, 4, 93, 117, 136, 91, 234, 58, 195, 223, 171, 207, 247, 172, 128, 5, 219, 97, 21, 128, 107, 96, 220, 189, 53, 223, 111, 181, 164, 185, 173, 80, 101, 75, 18, 180, 129, 243, 140, 253, 236, 45, 62, 101, 155, 252, 51, 102, 97, ], }, NtsRecord::EndOfMessage, ]; let state = client_decode_records(records.as_slice()).unwrap(); assert_eq!(state.remote, Some(name)); assert_eq!(state.port, Some(port)); } const EXAMPLE_COOKIE_DATA: &[u8] = &[ 178, 15, 188, 164, 68, 107, 175, 34, 77, 63, 18, 34, 122, 22, 95, 242, 175, 224, 29, 173, 58, 187, 47, 11, 245, 247, 119, 89, 5, 8, 221, 162, 106, 66, 30, 65, 218, 13, 108, 238, 12, 29, 200, 9, 92, 218, 38, 20, 238, 251, 68, 35, 44, 129, 189, 132, 4, 93, 117, 136, 91, 234, 58, 195, 223, 171, 207, 247, 172, 128, 5, 219, 97, 21, 128, 107, 96, 220, 189, 53, 223, 111, 181, 164, 185, 173, 80, 101, 75, 18, 180, 129, 243, 140, 253, 236, 45, 62, 101, 155, 252, 51, 102, 97, ]; #[test] fn hit_error_record() { let cookie = NtsRecord::NewCookie { cookie_data: EXAMPLE_COOKIE_DATA.to_vec(), }; // this fails. In theory it's alright if the protocol ID is not 0, // but we do not support any. (we assume ntpv5 has the same behavior as ntpv4 here) let records = [ cookie.clone(), NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::EndOfMessage, ]; assert!(matches!( client_decode_records(records.as_slice()), Err(KeyExchangeError::NoValidAlgorithm) )); // a warning does not change the outcome let records = [ cookie.clone(), NtsRecord::Warning { warningcode: 42 }, NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::EndOfMessage, ]; assert!(matches!( client_decode_records(records.as_slice()), Err(KeyExchangeError::NoValidAlgorithm) )); // an unknown non-critical does not change the outcome let records = [ cookie.clone(), NtsRecord::Unknown { record_type: 8, critical: false, data: vec![1, 2, 3], }, NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::EndOfMessage, ]; assert!(matches!( client_decode_records(records.as_slice()), Err(KeyExchangeError::NoValidAlgorithm) )); // fails with the expected error if there is an error record let records = [ cookie.clone(), NtsRecord::Error { errorcode: 42 }, NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::EndOfMessage, ]; let error = client_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::UnknownErrorCode(42))); let _ = cookie; } #[test] fn client_critical_unknown_record() { // an unknown non-critical does not change the outcome let records = [ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::Unknown { record_type: 8, critical: true, data: vec![1, 2, 3], }, NtsRecord::EndOfMessage, ]; assert!(matches!( client_decode_records(records.as_slice()), Err(KeyExchangeError::UnrecognizedCriticalRecord) )); } #[test] fn incomplete_response() { let error = client_decode_records(&[]).unwrap_err(); assert!(matches!(error, KeyExchangeError::IncompleteResponse)); // this succeeds on its own let records = [NtsRecord::NewCookie { cookie_data: EXAMPLE_COOKIE_DATA.to_vec(), }]; let error = client_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::IncompleteResponse)); } const NTS_TIME_NL_RESPONSE: &[u8] = &[ 128, 1, 0, 2, 0, 0, 0, 4, 0, 2, 0, 15, 0, 5, 0, 104, 178, 15, 188, 164, 68, 107, 175, 34, 77, 63, 18, 34, 122, 22, 95, 242, 175, 224, 29, 173, 58, 187, 47, 11, 245, 247, 119, 89, 5, 8, 221, 162, 106, 66, 30, 65, 218, 13, 108, 238, 12, 29, 200, 9, 92, 218, 38, 20, 238, 251, 68, 35, 44, 129, 189, 132, 4, 93, 117, 136, 91, 234, 58, 195, 223, 171, 207, 247, 172, 128, 5, 219, 97, 21, 128, 107, 96, 220, 189, 53, 223, 111, 181, 164, 185, 173, 80, 101, 75, 18, 180, 129, 243, 140, 253, 236, 45, 62, 101, 155, 252, 51, 102, 97, 0, 5, 0, 104, 178, 15, 188, 164, 106, 99, 31, 229, 75, 104, 141, 204, 89, 184, 80, 227, 43, 85, 25, 33, 78, 82, 22, 97, 167, 52, 65, 243, 216, 198, 99, 98, 161, 219, 215, 253, 165, 121, 130, 232, 131, 150, 158, 136, 113, 141, 34, 223, 42, 122, 185, 132, 185, 153, 158, 249, 192, 80, 167, 251, 116, 45, 179, 151, 82, 248, 13, 208, 33, 74, 125, 233, 176, 153, 61, 58, 25, 23, 54, 106, 208, 31, 40, 155, 227, 63, 58, 219, 119, 76, 101, 62, 154, 34, 187, 212, 106, 162, 140, 223, 37, 194, 20, 107, 0, 5, 0, 104, 178, 15, 188, 164, 240, 20, 28, 103, 149, 25, 37, 145, 187, 196, 100, 113, 36, 76, 171, 29, 69, 40, 19, 70, 95, 60, 30, 27, 188, 25, 1, 148, 55, 18, 253, 131, 8, 108, 44, 173, 236, 74, 227, 49, 47, 183, 156, 118, 152, 88, 31, 254, 134, 220, 129, 254, 186, 117, 80, 163, 167, 223, 208, 8, 124, 141, 240, 43, 161, 240, 60, 54, 241, 44, 87, 135, 116, 63, 236, 40, 138, 162, 65, 143, 193, 98, 44, 9, 61, 189, 89, 19, 45, 94, 6, 102, 82, 8, 175, 206, 87, 132, 51, 63, 0, 5, 0, 104, 178, 15, 188, 164, 56, 48, 71, 172, 153, 142, 223, 150, 73, 72, 201, 236, 26, 68, 29, 14, 139, 66, 190, 77, 218, 206, 90, 117, 75, 128, 88, 186, 187, 156, 130, 57, 198, 118, 176, 199, 55, 56, 173, 109, 35, 37, 15, 223, 17, 53, 110, 167, 251, 167, 208, 44, 158, 89, 113, 22, 178, 92, 235, 114, 176, 41, 255, 172, 175, 191, 227, 29, 85, 70, 152, 125, 67, 125, 96, 151, 151, 160, 188, 8, 35, 205, 152, 142, 225, 59, 71, 224, 254, 84, 20, 51, 162, 164, 94, 241, 7, 15, 9, 138, 0, 5, 0, 104, 178, 15, 188, 164, 198, 114, 113, 134, 102, 130, 116, 104, 6, 6, 81, 118, 89, 146, 119, 198, 80, 135, 104, 155, 101, 107, 51, 215, 243, 241, 163, 55, 84, 206, 179, 241, 105, 210, 184, 30, 44, 133, 235, 227, 87, 7, 40, 230, 185, 47, 180, 189, 84, 157, 182, 81, 69, 168, 147, 115, 94, 53, 242, 198, 132, 188, 56, 86, 70, 201, 78, 219, 140, 212, 94, 100, 38, 106, 168, 35, 57, 236, 156, 41, 86, 176, 225, 129, 152, 206, 49, 176, 252, 29, 235, 180, 161, 148, 195, 223, 27, 217, 85, 220, 0, 5, 0, 104, 178, 15, 188, 164, 52, 150, 226, 182, 229, 113, 23, 67, 155, 54, 34, 141, 125, 225, 98, 4, 22, 105, 111, 150, 212, 32, 9, 204, 212, 242, 161, 213, 135, 199, 246, 74, 160, 126, 167, 94, 174, 76, 11, 228, 13, 251, 20, 135, 0, 197, 207, 18, 168, 118, 218, 39, 79, 100, 203, 234, 224, 116, 59, 234, 247, 156, 128, 58, 104, 57, 204, 85, 48, 68, 229, 37, 20, 146, 159, 67, 49, 235, 142, 58, 225, 149, 187, 3, 11, 146, 193, 114, 122, 160, 19, 180, 146, 196, 50, 229, 22, 10, 86, 219, 0, 5, 0, 104, 178, 15, 188, 164, 98, 15, 6, 117, 71, 114, 79, 45, 197, 158, 30, 187, 51, 12, 43, 131, 252, 74, 92, 251, 139, 159, 99, 163, 149, 111, 89, 184, 95, 125, 73, 106, 62, 214, 210, 50, 190, 83, 138, 46, 65, 126, 152, 54, 137, 189, 19, 247, 37, 116, 79, 178, 83, 51, 31, 129, 24, 172, 108, 58, 10, 171, 128, 40, 220, 250, 168, 133, 164, 32, 47, 19, 231, 181, 124, 242, 192, 212, 153, 25, 10, 165, 52, 170, 177, 42, 232, 2, 77, 246, 118, 192, 68, 96, 152, 77, 238, 130, 53, 128, 0, 5, 0, 104, 178, 15, 188, 164, 208, 86, 125, 128, 153, 10, 107, 157, 50, 100, 148, 177, 10, 163, 41, 208, 32, 142, 176, 21, 10, 15, 39, 208, 111, 47, 233, 154, 23, 161, 191, 192, 105, 242, 25, 68, 234, 211, 81, 89, 244, 142, 184, 187, 236, 171, 34, 23, 227, 55, 207, 94, 48, 71, 236, 188, 146, 223, 77, 213, 74, 234, 190, 192, 151, 172, 223, 158, 44, 230, 247, 248, 212, 245, 43, 131, 80, 57, 187, 105, 148, 232, 15, 107, 239, 84, 131, 9, 222, 225, 137, 73, 202, 40, 48, 57, 122, 198, 245, 40, 128, 0, 0, 0, ]; fn nts_time_nl_records() -> [NtsRecord; 11] { [ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: false, algorithm_ids: vec![15], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 68, 107, 175, 34, 77, 63, 18, 34, 122, 22, 95, 242, 175, 224, 29, 173, 58, 187, 47, 11, 245, 247, 119, 89, 5, 8, 221, 162, 106, 66, 30, 65, 218, 13, 108, 238, 12, 29, 200, 9, 92, 218, 38, 20, 238, 251, 68, 35, 44, 129, 189, 132, 4, 93, 117, 136, 91, 234, 58, 195, 223, 171, 207, 247, 172, 128, 5, 219, 97, 21, 128, 107, 96, 220, 189, 53, 223, 111, 181, 164, 185, 173, 80, 101, 75, 18, 180, 129, 243, 140, 253, 236, 45, 62, 101, 155, 252, 51, 102, 97, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 106, 99, 31, 229, 75, 104, 141, 204, 89, 184, 80, 227, 43, 85, 25, 33, 78, 82, 22, 97, 167, 52, 65, 243, 216, 198, 99, 98, 161, 219, 215, 253, 165, 121, 130, 232, 131, 150, 158, 136, 113, 141, 34, 223, 42, 122, 185, 132, 185, 153, 158, 249, 192, 80, 167, 251, 116, 45, 179, 151, 82, 248, 13, 208, 33, 74, 125, 233, 176, 153, 61, 58, 25, 23, 54, 106, 208, 31, 40, 155, 227, 63, 58, 219, 119, 76, 101, 62, 154, 34, 187, 212, 106, 162, 140, 223, 37, 194, 20, 107, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 240, 20, 28, 103, 149, 25, 37, 145, 187, 196, 100, 113, 36, 76, 171, 29, 69, 40, 19, 70, 95, 60, 30, 27, 188, 25, 1, 148, 55, 18, 253, 131, 8, 108, 44, 173, 236, 74, 227, 49, 47, 183, 156, 118, 152, 88, 31, 254, 134, 220, 129, 254, 186, 117, 80, 163, 167, 223, 208, 8, 124, 141, 240, 43, 161, 240, 60, 54, 241, 44, 87, 135, 116, 63, 236, 40, 138, 162, 65, 143, 193, 98, 44, 9, 61, 189, 89, 19, 45, 94, 6, 102, 82, 8, 175, 206, 87, 132, 51, 63, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 56, 48, 71, 172, 153, 142, 223, 150, 73, 72, 201, 236, 26, 68, 29, 14, 139, 66, 190, 77, 218, 206, 90, 117, 75, 128, 88, 186, 187, 156, 130, 57, 198, 118, 176, 199, 55, 56, 173, 109, 35, 37, 15, 223, 17, 53, 110, 167, 251, 167, 208, 44, 158, 89, 113, 22, 178, 92, 235, 114, 176, 41, 255, 172, 175, 191, 227, 29, 85, 70, 152, 125, 67, 125, 96, 151, 151, 160, 188, 8, 35, 205, 152, 142, 225, 59, 71, 224, 254, 84, 20, 51, 162, 164, 94, 241, 7, 15, 9, 138, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 198, 114, 113, 134, 102, 130, 116, 104, 6, 6, 81, 118, 89, 146, 119, 198, 80, 135, 104, 155, 101, 107, 51, 215, 243, 241, 163, 55, 84, 206, 179, 241, 105, 210, 184, 30, 44, 133, 235, 227, 87, 7, 40, 230, 185, 47, 180, 189, 84, 157, 182, 81, 69, 168, 147, 115, 94, 53, 242, 198, 132, 188, 56, 86, 70, 201, 78, 219, 140, 212, 94, 100, 38, 106, 168, 35, 57, 236, 156, 41, 86, 176, 225, 129, 152, 206, 49, 176, 252, 29, 235, 180, 161, 148, 195, 223, 27, 217, 85, 220, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 52, 150, 226, 182, 229, 113, 23, 67, 155, 54, 34, 141, 125, 225, 98, 4, 22, 105, 111, 150, 212, 32, 9, 204, 212, 242, 161, 213, 135, 199, 246, 74, 160, 126, 167, 94, 174, 76, 11, 228, 13, 251, 20, 135, 0, 197, 207, 18, 168, 118, 218, 39, 79, 100, 203, 234, 224, 116, 59, 234, 247, 156, 128, 58, 104, 57, 204, 85, 48, 68, 229, 37, 20, 146, 159, 67, 49, 235, 142, 58, 225, 149, 187, 3, 11, 146, 193, 114, 122, 160, 19, 180, 146, 196, 50, 229, 22, 10, 86, 219, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 98, 15, 6, 117, 71, 114, 79, 45, 197, 158, 30, 187, 51, 12, 43, 131, 252, 74, 92, 251, 139, 159, 99, 163, 149, 111, 89, 184, 95, 125, 73, 106, 62, 214, 210, 50, 190, 83, 138, 46, 65, 126, 152, 54, 137, 189, 19, 247, 37, 116, 79, 178, 83, 51, 31, 129, 24, 172, 108, 58, 10, 171, 128, 40, 220, 250, 168, 133, 164, 32, 47, 19, 231, 181, 124, 242, 192, 212, 153, 25, 10, 165, 52, 170, 177, 42, 232, 2, 77, 246, 118, 192, 68, 96, 152, 77, 238, 130, 53, 128, ], }, NtsRecord::NewCookie { cookie_data: vec![ 178, 15, 188, 164, 208, 86, 125, 128, 153, 10, 107, 157, 50, 100, 148, 177, 10, 163, 41, 208, 32, 142, 176, 21, 10, 15, 39, 208, 111, 47, 233, 154, 23, 161, 191, 192, 105, 242, 25, 68, 234, 211, 81, 89, 244, 142, 184, 187, 236, 171, 34, 23, 227, 55, 207, 94, 48, 71, 236, 188, 146, 223, 77, 213, 74, 234, 190, 192, 151, 172, 223, 158, 44, 230, 247, 248, 212, 245, 43, 131, 80, 57, 187, 105, 148, 232, 15, 107, 239, 84, 131, 9, 222, 225, 137, 73, 202, 40, 48, 57, 122, 198, 245, 40, ], }, NtsRecord::EndOfMessage, ] } #[test] fn test_nts_time_nl_response() { let state = client_decode_records(nts_time_nl_records().as_slice()).unwrap(); assert_eq!(state.remote, None); assert_eq!(state.port, None); assert_eq!(state.cookies.gap(), 0); } #[test] fn test_decode_nts_time_nl_response() { let mut decoder = NtsRecord::decoder(); decoder.extend(NTS_TIME_NL_RESPONSE.iter().copied()); assert_eq!( [ decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), // cookies decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), decoder.step().unwrap().unwrap(), // end of message decoder.step().unwrap().unwrap(), ], nts_time_nl_records() ); assert!(decoder.step().unwrap().is_none()); } fn server_decode_records( records: &[NtsRecord], ) -> Result { let mut bytes = Vec::with_capacity(1024); for record in records { record.write(&mut bytes).unwrap(); } let mut decoder = KeyExchangeServerDecoder::new(); for chunk in bytes.chunks(24) { decoder = match decoder.step_with_slice(chunk) { ControlFlow::Continue(d) => d, ControlFlow::Break(done) => return done, }; } Err(KeyExchangeError::IncompleteResponse) } #[test] fn server_decoder_immediate_end_of_message() { assert!(matches!( server_decode_records(&[NtsRecord::EndOfMessage]), Err(KeyExchangeError::NoValidProtocol) )); } #[test] fn server_decoder_missing_aead_algorithm_record() { assert!(matches!( server_decode_records(&[ NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn server_decoder_empty_aead_algorithm_list() { assert!(matches!( server_decode_records(&[ NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![], }, NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage, ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn server_decoder_invalid_aead_algorithm_id() { assert!(matches!( server_decode_records(&[ NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![42], }, NtsRecord::NextProtocol { protocol_ids: vec![0] }, NtsRecord::EndOfMessage, ]), Err(KeyExchangeError::NoValidAlgorithm) )); } #[test] fn server_decoder_finds_algorithm() { let result = server_decode_records(&NtsRecord::client_key_exchange_records(None, vec![])).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_new_cookie() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert( 0, NtsRecord::NewCookie { cookie_data: EXAMPLE_COOKIE_DATA.to_vec(), }, ); let result = server_decode_records(&records).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_server_and_port_preference() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert( 0, NtsRecord::Server { critical: true, name: String::from("example.com"), }, ); records.insert( 0, NtsRecord::Port { critical: true, port: 4242, }, ); let result = server_decode_records(&records).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_warn() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert(0, NtsRecord::Warning { warningcode: 42 }); let result = server_decode_records(&records).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_ignores_unknown_not_critical() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert( 0, NtsRecord::Unknown { record_type: 8, critical: false, data: vec![1, 2, 3], }, ); let result = server_decode_records(&records).unwrap(); assert_eq!(result.algorithm, AeadAlgorithm::AeadAesSivCmac512); } #[test] fn server_decoder_reports_unknown_critical() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert( 0, NtsRecord::Unknown { record_type: 8, critical: true, data: vec![1, 2, 3], }, ); let result = server_decode_records(&records).unwrap_err(); assert!(matches!( result, KeyExchangeError::UnrecognizedCriticalRecord )); } #[test] fn server_decoder_reports_error() { let mut records = NtsRecord::client_key_exchange_records(None, vec![]).to_vec(); records.insert(0, NtsRecord::Error { errorcode: 2 }); let error = server_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::InternalServerError)); } #[test] fn server_decoder_no_valid_protocol() { let records = [ NtsRecord::NextProtocol { protocol_ids: vec![42], }, NtsRecord::EndOfMessage, ]; let error = server_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::NoValidProtocol)); } #[test] fn server_decoder_double_next_protocol() { let records = [ NtsRecord::NextProtocol { protocol_ids: vec![42], }, NtsRecord::EndOfMessage, ]; let error = server_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::NoValidProtocol)); } #[test] fn server_decoder_double_aead_algorithm() { let records = vec![ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::EndOfMessage, ]; let error = server_decode_records(records.as_slice()).unwrap_err(); assert!(matches!(error, KeyExchangeError::BadRequest)); } #[test] fn server_decoder_no_valid_algorithm() { let records = [ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: false, algorithm_ids: vec![1234], }, NtsRecord::EndOfMessage, ]; let error = server_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::NoValidAlgorithm)); } #[test] fn server_decoder_incomplete_response() { let error = server_decode_records(&[]).unwrap_err(); assert!(matches!(error, KeyExchangeError::IncompleteResponse)); let records = [ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::Unknown { record_type: 8, critical: false, data: vec![1, 2, 3], }, ]; let error = server_decode_records(&records).unwrap_err(); assert!(matches!(error, KeyExchangeError::IncompleteResponse)); } #[test] #[cfg(feature = "nts-pool")] fn server_decoder_supported_algorithms() { let records = vec![ NtsRecord::NextProtocol { protocol_ids: vec![0], }, NtsRecord::AeadAlgorithm { critical: true, algorithm_ids: vec![15], }, NtsRecord::SupportedAlgorithmList { supported_algorithms: vec![], }, NtsRecord::NewCookie { cookie_data: vec![], }, NtsRecord::EndOfMessage, ]; let data = server_decode_records(records.as_slice()).unwrap(); assert!(data.requested_supported_algorithms); let records = vec![ NtsRecord::SupportedAlgorithmList { supported_algorithms: vec![], }, NtsRecord::NewCookie { cookie_data: vec![], }, NtsRecord::EndOfMessage, ]; let data = server_decode_records(records.as_slice()).unwrap(); assert!(data.requested_supported_algorithms); } #[test] fn test_keyexchange_client() { let cert_chain: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.fullchain.pem") as &[u8]), ) .map(|res| res.unwrap()) .collect(); let key_der = tls_utils::pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( include_bytes!("../test-keys/end.key") as &[u8], )) .map(|res| res.unwrap()) .next() .unwrap(); let serverconfig = tls_utils::server_config_builder() .with_no_client_auth() .with_single_cert(cert_chain, key_der.into()) .unwrap(); let mut root_store = tls_utils::RootCertStore::empty(); #[cfg(any(feature = "rustls22", feature = "rustls23"))] root_store.add_parsable_certificates( tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()), ); #[cfg(not(any(feature = "rustls22", feature = "rustls23")))] root_store.add_parsable_certificates( &tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()) .collect::>(), ); let clientconfig = tls_utils::client_config_builder() .with_root_certificates(root_store) .with_no_client_auth(); let mut server = tls_utils::ServerConnection::new(Arc::new(serverconfig)).unwrap(); let mut client = KeyExchangeClient::new("localhost".into(), clientconfig, None, vec![]).unwrap(); server.writer().write_all(NTS_TIME_NL_RESPONSE).unwrap(); let mut buf = [0; 4096]; let result = 'result: loop { while client.wants_write() { let size = client.write_socket(&mut &mut buf[..]).unwrap(); let mut offset = 0; while offset < size { let cur = server.read_tls(&mut &buf[offset..size]).unwrap(); offset += cur; server.process_new_packets().unwrap(); } } while server.wants_write() { let size = server.write_tls(&mut &mut buf[..]).unwrap(); let mut offset = 0; while offset < size { let cur = client.read_socket(&mut &buf[offset..size]).unwrap(); offset += cur; client = match client.progress() { ControlFlow::Continue(client) => client, ControlFlow::Break(result) => break 'result result, } } } } .unwrap(); assert_eq!(result.remote, "localhost"); assert_eq!(result.port, 123); } #[allow(dead_code)] enum ClientType { Uncertified, Certified, } fn client_server_pair(client_type: ClientType) -> (KeyExchangeClient, KeyExchangeServer) { #[allow(unused)] use tls_utils::CloneKeyShim; let cert_chain: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.fullchain.pem") as &[u8]), ) .map(|res| res.unwrap()) .collect(); let key_der = tls_utils::pemfile::pkcs8_private_keys(&mut std::io::BufReader::new( include_bytes!("../test-keys/end.key") as &[u8], )) .map(|res| res.unwrap()) .next() .unwrap(); let mut root_store = tls_utils::RootCertStore::empty(); #[cfg(any(feature = "rustls22", feature = "rustls23"))] root_store.add_parsable_certificates( tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()), ); #[cfg(not(any(feature = "rustls22", feature = "rustls23")))] root_store.add_parsable_certificates( &tls_utils::pemfile::certs(&mut std::io::BufReader::new(include_bytes!( "../test-keys/testca.pem" ) as &[u8])) .map(|res| res.unwrap()) .collect::>(), ); let mut serverconfig = tls_utils::server_config_builder() .with_client_cert_verifier(Arc::new( #[cfg(not(feature = "nts-pool"))] tls_utils::NoClientAuth, #[cfg(feature = "nts-pool")] crate::tls_utils::AllowAnyAnonymousOrCertificateBearingClient::new( // We know that our previous call to ServerConfig::builder already // installed a default provider, but this is undocumented rustls23::crypto::CryptoProvider::get_default().unwrap(), ), )) .with_single_cert(cert_chain.clone(), key_der.clone_key().into()) .unwrap(); serverconfig.alpn_protocols.clear(); serverconfig.alpn_protocols.push(b"ntske/1".to_vec()); let clientconfig = match client_type { ClientType::Uncertified => tls_utils::client_config_builder() .with_root_certificates(root_store) .with_no_client_auth(), ClientType::Certified => tls_utils::client_config_builder() .with_root_certificates(root_store) .with_client_auth_cert(cert_chain, key_der.into()) .unwrap(), }; let keyset = KeySetProvider::new(8).get(); let pool_cert: Vec = tls_utils::pemfile::certs( &mut std::io::BufReader::new(include_bytes!("../test-keys/end.pem") as &[u8]), ) .map(|res| res.unwrap()) .collect(); assert!(pool_cert.len() == 1); let client = KeyExchangeClient::new_without_tls_write("localhost".into(), clientconfig).unwrap(); let server = KeyExchangeServer::new(Arc::new(serverconfig), keyset, None, None, pool_cert.into()) .unwrap(); (client, server) } fn keyexchange_loop( mut client: KeyExchangeClient, mut server: KeyExchangeServer, ) -> Result { let mut buf = [0; 4096]; 'result: loop { while server.wants_write() { let size = server.write_socket(&mut &mut buf[..]).unwrap(); let mut offset = 0; while offset < size { let cur = client .tls_connection .read_tls(&mut &buf[offset..size]) .unwrap(); offset += cur; client = match client.progress() { ControlFlow::Continue(client) => client, ControlFlow::Break(result) => break 'result result, } } } if client.wants_write() { let size = client.tls_connection.write_tls(&mut &mut buf[..]).unwrap(); let mut offset = 0; while offset < size { let cur = server.read_socket(&mut &buf[offset..size]).unwrap(); offset += cur; match server.progress() { ControlFlow::Continue(new) => server = new, ControlFlow::Break(Err(key_exchange_error)) => { return Err(key_exchange_error) } ControlFlow::Break(Ok(mut tls_connection)) => { // the server is now done but the client still needs to complete while tls_connection.wants_write() { let size = tls_connection.write_tls(&mut &mut buf[..]).unwrap(); let mut offset = 0; while offset < size { let cur = client .tls_connection .read_tls(&mut &buf[offset..size]) .unwrap(); offset += cur; client = match client.progress() { ControlFlow::Continue(client) => client, ControlFlow::Break(result) => return result, } } } unreachable!("client should finish up when the server is done") } } } } if !server.wants_write() && !client.wants_write() { client.tls_connection.send_close_notify(); } } } #[test] fn test_keyexchange_roundtrip() { let (mut client, server) = client_server_pair(ClientType::Uncertified); let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records(None, []).iter() { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); let result = keyexchange_loop(client, server).unwrap(); assert_eq!(&result.remote, "localhost"); assert_eq!(result.port, 123); assert_eq!(result.nts.cookies.len(), 8); #[cfg(feature = "ntpv5")] assert_eq!(result.protocol_version, ProtocolVersion::V5); // test that the supported algorithms record is not provided "unasked for" #[cfg(feature = "nts-pool")] assert!(result.algorithms_reported_by_server.is_none()); } #[test] #[cfg(feature = "nts-pool")] fn test_keyexchange_roundtrip_fixed_not_authorized() { let (mut client, server) = client_server_pair(ClientType::Uncertified); let c2s: Vec<_> = (0..).take(64).collect(); let s2c: Vec<_> = (0..).skip(64).take(64).collect(); let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records_fixed(c2s.clone(), s2c.clone()) { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); let error = keyexchange_loop(client, server); assert!(matches!( error, Err(KeyExchangeError::UnrecognizedCriticalRecord) )); } #[test] #[cfg(feature = "nts-pool")] fn test_keyexchange_roundtrip_fixed_authorized() { let (mut client, server) = client_server_pair(ClientType::Certified); let c2s: Vec<_> = (0..).take(64).collect(); let s2c: Vec<_> = (0..).skip(64).take(64).collect(); let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records_fixed(c2s.clone(), s2c.clone()) { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); let keyset = server.keyset.clone(); let mut result = keyexchange_loop(client, server).unwrap(); assert_eq!(&result.remote, "localhost"); assert_eq!(result.port, 123); let cookie = result.nts.get_cookie().unwrap(); let cookie = keyset.decode_cookie(&cookie).unwrap(); assert_eq!(cookie.c2s.key_bytes(), c2s); assert_eq!(cookie.s2c.key_bytes(), s2c); #[cfg(feature = "ntpv5")] assert_eq!(result.protocol_version, ProtocolVersion::V5); } #[cfg(feature = "nts-pool")] #[test] fn test_supported_algos_roundtrip() { let (mut client, server) = client_server_pair(ClientType::Uncertified); let mut buffer = Vec::with_capacity(1024); for record in [ NtsRecord::SupportedAlgorithmList { supported_algorithms: vec![], }, NtsRecord::EndOfMessage, ] { record.write(&mut buffer).unwrap(); } client.tls_connection.writer().write_all(&buffer).unwrap(); let result = keyexchange_loop(client, server).unwrap(); let algos = result.algorithms_reported_by_server.unwrap(); assert!(algos.contains(&(AeadAlgorithm::AeadAesSivCmac512, 64))); assert!(algos.contains(&(AeadAlgorithm::AeadAesSivCmac256, 32))); } #[test] fn test_keyexchange_invalid_input() { let mut buffer = Vec::with_capacity(1024); for record in NtsRecord::client_key_exchange_records(None, []).iter() { record.write(&mut buffer).unwrap(); } for n in 0..buffer.len() { let (mut client, server) = client_server_pair(ClientType::Uncertified); client .tls_connection .writer() .write_all(&buffer[..n]) .unwrap(); let error = keyexchange_loop(client, server).unwrap_err(); assert!(matches!(error, KeyExchangeError::IncompleteResponse)); } } } ntp-proto-1.4.0/src/packet/crypto.rs000064400000000000000000000342221046102023000155170ustar 00000000000000use std::fmt::Display; use aes_siv::{siv::Aes128Siv, siv::Aes256Siv, Key, KeyInit}; use rand::Rng; use zeroize::{Zeroize, ZeroizeOnDrop}; use crate::keyset::DecodedServerCookie; use super::extension_fields::ExtensionField; #[derive(Debug)] pub struct DecryptError; impl Display for DecryptError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Could not decrypt ciphertext") } } impl std::error::Error for DecryptError {} #[derive(Debug)] pub struct KeyError; impl Display for KeyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Invalid key") } } impl std::error::Error for KeyError {} struct Buffer<'a> { buffer: &'a mut [u8], valid: usize, } impl<'a> Buffer<'a> { fn new(buffer: &'a mut [u8], valid: usize) -> Self { Self { buffer, valid } } fn valid(&self) -> usize { self.valid } } impl AsMut<[u8]> for Buffer<'_> { fn as_mut(&mut self) -> &mut [u8] { &mut self.buffer[..self.valid] } } impl AsRef<[u8]> for Buffer<'_> { fn as_ref(&self) -> &[u8] { &self.buffer[..self.valid] } } impl aead::Buffer for Buffer<'_> { fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> { self.buffer .get_mut(self.valid..(self.valid + other.len())) .ok_or(aead::Error)? .copy_from_slice(other); self.valid += other.len(); Ok(()) } fn truncate(&mut self, len: usize) { self.valid = std::cmp::min(self.valid, len); } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EncryptResult { pub nonce_length: usize, pub ciphertext_length: usize, } pub trait Cipher: Sync + Send + ZeroizeOnDrop + 'static { /// encrypts the plaintext present in the buffer /// /// - encrypts `plaintext_length` bytes from the buffer /// - puts the nonce followed by the ciphertext into the buffer /// - returns the size of the nonce and ciphertext fn encrypt( &self, buffer: &mut [u8], plaintext_length: usize, associated_data: &[u8], ) -> std::io::Result; // MUST support arbitrary length nonces fn decrypt( &self, nonce: &[u8], ciphertext: &[u8], associated_data: &[u8], ) -> Result, DecryptError>; fn key_bytes(&self) -> &[u8]; } pub enum CipherHolder<'a> { DecodedServerCookie(DecodedServerCookie), Other(&'a dyn Cipher), } impl AsRef for CipherHolder<'_> { fn as_ref(&self) -> &dyn Cipher { match self { CipherHolder::DecodedServerCookie(cookie) => cookie.c2s.as_ref(), CipherHolder::Other(cipher) => *cipher, } } } pub trait CipherProvider { fn get(&self, context: &[ExtensionField<'_>]) -> Option>; } pub struct NoCipher; impl CipherProvider for NoCipher { fn get<'a>(&self, _context: &[ExtensionField<'_>]) -> Option> { None } } impl CipherProvider for dyn Cipher { fn get(&self, _context: &[ExtensionField<'_>]) -> Option> { Some(CipherHolder::Other(self)) } } impl CipherProvider for Option<&dyn Cipher> { fn get(&self, _context: &[ExtensionField<'_>]) -> Option> { self.map(CipherHolder::Other) } } impl CipherProvider for C { fn get(&self, _context: &[ExtensionField<'_>]) -> Option> { Some(CipherHolder::Other(self)) } } impl CipherProvider for Option { fn get(&self, _context: &[ExtensionField<'_>]) -> Option> { self.as_ref().map(|v| CipherHolder::Other(v)) } } pub struct AesSivCmac256 { // 128 vs 256 difference is due to using the official name (us) vs // the number of bits of security (aes_siv crate) key: Key, } impl ZeroizeOnDrop for AesSivCmac256 {} impl AesSivCmac256 { pub fn new(key: Key) -> Self { AesSivCmac256 { key } } #[cfg(feature = "nts-pool")] pub fn key_size() -> usize { // prefer trust in compiler optimisation over trust in mental arithmetic Self::new(Default::default()).key.len() } #[cfg(feature = "nts-pool")] pub fn from_key_bytes(key_bytes: &[u8]) -> Result { (key_bytes.len() == Self::key_size()) .then(|| Self::new(*aead::Key::::from_slice(key_bytes))) .ok_or(KeyError) } } impl Drop for AesSivCmac256 { fn drop(&mut self) { self.key.zeroize(); } } impl Cipher for AesSivCmac256 { fn encrypt( &self, buffer: &mut [u8], plaintext_length: usize, associated_data: &[u8], ) -> std::io::Result { let mut siv = Aes128Siv::new(&self.key); let nonce: [u8; 16] = rand::thread_rng().gen(); // Prepare the buffer for in place encryption by moving the plaintext // back, creating space for the nonce. if buffer.len() < nonce.len() + plaintext_length { return Err(std::io::ErrorKind::WriteZero.into()); } buffer.copy_within(..plaintext_length, nonce.len()); // And place the nonce where the caller expects it buffer[..nonce.len()].copy_from_slice(&nonce); // Create a wrapper around the plaintext portion of the buffer that has // the methods aes_siv needs to do encryption in-place. let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length); siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap) .map_err(|_| std::io::ErrorKind::Other)?; Ok(EncryptResult { nonce_length: nonce.len(), ciphertext_length: buffer_wrap.valid(), }) } fn decrypt( &self, nonce: &[u8], ciphertext: &[u8], associated_data: &[u8], ) -> Result, DecryptError> { let mut siv = Aes128Siv::new(&self.key); siv.decrypt([associated_data, nonce], ciphertext) .map_err(|_| DecryptError) } fn key_bytes(&self) -> &[u8] { &self.key } } // Ensure siv is not shown in debug output impl std::fmt::Debug for AesSivCmac256 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AesSivCmac256").finish() } } pub struct AesSivCmac512 { // 256 vs 512 difference is due to using the official name (us) vs // the number of bits of security (aes_siv crate) key: Key, } impl AesSivCmac512 { pub fn new(key: Key) -> Self { AesSivCmac512 { key } } #[cfg(feature = "nts-pool")] pub fn key_size() -> usize { // prefer trust in compiler optimisation over trust in mental arithmetic Self::new(Default::default()).key.len() } #[cfg(feature = "nts-pool")] pub fn from_key_bytes(key_bytes: &[u8]) -> Result { (key_bytes.len() == Self::key_size()) .then(|| Self::new(*aead::Key::::from_slice(key_bytes))) .ok_or(KeyError) } } impl ZeroizeOnDrop for AesSivCmac512 {} impl Drop for AesSivCmac512 { fn drop(&mut self) { self.key.zeroize(); } } impl Cipher for AesSivCmac512 { fn encrypt( &self, buffer: &mut [u8], plaintext_length: usize, associated_data: &[u8], ) -> std::io::Result { let mut siv = Aes256Siv::new(&self.key); let nonce: [u8; 16] = rand::thread_rng().gen(); // Prepare the buffer for in place encryption by moving the plaintext // back, creating space for the nonce. if buffer.len() < nonce.len() + plaintext_length { return Err(std::io::ErrorKind::WriteZero.into()); } buffer.copy_within(..plaintext_length, nonce.len()); // And place the nonce where the caller expects it buffer[..nonce.len()].copy_from_slice(&nonce); // Create a wrapper around the plaintext portion of the buffer that has // the methods aes_siv needs to do encryption in-place. let mut buffer_wrap = Buffer::new(&mut buffer[nonce.len()..], plaintext_length); siv.encrypt_in_place([associated_data, &nonce], &mut buffer_wrap) .map_err(|_| std::io::ErrorKind::Other)?; Ok(EncryptResult { nonce_length: nonce.len(), ciphertext_length: buffer_wrap.valid(), }) } fn decrypt( &self, nonce: &[u8], ciphertext: &[u8], associated_data: &[u8], ) -> Result, DecryptError> { let mut siv = Aes256Siv::new(&self.key); siv.decrypt([associated_data, nonce], ciphertext) .map_err(|_| DecryptError) } fn key_bytes(&self) -> &[u8] { &self.key } } // Ensure siv is not shown in debug output impl std::fmt::Debug for AesSivCmac512 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AesSivCmac512").finish() } } #[cfg(test)] pub struct IdentityCipher { nonce_length: usize, } #[cfg(test)] impl IdentityCipher { pub fn new(nonce_length: usize) -> Self { Self { nonce_length } } } #[cfg(test)] impl ZeroizeOnDrop for IdentityCipher {} #[cfg(test)] impl Cipher for IdentityCipher { fn encrypt( &self, buffer: &mut [u8], plaintext_length: usize, associated_data: &[u8], ) -> std::io::Result { debug_assert!(associated_data.is_empty()); let nonce: Vec = (0..self.nonce_length as u8).collect(); // Prepare the buffer for in place encryption by moving the plaintext // back, creating space for the nonce. if buffer.len() < nonce.len() + plaintext_length { return Err(std::io::ErrorKind::WriteZero.into()); } buffer.copy_within(..plaintext_length, nonce.len()); // And place the nonce where the caller expects it buffer[..nonce.len()].copy_from_slice(&nonce); Ok(EncryptResult { nonce_length: nonce.len(), ciphertext_length: plaintext_length, }) } fn decrypt( &self, nonce: &[u8], ciphertext: &[u8], associated_data: &[u8], ) -> Result, DecryptError> { debug_assert!(associated_data.is_empty()); debug_assert_eq!(nonce.len(), self.nonce_length); Ok(ciphertext.to_vec()) } fn key_bytes(&self) -> &[u8] { unimplemented!() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_aes_siv_cmac_256() { let mut testvec: Vec = (0..16).collect(); testvec.resize(testvec.len() + 32, 0); let key = AesSivCmac256::new([0u8; 32].into()); let EncryptResult { nonce_length, ciphertext_length, } = key.encrypt(&mut testvec, 16, &[]).unwrap(); let result = key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[], ) .unwrap(); assert_eq!(result, (0..16).collect::>()); } #[test] fn test_aes_siv_cmac_256_with_assoc_data() { let mut testvec: Vec = (0..16).collect(); testvec.resize(testvec.len() + 32, 0); let key = AesSivCmac256::new([0u8; 32].into()); let EncryptResult { nonce_length, ciphertext_length, } = key.encrypt(&mut testvec, 16, &[1]).unwrap(); assert!(key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[2] ) .is_err()); let result = key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[1], ) .unwrap(); assert_eq!(result, (0..16).collect::>()); } #[test] fn test_aes_siv_cmac_512() { let mut testvec: Vec = (0..16).collect(); testvec.resize(testvec.len() + 32, 0); let key = AesSivCmac512::new([0u8; 64].into()); let EncryptResult { nonce_length, ciphertext_length, } = key.encrypt(&mut testvec, 16, &[]).unwrap(); let result = key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[], ) .unwrap(); assert_eq!(result, (0..16).collect::>()); } #[test] fn test_aes_siv_cmac_512_with_assoc_data() { let mut testvec: Vec = (0..16).collect(); testvec.resize(testvec.len() + 32, 0); let key = AesSivCmac512::new([0u8; 64].into()); let EncryptResult { nonce_length, ciphertext_length, } = key.encrypt(&mut testvec, 16, &[1]).unwrap(); assert!(key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[2] ) .is_err()); let result = key .decrypt( &testvec[..nonce_length], &testvec[nonce_length..(nonce_length + ciphertext_length)], &[1], ) .unwrap(); assert_eq!(result, (0..16).collect::>()); } #[cfg(feature = "nts-pool")] #[test] fn key_functions_correctness() { use aead::KeySizeUser; assert_eq!(Aes128Siv::key_size(), AesSivCmac256::key_size()); assert_eq!(Aes256Siv::key_size(), AesSivCmac512::key_size()); let key_bytes = (1..=64).collect::>(); assert!(AesSivCmac256::from_key_bytes(&key_bytes).is_err()); let slice = &key_bytes[..AesSivCmac256::key_size()]; assert_eq!( AesSivCmac256::from_key_bytes(slice).unwrap().key_bytes(), slice ); let slice = &key_bytes[..AesSivCmac512::key_size()]; assert_eq!( AesSivCmac512::from_key_bytes(slice).unwrap().key_bytes(), slice ); } } ntp-proto-1.4.0/src/packet/error.rs000064400000000000000000000046031046102023000153300ustar 00000000000000use std::fmt::Display; use super::NtpPacket; #[derive(Debug)] pub enum ParsingError { InvalidVersion(u8), IncorrectLength, MalformedNtsExtensionFields, MalformedNonce, MalformedCookiePlaceholder, DecryptError(T), #[cfg(feature = "ntpv5")] V5(super::v5::V5Error), } impl ParsingError { pub(super) fn get_decrypt_error(self) -> Result> { use ParsingError::*; match self { InvalidVersion(v) => Err(InvalidVersion(v)), IncorrectLength => Err(IncorrectLength), MalformedNtsExtensionFields => Err(MalformedNtsExtensionFields), MalformedNonce => Err(MalformedNonce), MalformedCookiePlaceholder => Err(MalformedCookiePlaceholder), DecryptError(decrypt_error) => Ok(decrypt_error), #[cfg(feature = "ntpv5")] V5(e) => Err(V5(e)), } } } impl ParsingError { pub(super) fn generalize(self) -> ParsingError { use ParsingError::*; match self { InvalidVersion(v) => InvalidVersion(v), IncorrectLength => IncorrectLength, MalformedNtsExtensionFields => MalformedNtsExtensionFields, MalformedNonce => MalformedNonce, MalformedCookiePlaceholder => MalformedCookiePlaceholder, DecryptError(decrypt_error) => match decrypt_error {}, #[cfg(feature = "ntpv5")] V5(e) => V5(e), } } } pub type PacketParsingError<'a> = ParsingError>; impl Display for ParsingError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::InvalidVersion(version) => f.write_fmt(format_args!("Invalid version {version}")), Self::IncorrectLength => f.write_str("Incorrect packet length"), Self::MalformedNtsExtensionFields => f.write_str("Malformed nts extension fields"), Self::MalformedNonce => f.write_str("Malformed nonce (likely invalid length)"), Self::MalformedCookiePlaceholder => f.write_str("Malformed cookie placeholder"), Self::DecryptError(_) => f.write_str("Failed to decrypt NTS extension fields"), #[cfg(feature = "ntpv5")] Self::V5(e) => Display::fmt(e, f), } } } impl std::error::Error for ParsingError {} ntp-proto-1.4.0/src/packet/extension_fields.rs000064400000000000000000001366601046102023000175520ustar 00000000000000use std::{ borrow::Cow, io::{Cursor, Write}, }; use crate::{io::NonBlockingWrite, keyset::DecodedServerCookie}; #[cfg(feature = "ntpv5")] use crate::packet::v5::extension_fields::{ReferenceIdRequest, ReferenceIdResponse}; use super::{crypto::EncryptResult, error::ParsingError, Cipher, CipherProvider, Mac}; #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum ExtensionFieldTypeId { UniqueIdentifier, NtsCookie, NtsCookiePlaceholder, NtsEncryptedField, Unknown { type_id: u16, }, #[cfg(feature = "ntpv5")] DraftIdentification, #[cfg(feature = "ntpv5")] Padding, #[cfg(feature = "ntpv5")] ReferenceIdRequest, #[cfg(feature = "ntpv5")] ReferenceIdResponse, } impl ExtensionFieldTypeId { fn from_type_id(type_id: u16) -> Self { match type_id { 0x104 => Self::UniqueIdentifier, 0x204 => Self::NtsCookie, 0x304 => Self::NtsCookiePlaceholder, 0x404 => Self::NtsEncryptedField, #[cfg(feature = "ntpv5")] 0xF5FF => Self::DraftIdentification, #[cfg(feature = "ntpv5")] 0xF501 => Self::Padding, #[cfg(feature = "ntpv5")] 0xF503 => Self::ReferenceIdRequest, #[cfg(feature = "ntpv5")] 0xF504 => Self::ReferenceIdResponse, _ => Self::Unknown { type_id }, } } fn to_type_id(self) -> u16 { match self { ExtensionFieldTypeId::UniqueIdentifier => 0x104, ExtensionFieldTypeId::NtsCookie => 0x204, ExtensionFieldTypeId::NtsCookiePlaceholder => 0x304, ExtensionFieldTypeId::NtsEncryptedField => 0x404, #[cfg(feature = "ntpv5")] ExtensionFieldTypeId::DraftIdentification => 0xF5FF, #[cfg(feature = "ntpv5")] ExtensionFieldTypeId::Padding => 0xF501, #[cfg(feature = "ntpv5")] ExtensionFieldTypeId::ReferenceIdRequest => 0xF503, #[cfg(feature = "ntpv5")] ExtensionFieldTypeId::ReferenceIdResponse => 0xF504, ExtensionFieldTypeId::Unknown { type_id } => type_id, } } } #[derive(Clone, PartialEq, Eq)] pub enum ExtensionField<'a> { UniqueIdentifier(Cow<'a, [u8]>), NtsCookie(Cow<'a, [u8]>), NtsCookiePlaceholder { cookie_length: u16, }, InvalidNtsEncryptedField, #[cfg(feature = "ntpv5")] DraftIdentification(Cow<'a, str>), #[cfg(feature = "ntpv5")] Padding(usize), #[cfg(feature = "ntpv5")] ReferenceIdRequest(super::v5::extension_fields::ReferenceIdRequest), #[cfg(feature = "ntpv5")] ReferenceIdResponse(super::v5::extension_fields::ReferenceIdResponse<'a>), Unknown { type_id: u16, data: Cow<'a, [u8]>, }, } impl std::fmt::Debug for ExtensionField<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::UniqueIdentifier(arg0) => f.debug_tuple("UniqueIdentifier").field(arg0).finish(), Self::NtsCookie(arg0) => f.debug_tuple("NtsCookie").field(arg0).finish(), Self::NtsCookiePlaceholder { cookie_length: body_length, } => f .debug_struct("NtsCookiePlaceholder") .field("body_length", body_length) .finish(), Self::InvalidNtsEncryptedField => f.debug_struct("InvalidNtsEncryptedField").finish(), #[cfg(feature = "ntpv5")] Self::DraftIdentification(arg0) => { f.debug_tuple("DraftIdentification").field(arg0).finish() } #[cfg(feature = "ntpv5")] Self::Padding(len) => f.debug_struct("Padding").field("length", &len).finish(), #[cfg(feature = "ntpv5")] Self::ReferenceIdRequest(r) => f.debug_tuple("ReferenceIdRequest").field(r).finish(), #[cfg(feature = "ntpv5")] Self::ReferenceIdResponse(r) => f.debug_tuple("ReferenceIdResponse").field(r).finish(), Self::Unknown { type_id: typeid, data, } => f .debug_struct("Unknown") .field("typeid", typeid) .field("length", &data.len()) .field("data", data) .finish(), } } } impl<'a> ExtensionField<'a> { const HEADER_LENGTH: usize = 4; pub fn into_owned(self) -> ExtensionField<'static> { use ExtensionField::*; match self { Unknown { type_id: typeid, data, } => Unknown { type_id: typeid, data: Cow::Owned(data.into_owned()), }, UniqueIdentifier(data) => UniqueIdentifier(Cow::Owned(data.into_owned())), NtsCookie(data) => NtsCookie(Cow::Owned(data.into_owned())), NtsCookiePlaceholder { cookie_length: body_length, } => NtsCookiePlaceholder { cookie_length: body_length, }, InvalidNtsEncryptedField => InvalidNtsEncryptedField, #[cfg(feature = "ntpv5")] DraftIdentification(data) => DraftIdentification(Cow::Owned(data.into_owned())), #[cfg(feature = "ntpv5")] Padding(len) => Padding(len), #[cfg(feature = "ntpv5")] ReferenceIdRequest(req) => ReferenceIdRequest(req), #[cfg(feature = "ntpv5")] ReferenceIdResponse(res) => ReferenceIdResponse(res.into_owned()), } } pub(crate) fn serialize( &self, w: impl NonBlockingWrite, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { use ExtensionField::*; match self { Unknown { type_id, data } => { Self::encode_unknown(w, *type_id, data, minimum_size, version) } UniqueIdentifier(identifier) => { Self::encode_unique_identifier(w, identifier, minimum_size, version) } NtsCookie(cookie) => Self::encode_nts_cookie(w, cookie, minimum_size, version), NtsCookiePlaceholder { cookie_length: body_length, } => Self::encode_nts_cookie_placeholder(w, *body_length, minimum_size, version), InvalidNtsEncryptedField => Err(std::io::ErrorKind::Other.into()), #[cfg(feature = "ntpv5")] DraftIdentification(data) => { Self::encode_draft_identification(w, data, minimum_size, version) } #[cfg(feature = "ntpv5")] Padding(len) => Self::encode_padding_field(w, *len, minimum_size, version), #[cfg(feature = "ntpv5")] ReferenceIdRequest(req) => req.serialize(w), #[cfg(feature = "ntpv5")] ReferenceIdResponse(res) => res.serialize(w), } } #[cfg(feature = "__internal-fuzz")] pub fn serialize_pub( &self, w: impl NonBlockingWrite, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { self.serialize(w, minimum_size, version) } fn encode_framing( mut w: impl NonBlockingWrite, ef_id: ExtensionFieldTypeId, data_length: usize, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { if data_length > u16::MAX as usize - ExtensionField::HEADER_LENGTH { return Err(std::io::Error::new( std::io::ErrorKind::Other, "Extension field too long", )); } // u16 for the type_id, u16 for the length let mut actual_length = (data_length as u16 + ExtensionField::HEADER_LENGTH as u16).max(minimum_size); if version == ExtensionHeaderVersion::V4 { actual_length = next_multiple_of_u16(actual_length, 4) } w.write_all(&ef_id.to_type_id().to_be_bytes())?; w.write_all(&actual_length.to_be_bytes()) } fn encode_padding( w: impl NonBlockingWrite, data_length: usize, minimum_size: u16, ) -> std::io::Result<()> { if data_length > u16::MAX as usize - ExtensionField::HEADER_LENGTH { return Err(std::io::Error::new( std::io::ErrorKind::Other, "Extension field too long", )); } let actual_length = next_multiple_of_usize( (data_length + ExtensionField::HEADER_LENGTH).max(minimum_size as usize), 4, ); Self::write_zeros( w, actual_length - data_length - ExtensionField::HEADER_LENGTH, ) } fn write_zeros(mut w: impl NonBlockingWrite, n: usize) -> std::io::Result<()> { let mut remaining = n; let padding_bytes = [0_u8; 32]; while remaining > 0 { let added = usize::min(remaining, padding_bytes.len()); w.write_all(&padding_bytes[..added])?; remaining -= added; } Ok(()) } fn encode_unique_identifier( mut w: impl NonBlockingWrite, identifier: &[u8], minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::UniqueIdentifier, identifier.len(), minimum_size, version, )?; w.write_all(identifier)?; Self::encode_padding(w, identifier.len(), minimum_size) } fn encode_nts_cookie( mut w: impl NonBlockingWrite, cookie: &[u8], minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::NtsCookie, cookie.len(), minimum_size, version, )?; w.write_all(cookie)?; Self::encode_padding(w, cookie.len(), minimum_size)?; Ok(()) } fn encode_nts_cookie_placeholder( mut w: impl NonBlockingWrite, cookie_length: u16, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::NtsCookiePlaceholder, cookie_length as usize, minimum_size, version, )?; Self::write_zeros(&mut w, cookie_length as usize)?; Self::encode_padding(w, cookie_length as usize, minimum_size)?; Ok(()) } fn encode_unknown( mut w: impl NonBlockingWrite, type_id: u16, data: &[u8], minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::Unknown { type_id }, data.len(), minimum_size, version, )?; w.write_all(data)?; Self::encode_padding(w, data.len(), minimum_size)?; Ok(()) } fn encode_encrypted( w: &mut Cursor<&mut [u8]>, fields_to_encrypt: &[ExtensionField], cipher: &dyn Cipher, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { let padding = [0; 4]; let header_start = w.position(); // Placeholder header let type_id: u16 = ExtensionFieldTypeId::NtsEncryptedField.to_type_id(); w.write_all(&type_id.to_be_bytes())?; w.write_all(&0u16.to_be_bytes())?; w.write_all(&0u16.to_be_bytes())?; w.write_all(&0u16.to_be_bytes())?; // Write plaintext for the fields let plaintext_start = w.position(); for field in fields_to_encrypt { // RFC 8915, section 5.5: contrary to the RFC 7822 requirement that fields have a minimum length of 16 or 28 octets, // encrypted extension fields MAY be arbitrarily short (but still MUST be a multiple of 4 octets in length) let minimum_size = 0; field.serialize(&mut *w, minimum_size, version)?; } let plaintext_length = w.position() - plaintext_start; let (packet_so_far, cur_extension_field) = w.get_mut().split_at_mut(header_start as usize); let header_size = (plaintext_start - header_start) as usize; let EncryptResult { nonce_length, ciphertext_length, } = cipher.encrypt( &mut cur_extension_field[header_size..], plaintext_length as usize, packet_so_far, )?; // Nonce and ciphertext lengths may not be a multiple of 4, so add padding to them // to make their lengths multiples of 4. let padded_nonce_length = next_multiple_of_usize(nonce_length, 4); let padded_ciphertext_length = next_multiple_of_usize(ciphertext_length, 4); if cur_extension_field.len() < (plaintext_start - header_start) as usize + padded_ciphertext_length + padded_nonce_length { return Err(std::io::ErrorKind::WriteZero.into()); } // move the ciphertext over to make space for nonce padding cur_extension_field.copy_within( header_size + nonce_length..header_size + nonce_length + ciphertext_length, header_size + padded_nonce_length, ); // zero out then nonce padding let nonce_padding = padded_nonce_length - nonce_length; cur_extension_field[header_size + nonce_length..][..nonce_padding] .copy_from_slice(&padding[..nonce_padding]); // zero out the ciphertext padding let ciphertext_padding = padded_ciphertext_length - ciphertext_length; debug_assert_eq!( ciphertext_padding, 0, "extension field encoding should add padding" ); cur_extension_field[header_size + padded_nonce_length + ciphertext_length..] [..ciphertext_padding] .copy_from_slice(&padding[..ciphertext_padding]); // go back and fill in the header let signature_length = header_size + padded_nonce_length + padded_ciphertext_length; w.set_position(header_start); let type_id: u16 = ExtensionFieldTypeId::NtsEncryptedField.to_type_id(); w.write_all(&type_id.to_be_bytes())?; w.write_all(&(signature_length as u16).to_be_bytes())?; w.write_all(&(nonce_length as u16).to_be_bytes())?; w.write_all(&(ciphertext_length as u16).to_be_bytes())?; // set the final position w.set_position(header_start + signature_length as u64); Ok(()) } #[cfg(feature = "ntpv5")] fn encode_draft_identification( mut w: impl NonBlockingWrite, data: &str, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::DraftIdentification, data.len(), minimum_size, version, )?; w.write_all(data.as_bytes())?; Self::encode_padding(w, data.len(), minimum_size)?; Ok(()) } #[cfg(feature = "ntpv5")] pub fn encode_padding_field( mut w: impl NonBlockingWrite, length: usize, minimum_size: u16, version: ExtensionHeaderVersion, ) -> std::io::Result<()> { Self::encode_framing( &mut w, ExtensionFieldTypeId::Padding, length - Self::HEADER_LENGTH, minimum_size, version, )?; Self::write_zeros(&mut w, length - Self::HEADER_LENGTH)?; Self::encode_padding(w, length - Self::HEADER_LENGTH, minimum_size)?; Ok(()) } fn decode_unique_identifier( message: &'a [u8], ) -> Result> { // The string MUST be at least 32 octets long // TODO: Discuss if we really want this check here if message.len() < 32 { return Err(ParsingError::IncorrectLength); } Ok(ExtensionField::UniqueIdentifier(message[..].into())) } fn decode_nts_cookie( message: &'a [u8], ) -> Result> { Ok(ExtensionField::NtsCookie(message[..].into())) } fn decode_nts_cookie_placeholder( message: &'a [u8], ) -> Result> { if message.iter().any(|b| *b != 0) { Err(ParsingError::MalformedCookiePlaceholder) } else { Ok(ExtensionField::NtsCookiePlaceholder { cookie_length: message.len() as u16, }) } } fn decode_unknown( type_id: u16, message: &'a [u8], ) -> Result> { Ok(ExtensionField::Unknown { type_id, data: Cow::Borrowed(message), }) } #[cfg(feature = "ntpv5")] fn decode_draft_identification( message: &'a [u8], extension_header_version: ExtensionHeaderVersion, ) -> Result> { let di = match core::str::from_utf8(message) { Ok(di) if di.is_ascii() => di, _ => return Err(super::v5::V5Error::InvalidDraftIdentification.into()), }; let di = match extension_header_version { ExtensionHeaderVersion::V4 => di.trim_end_matches('\0'), ExtensionHeaderVersion::V5 => di, }; Ok(ExtensionField::DraftIdentification(Cow::Borrowed(di))) } fn decode( raw: RawExtensionField<'a>, #[cfg_attr(not(feature = "ntpv5"), allow(unused_variables))] extension_header_version: ExtensionHeaderVersion, ) -> Result> { type EF<'a> = ExtensionField<'a>; type TypeId = ExtensionFieldTypeId; let message = &raw.message_bytes; match raw.type_id { TypeId::UniqueIdentifier => EF::decode_unique_identifier(message), TypeId::NtsCookie => EF::decode_nts_cookie(message), TypeId::NtsCookiePlaceholder => EF::decode_nts_cookie_placeholder(message), #[cfg(feature = "ntpv5")] TypeId::DraftIdentification => { EF::decode_draft_identification(message, extension_header_version) } #[cfg(feature = "ntpv5")] TypeId::ReferenceIdRequest => Ok(ReferenceIdRequest::decode(message)?.into()), #[cfg(feature = "ntpv5")] TypeId::ReferenceIdResponse => Ok(ReferenceIdResponse::decode(message).into()), type_id => EF::decode_unknown(type_id.to_type_id(), message), } } } #[derive(Debug, Clone, PartialEq, Eq, Default)] pub(super) struct ExtensionFieldData<'a> { pub(super) authenticated: Vec>, pub(super) encrypted: Vec>, pub(super) untrusted: Vec>, } #[derive(Debug)] pub(super) struct DeserializedExtensionField<'a> { pub(super) efdata: ExtensionFieldData<'a>, pub(super) remaining_bytes: &'a [u8], pub(super) cookie: Option, } #[derive(Debug)] pub(super) struct InvalidNtsExtensionField<'a> { pub(super) efdata: ExtensionFieldData<'a>, pub(super) remaining_bytes: &'a [u8], } impl<'a> ExtensionFieldData<'a> { pub(super) fn into_owned(self) -> ExtensionFieldData<'static> { let map_into_owned = |vec: Vec| vec.into_iter().map(ExtensionField::into_owned).collect(); ExtensionFieldData { authenticated: map_into_owned(self.authenticated), encrypted: map_into_owned(self.encrypted), untrusted: map_into_owned(self.untrusted), } } pub(super) fn serialize( &self, w: &mut Cursor<&mut [u8]>, cipher: &(impl CipherProvider + ?Sized), version: ExtensionHeaderVersion, ) -> std::io::Result<()> { if !self.authenticated.is_empty() || !self.encrypted.is_empty() { let cipher = match cipher.get(&self.authenticated) { Some(cipher) => cipher, None => return Err(std::io::Error::new(std::io::ErrorKind::Other, "no cipher")), }; // the authenticated extension fields are always followed by the encrypted extension // field. We don't (currently) encode a MAC, so the minimum size per RFC 7822 is 16 octets let minimum_size = 16; for field in &self.authenticated { field.serialize(&mut *w, minimum_size, version)?; } // RFC 8915, section 5.5: contrary to the RFC 7822 requirement that fields have a minimum length of 16 or 28 octets, // encrypted extension fields MAY be arbitrarily short (but still MUST be a multiple of 4 octets in length) // hence we don't provide a minimum size here ExtensionField::encode_encrypted(w, &self.encrypted, cipher.as_ref(), version)?; } // per RFC 7822, section 7.5.1.4. let mut it = self.untrusted.iter().peekable(); while let Some(field) = it.next() { let is_last = it.peek().is_none(); let minimum_size = match version { ExtensionHeaderVersion::V4 if is_last => 28, ExtensionHeaderVersion::V4 => 16, #[cfg(feature = "ntpv5")] ExtensionHeaderVersion::V5 => 4, }; field.serialize(&mut *w, minimum_size, version)?; } Ok(()) } #[allow(clippy::type_complexity)] pub(super) fn deserialize( data: &'a [u8], header_size: usize, cipher: &(impl CipherProvider + ?Sized), version: ExtensionHeaderVersion, ) -> Result, ParsingError>> { use ExtensionField::InvalidNtsEncryptedField; let mut efdata = Self::default(); let mut size = 0; let mut is_valid_nts = true; let mut cookie = None; let mac_size = match version { ExtensionHeaderVersion::V4 => Mac::MAXIMUM_SIZE, #[cfg(feature = "ntpv5")] ExtensionHeaderVersion::V5 => 0, }; for field in RawExtensionField::deserialize_sequence( &data[header_size..], mac_size, RawExtensionField::V4_UNENCRYPTED_MINIMUM_SIZE, version, ) { let (offset, field) = field.map_err(|e| e.generalize())?; size = offset + field.wire_length(version); match field.type_id { ExtensionFieldTypeId::NtsEncryptedField => { let encrypted = RawEncryptedField::from_message_bytes(field.message_bytes) .map_err(|e| e.generalize())?; let cipher = match cipher.get(&efdata.untrusted) { Some(cipher) => cipher, None => { efdata.untrusted.push(InvalidNtsEncryptedField); is_valid_nts = false; continue; } }; let encrypted_fields = match encrypted.decrypt( cipher.as_ref(), &data[..header_size + offset], version, ) { Ok(encrypted_fields) => encrypted_fields, Err(e) => { // early return if it's anything but a decrypt error e.get_decrypt_error()?; efdata.untrusted.push(InvalidNtsEncryptedField); is_valid_nts = false; continue; } }; // for the current ciphers we allow in non-test code, // the nonce should always be 16 bytes debug_assert_eq!(encrypted.nonce.len(), 16); efdata.encrypted.extend(encrypted_fields); cookie = match cipher { super::crypto::CipherHolder::DecodedServerCookie(cookie) => Some(cookie), super::crypto::CipherHolder::Other(_) => None, }; // All previous untrusted fields are now validated efdata.authenticated.append(&mut efdata.untrusted); } _ => { let field = ExtensionField::decode(field, version).map_err(|e| e.generalize())?; efdata.untrusted.push(field); } } } let remaining_bytes = &data[header_size + size..]; if is_valid_nts { let result = DeserializedExtensionField { efdata, remaining_bytes, cookie, }; Ok(result) } else { let result = InvalidNtsExtensionField { efdata, remaining_bytes, }; Err(ParsingError::DecryptError(result)) } } } struct RawEncryptedField<'a> { nonce: &'a [u8], ciphertext: &'a [u8], } impl<'a> RawEncryptedField<'a> { fn from_message_bytes( message_bytes: &'a [u8], ) -> Result> { use ParsingError::*; let [b0, b1, b2, b3, ref rest @ ..] = message_bytes[..] else { return Err(IncorrectLength); }; let nonce_length = u16::from_be_bytes([b0, b1]) as usize; let ciphertext_length = u16::from_be_bytes([b2, b3]) as usize; let nonce = rest.get(..nonce_length).ok_or(IncorrectLength)?; // skip the lengths and the nonce. pad to a multiple of 4 let ciphertext_start = 4 + next_multiple_of_u16(nonce_length as u16, 4) as usize; let ciphertext = message_bytes .get(ciphertext_start..ciphertext_start + ciphertext_length) .ok_or(IncorrectLength)?; Ok(Self { nonce, ciphertext }) } fn decrypt( &self, cipher: &dyn Cipher, aad: &[u8], version: ExtensionHeaderVersion, ) -> Result>, ParsingError>> { let plaintext = match cipher.decrypt(self.nonce, self.ciphertext, aad) { Ok(plain) => plain, Err(_) => { return Err(ParsingError::DecryptError( ExtensionField::InvalidNtsEncryptedField, )); } }; RawExtensionField::deserialize_sequence( &plaintext, 0, RawExtensionField::BARE_MINIMUM_SIZE, version, ) .map(|encrypted_field| { let encrypted_field = encrypted_field.map_err(|e| e.generalize())?.1; if encrypted_field.type_id == ExtensionFieldTypeId::NtsEncryptedField { // TODO: Discuss whether we want this check Err(ParsingError::MalformedNtsExtensionFields) } else { Ok(ExtensionField::decode(encrypted_field, version) .map_err(|e| e.generalize())? .into_owned()) } }) .collect() } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ExtensionHeaderVersion { V4, #[cfg(feature = "ntpv5")] V5, } #[cfg(feature = "__internal-fuzz")] impl<'a> arbitrary::Arbitrary<'a> for ExtensionHeaderVersion { #[cfg(not(feature = "ntpv5"))] fn arbitrary(_u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { Ok(Self::V4) } #[cfg(feature = "ntpv5")] fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { Ok(if bool::arbitrary(u)? { Self::V4 } else { Self::V5 }) } } #[derive(Debug)] struct RawExtensionField<'a> { type_id: ExtensionFieldTypeId, // bytes of the value and any padding. Does not include the header (field type and length) // https://www.rfc-editor.org/rfc/rfc5905.html#section-7.5 message_bytes: &'a [u8], } impl<'a> RawExtensionField<'a> { const BARE_MINIMUM_SIZE: usize = 4; const V4_UNENCRYPTED_MINIMUM_SIZE: usize = 4; fn wire_length(&self, version: ExtensionHeaderVersion) -> usize { // field type + length + value + padding let length = 2 + 2 + self.message_bytes.len(); if version == ExtensionHeaderVersion::V4 { // All extension fields are zero-padded to a word (four octets) boundary. // // message_bytes should include this padding, so this should already be true debug_assert_eq!(length % 4, 0); } next_multiple_of_usize(length, 4) } fn deserialize( data: &'a [u8], minimum_size: usize, version: ExtensionHeaderVersion, ) -> Result> { use ParsingError::IncorrectLength; let [b0, b1, b2, b3, ..] = data[..] else { return Err(IncorrectLength); }; let type_id = u16::from_be_bytes([b0, b1]); // The Length field is a 16-bit unsigned integer that indicates the length of // the entire extension field in octets, including the Padding field. let field_length = u16::from_be_bytes([b2, b3]) as usize; if field_length < minimum_size { return Err(IncorrectLength); } // In NTPv4: padding is up to a multiple of 4 bytes, so a valid field length is divisible by 4 if version == ExtensionHeaderVersion::V4 && field_length % 4 != 0 { return Err(IncorrectLength); } // In NTPv5: There must still be enough room in the packet for data + padding data.get(4..next_multiple_of_usize(field_length, 4)) .ok_or(IncorrectLength)?; // because the field length includes padding, the message bytes may not exactly match the input let message_bytes = data.get(4..field_length).ok_or(IncorrectLength)?; Ok(Self { type_id: ExtensionFieldTypeId::from_type_id(type_id), message_bytes, }) } fn deserialize_sequence( buffer: &'a [u8], cutoff: usize, minimum_size: usize, version: ExtensionHeaderVersion, ) -> impl Iterator< Item = Result<(usize, RawExtensionField<'a>), ParsingError>, > + 'a { ExtensionFieldStreamer { buffer, cutoff, minimum_size, offset: 0, version, } } } struct ExtensionFieldStreamer<'a> { buffer: &'a [u8], cutoff: usize, minimum_size: usize, offset: usize, version: ExtensionHeaderVersion, } impl<'a> Iterator for ExtensionFieldStreamer<'a> { type Item = Result<(usize, RawExtensionField<'a>), ParsingError>; fn next(&mut self) -> Option { let remaining = &self.buffer.get(self.offset..)?; if remaining.len() <= self.cutoff { return None; } match RawExtensionField::deserialize(remaining, self.minimum_size, self.version) { Ok(field) => { let offset = self.offset; self.offset += field.wire_length(self.version); Some(Ok((offset, field))) } Err(error) => { self.offset = self.buffer.len(); Some(Err(error)) } } } } const fn next_multiple_of_u16(lhs: u16, rhs: u16) -> u16 { match lhs % rhs { 0 => lhs, r => lhs + (rhs - r), } } const fn next_multiple_of_usize(lhs: usize, rhs: usize) -> usize { match lhs % rhs { 0 => lhs, r => lhs + (rhs - r), } } #[cfg(test)] mod tests { use crate::{keyset::KeySet, packet::AesSivCmac256}; use super::*; #[test] fn roundtrip_ef_typeid() { for i in 0..=u16::MAX { let a = ExtensionFieldTypeId::from_type_id(i); assert_eq!(i, a.to_type_id()); } } #[test] fn test_unique_identifier() { let identifier: Vec<_> = (0..16).collect(); let mut w = vec![]; ExtensionField::encode_unique_identifier( &mut w, &identifier, 0, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!( w, &[1, 4, 0, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ); } #[test] fn test_nts_cookie() { let cookie: Vec<_> = (0..16).collect(); let mut w = vec![]; ExtensionField::encode_nts_cookie(&mut w, &cookie, 0, ExtensionHeaderVersion::V4).unwrap(); assert_eq!( w, &[2, 4, 0, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ); } #[test] fn test_nts_cookie_placeholder() { const COOKIE_LENGTH: usize = 16; let mut w = vec![]; ExtensionField::encode_nts_cookie_placeholder( &mut w, COOKIE_LENGTH as u16, 0, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!( w, &[3, 4, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,] ); let raw = RawExtensionField { type_id: ExtensionFieldTypeId::NtsCookiePlaceholder, message_bytes: &[1; COOKIE_LENGTH], }; let output = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap_err(); assert!(matches!(output, ParsingError::MalformedCookiePlaceholder)); let raw = RawExtensionField { type_id: ExtensionFieldTypeId::NtsCookiePlaceholder, message_bytes: &[0; COOKIE_LENGTH], }; let output = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap(); let ExtensionField::NtsCookiePlaceholder { cookie_length } = output else { panic!("incorrect variant"); }; assert_eq!(cookie_length, 16); } #[test] fn test_unknown() { let data: Vec<_> = (0..16).collect(); let mut w = vec![]; ExtensionField::encode_unknown(&mut w, 42, &data, 0, ExtensionHeaderVersion::V4).unwrap(); assert_eq!( w, &[0, 42, 0, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ); } #[cfg(feature = "ntpv5")] #[test] fn draft_identification() { let test_id = crate::packet::v5::DRAFT_VERSION; let len = u16::try_from(4 + test_id.len()).unwrap(); let mut data = vec![]; data.extend(&[0xF5, 0xFF]); // Type data.extend(&len.to_be_bytes()); // Length data.extend(test_id.as_bytes()); // Payload data.extend(&[0]); // Padding let raw = RawExtensionField::deserialize(&data, 4, ExtensionHeaderVersion::V5).unwrap(); let ef = ExtensionField::decode(raw, ExtensionHeaderVersion::V4).unwrap(); let ExtensionField::DraftIdentification(ref parsed) = ef else { panic!("Unexpected extension field {ef:?}... expected DraftIdentification"); }; assert_eq!(parsed, test_id); let mut out = vec![]; ef.serialize(&mut out, 4, ExtensionHeaderVersion::V5) .unwrap(); assert_eq!(&out, &data); } #[cfg(feature = "ntpv5")] #[test] fn extension_field_length() { let data: Vec<_> = (0..21).collect(); let mut w = vec![]; ExtensionField::encode_unknown(&mut w, 42, &data, 16, ExtensionHeaderVersion::V4).unwrap(); let raw: RawExtensionField<'_> = RawExtensionField::deserialize(&w, 16, ExtensionHeaderVersion::V4).unwrap(); // v4 extension field header length includes padding bytes assert_eq!(w[3], 28); assert_eq!(w.len(), 28); assert_eq!(raw.message_bytes.len(), 24); assert_eq!(raw.wire_length(ExtensionHeaderVersion::V4), 28); let mut w = vec![]; ExtensionField::encode_unknown(&mut w, 42, &data, 16, ExtensionHeaderVersion::V5).unwrap(); let raw: RawExtensionField<'_> = RawExtensionField::deserialize(&w, 16, ExtensionHeaderVersion::V5).unwrap(); // v5 extension field header length does not include padding bytes assert_eq!(w[3], 25); assert_eq!(w.len(), 28); assert_eq!(raw.message_bytes.len(), 21); assert_eq!(raw.wire_length(ExtensionHeaderVersion::V5), 28); } #[test] fn extension_field_minimum_size() { let minimum_size = 32; let expected_size = minimum_size as usize; let data: Vec<_> = (0..16).collect(); let mut w = vec![]; ExtensionField::encode_unique_identifier( &mut w, &data, minimum_size, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size, ExtensionHeaderVersion::V4) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_nts_cookie_placeholder( &mut w, data.len() as u16, minimum_size, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size, ExtensionHeaderVersion::V4) .unwrap(); assert_eq!(w.len(), expected_size); // NOTE: encrypted fields do not have a minimum_size } #[test] fn extension_field_padding() { let minimum_size = 0; let expected_size = 20; let data: Vec<_> = (0..15).collect(); // 15 bytes, so padding is needed let mut w = vec![]; ExtensionField::encode_unique_identifier( &mut w, &data, minimum_size, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_nts_cookie(&mut w, &data, minimum_size, ExtensionHeaderVersion::V4) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_nts_cookie_placeholder( &mut w, data.len() as u16, minimum_size, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = vec![]; ExtensionField::encode_unknown(&mut w, 42, &data, minimum_size, ExtensionHeaderVersion::V4) .unwrap(); assert_eq!(w.len(), expected_size); let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); let c2s = [0; 32]; let cipher = AesSivCmac256::new(c2s.into()); let fields_to_encrypt = [ExtensionField::UniqueIdentifier(Cow::Borrowed( data.as_slice(), ))]; ExtensionField::encode_encrypted( &mut cursor, &fields_to_encrypt, &cipher, ExtensionHeaderVersion::V4, ) .unwrap(); assert_eq!( cursor.position() as usize, 2 + 6 + c2s.len() + expected_size ); } #[test] fn nonce_padding() { let nonce_length = 11; let cipher = crate::packet::crypto::IdentityCipher::new(nonce_length); // multiple of 4; no padding is needed let fields_to_encrypt = [ExtensionField::Unknown { type_id: 42u16, data: Cow::Borrowed(&[1, 2, 3, 4]), }]; // 6 bytes of data, rounded up to a multiple of 4 let plaintext_length = 8; let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); ExtensionField::encode_encrypted( &mut cursor, &fields_to_encrypt, &cipher, ExtensionHeaderVersion::V4, ) .unwrap(); let expected_length = 2 + 6 + next_multiple_of_usize(nonce_length, 4) + plaintext_length; assert_eq!(cursor.position() as usize, expected_length,); let message_bytes = &w.as_ref()[..expected_length]; let mut it = RawExtensionField::deserialize_sequence( message_bytes, 0, 0, ExtensionHeaderVersion::V4, ); let field = it.next().unwrap().unwrap(); assert!(it.next().is_none()); match field { ( 0, RawExtensionField { type_id: ExtensionFieldTypeId::NtsEncryptedField, message_bytes, }, ) => { let raw = RawEncryptedField::from_message_bytes(message_bytes).unwrap(); let decrypted_fields = raw .decrypt(&cipher, &[], ExtensionHeaderVersion::V4) .unwrap(); assert_eq!(decrypted_fields, fields_to_encrypt); } _ => panic!("invalid"), } } #[test] fn deserialize_extension_field_data_no_cipher() { let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&[0; 16])); let cipher = crate::packet::crypto::NoCipher; // cause an error when the cipher is needed { let data = ExtensionFieldData { authenticated: vec![cookie.clone()], encrypted: vec![], untrusted: vec![], }; let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); assert!(data .serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .is_err()); } // but succeed when the cipher is not needed { let data = ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: vec![cookie.clone()], }; let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); assert!(data .serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .is_ok()); } } #[test] fn serialize_untrusted_fields() { let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&[0; 16])); let data = ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: vec![cookie.clone(), cookie], }; let nonce_length = 11; let cipher = crate::packet::crypto::IdentityCipher::new(nonce_length); let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; // the cookie we provide is `2 + 2 + 16 = 20` bytes let expected_length = Ord::max(20, 28) + Ord::max(20, 16); assert_eq!(slice.len(), expected_length); } #[test] fn serialize_untrusted_fields_smaller_than_minimum() { let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&[0; 4])); let data = ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: vec![cookie.clone(), cookie], }; let nonce_length = 11; let cipher = crate::packet::crypto::IdentityCipher::new(nonce_length); let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; // now we hit the minimum widths of extension fields // let minimum_size = if is_last { 28 } else { 16 }; assert_eq!(slice.len(), 28 + 16); } #[test] fn deserialize_without_cipher() { let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&[0; 32])); let data = ExtensionFieldData { authenticated: vec![], encrypted: vec![cookie], untrusted: vec![], }; let nonce_length = 11; let cipher = crate::packet::crypto::IdentityCipher::new(nonce_length); let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; let cipher = crate::packet::crypto::NoCipher; let result = ExtensionFieldData::deserialize(slice, 0, &cipher, ExtensionHeaderVersion::V4) .unwrap_err(); let ParsingError::DecryptError(InvalidNtsExtensionField { efdata, remaining_bytes, }) = result else { panic!("invalid variant"); }; let invalid = ExtensionField::InvalidNtsEncryptedField; assert_eq!(efdata.authenticated, &[]); assert_eq!(efdata.encrypted, &[]); assert_eq!(efdata.untrusted, &[invalid]); assert_eq!(remaining_bytes, &[]); } #[test] fn deserialize_different_cipher() { let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&[0; 32])); let data = ExtensionFieldData { authenticated: vec![], encrypted: vec![cookie], untrusted: vec![], }; let nonce_length = 11; let cipher = crate::packet::crypto::IdentityCipher::new(nonce_length); let mut w = [0u8; 128]; let mut cursor = Cursor::new(w.as_mut_slice()); data.serialize(&mut cursor, &cipher, ExtensionHeaderVersion::V4) .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; // now use a different (valid) cipher for deserialization let c2s = [0; 32]; let cipher = AesSivCmac256::new(c2s.into()); let result = ExtensionFieldData::deserialize(slice, 0, &cipher, ExtensionHeaderVersion::V4) .unwrap_err(); let ParsingError::DecryptError(InvalidNtsExtensionField { efdata, remaining_bytes, }) = result else { panic!("invalid variant"); }; let invalid = ExtensionField::InvalidNtsEncryptedField; assert_eq!(efdata.authenticated, &[]); assert_eq!(efdata.encrypted, &[]); assert_eq!(efdata.untrusted, &[invalid]); assert_eq!(remaining_bytes, &[]); } #[test] fn deserialize_with_keyset() { let keyset = KeySet::new(); let decoded_server_cookie = crate::keyset::test_cookie(); let cookie_data = keyset.encode_cookie(&decoded_server_cookie); let cookie = ExtensionField::NtsCookie(Cow::Borrowed(&cookie_data)); let data = ExtensionFieldData { authenticated: vec![cookie.clone()], encrypted: vec![cookie], untrusted: vec![], }; let mut w = [0u8; 256]; let mut cursor = Cursor::new(w.as_mut_slice()); data.serialize(&mut cursor, &keyset, ExtensionHeaderVersion::V4) .unwrap(); let n = cursor.position() as usize; let slice = &w.as_slice()[..n]; let result = ExtensionFieldData::deserialize(slice, 0, &keyset, ExtensionHeaderVersion::V4).unwrap(); let DeserializedExtensionField { efdata, remaining_bytes, cookie, } = result; assert_eq!(efdata.authenticated.len(), 1); assert_eq!(efdata.encrypted.len(), 1); assert_eq!(efdata.untrusted, &[]); assert_eq!(remaining_bytes, &[]); assert!(cookie.is_some()); } } ntp-proto-1.4.0/src/packet/mac.rs000064400000000000000000000031041046102023000147320ustar 00000000000000use std::borrow::Cow; use crate::io::NonBlockingWrite; use super::error::ParsingError; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct Mac<'a> { keyid: u32, mac: Cow<'a, [u8]>, } impl<'a> Mac<'a> { // As per RFC7822: // If a MAC is used, it resides at the end of the packet. This field // can be either 24 octets long, 20 octets long, or a 4-octet // crypto-NAK. pub(super) const MAXIMUM_SIZE: usize = 24; pub(super) fn into_owned(self) -> Mac<'static> { Mac { keyid: self.keyid, mac: Cow::Owned(self.mac.into_owned()), } } pub(super) fn serialize(&self, mut w: impl NonBlockingWrite) -> std::io::Result<()> { w.write_all(&self.keyid.to_be_bytes())?; w.write_all(&self.mac) } pub(super) fn deserialize( data: &'a [u8], ) -> Result, ParsingError> { if data.len() < 4 || data.len() >= Self::MAXIMUM_SIZE { return Err(ParsingError::IncorrectLength); } Ok(Mac { keyid: u32::from_be_bytes(data[0..4].try_into().unwrap()), mac: Cow::Borrowed(&data[4..]), }) } } #[cfg(test)] mod tests { use super::*; #[test] fn roundtrip() { let input = Mac { keyid: 42, mac: Cow::Borrowed(&[1, 2, 3, 4, 5, 6, 7, 8]), }; let input = input.to_owned(); let mut w = Vec::new(); input.serialize(&mut w).unwrap(); let output = Mac::deserialize(&w).unwrap(); assert_eq!(input, output); } } ntp-proto-1.4.0/src/packet/mod.rs000064400000000000000000002663071046102023000147710ustar 00000000000000use std::{borrow::Cow, io::Cursor}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use crate::{ clock::NtpClock, identifiers::ReferenceId, io::NonBlockingWrite, keyset::{DecodedServerCookie, KeySet}, system::SystemSnapshot, time_types::{NtpDuration, NtpTimestamp, PollInterval}, }; use self::{error::ParsingError, extension_fields::ExtensionFieldData, mac::Mac}; mod crypto; mod error; mod extension_fields; mod mac; #[cfg(feature = "ntpv5")] pub mod v5; pub use crypto::{ AesSivCmac256, AesSivCmac512, Cipher, CipherHolder, CipherProvider, DecryptError, EncryptResult, NoCipher, }; pub use error::PacketParsingError; pub use extension_fields::{ExtensionField, ExtensionHeaderVersion}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum NtpLeapIndicator { NoWarning, Leap61, Leap59, Unknown, } impl NtpLeapIndicator { // This function should only ever be called with 2 bit values // (in the least significant position) fn from_bits(bits: u8) -> NtpLeapIndicator { match bits { 0 => NtpLeapIndicator::NoWarning, 1 => NtpLeapIndicator::Leap61, 2 => NtpLeapIndicator::Leap59, 3 => NtpLeapIndicator::Unknown, // This function should only ever be called from the packet parser // with just two bits, so this really should be unreachable _ => unreachable!(), } } fn to_bits(self) -> u8 { match self { NtpLeapIndicator::NoWarning => 0, NtpLeapIndicator::Leap61 => 1, NtpLeapIndicator::Leap59 => 2, NtpLeapIndicator::Unknown => 3, } } pub fn is_synchronized(&self) -> bool { !matches!(self, Self::Unknown) } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum NtpAssociationMode { Reserved, SymmetricActive, SymmetricPassive, Client, Server, Broadcast, Control, Private, } impl NtpAssociationMode { // This function should only ever be called with 3 bit values // (in the least significant position) fn from_bits(bits: u8) -> NtpAssociationMode { match bits { 0 => NtpAssociationMode::Reserved, 1 => NtpAssociationMode::SymmetricActive, 2 => NtpAssociationMode::SymmetricPassive, 3 => NtpAssociationMode::Client, 4 => NtpAssociationMode::Server, 5 => NtpAssociationMode::Broadcast, 6 => NtpAssociationMode::Control, 7 => NtpAssociationMode::Private, // This function should only ever be called from the packet parser // with just three bits, so this really should be unreachable _ => unreachable!(), } } fn to_bits(self) -> u8 { match self { NtpAssociationMode::Reserved => 0, NtpAssociationMode::SymmetricActive => 1, NtpAssociationMode::SymmetricPassive => 2, NtpAssociationMode::Client => 3, NtpAssociationMode::Server => 4, NtpAssociationMode::Broadcast => 5, NtpAssociationMode::Control => 6, NtpAssociationMode::Private => 7, } } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct NtpPacket<'a> { header: NtpHeader, efdata: ExtensionFieldData<'a>, mac: Option>, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum NtpHeader { V3(NtpHeaderV3V4), V4(NtpHeaderV3V4), #[cfg(feature = "ntpv5")] V5(v5::NtpHeaderV5), } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct NtpHeaderV3V4 { leap: NtpLeapIndicator, mode: NtpAssociationMode, stratum: u8, poll: PollInterval, precision: i8, root_delay: NtpDuration, root_dispersion: NtpDuration, reference_id: ReferenceId, reference_timestamp: NtpTimestamp, /// Time at the client when the request departed for the server origin_timestamp: NtpTimestamp, /// Time at the server when the request arrived from the client receive_timestamp: NtpTimestamp, /// Time at the server when the response left for the client transmit_timestamp: NtpTimestamp, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct RequestIdentifier { expected_origin_timestamp: NtpTimestamp, uid: Option<[u8; 32]>, } impl NtpHeaderV3V4 { const WIRE_LENGTH: usize = 48; /// A new, empty NtpHeader fn new() -> Self { Self { leap: NtpLeapIndicator::NoWarning, mode: NtpAssociationMode::Client, stratum: 0, poll: PollInterval::from_byte(0), precision: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), reference_id: ReferenceId::from_int(0), reference_timestamp: NtpTimestamp::default(), origin_timestamp: NtpTimestamp::default(), receive_timestamp: NtpTimestamp::default(), transmit_timestamp: NtpTimestamp::default(), } } fn deserialize(data: &[u8]) -> Result<(Self, usize), ParsingError> { if data.len() < Self::WIRE_LENGTH { return Err(ParsingError::IncorrectLength); } Ok(( Self { leap: NtpLeapIndicator::from_bits((data[0] & 0xC0) >> 6), mode: NtpAssociationMode::from_bits(data[0] & 0x07), stratum: data[1], poll: PollInterval::from_byte(data[2]), precision: data[3] as i8, root_delay: NtpDuration::from_bits_short(data[4..8].try_into().unwrap()), root_dispersion: NtpDuration::from_bits_short(data[8..12].try_into().unwrap()), reference_id: ReferenceId::from_bytes(data[12..16].try_into().unwrap()), reference_timestamp: NtpTimestamp::from_bits(data[16..24].try_into().unwrap()), origin_timestamp: NtpTimestamp::from_bits(data[24..32].try_into().unwrap()), receive_timestamp: NtpTimestamp::from_bits(data[32..40].try_into().unwrap()), transmit_timestamp: NtpTimestamp::from_bits(data[40..48].try_into().unwrap()), }, Self::WIRE_LENGTH, )) } fn serialize(&self, mut w: impl NonBlockingWrite, version: u8) -> std::io::Result<()> { w.write_all(&[(self.leap.to_bits() << 6) | (version << 3) | self.mode.to_bits()])?; w.write_all(&[self.stratum, self.poll.as_byte(), self.precision as u8])?; w.write_all(&self.root_delay.to_bits_short())?; w.write_all(&self.root_dispersion.to_bits_short())?; w.write_all(&self.reference_id.to_bytes())?; w.write_all(&self.reference_timestamp.to_bits())?; w.write_all(&self.origin_timestamp.to_bits())?; w.write_all(&self.receive_timestamp.to_bits())?; w.write_all(&self.transmit_timestamp.to_bits())?; Ok(()) } fn poll_message(poll_interval: PollInterval) -> (Self, RequestIdentifier) { let mut packet = Self::new(); packet.poll = poll_interval; packet.mode = NtpAssociationMode::Client; // In order to increase the entropy of the transmit timestamp // it is just a randomly generated timestamp. // We then expect to get it back identically from the remote // in the origin field. let transmit_timestamp = thread_rng().gen(); packet.transmit_timestamp = transmit_timestamp; ( packet, RequestIdentifier { expected_origin_timestamp: transmit_timestamp, uid: None, }, ) } fn timestamp_response( system: &SystemSnapshot, input: Self, recv_timestamp: NtpTimestamp, clock: &C, ) -> Self { Self { mode: NtpAssociationMode::Server, stratum: system.stratum, origin_timestamp: input.transmit_timestamp, receive_timestamp: recv_timestamp, reference_id: system.reference_id, poll: input.poll, precision: system.time_snapshot.precision.log2(), root_delay: system.time_snapshot.root_delay, root_dispersion: system.time_snapshot.root_dispersion, // Timestamp must be last to make it as accurate as possible. transmit_timestamp: clock.now().expect("Failed to read time"), leap: system.time_snapshot.leap_indicator, reference_timestamp: Default::default(), } } fn rate_limit_response(packet_from_client: Self) -> Self { Self { mode: NtpAssociationMode::Server, stratum: 0, // indicates a kiss code reference_id: ReferenceId::KISS_RATE, origin_timestamp: packet_from_client.transmit_timestamp, ..Self::new() } } fn deny_response(packet_from_client: Self) -> Self { Self { mode: NtpAssociationMode::Server, stratum: 0, // indicates a kiss code reference_id: ReferenceId::KISS_DENY, origin_timestamp: packet_from_client.transmit_timestamp, ..Self::new() } } fn nts_nak_response(packet_from_client: Self) -> Self { Self { mode: NtpAssociationMode::Server, stratum: 0, reference_id: ReferenceId::KISS_NTSN, origin_timestamp: packet_from_client.transmit_timestamp, ..Self::new() } } } impl<'a> NtpPacket<'a> { pub fn into_owned(self) -> NtpPacket<'static> { NtpPacket::<'static> { header: self.header, efdata: self.efdata.into_owned(), mac: self.mac.map(|v| v.into_owned()), } } #[allow(clippy::result_large_err)] pub fn deserialize( data: &'a [u8], cipher: &(impl CipherProvider + ?Sized), ) -> Result<(Self, Option), PacketParsingError<'a>> { if data.is_empty() { return Err(PacketParsingError::IncorrectLength); } let version = (data[0] & 0b0011_1000) >> 3; match version { 3 => { let (header, header_size) = NtpHeaderV3V4::deserialize(data).map_err(|e| e.generalize())?; let mac = if header_size != data.len() { Some(Mac::deserialize(&data[header_size..]).map_err(|e| e.generalize())?) } else { None }; Ok(( NtpPacket { header: NtpHeader::V3(header), efdata: ExtensionFieldData::default(), mac, }, None, )) } 4 => { let (header, header_size) = NtpHeaderV3V4::deserialize(data).map_err(|e| e.generalize())?; let construct_packet = |remaining_bytes: &'a [u8], efdata| { let mac = if !remaining_bytes.is_empty() { Some(Mac::deserialize(remaining_bytes)?) } else { None }; let packet = NtpPacket { header: NtpHeader::V4(header), efdata, mac, }; Ok::<_, ParsingError>(packet) }; match ExtensionFieldData::deserialize( data, header_size, cipher, ExtensionHeaderVersion::V4, ) { Ok(decoded) => { let packet = construct_packet(decoded.remaining_bytes, decoded.efdata) .map_err(|e| e.generalize())?; Ok((packet, decoded.cookie)) } Err(e) => { // return early if it is anything but a decrypt error let invalid = e.get_decrypt_error()?; let packet = construct_packet(invalid.remaining_bytes, invalid.efdata) .map_err(|e| e.generalize())?; Err(ParsingError::DecryptError(packet)) } } } #[cfg(feature = "ntpv5")] 5 => { let (header, header_size) = v5::NtpHeaderV5::deserialize(data).map_err(|e| e.generalize())?; let construct_packet = |remaining_bytes: &'a [u8], efdata| { let mac = if !remaining_bytes.is_empty() { Some(Mac::deserialize(remaining_bytes)?) } else { None }; let packet = NtpPacket { header: NtpHeader::V5(header), efdata, mac, }; Ok::<_, ParsingError>(packet) }; // TODO: Check extension field handling in V5 let res_packet = match ExtensionFieldData::deserialize( data, header_size, cipher, ExtensionHeaderVersion::V5, ) { Ok(decoded) => { let packet = construct_packet(decoded.remaining_bytes, decoded.efdata) .map_err(|e| e.generalize())?; Ok((packet, decoded.cookie)) } Err(e) => { // return early if it is anything but a decrypt error let invalid = e.get_decrypt_error()?; let packet = construct_packet(invalid.remaining_bytes, invalid.efdata) .map_err(|e| e.generalize())?; Err(ParsingError::DecryptError(packet)) } }; let (packet, cookie) = res_packet?; match packet.draft_id() { Some(id) if id == v5::DRAFT_VERSION => Ok((packet, cookie)), received @ (Some(_) | None) => { tracing::error!( expected = v5::DRAFT_VERSION, received, "Mismatched draft ID ignoring packet!" ); Err(ParsingError::V5(v5::V5Error::InvalidDraftIdentification)) } } } _ => Err(PacketParsingError::InvalidVersion(version)), } } #[cfg(test)] pub fn serialize_without_encryption_vec( &self, #[cfg_attr(not(feature = "ntpv5"), allow(unused_variables))] desired_size: Option, ) -> std::io::Result> { let mut buffer = vec![0u8; 1024]; let mut cursor = Cursor::new(buffer.as_mut_slice()); self.serialize(&mut cursor, &NoCipher, desired_size)?; let length = cursor.position() as usize; let buffer = cursor.into_inner()[..length].to_vec(); Ok(buffer) } pub fn serialize( &self, w: &mut Cursor<&mut [u8]>, cipher: &(impl CipherProvider + ?Sized), #[cfg_attr(not(feature = "ntpv5"), allow(unused_variables))] desired_size: Option, ) -> std::io::Result<()> { #[cfg(feature = "ntpv5")] let start = w.position(); match self.header { NtpHeader::V3(header) => header.serialize(&mut *w, 3)?, NtpHeader::V4(header) => header.serialize(&mut *w, 4)?, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.serialize(&mut *w)?, }; match self.header { NtpHeader::V3(_) => { /* No extension fields in V3 */ } NtpHeader::V4(_) => { self.efdata .serialize(&mut *w, cipher, ExtensionHeaderVersion::V4)? } #[cfg(feature = "ntpv5")] NtpHeader::V5(_) => { self.efdata .serialize(&mut *w, cipher, ExtensionHeaderVersion::V5)? } } if let Some(ref mac) = self.mac { mac.serialize(&mut *w)?; } #[cfg(feature = "ntpv5")] if let Some(desired_size) = desired_size { let written = (w.position() - start) as usize; if desired_size > written { ExtensionField::Padding(desired_size - written).serialize( w, 4, ExtensionHeaderVersion::V5, )?; } } Ok(()) } pub fn nts_poll_message( cookie: &'a [u8], new_cookies: u8, poll_interval: PollInterval, ) -> (NtpPacket<'static>, RequestIdentifier) { let (header, id) = NtpHeaderV3V4::poll_message(poll_interval); let identifier: [u8; 32] = rand::thread_rng().gen(); let mut authenticated = vec![ ExtensionField::UniqueIdentifier(identifier.to_vec().into()), ExtensionField::NtsCookie(cookie.to_vec().into()), ]; for _ in 1..new_cookies { authenticated.push(ExtensionField::NtsCookiePlaceholder { cookie_length: cookie.len() as u16, }); } ( NtpPacket { header: NtpHeader::V4(header), efdata: ExtensionFieldData { authenticated, encrypted: vec![], untrusted: vec![], }, mac: None, }, RequestIdentifier { uid: Some(identifier), ..id }, ) } #[cfg(feature = "ntpv5")] pub fn nts_poll_message_v5( cookie: &'a [u8], new_cookies: u8, poll_interval: PollInterval, ) -> (NtpPacket<'static>, RequestIdentifier) { let (header, id) = v5::NtpHeaderV5::poll_message(poll_interval); let identifier: [u8; 32] = rand::thread_rng().gen(); let mut authenticated = vec![ ExtensionField::UniqueIdentifier(identifier.to_vec().into()), ExtensionField::NtsCookie(cookie.to_vec().into()), ]; for _ in 1..new_cookies { authenticated.push(ExtensionField::NtsCookiePlaceholder { cookie_length: cookie.len() as u16, }); } let draft_id = ExtensionField::DraftIdentification(Cow::Borrowed(v5::DRAFT_VERSION)); authenticated.push(draft_id); ( NtpPacket { header: NtpHeader::V5(header), efdata: ExtensionFieldData { authenticated, encrypted: vec![], untrusted: vec![], }, mac: None, }, RequestIdentifier { uid: Some(identifier), ..id }, ) } pub fn poll_message(poll_interval: PollInterval) -> (Self, RequestIdentifier) { let (header, id) = NtpHeaderV3V4::poll_message(poll_interval); ( NtpPacket { header: NtpHeader::V4(header), efdata: Default::default(), mac: None, }, id, ) } #[cfg(feature = "ntpv5")] pub fn poll_message_upgrade_request(poll_interval: PollInterval) -> (Self, RequestIdentifier) { let (mut header, id) = NtpHeaderV3V4::poll_message(poll_interval); header.reference_timestamp = v5::UPGRADE_TIMESTAMP; ( NtpPacket { header: NtpHeader::V4(header), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: vec![], }, mac: None, }, id, ) } #[cfg(feature = "ntpv5")] pub fn poll_message_v5(poll_interval: PollInterval) -> (Self, RequestIdentifier) { let (header, id) = v5::NtpHeaderV5::poll_message(poll_interval); let draft_id = ExtensionField::DraftIdentification(Cow::Borrowed(v5::DRAFT_VERSION)); ( NtpPacket { header: NtpHeader::V5(header), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: vec![draft_id], }, mac: None, }, id, ) } #[cfg_attr(not(feature = "ntpv5"), allow(unused_mut))] pub fn timestamp_response( system: &SystemSnapshot, input: Self, recv_timestamp: NtpTimestamp, clock: &C, ) -> Self { match &input.header { NtpHeader::V3(header) => NtpPacket { header: NtpHeader::V3(NtpHeaderV3V4::timestamp_response( system, *header, recv_timestamp, clock, )), efdata: Default::default(), mac: None, }, NtpHeader::V4(header) => { let mut response_header = NtpHeaderV3V4::timestamp_response(system, *header, recv_timestamp, clock); #[cfg(feature = "ntpv5")] { // Respond with the upgrade timestamp (NTP5NTP5) iff the input had it and the packet // had the correct draft identification if header.reference_timestamp == v5::UPGRADE_TIMESTAMP { response_header.reference_timestamp = v5::UPGRADE_TIMESTAMP; }; } NtpPacket { header: NtpHeader::V4(response_header), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: input .efdata .untrusted .into_iter() .chain(input.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), }, mac: None, } } #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { // TODO deduplicate extension handling with V4 header: NtpHeader::V5(v5::NtpHeaderV5::timestamp_response( system, *header, recv_timestamp, clock, )), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: input .efdata .untrusted .into_iter() .chain(input.efdata.authenticated) .filter_map(|ef| match ef { uid @ ExtensionField::UniqueIdentifier(_) => Some(uid), ExtensionField::ReferenceIdRequest(req) => { let response = req.to_response(&system.bloom_filter)?; Some(ExtensionField::ReferenceIdResponse(response).into_owned()) } _ => None, }) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), }, mac: None, }, } } #[cfg(feature = "ntpv5")] fn draft_id(&self) -> Option<&'_ str> { self.efdata .untrusted .iter() .chain(self.efdata.authenticated.iter()) .find_map(|ef| match ef { ExtensionField::DraftIdentification(id) => Some(&**id), _ => None, }) } pub fn nts_timestamp_response( system: &SystemSnapshot, input: Self, recv_timestamp: NtpTimestamp, clock: &C, cookie: &DecodedServerCookie, keyset: &KeySet, ) -> Self { match input.header { NtpHeader::V3(_) => unreachable!("NTS shouldn't work with NTPv3"), NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::timestamp_response( system, header, recv_timestamp, clock, )), efdata: ExtensionFieldData { encrypted: input .efdata .authenticated .iter() .chain(input.efdata.encrypted.iter()) .filter_map(|f| match f { ExtensionField::NtsCookiePlaceholder { cookie_length } => { let new_cookie = keyset.encode_cookie(cookie); if new_cookie.len() > *cookie_length as usize { None } else { Some(ExtensionField::NtsCookie(Cow::Owned(new_cookie))) } } ExtensionField::NtsCookie(old_cookie) => { let new_cookie = keyset.encode_cookie(cookie); if new_cookie.len() > old_cookie.len() { None } else { Some(ExtensionField::NtsCookie(Cow::Owned(new_cookie))) } } _ => None, }) .collect(), authenticated: input .efdata .authenticated .into_iter() .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), // Ignore encrypted so as not to accidentally leak anything untrusted: vec![], }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::timestamp_response( system, header, recv_timestamp, clock, )), efdata: ExtensionFieldData { encrypted: input .efdata .authenticated .iter() .chain(input.efdata.encrypted.iter()) .filter_map(|f| match f { ExtensionField::NtsCookiePlaceholder { cookie_length } => { let new_cookie = keyset.encode_cookie(cookie); if new_cookie.len() > *cookie_length as usize { None } else { Some(ExtensionField::NtsCookie(Cow::Owned(new_cookie))) } } ExtensionField::NtsCookie(old_cookie) => { let new_cookie = keyset.encode_cookie(cookie); if new_cookie.len() > old_cookie.len() { None } else { Some(ExtensionField::NtsCookie(Cow::Owned(new_cookie))) } } _ => None, }) .collect(), authenticated: input .efdata .authenticated .into_iter() .filter_map(|ef| match ef { uid @ ExtensionField::UniqueIdentifier(_) => Some(uid), ExtensionField::ReferenceIdRequest(req) => { let response = req.to_response(&system.bloom_filter)?; Some(ExtensionField::ReferenceIdResponse(response).into_owned()) } _ => None, }) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), untrusted: vec![], }, mac: None, }, } } pub fn rate_limit_response(packet_from_client: Self) -> Self { match packet_from_client.header { NtpHeader::V3(header) => NtpPacket { header: NtpHeader::V3(NtpHeaderV3V4::rate_limit_response(header)), efdata: Default::default(), mac: None, }, NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::rate_limit_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::rate_limit_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), }, mac: None, }, } } pub fn nts_rate_limit_response(packet_from_client: Self) -> Self { match packet_from_client.header { NtpHeader::V3(_) => unreachable!("NTS shouldn't work with NTPv3"), NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::rate_limit_response(header)), efdata: ExtensionFieldData { authenticated: packet_from_client .efdata .authenticated .into_iter() .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), encrypted: vec![], untrusted: vec![], }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::rate_limit_response(header)), efdata: ExtensionFieldData { authenticated: packet_from_client .efdata .authenticated .into_iter() .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), encrypted: vec![], untrusted: vec![], }, mac: None, }, } } pub fn deny_response(packet_from_client: Self) -> Self { match packet_from_client.header { NtpHeader::V3(header) => NtpPacket { header: NtpHeader::V3(NtpHeaderV3V4::deny_response(header)), efdata: Default::default(), mac: None, }, NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::deny_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::deny_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], // Ignore encrypted so as not to accidentally leak anything untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), }, mac: None, }, } } pub fn nts_deny_response(packet_from_client: Self) -> Self { match packet_from_client.header { NtpHeader::V3(_) => unreachable!("NTS shouldn't work with NTPv3"), NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::deny_response(header)), efdata: ExtensionFieldData { authenticated: packet_from_client .efdata .authenticated .into_iter() .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), encrypted: vec![], untrusted: vec![], }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::deny_response(header)), efdata: ExtensionFieldData { authenticated: packet_from_client .efdata .authenticated .into_iter() .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), encrypted: vec![], untrusted: vec![], }, mac: None, }, } } pub fn nts_nak_response(packet_from_client: Self) -> Self { match packet_from_client.header { NtpHeader::V3(_) => unreachable!("NTS shouldn't work with NTPv3"), NtpHeader::V4(header) => NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::nts_nak_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .collect(), }, mac: None, }, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => NtpPacket { header: NtpHeader::V5(v5::NtpHeaderV5::nts_nak_response(header)), efdata: ExtensionFieldData { authenticated: vec![], encrypted: vec![], untrusted: packet_from_client .efdata .untrusted .into_iter() .chain(packet_from_client.efdata.authenticated) .filter(|ef| matches!(ef, ExtensionField::UniqueIdentifier(_))) .chain(std::iter::once(ExtensionField::DraftIdentification( Cow::Borrowed(v5::DRAFT_VERSION), ))) .collect(), }, mac: None, }, } } } impl<'a> NtpPacket<'a> { pub fn new_cookies<'b: 'a>(&'b self) -> impl Iterator> + 'b { self.efdata.encrypted.iter().filter_map(|ef| match ef { ExtensionField::NtsCookie(cookie) => Some(cookie.to_vec()), _ => None, }) } pub fn version(&self) -> u8 { match self.header { NtpHeader::V3(_) => 3, NtpHeader::V4(_) => 4, #[cfg(feature = "ntpv5")] NtpHeader::V5(_) => 5, } } pub fn header(&self) -> NtpHeader { self.header } pub fn leap(&self) -> NtpLeapIndicator { match self.header { NtpHeader::V3(header) => header.leap, NtpHeader::V4(header) => header.leap, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.leap, } } pub fn mode(&self) -> NtpAssociationMode { match self.header { NtpHeader::V3(header) => header.mode, NtpHeader::V4(header) => header.mode, // FIXME long term the return type should change to capture both mode types #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => match header.mode { v5::NtpMode::Request => NtpAssociationMode::Client, v5::NtpMode::Response => NtpAssociationMode::Server, }, } } pub fn poll(&self) -> PollInterval { match self.header { NtpHeader::V3(h) | NtpHeader::V4(h) => h.poll, #[cfg(feature = "ntpv5")] NtpHeader::V5(h) => h.poll, } } pub fn stratum(&self) -> u8 { match self.header { NtpHeader::V3(header) => header.stratum, NtpHeader::V4(header) => header.stratum, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.stratum, } } pub fn precision(&self) -> i8 { match self.header { NtpHeader::V3(header) => header.precision, NtpHeader::V4(header) => header.precision, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.precision, } } pub fn root_delay(&self) -> NtpDuration { match self.header { NtpHeader::V3(header) => header.root_delay, NtpHeader::V4(header) => header.root_delay, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.root_delay, } } pub fn root_dispersion(&self) -> NtpDuration { match self.header { NtpHeader::V3(header) => header.root_dispersion, NtpHeader::V4(header) => header.root_dispersion, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.root_dispersion, } } pub fn receive_timestamp(&self) -> NtpTimestamp { match self.header { NtpHeader::V3(header) => header.receive_timestamp, NtpHeader::V4(header) => header.receive_timestamp, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.receive_timestamp, } } pub fn transmit_timestamp(&self) -> NtpTimestamp { match self.header { NtpHeader::V3(header) => header.transmit_timestamp, NtpHeader::V4(header) => header.transmit_timestamp, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.transmit_timestamp, } } pub fn reference_id(&self) -> ReferenceId { match self.header { NtpHeader::V3(header) => header.reference_id, NtpHeader::V4(header) => header.reference_id, #[cfg(feature = "ntpv5")] // TODO NTPv5 does not have reference IDs so this should always be None for now NtpHeader::V5(_header) => ReferenceId::NONE, } } fn kiss_code(&self) -> ReferenceId { match self.header { NtpHeader::V3(header) => header.reference_id, NtpHeader::V4(header) => header.reference_id, #[cfg(feature = "ntpv5")] // Kiss code in ntpv5 is the first four bytes of the server cookie NtpHeader::V5(header) => { ReferenceId::from_bytes(header.server_cookie.0[..4].try_into().unwrap()) } } } pub fn is_kiss(&self) -> bool { match self.header { NtpHeader::V3(header) => header.stratum == 0, NtpHeader::V4(header) => header.stratum == 0, #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.stratum == 0, } } pub fn is_kiss_deny(&self) -> bool { self.is_kiss() && match self.header { NtpHeader::V3(_) | NtpHeader::V4(_) => self.kiss_code().is_deny(), #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.poll == PollInterval::NEVER, } } pub fn is_kiss_rate( &self, #[cfg_attr(not(feature = "ntpv5"), allow(unused))] own_interval: PollInterval, ) -> bool { self.is_kiss() && match self.header { NtpHeader::V3(_) | NtpHeader::V4(_) => self.kiss_code().is_rate(), #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => { header.poll > own_interval && header.poll != PollInterval::NEVER } } } pub fn is_kiss_rstr(&self) -> bool { self.is_kiss() && match self.header { NtpHeader::V3(_) | NtpHeader::V4(_) => self.kiss_code().is_rstr(), #[cfg(feature = "ntpv5")] NtpHeader::V5(_) => false, } } pub fn is_kiss_ntsn(&self) -> bool { self.is_kiss() && match self.header { NtpHeader::V3(_) | NtpHeader::V4(_) => self.kiss_code().is_ntsn(), #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => header.flags.authnak, } } #[cfg(feature = "ntpv5")] pub fn is_upgrade(&self) -> bool { matches!( self.header, NtpHeader::V4(NtpHeaderV3V4 { reference_timestamp: v5::UPGRADE_TIMESTAMP, .. }), ) } pub fn valid_server_response(&self, identifier: RequestIdentifier, nts_enabled: bool) -> bool { if let Some(uid) = identifier.uid { let auth = check_uid_extensionfield(self.efdata.authenticated.iter(), &uid); let encr = check_uid_extensionfield(self.efdata.encrypted.iter(), &uid); let untrusted = check_uid_extensionfield(self.efdata.untrusted.iter(), &uid); // we need at least one uid ef that matches, and none should contradict // our uid. Untrusted uids should only be considered on nts naks or // non-nts requests. let uid_ok = auth != Some(false) && encr != Some(false) && (untrusted != Some(false) || (nts_enabled && !self.is_kiss_ntsn())) && (auth.is_some() || encr.is_some() || ((!nts_enabled || self.is_kiss_ntsn()) && untrusted.is_some())); if !uid_ok { return false; } } match self.header { NtpHeader::V3(header) => { header.origin_timestamp == identifier.expected_origin_timestamp } NtpHeader::V4(header) => { header.origin_timestamp == identifier.expected_origin_timestamp } #[cfg(feature = "ntpv5")] NtpHeader::V5(header) => { header.client_cookie == v5::NtpClientCookie::from_ntp_timestamp(identifier.expected_origin_timestamp) } } } pub fn untrusted_extension_fields(&self) -> impl Iterator { self.efdata.untrusted.iter() } pub fn authenticated_extension_fields(&self) -> impl Iterator { self.efdata.authenticated.iter() } pub fn push_additional(&mut self, ef: ExtensionField<'static>) { if !self.efdata.authenticated.is_empty() || !self.efdata.encrypted.is_empty() { self.efdata.authenticated.push(ef); } else { self.efdata.untrusted.push(ef); } } } // Returns whether all uid extension fields found match the given uid, or // None if there were none. fn check_uid_extensionfield<'a, I: IntoIterator>>( iter: I, uid: &[u8], ) -> Option { let mut found_uid = false; for ef in iter { if let ExtensionField::UniqueIdentifier(pid) = ef { if pid.len() < uid.len() || &pid[0..uid.len()] != uid { return Some(false); } found_uid = true; } } if found_uid { Some(true) } else { None } } #[cfg(any(test, feature = "__internal-fuzz", feature = "__internal-test"))] impl NtpPacket<'_> { pub fn test() -> Self { Self::default() } pub fn set_mode(&mut self, mode: NtpAssociationMode) { match &mut self.header { NtpHeader::V3(ref mut header) => header.mode = mode, NtpHeader::V4(ref mut header) => header.mode = mode, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => { header.mode = match mode { NtpAssociationMode::Client => v5::NtpMode::Request, NtpAssociationMode::Server => v5::NtpMode::Response, _ => todo!("NTPv5 can only handle client-server"), } } } } pub fn set_origin_timestamp(&mut self, timestamp: NtpTimestamp) { match &mut self.header { NtpHeader::V3(ref mut header) => header.origin_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.origin_timestamp = timestamp, #[cfg(feature = "ntpv5")] // TODO can we just reuse the cookie as the origin timestamp? NtpHeader::V5(ref mut header) => { header.client_cookie = v5::NtpClientCookie::from_ntp_timestamp(timestamp) } } } pub fn set_transmit_timestamp(&mut self, timestamp: NtpTimestamp) { match &mut self.header { NtpHeader::V3(ref mut header) => header.transmit_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.transmit_timestamp = timestamp, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.transmit_timestamp = timestamp, } } pub fn set_receive_timestamp(&mut self, timestamp: NtpTimestamp) { match &mut self.header { NtpHeader::V3(ref mut header) => header.receive_timestamp = timestamp, NtpHeader::V4(ref mut header) => header.receive_timestamp = timestamp, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.receive_timestamp = timestamp, } } pub fn set_precision(&mut self, precision: i8) { match &mut self.header { NtpHeader::V3(ref mut header) => header.precision = precision, NtpHeader::V4(ref mut header) => header.precision = precision, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.precision = precision, } } pub fn set_leap(&mut self, leap: NtpLeapIndicator) { match &mut self.header { NtpHeader::V3(ref mut header) => header.leap = leap, NtpHeader::V4(ref mut header) => header.leap = leap, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.leap = leap, } } pub fn set_stratum(&mut self, stratum: u8) { match &mut self.header { NtpHeader::V3(ref mut header) => header.stratum = stratum, NtpHeader::V4(ref mut header) => header.stratum = stratum, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.stratum = stratum, } } pub fn set_reference_id(&mut self, reference_id: ReferenceId) { match &mut self.header { NtpHeader::V3(ref mut header) => header.reference_id = reference_id, NtpHeader::V4(ref mut header) => header.reference_id = reference_id, #[cfg(feature = "ntpv5")] NtpHeader::V5(_header) => todo!("NTPv5 does not have reference IDs"), } } pub fn set_root_delay(&mut self, root_delay: NtpDuration) { match &mut self.header { NtpHeader::V3(ref mut header) => header.root_delay = root_delay, NtpHeader::V4(ref mut header) => header.root_delay = root_delay, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.root_delay = root_delay, } } pub fn set_root_dispersion(&mut self, root_dispersion: NtpDuration) { match &mut self.header { NtpHeader::V3(ref mut header) => header.root_dispersion = root_dispersion, NtpHeader::V4(ref mut header) => header.root_dispersion = root_dispersion, #[cfg(feature = "ntpv5")] NtpHeader::V5(ref mut header) => header.root_dispersion = root_dispersion, } } } impl Default for NtpPacket<'_> { fn default() -> Self { Self { header: NtpHeader::V4(NtpHeaderV3V4::new()), efdata: Default::default(), mac: None, } } } #[cfg(test)] mod tests { use crate::{ keyset::KeySetProvider, nts_record::AeadAlgorithm, system::TimeSnapshot, time_types::PollIntervalLimits, }; use super::*; #[derive(Debug, Clone)] struct TestClock { now: NtpTimestamp, } impl NtpClock for TestClock { type Error = std::io::Error; fn now(&self) -> Result { Ok(self.now) } fn set_frequency(&self, _freq: f64) -> Result { panic!("Unexpected clock steer"); } fn get_frequency(&self) -> Result { Ok(0.0) } fn step_clock(&self, _offset: NtpDuration) -> Result { panic!("Unexpected clock steer"); } fn disable_ntp_algorithm(&self) -> Result<(), Self::Error> { panic!("Unexpected clock steer"); } fn error_estimate_update( &self, _est_error: NtpDuration, _max_error: NtpDuration, ) -> Result<(), Self::Error> { panic!("Unexpected clock steer"); } fn status_update(&self, _leap_status: NtpLeapIndicator) -> Result<(), Self::Error> { panic!("Unexpected clock steer"); } } #[test] fn roundtrip_bitrep_leap() { for i in 0..4u8 { let a = NtpLeapIndicator::from_bits(i); let b = a.to_bits(); let c = NtpLeapIndicator::from_bits(b); assert_eq!(i, b); assert_eq!(a, c); } } #[test] fn roundtrip_bitrep_mode() { for i in 0..8u8 { let a = NtpAssociationMode::from_bits(i); let b = a.to_bits(); let c = NtpAssociationMode::from_bits(b); assert_eq!(i, b); assert_eq!(a, c); } } #[test] fn test_captured_client() { let packet = b"\x23\x02\x06\xe8\x00\x00\x03\xff\x00\x00\x03\x7d\x5e\xc6\x9f\x0f\xe5\xf6\x62\x98\x7b\x61\xb9\xaf\xe5\xf6\x63\x66\x7b\x64\x99\x5d\xe5\xf6\x63\x66\x81\x40\x55\x90\xe5\xf6\x63\xa8\x76\x1d\xde\x48"; let reference = NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4 { leap: NtpLeapIndicator::NoWarning, mode: NtpAssociationMode::Client, stratum: 2, poll: PollInterval::from_byte(6), precision: -24, root_delay: NtpDuration::from_fixed_int(1023 << 16), root_dispersion: NtpDuration::from_fixed_int(893 << 16), reference_id: ReferenceId::from_int(0x5ec69f0f), reference_timestamp: NtpTimestamp::from_fixed_int(0xe5f662987b61b9af), origin_timestamp: NtpTimestamp::from_fixed_int(0xe5f663667b64995d), receive_timestamp: NtpTimestamp::from_fixed_int(0xe5f6636681405590), transmit_timestamp: NtpTimestamp::from_fixed_int(0xe5f663a8761dde48), }), efdata: Default::default(), mac: None, }; assert_eq!( reference, NtpPacket::deserialize(packet, &NoCipher).unwrap().0 ); match reference.serialize_without_encryption_vec(None) { Ok(buf) => assert_eq!(packet[..], buf[..]), Err(e) => panic!("{e:?}"), } let packet = b"\x1B\x02\x06\xe8\x00\x00\x03\xff\x00\x00\x03\x7d\x5e\xc6\x9f\x0f\xe5\xf6\x62\x98\x7b\x61\xb9\xaf\xe5\xf6\x63\x66\x7b\x64\x99\x5d\xe5\xf6\x63\x66\x81\x40\x55\x90\xe5\xf6\x63\xa8\x76\x1d\xde\x48"; let reference = NtpPacket { header: NtpHeader::V3(NtpHeaderV3V4 { leap: NtpLeapIndicator::NoWarning, mode: NtpAssociationMode::Client, stratum: 2, poll: PollInterval::from_byte(6), precision: -24, root_delay: NtpDuration::from_fixed_int(1023 << 16), root_dispersion: NtpDuration::from_fixed_int(893 << 16), reference_id: ReferenceId::from_int(0x5ec69f0f), reference_timestamp: NtpTimestamp::from_fixed_int(0xe5f662987b61b9af), origin_timestamp: NtpTimestamp::from_fixed_int(0xe5f663667b64995d), receive_timestamp: NtpTimestamp::from_fixed_int(0xe5f6636681405590), transmit_timestamp: NtpTimestamp::from_fixed_int(0xe5f663a8761dde48), }), efdata: Default::default(), mac: None, }; assert_eq!( reference, NtpPacket::deserialize(packet, &NoCipher).unwrap().0 ); match reference.serialize_without_encryption_vec(None) { Ok(buf) => assert_eq!(packet[..], buf[..]), Err(e) => panic!("{e:?}"), } } #[test] fn test_captured_server() { let packet = b"\x24\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; let reference = NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4 { leap: NtpLeapIndicator::NoWarning, mode: NtpAssociationMode::Server, stratum: 2, poll: PollInterval::from_byte(6), precision: -23, root_delay: NtpDuration::from_fixed_int(566 << 16), root_dispersion: NtpDuration::from_fixed_int(951 << 16), reference_id: ReferenceId::from_int(0xc035676c), reference_timestamp: NtpTimestamp::from_fixed_int(0xe5f661fd6f165f03), origin_timestamp: NtpTimestamp::from_fixed_int(0xe5f663a87619ef40), receive_timestamp: NtpTimestamp::from_fixed_int(0xe5f663a8798c6581), transmit_timestamp: NtpTimestamp::from_fixed_int(0xe5f663a8798eae2b), }), efdata: Default::default(), mac: None, }; assert_eq!( reference, NtpPacket::deserialize(packet, &NoCipher).unwrap().0 ); match reference.serialize_without_encryption_vec(None) { Ok(buf) => assert_eq!(packet[..], buf[..]), Err(e) => panic!("{e:?}"), } } #[test] fn test_version() { let packet = b"\x04\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x0B\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x14\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x34\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); let packet = b"\x3B\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); #[cfg(not(feature = "ntpv5"))] { // Version 5 packet should not parse without the ntpv5 feature let packet = b"\x2C\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b"; assert!(NtpPacket::deserialize(packet, &NoCipher).is_err()); } } #[test] fn test_packed_flags() { let base = b"\x24\x02\x06\xe9\x00\x00\x02\x36\x00\x00\x03\xb7\xc0\x35\x67\x6c\xe5\xf6\x61\xfd\x6f\x16\x5f\x03\xe5\xf6\x63\xa8\x76\x19\xef\x40\xe5\xf6\x63\xa8\x79\x8c\x65\x81\xe5\xf6\x63\xa8\x79\x8e\xae\x2b".to_owned(); let base_structured = NtpPacket::deserialize(&base, &NoCipher).unwrap().0; for leap_type in 0..3 { for mode in 0..8 { let mut header = base_structured.clone(); header.set_leap(NtpLeapIndicator::from_bits(leap_type)); header.set_mode(NtpAssociationMode::from_bits(mode)); let data = header.serialize_without_encryption_vec(None).unwrap(); let copy = NtpPacket::deserialize(&data, &NoCipher).unwrap().0; assert_eq!(header, copy); } } for i in 0..=0xFF { let mut packet = base; packet[0] = i; if let Ok((a, _)) = NtpPacket::deserialize(&packet, &NoCipher) { let b = a.serialize_without_encryption_vec(None).unwrap(); assert_eq!(packet[..], b[..]); } } } #[test] fn test_nts_roundtrip() { let cookie = [0; 16]; let (packet1, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let cipher = AesSivCmac512::new(std::array::from_fn::<_, 64, _>(|i| i as u8).into()); let mut buffer = [0u8; 2048]; let mut cursor = Cursor::new(buffer.as_mut()); packet1.serialize(&mut cursor, &cipher, None).unwrap(); let (packet2, _) = NtpPacket::deserialize(&cursor.get_ref()[..cursor.position() as usize], &cipher) .unwrap(); assert_eq!(packet1, packet2); } #[test] fn test_nts_captured_server() { let packet = b"\x24\x01\x04\xe8\x00\x00\x00\x00\x00\x00\x00\x60\x54\x4d\x4e\x4c\xe8\x49\x48\x92\xf9\x29\x57\x9e\x62\x87\xdb\x47\x3f\xf7\x5f\x58\xe8\x49\x48\xb2\xb6\x40\xd7\x01\xe8\x49\x48\xb2\xb6\x44\xbf\xf8\x01\x04\x00\x24\xe4\x83\x3a\x8d\x60\x0e\x13\x42\x43\x5c\xb2\x9d\xe5\x50\xac\xc0\xf8\xd8\xfa\x16\xe5\xc5\x37\x0a\x62\x0b\x15\x5f\x58\x6a\xda\xd6\x04\x04\x00\xd4\x00\x10\x00\xbc\x6a\x1d\xe3\xc2\x6e\x13\xeb\x10\xc7\x39\xd7\x0b\x84\x1f\xad\x1b\x86\xe2\x30\xc6\x3e\x9e\xa5\xf7\x1b\x62\xa8\xa7\x98\x81\xce\x7c\x6b\x17\xcb\x31\x32\x49\x0f\xde\xcf\x21\x10\x56\x4e\x36\x88\x92\xdd\xee\xf1\xf4\x23\xf6\x55\x53\x41\xc2\xc9\x17\x61\x20\xa5\x18\xdc\x1a\x7e\xdc\x5e\xe3\xc8\x3b\x05\x08\x7b\x73\x03\xf7\xab\x86\xd5\x2c\xc7\x49\x0c\xe8\x29\x39\x72\x23\xdc\xef\x2d\x94\xfa\xf8\xd7\x1d\x12\x80\xda\x03\x2d\xd7\x04\x69\xe9\xac\x5f\x82\xef\x57\x81\xd2\x07\xfb\xac\xb4\xa8\xb6\x31\x91\x14\xd5\xf5\x6f\xb2\x2a\x0c\xb6\xd7\xdc\xf7\x7d\xf0\x21\x46\xf6\x7e\x46\x01\xb5\x3b\x21\x7c\xa8\xac\x1a\x4d\x97\xd5\x9b\xce\xeb\x98\x33\x99\x7f\x10\x0e\xd4\x69\x85\x8b\xcd\x73\x52\x01\xad\xec\x38\xcf\x8c\xb2\xc6\xd0\x54\x1a\x97\x67\xdd\xb3\xea\x09\x1d\x63\xd9\x8d\x03\xdd\x6e\x48\x15\x3d\xc9\xb6\x1f\xe5\xd9\x1d\x74\xae\x35\x48"; let cipher = AesSivCmac512::new( [ 244, 6, 63, 13, 47, 226, 180, 25, 104, 212, 47, 14, 186, 70, 187, 93, 134, 140, 2, 82, 238, 254, 113, 79, 90, 31, 135, 138, 123, 210, 121, 47, 228, 208, 243, 76, 126, 213, 196, 233, 65, 15, 33, 163, 196, 30, 6, 197, 222, 105, 40, 14, 73, 138, 200, 45, 235, 127, 48, 248, 171, 8, 141, 180, ] .into(), ); assert!(NtpPacket::deserialize(packet, &cipher).is_ok()); } #[test] fn test_nts_captured_client() { let packet = b"\x23\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x62\x87\xdb\x47\x3f\xf7\x5f\x58\x01\x04\x00\x24\xe4\x83\x3a\x8d\x60\x0e\x13\x42\x43\x5c\xb2\x9d\xe5\x50\xac\xc0\xf8\xd8\xfa\x16\xe5\xc5\x37\x0a\x62\x0b\x15\x5f\x58\x6a\xda\xd6\x02\x04\x00\xac\x1c\xc4\x0a\x94\xda\x3f\x94\xa4\xd1\x2a\xc2\xd6\x09\xf1\x6f\x72\x11\x59\x6a\x0a\xce\xfc\x62\xd1\x1f\x28\x3a\xd1\x08\xd8\x01\xb5\x91\x38\x5d\x9b\xf5\x07\xf9\x0d\x21\x82\xe6\x81\x2a\x58\xa7\x35\xdc\x49\xc4\xd3\xe9\xb7\x9c\x72\xb7\xf6\x44\x64\xf8\xfc\x0d\xed\x25\xea\x1f\x7c\x9b\x31\x5c\xd8\x60\x86\xfd\x67\x74\x90\xf5\x0e\x61\xe6\x68\x0e\x29\x0d\x49\x77\x0c\xed\x44\xd4\x2f\x2d\x9b\xa8\x9f\x4d\x5d\xce\x4f\xdd\x57\x49\x51\x49\x5a\x1f\x38\xdb\xc7\xec\x1b\x86\x5b\xa5\x8f\x23\x1e\xdd\x76\xee\x1d\xaf\xdd\x66\xb2\xb2\x64\x1f\x03\xc6\x47\x9b\x42\x9c\x7f\xf6\x59\x6b\x82\x44\xcf\x67\xb5\xa2\xcd\x20\x9d\x39\xbb\xe6\x40\x2b\xf6\x20\x45\xdf\x95\x50\xf0\x38\x77\x06\x89\x79\x12\x18\x04\x04\x00\x28\x00\x10\x00\x10\xce\x89\xee\x97\x34\x42\xbc\x0f\x43\xaa\xce\x49\x99\xbd\xf5\x8e\x8f\xee\x7b\x1a\x2d\x58\xaf\x6d\xe9\xa2\x0e\x56\x1f\x7f\xf0\x6a"; let cipher = AesSivCmac512::new( [ 170, 111, 161, 118, 7, 200, 232, 128, 145, 250, 170, 186, 87, 143, 171, 252, 110, 241, 170, 179, 13, 150, 134, 147, 211, 248, 62, 207, 122, 155, 198, 109, 167, 15, 18, 118, 146, 63, 186, 146, 212, 188, 175, 27, 89, 3, 237, 212, 52, 113, 28, 21, 203, 200, 230, 17, 8, 186, 126, 1, 52, 230, 86, 40, ] .into(), ); assert!(NtpPacket::deserialize(packet, &cipher).is_ok()); } #[test] fn test_nts_poll_message() { let cookie = [0; 16]; let (packet1, ref1) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); assert_eq!(0, packet1.efdata.encrypted.len()); assert_eq!(0, packet1.efdata.untrusted.len()); let mut have_uid = false; let mut have_cookie = false; let mut nplaceholders = 0; for ef in packet1.efdata.authenticated { match ef { ExtensionField::UniqueIdentifier(uid) => { assert_eq!(ref1.uid.as_ref().unwrap(), uid.as_ref()); assert!(!have_uid); have_uid = true; } ExtensionField::NtsCookie(cookie_p) => { assert_eq!(&cookie, cookie_p.as_ref()); assert!(!have_cookie); have_cookie = true; } ExtensionField::NtsCookiePlaceholder { cookie_length } => { assert_eq!(cookie_length, cookie.len() as u16); nplaceholders += 1; } _ => unreachable!(), } } assert!(have_cookie); assert!(have_uid); assert_eq!(nplaceholders, 0); let (packet2, ref2) = NtpPacket::nts_poll_message(&cookie, 3, PollIntervalLimits::default().min); assert_ne!( ref1.expected_origin_timestamp, ref2.expected_origin_timestamp ); assert_ne!(ref1.uid, ref2.uid); assert_eq!(0, packet2.efdata.encrypted.len()); assert_eq!(0, packet2.efdata.untrusted.len()); let mut have_uid = false; let mut have_cookie = false; let mut nplaceholders = 0; for ef in packet2.efdata.authenticated { match ef { ExtensionField::UniqueIdentifier(uid) => { assert_eq!(ref2.uid.as_ref().unwrap(), uid.as_ref()); assert!(!have_uid); have_uid = true; } ExtensionField::NtsCookie(cookie_p) => { assert_eq!(&cookie, cookie_p.as_ref()); assert!(!have_cookie); have_cookie = true; } ExtensionField::NtsCookiePlaceholder { cookie_length } => { assert_eq!(cookie_length, cookie.len() as u16); nplaceholders += 1; } _ => unreachable!(), } } assert!(have_cookie); assert!(have_uid); assert_eq!(nplaceholders, 2); } #[test] fn test_nts_response_validation() { let cookie = [0; 16]; let (packet, id) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); let mut response = NtpPacket::timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(2), }, ); assert!(response.valid_server_response(id, false)); assert!(!response.valid_server_response(id, true)); response .efdata .untrusted .push(ExtensionField::UniqueIdentifier(Cow::Borrowed( id.uid.as_ref().unwrap(), ))); assert!(response.valid_server_response(id, false)); assert!(!response.valid_server_response(id, true)); response.efdata.untrusted.clear(); response .efdata .authenticated .push(ExtensionField::UniqueIdentifier(Cow::Borrowed( id.uid.as_ref().unwrap(), ))); assert!(response.valid_server_response(id, false)); assert!(response.valid_server_response(id, true)); response .efdata .untrusted .push(ExtensionField::UniqueIdentifier(Cow::Borrowed(&[]))); assert!(!response.valid_server_response(id, false)); assert!(response.valid_server_response(id, true)); response.efdata.untrusted.clear(); response .efdata .encrypted .push(ExtensionField::UniqueIdentifier(Cow::Borrowed(&[]))); assert!(!response.valid_server_response(id, false)); assert!(!response.valid_server_response(id, true)); } #[cfg(feature = "ntpv5")] #[test] fn v5_upgrade_packet() { let (packet, _) = NtpPacket::poll_message_upgrade_request(PollInterval::default()); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, ); let NtpHeader::V4(header) = response.header else { panic!("wrong version"); }; assert_eq!( header.reference_timestamp, NtpTimestamp::from_fixed_int(0x4E54503544524654) ); } #[test] fn test_timestamp_response() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let keysetprovider = KeySetProvider::new(1); let cookie = keysetprovider.get().encode_cookie(&decoded); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::timestamp_response( &SystemSnapshot { time_snapshot: TimeSnapshot { leap_indicator: NtpLeapIndicator::Leap59, ..Default::default() }, ..Default::default() }, packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, ); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!( response.receive_timestamp(), NtpTimestamp::from_fixed_int(0) ); assert_eq!( response.transmit_timestamp(), NtpTimestamp::from_fixed_int(1) ); assert_eq!(response.leap(), NtpLeapIndicator::Leap59); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let packet_id = packet .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, ); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!( response.receive_timestamp(), NtpTimestamp::from_fixed_int(0) ); assert_eq!( response.transmit_timestamp(), NtpTimestamp::from_fixed_int(1) ); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); let response_id = response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!( response.receive_timestamp(), NtpTimestamp::from_fixed_int(0) ); assert_eq!( response.transmit_timestamp(), NtpTimestamp::from_fixed_int(1) ); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); assert!(response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .is_none()); assert_eq!( response.receive_timestamp(), NtpTimestamp::from_fixed_int(0) ); assert_eq!( response.transmit_timestamp(), NtpTimestamp::from_fixed_int(1) ); } #[test] fn test_timestamp_cookies() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let keysetprovider = KeySetProvider::new(1); let cookie = keysetprovider.get().encode_cookie(&decoded); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); assert_eq!(response.new_cookies().count(), 1); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 2, PollIntervalLimits::default().min); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); assert_eq!(response.new_cookies().count(), 2); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 3, PollIntervalLimits::default().min); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); assert_eq!(response.new_cookies().count(), 3); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 4, PollIntervalLimits::default().min); let response = NtpPacket::nts_timestamp_response( &SystemSnapshot::default(), packet, NtpTimestamp::from_fixed_int(0), &TestClock { now: NtpTimestamp::from_fixed_int(1), }, &decoded, &keysetprovider.get(), ); assert_eq!(response.new_cookies().count(), 4); } #[test] fn test_deny_response() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let keysetprovider = KeySetProvider::new(1); let cookie = keysetprovider.get().encode_cookie(&decoded); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::deny_response(packet); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_deny()); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let response = NtpPacket::deny_response(packet); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_deny()); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::nts_deny_response(packet); let response_id = response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_deny()); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let response = NtpPacket::nts_deny_response(packet); assert!(response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .is_none()); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_deny()); } #[test] fn test_rate_response() { let decoded = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new((0..32_u8).collect())), c2s: Box::new(AesSivCmac256::new((32..64_u8).collect())), }; let keysetprovider = KeySetProvider::new(1); let cookie = keysetprovider.get().encode_cookie(&decoded); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::rate_limit_response(packet); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_rate(PollIntervalLimits::default().min)); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let response = NtpPacket::rate_limit_response(packet); let response_id = response .efdata .untrusted .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_rate(PollIntervalLimits::default().min)); let (packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); let packet_id = packet .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); let response = NtpPacket::nts_rate_limit_response(packet); let response_id = response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .unwrap(); assert_eq!(packet_id, response_id); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_rate(PollIntervalLimits::default().min)); let (mut packet, _) = NtpPacket::nts_poll_message(&cookie, 1, PollIntervalLimits::default().min); std::mem::swap( &mut packet.efdata.authenticated, &mut packet.efdata.untrusted, ); let response = NtpPacket::nts_rate_limit_response(packet); assert!(response .efdata .authenticated .iter() .find_map(|f| { if let ExtensionField::UniqueIdentifier(id) = f { Some(id.clone().into_owned()) } else { None } }) .is_none()); assert_eq!(response.new_cookies().count(), 0); assert!(response.is_kiss_rate(PollIntervalLimits::default().min)); } #[test] fn test_new_cookies_only_from_encrypted() { let allowed: [u8; 16] = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let disallowed: [u8; 16] = [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let packet = NtpPacket { header: NtpHeader::V4(NtpHeaderV3V4::poll_message(PollIntervalLimits::default().min).0), efdata: ExtensionFieldData { authenticated: vec![ExtensionField::NtsCookie(Cow::Borrowed(&disallowed))], encrypted: vec![ExtensionField::NtsCookie(Cow::Borrowed(&allowed))], untrusted: vec![ExtensionField::NtsCookie(Cow::Borrowed(&disallowed))], }, mac: None, }; assert_eq!(1, packet.new_cookies().count()); for cookie in packet.new_cookies() { assert_eq!(&cookie, &allowed); } } #[test] fn test_undersized_ef_in_encrypted_data() { let cipher = AesSivCmac256::new([0_u8; 32].into()); let packet = [ 35, 2, 6, 232, 0, 0, 3, 255, 0, 0, 3, 125, 94, 198, 159, 15, 229, 246, 98, 152, 123, 97, 185, 175, 229, 246, 99, 102, 123, 100, 153, 93, 229, 246, 99, 102, 129, 64, 85, 144, 229, 246, 99, 168, 118, 29, 222, 72, 4, 4, 0, 44, 0, 16, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 24, 181, 156, 166, 35, 154, 207, 38, 150, 15, 190, 152, 87, 142, 206, 254, 105, 0, 0, ]; //should not crash assert!(NtpPacket::deserialize(&packet, &cipher).is_err()); } #[test] fn test_undersized_ef() { let packet = [ 35, 2, 6, 232, 0, 0, 3, 255, 0, 0, 3, 125, 94, 198, 159, 15, 229, 246, 98, 152, 123, 97, 185, 175, 229, 246, 99, 102, 123, 100, 153, 93, 229, 246, 99, 102, 129, 64, 85, 144, 229, 246, 99, 168, 118, 29, 222, 72, 4, 4, ]; //should not crash assert!(NtpPacket::deserialize(&packet, &NoCipher).is_err()); } #[test] fn test_undersized_nonce() { let input = [ 32, 206, 206, 206, 77, 206, 206, 255, 216, 216, 216, 127, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 206, 217, 216, 216, 216, 216, 216, 216, 206, 206, 206, 1, 0, 0, 0, 206, 206, 206, 4, 44, 4, 4, 4, 4, 4, 4, 4, 0, 4, 206, 206, 222, 206, 206, 206, 206, 0, 0, 0, 206, 206, 206, 0, 0, 0, 206, 206, 206, 206, 206, 206, 131, 206, 206, ]; //should not crash assert!(NtpPacket::deserialize(&input, &NoCipher).is_err()); } #[test] fn test_undersized_encryption_ef() { let input = [ 32, 206, 206, 206, 77, 206, 216, 216, 127, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 206, 217, 216, 216, 216, 216, 216, 216, 206, 206, 206, 1, 0, 0, 0, 206, 206, 206, 4, 44, 4, 4, 4, 4, 4, 4, 4, 0, 4, 4, 0, 12, 206, 206, 222, 206, 206, 206, 206, 0, 0, 0, 12, 206, 206, 222, 206, 206, 206, 206, 206, 206, 206, 206, 131, 206, 206, ]; assert!(NtpPacket::deserialize(&input, &NoCipher).is_err()); } #[test] fn round_trip_with_ef() { let (mut p, _) = NtpPacket::poll_message(PollInterval::default()); p.efdata.untrusted.push(ExtensionField::Unknown { type_id: 0x42, data: vec![].into(), }); let serialized = p.serialize_without_encryption_vec(None).unwrap(); let (mut out, _) = NtpPacket::deserialize(&serialized, &NoCipher).unwrap(); // Strip any padding let ExtensionField::Unknown { data, .. } = &mut out.efdata.untrusted[0] else { panic!("wrong ef"); }; assert!(data.iter().all(|&e| e == 0)); *data = vec![].into(); assert_eq!(p, out); } #[cfg(feature = "ntpv5")] #[test] fn ef_with_missing_padding_v5() { let (packet, _) = NtpPacket::poll_message_v5(PollInterval::default()); let mut data = packet.serialize_without_encryption_vec(None).unwrap(); data.extend([ 0, 0, // Type = Unknown 0, 6, // Length = 5 1, 2, // Data // Missing 2 padding bytes ]); assert!(matches!( NtpPacket::deserialize(&data, &NoCipher), Err(ParsingError::IncorrectLength) )); } #[cfg(feature = "ntpv5")] #[test] fn padding_v5() { for i in 10..40 { let packet = NtpPacket::poll_message_v5(PollInterval::default()).0; let data = packet .serialize_without_encryption_vec(Some(4 * i)) .unwrap(); assert_eq!(data.len(), 76.max(i * 4)); assert!(NtpPacket::deserialize(&data, &NoCipher).is_ok()); } } } ntp-proto-1.4.0/src/packet/v5/error.rs000064400000000000000000000017261046102023000156650ustar 00000000000000use crate::packet::error::ParsingError; use std::fmt::{Display, Formatter}; #[derive(Debug)] pub enum V5Error { InvalidDraftIdentification, MalformedTimescale, MalformedMode, InvalidFlags, } impl V5Error { /// `const` alternative to `.into()` pub const fn into_parse_err(self) -> ParsingError { ParsingError::V5(self) } } impl Display for V5Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::InvalidDraftIdentification => f.write_str("Draft Identification invalid"), Self::MalformedTimescale => f.write_str("Malformed timescale"), Self::MalformedMode => f.write_str("Malformed mode"), Self::InvalidFlags => f.write_str("Invalid flags specified"), } } } impl From for crate::packet::error::ParsingError { fn from(value: V5Error) -> Self { Self::V5(value) } } ntp-proto-1.4.0/src/packet/v5/extension_fields.rs000064400000000000000000000146031046102023000200740ustar 00000000000000use crate::io::NonBlockingWrite; use crate::packet::error::ParsingError; use crate::packet::v5::server_reference_id::BloomFilter; use crate::packet::ExtensionField; use std::borrow::Cow; use std::convert::Infallible; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Type { DraftIdentification, Padding, Mac, ReferenceIdRequest, ReferenceIdResponse, ServerInformation, Correction, ReferenceTimestamp, MonotonicReceiveTimestamp, SecondaryReceiveTimestamp, Unknown(u16), } impl Type { pub const fn from_bits(bits: u16) -> Self { match bits { 0xF5FF => Self::DraftIdentification, 0xF501 => Self::Padding, 0xF502 => Self::Mac, 0xF503 => Self::ReferenceIdRequest, 0xF504 => Self::ReferenceIdResponse, 0xF505 => Self::ServerInformation, 0xF506 => Self::Correction, 0xF507 => Self::ReferenceTimestamp, 0xF508 => Self::MonotonicReceiveTimestamp, 0xF509 => Self::SecondaryReceiveTimestamp, other => Self::Unknown(other), } } pub const fn to_bits(self) -> u16 { match self { Self::DraftIdentification => 0xF5FF, Self::Padding => 0xF501, Self::Mac => 0xF502, Self::ReferenceIdRequest => 0xF503, Self::ReferenceIdResponse => 0xF504, Self::ServerInformation => 0xF505, Self::Correction => 0xF506, Self::ReferenceTimestamp => 0xF507, Self::MonotonicReceiveTimestamp => 0xF508, Self::SecondaryReceiveTimestamp => 0xF509, Self::Unknown(other) => other, } } #[cfg(test)] fn all_known() -> impl Iterator { [ Self::DraftIdentification, Self::Padding, Self::Mac, Self::ReferenceIdRequest, Self::ReferenceIdResponse, Self::ServerInformation, Self::Correction, Self::ReferenceTimestamp, Self::MonotonicReceiveTimestamp, Self::SecondaryReceiveTimestamp, ] .into_iter() } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ReferenceIdRequest { payload_len: u16, offset: u16, } impl ReferenceIdRequest { pub const fn new(payload_len: u16, offset: u16) -> Option { if payload_len % 4 != 0 { return None; } if payload_len + offset > 512 { return None; } Some(Self { payload_len, offset, }) } pub fn to_response(self, filter: &BloomFilter) -> Option { let offset = usize::from(self.offset); let payload_len = usize::from(self.payload_len); let bytes = filter.as_bytes().get(offset..)?.get(..payload_len)?.into(); Some(ReferenceIdResponse { bytes }) } pub fn serialize(&self, mut writer: impl NonBlockingWrite) -> std::io::Result<()> { let payload_len = self.payload_len; let ef_len: u16 = payload_len + 4; writer.write_all(&Type::ReferenceIdRequest.to_bits().to_be_bytes())?; writer.write_all(&ef_len.to_be_bytes())?; writer.write_all(&self.offset.to_be_bytes())?; writer.write_all(&[0; 2])?; let words = payload_len / 4; assert_eq!(payload_len % 4, 0); for _ in 1..words { writer.write_all(&[0; 4])?; } Ok(()) } pub fn decode(msg: &[u8]) -> Result> { let payload_len = u16::try_from(msg.len()).expect("NTP fields can not be longer than u16::MAX"); let offset_bytes: [u8; 2] = msg .get(0..2) .ok_or(ParsingError::IncorrectLength)? .try_into() .unwrap(); Ok(Self { payload_len, offset: u16::from_be_bytes(offset_bytes), }) } pub const fn offset(&self) -> u16 { self.offset } pub const fn payload_len(&self) -> u16 { self.payload_len } } #[derive(Debug, Clone, Eq, PartialEq)] pub struct ReferenceIdResponse<'a> { bytes: Cow<'a, [u8]>, } impl<'a> ReferenceIdResponse<'a> { pub const fn new(bytes: &'a [u8]) -> Option { if bytes.len() % 4 != 0 { return None; } if bytes.len() > 512 { return None; } Some(Self { bytes: Cow::Borrowed(bytes), }) } pub fn into_owned(self) -> ReferenceIdResponse<'static> { ReferenceIdResponse { bytes: Cow::Owned(self.bytes.into_owned()), } } pub fn serialize(&self, mut writer: impl NonBlockingWrite) -> std::io::Result<()> { let len: u16 = self.bytes.len().try_into().unwrap(); let len = len + 4; // Add room for type and length assert_eq!(len % 4, 0); writer.write_all(&Type::ReferenceIdResponse.to_bits().to_be_bytes())?; writer.write_all(&len.to_be_bytes())?; writer.write_all(self.bytes.as_ref())?; Ok(()) } pub const fn decode(bytes: &'a [u8]) -> Self { Self { bytes: Cow::Borrowed(bytes), } } pub fn bytes(&self) -> &[u8] { &self.bytes } } impl From for ExtensionField<'static> { fn from(value: ReferenceIdRequest) -> Self { Self::ReferenceIdRequest(value) } } impl<'a> From> for ExtensionField<'a> { fn from(value: ReferenceIdResponse<'a>) -> Self { Self::ReferenceIdResponse(value) } } #[cfg(test)] mod tests { use super::*; #[test] fn type_round_trip() { for i in 0..=u16::MAX { let ty = Type::from_bits(i); assert_eq!(i, ty.to_bits()); } for ty in Type::all_known() { let bits = ty.to_bits(); let ty2 = Type::from_bits(bits); assert_eq!(ty, ty2); let bits2 = ty2.to_bits(); assert_eq!(bits, bits2); } } #[test] fn test_reference_id_request_too_short() { assert!(matches!( ReferenceIdRequest::decode(&[]), Err(ParsingError::IncorrectLength) )); } #[test] fn test_reference_id_request_decode() { let res = ReferenceIdRequest::decode(&[0, 2, 0, 0, 0]).unwrap(); assert_eq!(res.payload_len, 5); assert_eq!(res.offset, 2); } } ntp-proto-1.4.0/src/packet/v5/mod.rs000064400000000000000000000431501046102023000153100ustar 00000000000000#![warn(clippy::missing_const_for_fn)] use crate::{ io::NonBlockingWrite, NtpClock, NtpDuration, NtpLeapIndicator, NtpTimestamp, PollInterval, SystemSnapshot, }; use rand::random; mod error; #[allow(dead_code)] pub mod extension_fields; pub mod server_reference_id; use crate::packet::error::ParsingError; pub use error::V5Error; use super::RequestIdentifier; #[allow(dead_code)] pub(crate) const DRAFT_VERSION: &str = "draft-ietf-ntp-ntpv5-02"; pub(crate) const UPGRADE_TIMESTAMP: NtpTimestamp = NtpTimestamp::from_bits(*b"NTP5DRFT"); #[repr(u8)] #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum NtpMode { Request = 3, Response = 4, } impl NtpMode { const fn from_bits(bits: u8) -> Result> { Ok(match bits { 3 => Self::Request, 4 => Self::Response, _ => return Err(V5Error::MalformedMode.into_parse_err()), }) } const fn to_bits(self) -> u8 { self as u8 } #[allow(dead_code)] pub(crate) const fn is_request(self) -> bool { matches!(self, Self::Request) } #[allow(dead_code)] pub(crate) const fn is_response(self) -> bool { matches!(self, Self::Response) } } #[repr(u8)] #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum NtpTimescale { Utc = 0, Tai = 1, Ut1 = 2, LeapSmearedUtc = 3, } impl NtpTimescale { const fn from_bits(bits: u8) -> Result> { Ok(match bits { 0 => Self::Utc, 1 => Self::Tai, 2 => Self::Ut1, 3 => Self::LeapSmearedUtc, _ => return Err(V5Error::MalformedTimescale.into_parse_err()), }) } const fn to_bits(self) -> u8 { self as u8 } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct NtpEra(pub u8); #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct NtpFlags { pub unknown_leap: bool, pub interleaved_mode: bool, pub authnak: bool, } impl NtpFlags { const fn from_bits(bits: [u8; 2]) -> Result> { if bits[0] != 0x00 || bits[1] & 0b1111_1000 != 0 { return Err(V5Error::InvalidFlags.into_parse_err()); } Ok(Self { unknown_leap: bits[1] & 0b01 != 0, interleaved_mode: bits[1] & 0b10 != 0, authnak: bits[1] & 0b100 != 0, }) } const fn as_bits(self) -> [u8; 2] { let mut flags: u8 = 0; if self.unknown_leap { flags |= 0b01; } if self.interleaved_mode { flags |= 0b10; } if self.authnak { flags |= 0b100; } [0x00, flags] } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct NtpServerCookie(pub [u8; 8]); impl NtpServerCookie { fn new_random() -> Self { // TODO does this match entropy handling of the rest of the system? Self(random()) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct NtpClientCookie(pub [u8; 8]); impl NtpClientCookie { fn new_random() -> Self { // TODO does this match entropy handling of the rest of the system? Self(random()) } pub const fn from_ntp_timestamp(ts: NtpTimestamp) -> Self { Self(ts.to_bits()) } pub const fn into_ntp_timestamp(self) -> NtpTimestamp { NtpTimestamp::from_bits(self.0) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct NtpHeaderV5 { pub leap: NtpLeapIndicator, pub mode: NtpMode, pub stratum: u8, pub poll: PollInterval, pub precision: i8, pub timescale: NtpTimescale, pub era: NtpEra, pub flags: NtpFlags, pub root_delay: NtpDuration, pub root_dispersion: NtpDuration, pub server_cookie: NtpServerCookie, pub client_cookie: NtpClientCookie, /// Time at the server when the request arrived from the client pub receive_timestamp: NtpTimestamp, /// Time at the server when the response left for the client pub transmit_timestamp: NtpTimestamp, } impl NtpHeaderV5 { fn new() -> Self { Self { leap: NtpLeapIndicator::NoWarning, mode: NtpMode::Request, stratum: 0, poll: PollInterval::from_byte(0), precision: 0, root_delay: NtpDuration::default(), root_dispersion: NtpDuration::default(), receive_timestamp: NtpTimestamp::default(), transmit_timestamp: NtpTimestamp::default(), timescale: NtpTimescale::Utc, era: NtpEra(0), flags: NtpFlags { unknown_leap: false, interleaved_mode: false, authnak: false, }, server_cookie: NtpServerCookie([0; 8]), client_cookie: NtpClientCookie([0; 8]), } } pub(crate) fn timestamp_response( system: &SystemSnapshot, input: Self, recv_timestamp: NtpTimestamp, clock: &C, ) -> Self { Self { leap: system.time_snapshot.leap_indicator, mode: NtpMode::Response, stratum: system.stratum, // TODO this changed in NTPv5 poll: input.poll, precision: system.time_snapshot.precision.log2(), // TODO this is new in NTPv5 timescale: NtpTimescale::Utc, // TODO this is new in NTPv5 era: NtpEra(0), // TODO this is new in NTPv5 flags: NtpFlags { unknown_leap: false, interleaved_mode: false, authnak: false, }, root_delay: system.time_snapshot.root_delay, root_dispersion: system.time_snapshot.root_dispersion, server_cookie: NtpServerCookie::new_random(), client_cookie: input.client_cookie, receive_timestamp: recv_timestamp, transmit_timestamp: clock.now().expect("Failed to read time"), } } fn kiss_response(packet_from_client: Self) -> Self { Self { mode: NtpMode::Response, flags: NtpFlags { unknown_leap: false, interleaved_mode: false, authnak: false, }, server_cookie: NtpServerCookie::new_random(), client_cookie: packet_from_client.client_cookie, stratum: 0, ..Self::new() } } pub(crate) fn rate_limit_response(packet_from_client: Self) -> Self { Self { poll: packet_from_client.poll.force_inc(), ..Self::kiss_response(packet_from_client) } } pub(crate) fn deny_response(packet_from_client: Self) -> Self { Self { poll: PollInterval::NEVER, ..Self::kiss_response(packet_from_client) } } pub(crate) fn nts_nak_response(packet_from_client: Self) -> Self { Self { flags: NtpFlags { unknown_leap: false, interleaved_mode: false, authnak: true, }, ..Self::kiss_response(packet_from_client) } } const WIRE_LENGTH: usize = 48; const VERSION: u8 = 5; pub(crate) fn deserialize( data: &[u8], ) -> Result<(Self, usize), ParsingError> { if data.len() < Self::WIRE_LENGTH { return Err(ParsingError::IncorrectLength); } let version = (data[0] >> 3) & 0b111; if version != 5 { return Err(ParsingError::InvalidVersion(version)); } Ok(( Self { leap: NtpLeapIndicator::from_bits((data[0] & 0xC0) >> 6), mode: NtpMode::from_bits(data[0] & 0x07)?, stratum: data[1], poll: PollInterval::from_byte(data[2]), precision: data[3] as i8, timescale: NtpTimescale::from_bits(data[4])?, era: NtpEra(data[5]), flags: NtpFlags::from_bits(data[6..8].try_into().unwrap())?, root_delay: NtpDuration::from_bits_time32(data[8..12].try_into().unwrap()), root_dispersion: NtpDuration::from_bits_time32(data[12..16].try_into().unwrap()), server_cookie: NtpServerCookie(data[16..24].try_into().unwrap()), client_cookie: NtpClientCookie(data[24..32].try_into().unwrap()), receive_timestamp: NtpTimestamp::from_bits(data[32..40].try_into().unwrap()), transmit_timestamp: NtpTimestamp::from_bits(data[40..48].try_into().unwrap()), }, Self::WIRE_LENGTH, )) } #[allow(dead_code)] pub(crate) fn serialize(&self, mut w: impl NonBlockingWrite) -> std::io::Result<()> { w.write_all(&[(self.leap.to_bits() << 6) | (Self::VERSION << 3) | self.mode.to_bits()])?; w.write_all(&[self.stratum, self.poll.as_byte(), self.precision as u8])?; w.write_all(&[self.timescale.to_bits()])?; w.write_all(&[self.era.0])?; w.write_all(&self.flags.as_bits())?; w.write_all(&self.root_delay.to_bits_time32())?; w.write_all(&self.root_dispersion.to_bits_time32())?; w.write_all(&self.server_cookie.0)?; w.write_all(&self.client_cookie.0)?; w.write_all(&self.receive_timestamp.to_bits())?; w.write_all(&self.transmit_timestamp.to_bits())?; Ok(()) } pub fn poll_message(poll_interval: PollInterval) -> (Self, RequestIdentifier) { let mut packet = Self::new(); packet.poll = poll_interval; packet.mode = NtpMode::Request; let client_cookie = NtpClientCookie::new_random(); packet.client_cookie = client_cookie; ( packet, RequestIdentifier { expected_origin_timestamp: client_cookie.into_ntp_timestamp(), uid: None, }, ) } } #[cfg(test)] mod tests { use super::*; use std::io::Cursor; #[test] fn round_trip_timescale() { for i in 0..=u8::MAX { if let Ok(ts) = NtpTimescale::from_bits(i) { assert_eq!(ts as u8, i); } } } #[test] fn flags() { let flags = NtpFlags::from_bits([0x00, 0x00]).unwrap(); assert!(!flags.unknown_leap); assert!(!flags.interleaved_mode); let flags = NtpFlags::from_bits([0x00, 0x01]).unwrap(); assert!(flags.unknown_leap); assert!(!flags.interleaved_mode); let flags = NtpFlags::from_bits([0x00, 0x02]).unwrap(); assert!(!flags.unknown_leap); assert!(flags.interleaved_mode); let flags = NtpFlags::from_bits([0x00, 0x03]).unwrap(); assert!(flags.unknown_leap); assert!(flags.interleaved_mode); let result = NtpFlags::from_bits([0xFF, 0xFF]); assert!(matches!( result, Err(ParsingError::V5(V5Error::InvalidFlags)) )); } #[test] fn parse_request() { #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields #[rustfmt::skip] let data = [ // LI VN Mode 0b_00_101_011, // Stratum 0x00, // Poll 0x05, // Precision 0x00, // Timescale (0: UTC, 1: TAI, 2: UT1, 3: Leap-smeared UTC) 0x02, // Era 0x00, // Flags 0x00, 0b0000_00_1_0, // Root Delay 0x00, 0x00, 0x00, 0x00, // Root Dispersion 0x00, 0x00, 0x00, 0x00, // Server Cookie 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Client Cookie 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, // Receive Timestamp 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Transmit Timestamp 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let (parsed, len) = NtpHeaderV5::deserialize(&data).unwrap(); assert_eq!(len, 48); assert_eq!(parsed.leap, NtpLeapIndicator::NoWarning); assert!(parsed.mode.is_request()); assert_eq!(parsed.stratum, 0); assert_eq!(parsed.poll, PollInterval::from_byte(5)); assert_eq!(parsed.precision, 0); assert_eq!(parsed.timescale, NtpTimescale::Ut1); assert_eq!(parsed.era, NtpEra(0)); assert!(parsed.flags.interleaved_mode); assert!(!parsed.flags.unknown_leap); assert!(parsed.flags.interleaved_mode); assert_eq!(parsed.root_delay, NtpDuration::from_seconds(0.0)); assert_eq!(parsed.root_dispersion, NtpDuration::from_seconds(0.0)); assert_eq!(parsed.server_cookie, NtpServerCookie([0x0; 8])); assert_eq!( parsed.client_cookie, NtpClientCookie([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]) ); assert_eq!(parsed.receive_timestamp, NtpTimestamp::from_fixed_int(0x0)); assert_eq!(parsed.transmit_timestamp, NtpTimestamp::from_fixed_int(0x0)); let mut buffer: [u8; 48] = [0u8; 48]; let cursor = Cursor::new(buffer.as_mut_slice()); parsed.serialize(cursor).unwrap(); assert_eq!(data, buffer); } #[test] fn parse_response() { #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields #[rustfmt::skip] let data = [ // LI VN Mode 0b_00_101_100, // Stratum 0x04, // Poll 0x05, // Precision 0x06, // Timescale (0: UTC, 1: TAI, 2: UT1, 3: Leap-smeared UTC) 0x01, // Era 0x07, // Flags 0x00, 0b0000_00_1_0, // Root Delay 0x10, 0x00, 0x00, 0x00, // Root Dispersion 0x20, 0x00, 0x00, 0x00, // Server Cookie 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // Client Cookie 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, // Receive Timestamp 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, // Transmit Timestamp 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, ]; let (parsed, len) = NtpHeaderV5::deserialize(&data).unwrap(); assert_eq!(len, 48); assert_eq!(parsed.leap, NtpLeapIndicator::NoWarning); assert!(parsed.mode.is_response()); assert_eq!(parsed.stratum, 4); assert_eq!(parsed.poll, PollInterval::from_byte(5)); assert_eq!(parsed.precision, 6); assert_eq!(parsed.timescale, NtpTimescale::Tai); assert_eq!(parsed.era, NtpEra(7)); assert!(parsed.flags.interleaved_mode); assert!(!parsed.flags.unknown_leap); assert!(parsed.flags.interleaved_mode); assert_eq!(parsed.root_delay, NtpDuration::from_seconds(1.0)); assert_eq!(parsed.root_dispersion, NtpDuration::from_seconds(2.0)); assert_eq!( parsed.server_cookie, NtpServerCookie([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]) ); assert_eq!( parsed.client_cookie, NtpClientCookie([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]) ); assert_eq!( parsed.receive_timestamp, NtpTimestamp::from_fixed_int(0x1111111111111111) ); assert_eq!( parsed.transmit_timestamp, NtpTimestamp::from_fixed_int(0x2222222222222222) ); let mut buffer: [u8; 48] = [0u8; 48]; let cursor = Cursor::new(buffer.as_mut_slice()); parsed.serialize(cursor).unwrap(); assert_eq!(data, buffer); } #[test] fn test_encode_decode_roundtrip() { for i in 0..=u8::MAX { let header = NtpHeaderV5 { leap: NtpLeapIndicator::from_bits(i % 4), mode: NtpMode::from_bits(3 + (i % 2)).unwrap(), stratum: i.wrapping_add(1), poll: PollInterval::from_byte(i.wrapping_add(3)), precision: i.wrapping_add(4) as i8, timescale: NtpTimescale::from_bits(i % 4).unwrap(), era: NtpEra(i.wrapping_add(6)), flags: NtpFlags { unknown_leap: i % 3 == 0, interleaved_mode: i % 4 == 0, authnak: i % 5 == 0, }, root_delay: NtpDuration::from_bits_time32([i; 4]), root_dispersion: NtpDuration::from_bits_time32([i.wrapping_add(1); 4]), server_cookie: NtpServerCookie([i.wrapping_add(2); 8]), client_cookie: NtpClientCookie([i.wrapping_add(3); 8]), receive_timestamp: NtpTimestamp::from_bits([i.wrapping_add(4); 8]), transmit_timestamp: NtpTimestamp::from_bits([i.wrapping_add(5); 8]), }; let mut buffer: [u8; 48] = [0u8; 48]; let mut cursor = Cursor::new(buffer.as_mut_slice()); header.serialize(&mut cursor).unwrap(); let (parsed, _) = NtpHeaderV5::deserialize(&buffer).unwrap(); assert_eq!(header, parsed); } } #[test] fn fail_on_incorrect_length() { let data: [u8; 47] = [0u8; 47]; assert!(matches!( NtpHeaderV5::deserialize(&data), Err(ParsingError::IncorrectLength) )); } #[test] #[allow(clippy::unusual_byte_groupings)] // Bits are grouped by fields fn fail_on_incorrect_version() { let mut data: [u8; 48] = [0u8; 48]; data[0] = 0b_00_111_100; assert!(matches!( NtpHeaderV5::deserialize(&data), Err(ParsingError::InvalidVersion(7)) )); } } ntp-proto-1.4.0/src/packet/v5/server_reference_id.rs000064400000000000000000000256751046102023000205450ustar 00000000000000use crate::packet::v5::extension_fields::{ReferenceIdRequest, ReferenceIdResponse}; use crate::packet::v5::NtpClientCookie; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; use std::array::from_fn; use std::fmt::{Debug, Formatter}; #[derive(Copy, Clone, Debug)] struct U12(u16); impl U12 { pub const MAX: Self = Self(4095); /// For an array of bytes calculate the index at which a bit would live as well as a mask where the /// corresponding bit in that byte would be set const fn byte_and_mask(self) -> (usize, u8) { (self.0 as usize / 8, 1 << (self.0 % 8)) } } impl Distribution for Standard { fn sample(&self, rng: &mut R) -> U12 { U12(rng.gen_range(0..4096)) } } impl From for u16 { fn from(value: U12) -> Self { value.0 } } impl TryFrom for U12 { type Error = (); fn try_from(value: u16) -> Result { if value > Self::MAX.into() { Err(()) } else { Ok(Self(value)) } } } #[derive(Debug, Copy, Clone)] pub struct ServerId([U12; 10]); impl ServerId { /// Generate a new random `ServerId` pub fn new(rng: &mut impl Rng) -> Self { // FIXME: sort IDs so we access the filters predictably // FIXME: check for double rolls to reduce false positive rate Self(from_fn(|_| rng.gen())) } } impl Default for ServerId { fn default() -> Self { Self::new(&mut thread_rng()) } } #[derive(Copy, Clone, Eq, PartialEq)] pub struct BloomFilter([u8; Self::BYTES]); impl BloomFilter { pub const BYTES: usize = 512; pub const fn new() -> Self { Self([0; Self::BYTES]) } pub fn contains_id(&self, other: &ServerId) -> bool { other.0.iter().all(|idx| self.is_set(*idx)) } pub fn add_id(&mut self, id: &ServerId) { for idx in id.0 { self.set_bit(idx); } } pub fn add(&mut self, other: &BloomFilter) { for (ours, theirs) in self.0.iter_mut().zip(other.0.iter()) { *ours |= theirs; } } pub fn union<'a>(others: impl Iterator) -> Self { let mut union = Self::new(); for other in others { union.add(other); } union } pub fn count_ones(&self) -> u16 { self.0.iter().map(|b| b.count_ones() as u16).sum() } pub const fn as_bytes(&self) -> &[u8; Self::BYTES] { &self.0 } fn set_bit(&mut self, idx: U12) { let (idx, mask) = idx.byte_and_mask(); self.0[idx] |= mask; } const fn is_set(&self, idx: U12) -> bool { let (idx, mask) = idx.byte_and_mask(); self.0[idx] & mask != 0 } } impl<'a> FromIterator<&'a BloomFilter> for BloomFilter { fn from_iter>(iter: T) -> Self { Self::union(iter.into_iter()) } } impl Default for BloomFilter { fn default() -> Self { Self::new() } } impl Debug for BloomFilter { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let str: String = self .0 .chunks_exact(32) .map(|chunk| chunk.iter().fold(0, |acc, b| acc | b)) .map(|b| char::from_u32(0x2800 + b as u32).unwrap()) .collect(); f.debug_tuple("BloomFilter").field(&str).finish() } } pub struct RemoteBloomFilter { filter: BloomFilter, chunk_size: u16, last_requested: Option<(u16, NtpClientCookie)>, next_to_request: u16, is_filled: bool, } impl RemoteBloomFilter { /// Create a new `BloomFilter` that can poll chunks from the server /// /// `chunk_size` has to be: /// * divisible by 4 /// * divide 512 without remainder /// * between `4..=512` pub const fn new(chunk_size: u16) -> Option { if chunk_size % 4 != 0 { return None; } if chunk_size == 0 || chunk_size > 512 { return None; } if 512 % chunk_size != 0 { return None; } Some(Self { filter: BloomFilter::new(), chunk_size, last_requested: None, next_to_request: 0, is_filled: false, }) } /// Returns the fully fetched filter or None if not all chunks were received yet pub fn full_filter(&self) -> Option<&BloomFilter> { self.is_filled.then_some(&self.filter) } pub fn next_request(&mut self, cookie: NtpClientCookie) -> ReferenceIdRequest { let offset = self.next_to_request; let last_request = self.last_requested.replace((offset, cookie)); if let Some(_last_request) = last_request { // TODO log something about never got a response } ReferenceIdRequest::new(self.chunk_size, offset) .expect("We ensure that our request always falls within the BloomFilter") } pub fn handle_response( &mut self, cookie: NtpClientCookie, response: &ReferenceIdResponse, ) -> Result<(), ResponseHandlingError> { let Some((offset, expected_cookie)) = self.last_requested else { return Err(ResponseHandlingError::NotAwaitingResponse); }; if cookie != expected_cookie { return Err(ResponseHandlingError::MismatchedCookie); } if response.bytes().len() != self.chunk_size as usize { return Err(ResponseHandlingError::MismatchedLength); } self.filter.0[(offset as usize)..][..(self.chunk_size as usize)] .copy_from_slice(response.bytes()); self.advance_next_to_request(); self.last_requested = None; Ok(()) } fn advance_next_to_request(&mut self) { self.next_to_request = (self.next_to_request + self.chunk_size) % BloomFilter::BYTES as u16; if self.next_to_request == 0 { // We made the round at least once... so we must be fully filled self.is_filled = true; } } } impl Debug for RemoteBloomFilter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RemoteBloomFilter") .field("chunk_size", &self.chunk_size) .field("last_requested", &self.last_requested) .field("next_to_request", &self.next_to_request) .field("is_filled", &self.is_filled) .finish() } } #[derive(Debug, Copy, Clone)] pub enum ResponseHandlingError { NotAwaitingResponse, MismatchedCookie, MismatchedLength, } #[cfg(test)] mod tests { use super::*; use rand::thread_rng; #[test] fn set_bits() { let mut rid = BloomFilter::new(); assert!(rid.0.iter().all(|x| x == &0)); assert!((0..4096).all(|idx| !rid.is_set(U12(idx)))); assert_eq!(rid.count_ones(), 0); rid.set_bit(U12(0)); assert_eq!(rid.count_ones(), 1); assert!(rid.is_set(U12(0))); assert_eq!(rid.0[0], 1); rid.set_bit(U12(4)); assert_eq!(rid.count_ones(), 2); assert!(rid.is_set(U12(4))); assert_eq!(rid.0[0], 0b0001_0001); rid.set_bit(U12::MAX); assert_eq!(rid.count_ones(), 3); assert!(rid.is_set(U12::MAX)); assert_eq!(rid.0[511], 0b1000_0000); } #[test] fn set_contains() { let mut rng = thread_rng(); let mut filter = BloomFilter::new(); let id = ServerId::new(&mut rng); assert!(!filter.contains_id(&id)); filter.add_id(&id); assert!(filter.contains_id(&id)); for _ in 0..128 { let rid = ServerId::new(&mut rng); filter.add_id(&rid); assert!(filter.contains_id(&rid)); } } #[test] fn set_collect() { let mut rng = thread_rng(); let mut ids = vec![]; let mut filters = vec![]; for _ in 0..10 { let id = ServerId::new(&mut rng); let mut filter = BloomFilter::new(); filter.add_id(&id); ids.push(id); filters.push(filter); } let set: BloomFilter = filters.iter().collect(); for rid in &ids { assert!(set.contains_id(rid)); } } #[test] fn requesting() { use ResponseHandlingError::{MismatchedCookie, MismatchedLength, NotAwaitingResponse}; let chunk_size = 16; let mut bf = RemoteBloomFilter::new(chunk_size).unwrap(); assert!(matches!( bf.handle_response( NtpClientCookie::new_random(), &ReferenceIdResponse::new(&[0u8; 16]).unwrap() ), Err(NotAwaitingResponse) )); let cookie = NtpClientCookie::new_random(); let req = bf.next_request(cookie); assert_eq!(req.offset(), 0); assert_eq!(req.payload_len(), chunk_size); assert!(matches!( bf.handle_response(cookie, &ReferenceIdResponse::new(&[0; 24]).unwrap()), Err(MismatchedLength) )); let mut wrong_cookie = cookie; wrong_cookie.0[0] ^= 0xFF; // Flip all bits in first byte assert!(matches!( bf.handle_response(wrong_cookie, &ReferenceIdResponse::new(&[0; 16]).unwrap()), Err(MismatchedCookie) )); bf.handle_response(cookie, &ReferenceIdResponse::new(&[1; 16]).unwrap()) .unwrap(); assert_eq!(bf.next_to_request, 16); assert_eq!(bf.last_requested, None); assert!(!bf.is_filled); assert!(bf.full_filter().is_none()); assert_eq!(&bf.filter.0[..16], &[1; 16]); assert_eq!(&bf.filter.0[16..], &[0; 512 - 16]); for chunk in 1..(512 / chunk_size) { let cookie = NtpClientCookie::new_random(); let req = bf.next_request(cookie); assert_eq!(req.offset(), chunk * chunk_size); assert!(bf.full_filter().is_none()); let bytes: Vec<_> = (0..req.payload_len()).map(|_| chunk as u8 + 1).collect(); let response = ReferenceIdResponse::new(&bytes).unwrap(); bf.handle_response(cookie, &response).unwrap(); } assert_eq!(bf.next_to_request, 0); assert!(bf.full_filter().is_some()); } #[test] fn works_with_any_chunk_size() { let mut target_filter = BloomFilter::new(); for _ in 0..16 { target_filter.add_id(&ServerId::new(&mut thread_rng())); } for chunk_size in 0..=512 { let Some(mut bf) = RemoteBloomFilter::new(chunk_size) else { continue; }; for _chunk in 0..((512 / chunk_size) + 1) { let cookie = NtpClientCookie::new_random(); let request = bf.next_request(cookie); let response = request.to_response(&target_filter).unwrap(); bf.handle_response(cookie, &response).unwrap(); } let result_filter = bf.full_filter().unwrap(); assert_eq!(&target_filter, result_filter); } } } ntp-proto-1.4.0/src/server.rs000064400000000000000000001366351046102023000142510ustar 00000000000000use std::{ collections::hash_map::RandomState, fmt::Display, io::Cursor, net::{AddrParseError, IpAddr}, sync::Arc, time::{Duration, Instant}, }; use serde::{de, Deserialize, Deserializer}; use crate::{ ipfilter::IpFilter, KeySet, NoCipher, NtpClock, NtpPacket, NtpTimestamp, PacketParsingError, SystemSnapshot, }; pub enum ServerAction<'a> { Ignore, Respond { message: &'a [u8] }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ServerReason { /// Rate limit mechanism kicked in RateLimit, /// Packet could not be parsed because it was malformed in some way ParseError, /// Packet could be parsed but the cryptography was invalid InvalidCrypto, /// Internal error in the server InternalError, /// Configuration was used to decide response Policy, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ServerResponse { /// NTS was invalid (failure to decrypt etc) NTSNak, /// Sent a deny response to client Deny, /// Only for a conscious choice to not respond, error conditions are separate Ignore, /// Accepted packet and provided time to requestor ProvideTime, } pub trait ServerStatHandler { /// Called by the server handle once per packet fn register(&mut self, version: u8, nts: bool, reason: ServerReason, response: ServerResponse); } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize)] #[serde(rename_all = "lowercase")] pub enum FilterAction { Ignore, Deny, } impl From for ServerResponse { fn from(value: FilterAction) -> Self { match value { FilterAction::Ignore => ServerResponse::Ignore, FilterAction::Deny => ServerResponse::Deny, } } } #[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)] pub struct FilterList { pub filter: Vec, pub action: FilterAction, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ServerConfig { pub denylist: FilterList, pub allowlist: FilterList, pub rate_limiting_cache_size: usize, pub rate_limiting_cutoff: Duration, pub require_nts: Option, } pub struct Server { config: ServerConfig, clock: C, denyfilter: IpFilter, allowfilter: IpFilter, client_cache: TimestampedCache, system: SystemSnapshot, keyset: Arc, } // Quick estimation of ntp packet message version without doing full parsing fn fallback_message_version(message: &[u8]) -> u8 { message.first().map(|v| (v & 0b0011_1000) >> 3).unwrap_or(0) } impl Server { /// Create a new server pub fn new( config: ServerConfig, clock: C, system: SystemSnapshot, keyset: Arc, ) -> Self { let denyfilter = IpFilter::new(&config.denylist.filter); let allowfilter = IpFilter::new(&config.allowlist.filter); let client_cache = TimestampedCache::new(config.rate_limiting_cache_size); Self { config, clock, denyfilter, allowfilter, client_cache, system, keyset, } } /// Update the [`ServerConfig`] of the server pub fn update_config(&mut self, config: ServerConfig) { if self.config.denylist.filter != config.denylist.filter { self.denyfilter = IpFilter::new(&config.denylist.filter); } if self.config.allowlist.filter != config.allowlist.filter { self.allowfilter = IpFilter::new(&config.allowlist.filter); } if self.config.rate_limiting_cache_size != config.rate_limiting_cache_size { self.client_cache = TimestampedCache::new(config.rate_limiting_cache_size); } self.config = config; } /// Provide the server with the latest [`SystemSnapshot`] pub fn update_system(&mut self, system: SystemSnapshot) { self.system = system; } /// Provide the server with a new [`KeySet`] pub fn update_keyset(&mut self, keyset: Arc) { self.keyset = keyset; } fn intended_action(&mut self, client_ip: IpAddr) -> (ServerResponse, ServerReason) { if self.denyfilter.is_in(&client_ip) { // First apply denylist (self.config.denylist.action.into(), ServerReason::Policy) } else if !self.allowfilter.is_in(&client_ip) { // Then allowlist (self.config.allowlist.action.into(), ServerReason::Policy) } else if !self.client_cache.is_allowed( client_ip, Instant::now(), self.config.rate_limiting_cutoff, ) { // Then ratelimit (ServerResponse::Ignore, ServerReason::RateLimit) } else { // Then accept (ServerResponse::ProvideTime, ServerReason::Policy) } } } impl Server { /// Handle a packet sent to the server /// /// If the buffer isn't large enough to encode the reply, this /// will log an error and ignore the incoming packet. A buffer /// as large as the message will always suffice. pub fn handle<'a>( &mut self, client_ip: IpAddr, recv_timestamp: NtpTimestamp, message: &[u8], buffer: &'a mut [u8], stats_handler: &mut impl ServerStatHandler, ) -> ServerAction<'a> { let (mut action, mut reason) = self.intended_action(client_ip); if action == ServerResponse::Ignore { // Early exit for ignore stats_handler.register(fallback_message_version(message), false, reason, action); return ServerAction::Ignore; } // Try and parse the message let (packet, cookie) = match NtpPacket::deserialize(message, self.keyset.as_ref()) { Ok(packet) => packet, Err(PacketParsingError::DecryptError(packet)) => { // Don't care about decryption errors when denying anyway if action != ServerResponse::Deny { action = ServerResponse::NTSNak; reason = ServerReason::InvalidCrypto; } (packet, None) } Err(_) => { stats_handler.register( fallback_message_version(message), false, ServerReason::ParseError, ServerResponse::Ignore, ); return ServerAction::Ignore; } }; // Generate the appropriate response let version = packet.version(); let nts = cookie.is_some() || action == ServerResponse::NTSNak; // ignore non-NTS packets when configured to require NTS if let (false, Some(non_nts_action)) = (nts, self.config.require_nts) { if non_nts_action == FilterAction::Ignore { stats_handler.register(version, nts, ServerReason::Policy, ServerResponse::Ignore); return ServerAction::Ignore; } else { action = ServerResponse::Deny; reason = ServerReason::Policy; } } let mut cursor = Cursor::new(buffer); let result = match action { ServerResponse::NTSNak => { NtpPacket::nts_nak_response(packet).serialize(&mut cursor, &NoCipher, None) } ServerResponse::Deny => { if let Some(cookie) = cookie { NtpPacket::nts_deny_response(packet).serialize( &mut cursor, cookie.s2c.as_ref(), None, ) } else { NtpPacket::deny_response(packet).serialize(&mut cursor, &NoCipher, None) } } ServerResponse::ProvideTime => { if let Some(cookie) = cookie { NtpPacket::nts_timestamp_response( &self.system, packet, recv_timestamp, &self.clock, &cookie, &self.keyset, ) .serialize( &mut cursor, cookie.s2c.as_ref(), Some(message.len()), ) } else { NtpPacket::timestamp_response(&self.system, packet, recv_timestamp, &self.clock) .serialize(&mut cursor, &NoCipher, Some(message.len())) } } ServerResponse::Ignore => unreachable!(), }; match result { Ok(_) => { stats_handler.register(version, nts, reason, action); let length = cursor.position(); ServerAction::Respond { message: &cursor.into_inner()[..length as _], } } Err(e) => { tracing::error!("Could not serialize response: {}", e); stats_handler.register( version, nts, ServerReason::InternalError, ServerResponse::Ignore, ); ServerAction::Ignore } } } } /// A size-bounded cache where each entry is timestamped. /// /// The planned use is in rate limiting: we keep track of when a source last checked in. If it checks /// in too often, we issue a rate limiting KISS code. /// /// For this use case we want fast /// /// - lookups: for each incoming IP we must check when it last checked in /// - inserts: for each incoming IP we store that its most recent check-in is now /// /// Hence, this data structure is a vector, and we use a simple hash function to turn the incoming /// address into an index. Lookups and inserts are therefore O(1). /// /// The likelihood of hash collisions can be controlled by changing the size of the cache. Hash collisions /// will happen, so this cache should not be relied on if perfect alerting is deemed critical. #[derive(Debug)] struct TimestampedCache { randomstate: RandomState, elements: Vec>, } impl TimestampedCache { fn new(length: usize) -> Self { Self { // looks a bit odd, but prevents a `Clone` constraint elements: std::iter::repeat_with(|| None).take(length).collect(), randomstate: RandomState::new(), } } fn index(&self, item: &T) -> usize { use std::hash::{BuildHasher, Hasher}; let mut hasher = self.randomstate.build_hasher(); item.hash(&mut hasher); hasher.finish() as usize % self.elements.len() } fn is_allowed(&mut self, item: T, timestamp: Instant, cutoff: Duration) -> bool { if self.elements.is_empty() { // cache disabled, always OK return true; } let index = self.index(&item); // check if the current occupant of this slot is actually the same item let timestamp_if_same = self.elements[index] .as_ref() .and_then(|(v, t)| (&item == v).then_some(t)) .copied(); self.elements[index] = Some((item, timestamp)); if let Some(old_timestamp) = timestamp_if_same { // old and new are the same; check the time timestamp.duration_since(old_timestamp) >= cutoff } else { // old and new are different; this is always OK true } } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct IpSubnet { pub addr: IpAddr, pub mask: u8, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum SubnetParseError { Subnet, Ip(AddrParseError), Mask, } impl std::error::Error for SubnetParseError {} impl Display for SubnetParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Subnet => write!(f, "Invalid subnet syntax"), Self::Ip(e) => write!(f, "{e} in subnet"), Self::Mask => write!(f, "Invalid subnet mask"), } } } impl From for SubnetParseError { fn from(value: AddrParseError) -> Self { Self::Ip(value) } } impl std::str::FromStr for IpSubnet { type Err = SubnetParseError; fn from_str(s: &str) -> Result { let (addr, mask) = s.split_once('/').ok_or(SubnetParseError::Subnet)?; let addr: IpAddr = addr.parse()?; let mask: u8 = mask.parse().map_err(|_| SubnetParseError::Mask)?; let max_mask = match addr { IpAddr::V4(_) => 32, IpAddr::V6(_) => 128, }; if mask > max_mask { return Err(SubnetParseError::Mask); } Ok(IpSubnet { addr, mask }) } } impl<'de> Deserialize<'de> for IpSubnet { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let s = String::deserialize(deserializer)?; std::str::FromStr::from_str(&s).map_err(de::Error::custom) } } #[cfg(test)] mod tests { use std::net::{Ipv4Addr, Ipv6Addr}; use serde_test::{assert_de_tokens, assert_de_tokens_error, Token}; use crate::{ nts_record::AeadAlgorithm, packet::AesSivCmac256, Cipher, DecodedServerCookie, KeySetProvider, NtpDuration, NtpLeapIndicator, PollIntervalLimits, }; use super::*; #[derive(Debug, Clone, Default)] struct TestClock { cur: NtpTimestamp, } impl NtpClock for TestClock { type Error = std::time::SystemTimeError; fn now(&self) -> std::result::Result { Ok(self.cur) } fn set_frequency(&self, _freq: f64) -> Result { panic!("Shouldn't be called by server"); } fn get_frequency(&self) -> Result { Ok(0.0) } fn step_clock(&self, _offset: NtpDuration) -> Result { panic!("Shouldn't be called by server"); } fn disable_ntp_algorithm(&self) -> Result<(), Self::Error> { panic!("Shouldn't be called by server"); } fn error_estimate_update( &self, _est_error: NtpDuration, _max_error: NtpDuration, ) -> Result<(), Self::Error> { panic!("Shouldn't be called by server"); } fn status_update(&self, _leap_status: NtpLeapIndicator) -> Result<(), Self::Error> { panic!("Shouldn't be called by source"); } } #[derive(Debug, Default)] struct TestStatHandler { last_register: Option<(u8, bool, ServerReason, ServerResponse)>, } impl ServerStatHandler for TestStatHandler { fn register( &mut self, version: u8, nts: bool, reason: ServerReason, response: ServerResponse, ) { assert!(self.last_register.is_none()); self.last_register = Some((version, nts, reason, response)); } } fn serialize_packet_unencrypted(send_packet: &NtpPacket) -> Vec { let mut buf = vec![0; 1024]; let mut cursor = Cursor::new(buf.as_mut_slice()); send_packet.serialize(&mut cursor, &NoCipher, None).unwrap(); let end = cursor.position() as usize; buf.truncate(end); buf } fn serialize_packet_encrypted(send_packet: &NtpPacket, key: &dyn Cipher) -> Vec { let mut buf = vec![0; 1024]; let mut cursor = Cursor::new(buf.as_mut_slice()); send_packet.serialize(&mut cursor, key, None).unwrap(); let end = cursor.position() as usize; buf.truncate(end); buf } #[test] fn test_server_allow_filter() { let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["127.0.0.0/24".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_secs(1), rate_limiting_cache_size: 0, require_nts: None, }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config, clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, id) = NtpPacket::poll_message(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let mut buf = [0; 48]; let response = server.handle( "128.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["127.0.0.0/24".parse().unwrap()], action: FilterAction::Deny, }, rate_limiting_cutoff: Duration::from_secs(1), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "128.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Deny)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert!(packet.valid_server_response(id, false)); assert!(packet.is_kiss_deny()); } #[test] fn test_server_deny_filter() { let config = ServerConfig { denylist: FilterList { filter: vec!["128.0.0.0/24".parse().unwrap()], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_secs(1), rate_limiting_cache_size: 0, require_nts: None, }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config, clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, id) = NtpPacket::poll_message(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let mut buf = [0; 48]; let response = server.handle( "128.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Deny)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert!(packet.valid_server_response(id, false)); assert!(packet.is_kiss_deny()); let config = ServerConfig { denylist: FilterList { filter: vec!["128.0.0.0/24".parse().unwrap()], action: FilterAction::Ignore, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_secs(1), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "128.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); } #[test] fn test_server_rate_limit() { let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 32, require_nts: None, }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config, clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, id) = NtpPacket::poll_message(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::RateLimit, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); std::thread::sleep(std::time::Duration::from_millis(120)); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); } #[test] fn test_server_corrupted() { let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config, clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, _) = NtpPacket::poll_message(PollIntervalLimits::default().min); let mut serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 1]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some(( 4, false, ServerReason::InternalError, ServerResponse::Ignore )) ); assert!(matches!(response, ServerAction::Ignore)); serialized[0] = 42; let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::ParseError, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["128.0.0.0/24".parse().unwrap()], action: FilterAction::Deny, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::ParseError, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["128.0.0.0/24".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::Policy, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let config = ServerConfig { denylist: FilterList { filter: vec!["127.0.0.0/24".parse().unwrap()], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::ParseError, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let config = ServerConfig { denylist: FilterList { filter: vec!["127.0.0.0/24".parse().unwrap()], action: FilterAction::Ignore, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; server.update_config(config); let mut buf = [0; 48]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::Policy, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); } #[test] fn test_server_nts() { let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: Some(FilterAction::Ignore), }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let keyset = KeySetProvider::new(1).get(); let mut server = Server::new(config, clock, SystemSnapshot::default(), keyset.clone()); let decodedcookie = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new([0; 32].into())), c2s: Box::new(AesSivCmac256::new([0; 32].into())), }; let cookie = keyset.encode_cookie(&decodedcookie); let (packet, id) = NtpPacket::nts_poll_message(&cookie, 0, PollIntervalLimits::default().min); let serialized = serialize_packet_encrypted(&packet, decodedcookie.c2s.as_ref()); let mut buf = [0; 1024]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, true, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, decodedcookie.s2c.as_ref()) .unwrap() .0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, true)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let cookie_invalid = KeySetProvider::new(1).get().encode_cookie(&decodedcookie); let (packet_invalid, _) = NtpPacket::nts_poll_message(&cookie_invalid, 0, PollIntervalLimits::default().min); let serialized = serialize_packet_encrypted(&packet_invalid, decodedcookie.c2s.as_ref()); let mut buf = [0; 1024]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, true, ServerReason::InvalidCrypto, ServerResponse::NTSNak)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, decodedcookie.s2c.as_ref()) .unwrap() .0; assert!(packet.is_kiss_ntsn()); } #[test] fn test_server_require_nts() { let mut config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["0.0.0.0/0".parse().unwrap()], action: FilterAction::Ignore, }, rate_limiting_cutoff: Duration::from_secs(1), rate_limiting_cache_size: 0, require_nts: Some(FilterAction::Ignore), }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config.clone(), clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, _) = NtpPacket::poll_message(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 1024]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Ignore)) ); assert!(matches!(response, ServerAction::Ignore)); let decodedcookie = DecodedServerCookie { algorithm: AeadAlgorithm::AeadAesSivCmac256, s2c: Box::new(AesSivCmac256::new([0; 32].into())), c2s: Box::new(AesSivCmac256::new([0; 32].into())), }; let cookie_invalid = KeySetProvider::new(1).get().encode_cookie(&decodedcookie); let (packet_invalid, _) = NtpPacket::nts_poll_message(&cookie_invalid, 0, PollIntervalLimits::default().min); let serialized = serialize_packet_encrypted(&packet_invalid, decodedcookie.c2s.as_ref()); let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, true, ServerReason::InvalidCrypto, ServerResponse::NTSNak)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, decodedcookie.s2c.as_ref()) .unwrap() .0; assert!(packet.is_kiss_ntsn()); config.require_nts = Some(FilterAction::Deny); server.update_config(config.clone()); let (packet, id) = NtpPacket::poll_message(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((4, false, ServerReason::Policy, ServerResponse::Deny)) ); let ServerAction::Respond { message } = response else { panic!("Server ignored packet") }; let packet = NtpPacket::deserialize(message, &NoCipher).unwrap().0; assert!(packet.valid_server_response(id, false)); assert!(packet.is_kiss_deny()); } #[cfg(feature = "ntpv5")] #[test] fn test_server_v5() { let config = ServerConfig { denylist: FilterList { filter: vec![], action: FilterAction::Deny, }, allowlist: FilterList { filter: vec!["127.0.0.0/24".parse().unwrap()], action: FilterAction::Deny, }, rate_limiting_cutoff: Duration::from_millis(100), rate_limiting_cache_size: 0, require_nts: None, }; let clock = TestClock { cur: NtpTimestamp::from_fixed_int(200), }; let mut stats = TestStatHandler::default(); let mut server = Server::new( config, clock, SystemSnapshot::default(), KeySetProvider::new(1).get(), ); let (packet, id) = NtpPacket::poll_message_v5(PollIntervalLimits::default().min); let serialized = serialize_packet_unencrypted(&packet); let mut buf = [0; 1024]; let response = server.handle( "127.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::Policy, ServerResponse::ProvideTime)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert_ne!(packet.stratum(), 0); assert!(packet.valid_server_response(id, false)); assert_eq!( packet.receive_timestamp(), NtpTimestamp::from_fixed_int(100) ); assert_eq!( packet.transmit_timestamp(), NtpTimestamp::from_fixed_int(200) ); let mut buf = [0; 1024]; let response = server.handle( "128.0.0.1".parse().unwrap(), NtpTimestamp::from_fixed_int(100), &serialized, &mut buf, &mut stats, ); assert_eq!( stats.last_register.take(), Some((5, false, ServerReason::Policy, ServerResponse::Deny)) ); let data = match response { ServerAction::Ignore => panic!("Server ignored packet"), ServerAction::Respond { message } => message, }; let packet = NtpPacket::deserialize(data, &NoCipher).unwrap().0; assert!(packet.valid_server_response(id, false)); assert!(packet.is_kiss_deny()); } // TimestampedCache tests #[test] fn timestamped_cache() { let length = 8u8; let mut cache: TimestampedCache = TimestampedCache::new(length as usize); let second = Duration::from_secs(1); let instant = Instant::now(); assert!(cache.is_allowed(0, instant, second)); assert!(!cache.is_allowed(0, instant, second)); let later = instant + 2 * second; assert!(cache.is_allowed(0, later, second)); // simulate a hash collision let even_later = later + 2 * second; assert!(cache.is_allowed(length, even_later, second)); } #[test] fn timestamped_cache_size_0() { let mut cache = TimestampedCache::new(0); let second = Duration::from_secs(1); let instant = Instant::now(); assert!(cache.is_allowed(0, instant, second)); } // IpSubnet parsing tests #[test] fn test_ipv4_subnet_parse() { use std::str::FromStr; assert!(IpSubnet::from_str("bla/5").is_err()); assert!(IpSubnet::from_str("0.0.0.0").is_err()); assert!(IpSubnet::from_str("0.0.0.0/33").is_err()); assert_eq!( IpSubnet::from_str("0.0.0.0/0"), Ok(IpSubnet { addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), mask: 0 }) ); assert_eq!( IpSubnet::from_str("127.0.0.1/32"), Ok(IpSubnet { addr: IpAddr::V4(Ipv4Addr::LOCALHOST), mask: 32 }) ); assert_de_tokens_error::( &[Token::Str("bla/5")], "invalid IP address syntax in subnet", ); assert_de_tokens_error::(&[Token::Str("0.0.0.0")], "Invalid subnet syntax"); assert_de_tokens_error::(&[Token::Str("0.0.0.0/33")], "Invalid subnet mask"); assert_de_tokens( &IpSubnet { addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), mask: 0, }, &[Token::Str("0.0.0.0/0")], ); assert_de_tokens( &IpSubnet { addr: IpAddr::V4(Ipv4Addr::LOCALHOST), mask: 32, }, &[Token::Str("127.0.0.1/32")], ); } #[test] fn test_ipv6_subnet_parse() { use std::str::FromStr; assert!(IpSubnet::from_str("bla/5").is_err()); assert!(IpSubnet::from_str("::").is_err()); assert!(IpSubnet::from_str("::/129").is_err()); assert_eq!( IpSubnet::from_str("::/0"), Ok(IpSubnet { addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), mask: 0 }) ); assert_eq!( IpSubnet::from_str("::1/128"), Ok(IpSubnet { addr: IpAddr::V6(Ipv6Addr::LOCALHOST), mask: 128 }) ); assert_de_tokens_error::( &[Token::Str("bla/5")], "invalid IP address syntax in subnet", ); assert_de_tokens_error::(&[Token::Str("::")], "Invalid subnet syntax"); assert_de_tokens_error::(&[Token::Str("::/129")], "Invalid subnet mask"); assert_de_tokens( &IpSubnet { addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), mask: 0, }, &[Token::Str("::/0")], ); assert_de_tokens( &IpSubnet { addr: IpAddr::V6(Ipv6Addr::LOCALHOST), mask: 128, }, &[Token::Str("::1/128")], ); } } ntp-proto-1.4.0/src/source.rs000064400000000000000000001646171046102023000142440ustar 00000000000000#[cfg(feature = "ntpv5")] use crate::packet::{ v5::server_reference_id::{BloomFilter, RemoteBloomFilter}, ExtensionField, NtpHeader, }; use crate::{ algorithm::{ObservableSourceTimedata, SourceController}, config::SourceDefaultsConfig, cookiestash::CookieStash, identifiers::ReferenceId, packet::{Cipher, NtpAssociationMode, NtpLeapIndicator, NtpPacket, RequestIdentifier}, system::{SystemSnapshot, SystemSourceUpdate}, time_types::{NtpDuration, NtpInstant, NtpTimestamp, PollInterval}, }; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use std::{ fmt::Debug, io::Cursor, net::{IpAddr, SocketAddr}, time::Duration, }; use tracing::{debug, trace, warn}; const MAX_STRATUM: u8 = 16; const POLL_WINDOW: std::time::Duration = std::time::Duration::from_secs(5); const STARTUP_TRIES_THRESHOLD: usize = 3; #[cfg(feature = "ntpv5")] const AFTER_UPGRADE_TRIES_THRESHOLD: u32 = 2; pub struct SourceNtsData { pub(crate) cookies: CookieStash, // Note: we use Box to support the use // of multiple different ciphers, that might differ // in the key information they need to keep. pub(crate) c2s: Box, pub(crate) s2c: Box, } #[cfg(any(test, feature = "__internal-test"))] impl SourceNtsData { pub fn get_cookie(&mut self) -> Option> { self.cookies.get() } pub fn get_keys(self) -> (Box, Box) { (self.c2s, self.s2c) } } impl std::fmt::Debug for SourceNtsData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SourceNtsData") .field("cookies", &self.cookies) .finish() } } #[derive(Debug)] pub struct NtpSource> { nts: Option>, // Poll interval used when sending last poll message. last_poll_interval: PollInterval, // The poll interval desired by the remove server. // Must be increased when the server sends the RATE kiss code. remote_min_poll_interval: PollInterval, // Identifier of the last request sent to the server. This is correlated // with any received response from the server to guard against replay // attacks and packet reordering. current_request_identifier: Option<(RequestIdentifier, NtpInstant)>, stratum: u8, reference_id: ReferenceId, source_addr: SocketAddr, source_id: ReferenceId, reach: Reach, tries: usize, controller: Controller, source_defaults_config: SourceDefaultsConfig, buffer: [u8; 1024], protocol_version: ProtocolVersion, #[cfg(feature = "ntpv5")] // TODO we only need this if we run as a server bloom_filter: RemoteBloomFilter, } pub struct OneWaySource> { controller: Controller, } impl> OneWaySource { pub(crate) fn new(controller: Controller) -> OneWaySource { OneWaySource { controller } } pub fn handle_measurement( &mut self, measurement: Measurement<()>, ) -> Option { self.controller.handle_measurement(measurement) } pub fn handle_message(&mut self, message: Controller::ControllerMessage) { self.controller.handle_message(message) } } #[derive(Debug, Copy, Clone)] pub struct Measurement { pub delay: D, pub offset: NtpDuration, pub localtime: NtpTimestamp, pub monotime: NtpInstant, pub stratum: u8, pub root_delay: NtpDuration, pub root_dispersion: NtpDuration, pub leap: NtpLeapIndicator, pub precision: i8, } impl Measurement { fn from_packet( packet: &NtpPacket, send_timestamp: NtpTimestamp, recv_timestamp: NtpTimestamp, local_clock_time: NtpInstant, ) -> Self { Self { delay: (recv_timestamp - send_timestamp) - (packet.transmit_timestamp() - packet.receive_timestamp()), offset: ((packet.receive_timestamp() - send_timestamp) + (packet.transmit_timestamp() - recv_timestamp)) / 2, localtime: send_timestamp + (recv_timestamp - send_timestamp) / 2, monotime: local_clock_time, stratum: packet.stratum(), root_delay: packet.root_delay(), root_dispersion: packet.root_dispersion(), leap: packet.leap(), precision: packet.precision(), } } } /// Used to determine whether the server is reachable and the data are fresh /// /// This value is represented as an 8-bit shift register. The register is shifted left /// by one bit when a packet is sent and the rightmost bit is set to zero. /// As valid packets arrive, the rightmost bit is set to one. /// If the register contains any nonzero bits, the server is considered reachable; /// otherwise, it is unreachable. #[derive(Default, Clone, Copy, Serialize, Deserialize)] pub struct Reach(u8); impl std::fmt::Debug for Reach { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.is_reachable() { write!( f, "Reach(0b{:07b} ({} polls until unreachable))", self.0, 7 - self.0.trailing_zeros() ) } else { write!(f, "Reach(unreachable)",) } } } impl Reach { pub fn is_reachable(&self) -> bool { self.0 != 0 } /// We have just received a packet, so the source is definitely reachable pub(crate) fn received_packet(&mut self) { self.0 |= 1; } /// A packet received some number of poll intervals ago is decreasingly relevant for /// determining that a source is still reachable. We discount the packets received so far. fn poll(&mut self) { self.0 <<= 1; } /// Number of polls since the last message we received pub fn unanswered_polls(&self) -> u32 { self.0.trailing_zeros() } } #[derive(Debug, Clone)] pub struct OneWaySourceUpdate { pub snapshot: OneWaySourceSnapshot, pub message: Option, } #[derive(Debug, Clone, Copy)] #[allow(clippy::large_enum_variant)] pub enum SourceSnapshot { Ntp(NtpSourceSnapshot), OneWay(OneWaySourceSnapshot), } #[derive(Debug, Clone, Copy)] pub struct OneWaySourceSnapshot { pub source_id: ReferenceId, pub stratum: u8, } #[derive(Debug, Clone, Copy)] pub struct NtpSourceSnapshot { pub source_addr: SocketAddr, pub source_id: ReferenceId, pub poll_interval: PollInterval, pub reach: Reach, pub stratum: u8, pub reference_id: ReferenceId, pub protocol_version: ProtocolVersion, #[cfg(feature = "ntpv5")] pub bloom_filter: Option, } impl NtpSourceSnapshot { pub fn accept_synchronization( &self, local_stratum: u8, local_ips: &[IpAddr], #[cfg_attr(not(feature = "ntpv5"), allow(unused_variables))] system: &SystemSnapshot, ) -> Result<(), AcceptSynchronizationError> { use AcceptSynchronizationError::*; if self.stratum >= local_stratum { debug!( source_stratum = self.stratum, own_stratum = local_stratum, "Source rejected due to invalid stratum. The stratum of a source must be lower than the own stratum", ); return Err(Stratum); } // Detect whether the remote uses us as their main time reference. // if so, we shouldn't sync to them as that would create a loop. // Note, this can only ever be an issue if the source is not using // hardware as its source, so ignore reference_id if stratum is 1. if self.stratum != 1 && local_ips .iter() .any(|ip| ReferenceId::from_ip(*ip) == self.source_id) { debug!("Source rejected because of detected synchronization loop (ref id)"); return Err(Loop); } #[cfg(feature = "ntpv5")] match self.bloom_filter { Some(filter) if filter.contains_id(&system.server_id) => { debug!("Source rejected because of detected synchronization loop (bloom filter)"); return Err(Loop); } _ => {} } // An unreachable error occurs if the server is unreachable. if !self.reach.is_reachable() { debug!("Source is unreachable"); return Err(ServerUnreachable); } Ok(()) } pub fn from_source>( source: &NtpSource, ) -> Self { Self { source_addr: source.source_addr, source_id: source.source_id, stratum: source.stratum, reference_id: source.reference_id, reach: source.reach, poll_interval: source.last_poll_interval, protocol_version: source.protocol_version, #[cfg(feature = "ntpv5")] bloom_filter: source.bloom_filter.full_filter().copied(), } } } #[cfg(feature = "__internal-test")] pub fn source_snapshot() -> NtpSourceSnapshot { use std::net::Ipv4Addr; let mut reach = crate::source::Reach::default(); reach.received_packet(); NtpSourceSnapshot { source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), source_id: ReferenceId::from_int(0), stratum: 0, reference_id: ReferenceId::from_int(0), reach, poll_interval: crate::time_types::PollIntervalLimits::default().min, protocol_version: Default::default(), #[cfg(feature = "ntpv5")] bloom_filter: None, } } #[derive(Debug, PartialEq, Eq)] #[repr(u8)] pub enum AcceptSynchronizationError { ServerUnreachable, Loop, Distance, Stratum, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ProtocolVersion { V4, #[cfg(feature = "ntpv5")] V4UpgradingToV5 { tries_left: u8, }, #[cfg(feature = "ntpv5")] UpgradedToV5, #[cfg(feature = "ntpv5")] V5, } impl ProtocolVersion { pub fn is_expected_incoming_version(&self, incoming_version: u8) -> bool { match self { ProtocolVersion::V4 => incoming_version == 4 || incoming_version == 3, #[cfg(feature = "ntpv5")] ProtocolVersion::V4UpgradingToV5 { .. } => incoming_version == 4, #[cfg(feature = "ntpv5")] ProtocolVersion::UpgradedToV5 | ProtocolVersion::V5 => incoming_version == 5, } } } impl Default for ProtocolVersion { #[cfg(feature = "ntpv5")] fn default() -> Self { Self::V4UpgradingToV5 { tries_left: 8 } } #[cfg(not(feature = "ntpv5"))] fn default() -> Self { Self::V4 } } pub struct NtpSourceUpdate { pub(crate) snapshot: NtpSourceSnapshot, pub(crate) message: Option, } impl std::fmt::Debug for NtpSourceUpdate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("NtpSourceUpdate") .field("snapshot", &self.snapshot) .field("message", &self.message) .finish() } } impl Clone for NtpSourceUpdate { fn clone(&self) -> Self { Self { snapshot: self.snapshot, message: self.message.clone(), } } } #[cfg(feature = "__internal-test")] impl NtpSourceUpdate { pub fn snapshot(snapshot: NtpSourceSnapshot) -> Self { NtpSourceUpdate { snapshot, message: None, } } } #[derive(Debug, Clone)] #[allow(clippy::large_enum_variant)] pub enum NtpSourceAction { /// Send a message over the network. When this is issued, the network port maybe changed. Send(Vec), /// Send an update to [`System`](crate::system::System) UpdateSystem(NtpSourceUpdate), /// Call [`NtpSource::handle_timer`] after given duration SetTimer(Duration), /// A complete reset of the connection is necessary, including a potential new NTSKE client session and/or DNS lookup. Reset, /// We must stop talking to this particular server. Demobilize, } #[derive(Debug)] pub struct NtpSourceActionIterator { iter: > as IntoIterator>::IntoIter, } impl Default for NtpSourceActionIterator { fn default() -> Self { Self { iter: vec![].into_iter(), } } } impl Iterator for NtpSourceActionIterator { type Item = NtpSourceAction; fn next(&mut self) -> Option { self.iter.next() } } impl NtpSourceActionIterator { fn from(data: Vec>) -> Self { Self { iter: data.into_iter(), } } } macro_rules! actions { [$($action:expr),*] => { { NtpSourceActionIterator::from(vec![$($action),*]) } } } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ObservableSourceState { #[serde(flatten)] pub timedata: ObservableSourceTimedata, pub unanswered_polls: u32, pub poll_interval: PollInterval, pub nts_cookies: Option, pub name: String, pub address: String, pub id: SourceId, } impl> NtpSource { pub(crate) fn new( source_addr: SocketAddr, source_defaults_config: SourceDefaultsConfig, protocol_version: ProtocolVersion, controller: Controller, nts: Option>, ) -> (Self, NtpSourceActionIterator) { ( Self { nts, last_poll_interval: source_defaults_config.poll_interval_limits.min, remote_min_poll_interval: source_defaults_config.poll_interval_limits.min, current_request_identifier: None, source_id: ReferenceId::from_ip(source_addr.ip()), source_addr, reach: Default::default(), tries: 0, stratum: 16, reference_id: ReferenceId::NONE, source_defaults_config, controller, buffer: [0; 1024], protocol_version, // TODO make this configurable #[cfg(feature = "ntpv5")] bloom_filter: RemoteBloomFilter::new(16).expect("16 is a valid chunk size"), }, actions!(NtpSourceAction::SetTimer(Duration::from_secs(0))), ) } pub fn observe(&self, name: String, id: SourceId) -> ObservableSourceState { ObservableSourceState { timedata: self.controller.observe(), unanswered_polls: self.reach.unanswered_polls(), poll_interval: self.last_poll_interval, nts_cookies: self.nts.as_ref().map(|nts| nts.cookies.len()), name, address: self.source_addr.to_string(), id, } } pub fn current_poll_interval(&self) -> PollInterval { self.controller .desired_poll_interval() .max(self.remote_min_poll_interval) } #[cfg_attr(not(feature = "ntpv5"), allow(unused_mut))] pub fn handle_timer(&mut self) -> NtpSourceActionIterator { if !self.reach.is_reachable() && self.tries >= STARTUP_TRIES_THRESHOLD { return actions!(NtpSourceAction::Reset); } #[cfg(feature = "ntpv5")] if matches!(self.protocol_version, ProtocolVersion::UpgradedToV5) && self.reach.unanswered_polls() >= AFTER_UPGRADE_TRIES_THRESHOLD { // For some reason V5 communication isn't working, even though we and the server support it. Fall back. self.protocol_version = ProtocolVersion::V4; } self.reach.poll(); self.tries = self.tries.saturating_add(1); let poll_interval = self.current_poll_interval(); let (mut packet, identifier) = match &mut self.nts { Some(nts) => { let Some(cookie) = nts.cookies.get() else { return actions!(NtpSourceAction::Reset); }; // Do ensure we don't exceed the buffer size // when requesting new cookies. We keep 350 // bytes of margin for header, ids, extension // field headers and signature. let new_cookies = nts .cookies .gap() .min(((self.buffer.len() - 300) / cookie.len()).min(u8::MAX as usize) as u8); match self.protocol_version { ProtocolVersion::V4 => { NtpPacket::nts_poll_message(&cookie, new_cookies, poll_interval) } #[cfg(feature = "ntpv5")] ProtocolVersion::V4UpgradingToV5 { .. } | ProtocolVersion::V5 | ProtocolVersion::UpgradedToV5 => { NtpPacket::nts_poll_message_v5(&cookie, new_cookies, poll_interval) } } } None => match self.protocol_version { ProtocolVersion::V4 => NtpPacket::poll_message(poll_interval), #[cfg(feature = "ntpv5")] ProtocolVersion::V4UpgradingToV5 { .. } => { NtpPacket::poll_message_upgrade_request(poll_interval) } #[cfg(feature = "ntpv5")] ProtocolVersion::UpgradedToV5 | ProtocolVersion::V5 => { NtpPacket::poll_message_v5(poll_interval) } }, }; self.current_request_identifier = Some((identifier, NtpInstant::now() + POLL_WINDOW)); #[cfg(feature = "ntpv5")] if let NtpHeader::V5(header) = packet.header() { let req_ef = self.bloom_filter.next_request(header.client_cookie); packet.push_additional(ExtensionField::ReferenceIdRequest(req_ef)); } // update the poll interval self.last_poll_interval = poll_interval; let snapshot = NtpSourceSnapshot::from_source(self); // Write packet to buffer let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut self.buffer); packet .serialize( &mut cursor, &self.nts.as_ref().map(|nts| nts.c2s.as_ref()), None, ) .expect("Internal error: could not serialize packet"); let used = cursor.position(); let result = &cursor.into_inner()[..used as usize]; actions!( NtpSourceAction::Send(result.into()), NtpSourceAction::UpdateSystem(NtpSourceUpdate { snapshot, message: None }), // randomize the poll interval a little to make it harder to predict poll requests NtpSourceAction::SetTimer( poll_interval .as_system_duration() .mul_f64(thread_rng().gen_range(1.01..=1.05)) ) ) } pub fn handle_system_update( &mut self, update: SystemSourceUpdate, ) -> NtpSourceActionIterator { self.controller.handle_message(update.message); actions!() } pub fn handle_incoming( &mut self, message: &[u8], local_clock_time: NtpInstant, send_time: NtpTimestamp, recv_time: NtpTimestamp, ) -> NtpSourceActionIterator { let message = match NtpPacket::deserialize(message, &self.nts.as_ref().map(|nts| nts.s2c.as_ref())) { Ok((packet, _)) => packet, Err(e) => { warn!("received invalid packet: {}", e); return actions!(); } }; if !self .protocol_version .is_expected_incoming_version(message.version()) { warn!( incoming_version = message.version(), expected_version = ?self.protocol_version, "Received packet with unexpected version from source" ); return actions!(); } let request_identifier = match self.current_request_identifier { Some((next_expected_origin, validity)) if validity >= NtpInstant::now() => { next_expected_origin } _ => { debug!("Received old/unexpected packet from source"); return actions!(); } }; #[cfg(feature = "ntpv5")] if message.valid_server_response(request_identifier, self.nts.is_some()) { if let ProtocolVersion::V4UpgradingToV5 { tries_left } = self.protocol_version { let tries_left = tries_left.saturating_sub(1); if message.is_upgrade() { debug!("Received a valid upgrade response, switching to NTPv5!"); self.protocol_version = ProtocolVersion::UpgradedToV5; } else if tries_left == 0 { debug!("Server does not support NTPv5, stopping the upgrade process"); self.protocol_version = ProtocolVersion::V4; } else { debug!(tries_left, "Server did not yet respond with upgrade code"); self.protocol_version = ProtocolVersion::V4UpgradingToV5 { tries_left }; }; } else if let ProtocolVersion::UpgradedToV5 = self.protocol_version { self.protocol_version = ProtocolVersion::V5; } } if !message.valid_server_response(request_identifier, self.nts.is_some()) { // Packets should be a response to a previous request from us, // if not just ignore. Note that this might also happen when // we reset between sending the request and receiving the response. // We do this as the first check since accepting even a KISS // packet that is not a response will leave us vulnerable // to denial of service attacks. debug!("Received old/unexpected packet from source"); actions!() } else if message.is_kiss_rate(self.last_poll_interval) { // KISS packets may not have correct timestamps at all, handle them anyway self.remote_min_poll_interval = Ord::max( self.remote_min_poll_interval .inc(self.source_defaults_config.poll_interval_limits), self.last_poll_interval, ); warn!(?self.remote_min_poll_interval, "Source requested rate limit"); actions!() } else if message.is_kiss_rstr() || message.is_kiss_deny() { warn!("Source denied service"); // KISS packets may not have correct timestamps at all, handle them anyway actions!(NtpSourceAction::Demobilize) } else if message.is_kiss_ntsn() { warn!("Received nts not-acknowledge"); // as these can be easily faked, we dont immediately give up on receiving // a response. actions!() } else if message.is_kiss() { warn!("Unrecognized KISS Message from source"); // Ignore unrecognized control messages actions!() } else if message.stratum() > MAX_STRATUM { // A servers stratum should be between 1 and MAX_STRATUM (16) inclusive. warn!( "Received message from server with excessive stratum {}", message.stratum() ); actions!() } else if message.mode() != NtpAssociationMode::Server { // we currently only support a client <-> server association warn!("Received packet with invalid mode"); actions!() } else { self.process_message(message, local_clock_time, send_time, recv_time) } } fn process_message( &mut self, message: NtpPacket, local_clock_time: NtpInstant, send_time: NtpTimestamp, recv_time: NtpTimestamp, ) -> NtpSourceActionIterator { trace!("Packet accepted for processing"); // For reachability, mark that we have had a response self.reach.received_packet(); // we received this packet, and don't want to accept future ones with this next_expected_origin self.current_request_identifier = None; // Update stratum and reference id self.stratum = message.stratum(); self.reference_id = message.reference_id(); #[cfg(feature = "ntpv5")] if let NtpHeader::V5(header) = message.header() { // Handle new requested poll interval let requested_poll = message.poll(); if requested_poll > self.remote_min_poll_interval { debug!( ?requested_poll, ?self.remote_min_poll_interval, "Adapting to longer poll interval requested by server" ); self.remote_min_poll_interval = requested_poll; } // Update our bloom filter (we need separate branches due to types let bloom_responses = if self.nts.is_some() { message .authenticated_extension_fields() .filter_map(|ef| match ef { ExtensionField::ReferenceIdResponse(response) => Some(response), _ => None, }) .next() } else { message .untrusted_extension_fields() .filter_map(|ef| match ef { ExtensionField::ReferenceIdResponse(response) => Some(response), _ => None, }) .next() }; if let Some(ref_id) = bloom_responses { let result = self .bloom_filter .handle_response(header.client_cookie, ref_id); if let Err(err) = result { warn!(?err, "Invalid ReferenceIdResponse from source, ignoring...") } } } // generate and handle measurement let measurement = Measurement::from_packet(&message, send_time, recv_time, local_clock_time); let controller_message = self.controller.handle_measurement(measurement); // Process new cookies if let Some(nts) = self.nts.as_mut() { for cookie in message.new_cookies() { nts.cookies.store(cookie); } } actions!(NtpSourceAction::UpdateSystem(NtpSourceUpdate { snapshot: NtpSourceSnapshot::from_source(self), message: controller_message, })) } #[cfg(test)] pub(crate) fn test_ntp_source(controller: Controller) -> Self { use std::net::Ipv4Addr; NtpSource { nts: None, last_poll_interval: PollInterval::default(), remote_min_poll_interval: PollInterval::default(), current_request_identifier: None, source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), source_id: ReferenceId::from_int(0), reach: Reach::default(), tries: 0, stratum: 0, reference_id: ReferenceId::from_int(0), source_defaults_config: SourceDefaultsConfig::default(), controller, buffer: [0; 1024], protocol_version: Default::default(), #[cfg(feature = "ntpv5")] bloom_filter: RemoteBloomFilter::new(16).unwrap(), } } } #[cfg(test)] mod test { use crate::{packet::NoCipher, time_types::PollIntervalLimits, NtpClock}; use super::*; #[cfg(feature = "ntpv5")] use crate::packet::v5::server_reference_id::ServerId; #[cfg(feature = "ntpv5")] use rand::thread_rng; #[derive(Debug, Clone, Default)] struct TestClock {} const EPOCH_OFFSET: u32 = (70 * 365 + 17) * 86400; impl NtpClock for TestClock { type Error = std::time::SystemTimeError; fn now(&self) -> std::result::Result { let cur = std::time::SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH)?; Ok(NtpTimestamp::from_seconds_nanos_since_ntp_era( EPOCH_OFFSET.wrapping_add(cur.as_secs() as u32), cur.subsec_nanos(), )) } fn set_frequency(&self, _freq: f64) -> Result { panic!("Shouldn't be called by source"); } fn get_frequency(&self) -> Result { Ok(0.0) } fn step_clock(&self, _offset: NtpDuration) -> Result { panic!("Shouldn't be called by source"); } fn disable_ntp_algorithm(&self) -> Result<(), Self::Error> { panic!("Shouldn't be called by source"); } fn error_estimate_update( &self, _est_error: NtpDuration, _max_error: NtpDuration, ) -> Result<(), Self::Error> { panic!("Shouldn't be called by source"); } fn status_update(&self, _leap_status: NtpLeapIndicator) -> Result<(), Self::Error> { panic!("Shouldn't be called by source"); } } struct NoopController; impl SourceController for NoopController { type ControllerMessage = (); type SourceMessage = (); type MeasurementDelay = NtpDuration; fn handle_message(&mut self, _: Self::ControllerMessage) { // do nothing } fn handle_measurement( &mut self, _: Measurement, ) -> Option { // do nothing Some(()) } fn desired_poll_interval(&self) -> PollInterval { PollInterval::default() } fn observe(&self) -> crate::ObservableSourceTimedata { panic!("Not implemented on noop controller"); } } #[test] fn test_measurement_from_packet() { let instant = NtpInstant::now(); let mut packet = NtpPacket::test(); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(1)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(2)); let result = Measurement::from_packet( &packet, NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(3), instant, ); assert_eq!(result.offset, NtpDuration::from_fixed_int(0)); assert_eq!(result.delay, NtpDuration::from_fixed_int(2)); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(2)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(3)); let result = Measurement::from_packet( &packet, NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(3), instant, ); assert_eq!(result.offset, NtpDuration::from_fixed_int(1)); assert_eq!(result.delay, NtpDuration::from_fixed_int(2)); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(0)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(5)); let result = Measurement::from_packet( &packet, NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(3), instant, ); assert_eq!(result.offset, NtpDuration::from_fixed_int(1)); assert_eq!(result.delay, NtpDuration::from_fixed_int(-2)); } #[test] fn reachability() { let mut reach = Reach::default(); // the default reach register value is 0, and hence not reachable assert!(!reach.is_reachable()); // when we receive a packet, we set the right-most bit; // we just received a packet from the source, so it is reachable reach.received_packet(); assert!(reach.is_reachable()); // on every poll, the register is shifted to the left, and there are // 8 bits. So we can poll 7 times and the source is still considered reachable for _ in 0..7 { reach.poll(); } assert!(reach.is_reachable()); // but one more poll and all 1 bits have been shifted out; // the source is no longer reachable reach.poll(); assert!(!reach.is_reachable()); // until we receive a packet from it again reach.received_packet(); assert!(reach.is_reachable()); } #[test] fn test_accept_synchronization() { use AcceptSynchronizationError::*; let mut source = NtpSource::test_ntp_source(NoopController); #[cfg_attr(not(feature = "ntpv5"), allow(unused_mut))] let mut system = SystemSnapshot::default(); #[cfg(feature = "ntpv5")] { system.server_id = ServerId::new(&mut thread_rng()); } macro_rules! accept { () => {{ let snapshot = NtpSourceSnapshot::from_source(&source); snapshot.accept_synchronization(16, &["127.0.0.1".parse().unwrap()], &system) }}; } source.source_id = ReferenceId::from_ip("127.0.0.1".parse().unwrap()); assert_eq!(accept!(), Err(Loop)); source.source_id = ReferenceId::from_ip("127.0.1.1".parse().unwrap()); assert_eq!(accept!(), Err(ServerUnreachable)); source.reach.received_packet(); assert_eq!(accept!(), Ok(())); source.stratum = 42; assert_eq!(accept!(), Err(Stratum)); } #[test] fn test_poll_interval() { struct PollIntervalController(PollInterval); impl SourceController for PollIntervalController { type ControllerMessage = (); type SourceMessage = (); type MeasurementDelay = NtpDuration; fn handle_message(&mut self, _: Self::ControllerMessage) {} fn handle_measurement( &mut self, _: Measurement, ) -> Option { None } fn desired_poll_interval(&self) -> PollInterval { self.0 } fn observe(&self) -> crate::ObservableSourceTimedata { unimplemented!() } } let mut source = NtpSource::test_ntp_source(PollIntervalController(PollIntervalLimits::default().min)); assert!(source.current_poll_interval() >= source.remote_min_poll_interval); assert!(source.current_poll_interval() >= source.controller.0); source.controller.0 = PollIntervalLimits::default().max; assert!(source.current_poll_interval() >= source.remote_min_poll_interval); assert!(source.current_poll_interval() >= source.controller.0); source.controller.0 = PollIntervalLimits::default().min; source.remote_min_poll_interval = PollIntervalLimits::default().max; assert!(source.current_poll_interval() >= source.remote_min_poll_interval); assert!(source.current_poll_interval() >= source.controller.0); } #[test] fn test_handle_incoming() { let base = NtpInstant::now(); let mut source = NtpSource::test_ntp_source(NoopController); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; let mut packet = NtpPacket::test(); packet.set_stratum(1); packet.set_mode(NtpAssociationMode::Server); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(100)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(200)); let actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(400), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize | NtpSourceAction::SetTimer(_) | NtpSourceAction::Send(_) )); } let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(500), ); assert!(actions.next().is_none()); } #[test] fn test_startup_unreachable() { let mut source = NtpSource::test_ntp_source(NoopController); let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let mut actions = source.handle_timer(); assert!(matches!(actions.next(), Some(NtpSourceAction::Reset))); } #[test] fn test_running_unreachable() { let base = NtpInstant::now(); let mut source = NtpSource::test_ntp_source(NoopController); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; let mut packet = NtpPacket::test(); packet.set_stratum(1); packet.set_mode(NtpAssociationMode::Server); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(100)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(200)); let actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(400), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize | NtpSourceAction::SetTimer(_) | NtpSourceAction::Send(_) )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let actions = source.handle_timer(); for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); } let mut actions = source.handle_timer(); assert!(matches!(actions.next(), Some(NtpSourceAction::Reset))); } #[test] fn test_stratum_checks() { let base = NtpInstant::now(); let mut source = NtpSource::test_ntp_source(NoopController); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; let mut packet = NtpPacket::test(); packet.set_stratum(MAX_STRATUM + 1); packet.set_mode(NtpAssociationMode::Server); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_receive_timestamp(NtpTimestamp::from_fixed_int(100)); packet.set_transmit_timestamp(NtpTimestamp::from_fixed_int(200)); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(500), ); assert!(actions.next().is_none()); packet.set_stratum(0); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(500), ); assert!(actions.next().is_none()); } #[test] fn test_handle_kod() { let base = NtpInstant::now(); let mut source = NtpSource::test_ntp_source(NoopController); let mut packet = NtpPacket::test(); packet.set_reference_id(ReferenceId::KISS_RSTR); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(actions.next().is_none()); let mut packet = NtpPacket::test(); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; packet.set_reference_id(ReferenceId::KISS_RSTR); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(matches!(actions.next(), Some(NtpSourceAction::Demobilize))); let mut packet = NtpPacket::test(); packet.set_reference_id(ReferenceId::KISS_DENY); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(actions.next().is_none()); let mut packet = NtpPacket::test(); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; packet.set_reference_id(ReferenceId::KISS_DENY); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(matches!(actions.next(), Some(NtpSourceAction::Demobilize))); let old_remote_interval = source.remote_min_poll_interval; let mut packet = NtpPacket::test(); packet.set_reference_id(ReferenceId::KISS_RATE); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(actions.next().is_none()); assert_eq!(source.remote_min_poll_interval, old_remote_interval); let old_remote_interval = source.remote_min_poll_interval; let mut packet = NtpPacket::test(); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let outgoingbuf = outgoingbuf.unwrap(); let outgoing = NtpPacket::deserialize(&outgoingbuf, &NoCipher).unwrap().0; packet.set_reference_id(ReferenceId::KISS_RATE); packet.set_origin_timestamp(outgoing.transmit_timestamp()); packet.set_mode(NtpAssociationMode::Server); let mut actions = source.handle_incoming( &packet.serialize_without_encryption_vec(None).unwrap(), base + Duration::from_secs(1), NtpTimestamp::from_fixed_int(0), NtpTimestamp::from_fixed_int(100), ); assert!(actions.next().is_none()); assert!(source.remote_min_poll_interval >= old_remote_interval); } #[cfg(feature = "ntpv5")] #[test] fn upgrade_state_machine_does_stop() { let mut source = NtpSource::test_ntp_source(NoopController); let clock = TestClock {}; assert!(matches!( source.protocol_version, ProtocolVersion::V4UpgradingToV5 { .. } )); for _ in 0..8 { let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let poll_len: usize = poll.len(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 4); assert!(poll.is_upgrade()); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), poll, NtpTimestamp::default(), &clock, ); let mut response = response .serialize_without_encryption_vec(Some(poll_len)) .unwrap(); // Kill the reference timestamp response[16] = 0; let actions = source.handle_incoming( &response, NtpInstant::now(), NtpTimestamp::default(), NtpTimestamp::default(), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Demobilize | NtpSourceAction::Reset )); } } let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 4); assert!(!poll.is_upgrade()); } #[cfg(feature = "ntpv5")] #[test] fn upgrade_state_machine_does_upgrade() { let mut source = NtpSource::test_ntp_source(NoopController); let clock = TestClock {}; assert!(matches!( source.protocol_version, ProtocolVersion::V4UpgradingToV5 { .. } )); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let poll_len = poll.len(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 4); assert!(poll.is_upgrade()); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), poll, NtpTimestamp::default(), &clock, ); let response = response .serialize_without_encryption_vec(Some(poll_len)) .unwrap(); let actions = source.handle_incoming( &response, NtpInstant::now(), NtpTimestamp::default(), NtpTimestamp::default(), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Demobilize | NtpSourceAction::Reset )); } // We should have received a upgrade response and updated to NTPv5 assert!(matches!( source.protocol_version, ProtocolVersion::UpgradedToV5 )); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 5); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), poll, NtpTimestamp::default(), &clock, ); let response = response .serialize_without_encryption_vec(Some(poll_len)) .unwrap(); let actions = source.handle_incoming( &response, NtpInstant::now(), NtpTimestamp::default(), NtpTimestamp::default(), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Demobilize | NtpSourceAction::Reset )); } // NtpV5 is confirmed to work now assert!(matches!(source.protocol_version, ProtocolVersion::V5)); } #[cfg(feature = "ntpv5")] #[test] fn upgrade_state_machine_does_fallback_after_upgrade() { let mut source = NtpSource::test_ntp_source(NoopController); let clock = TestClock {}; assert!(matches!( source.protocol_version, ProtocolVersion::V4UpgradingToV5 { .. } )); let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let poll_len = poll.len(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 4); assert!(poll.is_upgrade()); let response = NtpPacket::timestamp_response( &SystemSnapshot::default(), poll, NtpTimestamp::default(), &clock, ); let response = response .serialize_without_encryption_vec(Some(poll_len)) .unwrap(); let actions = source.handle_incoming( &response, NtpInstant::now(), NtpTimestamp::default(), NtpTimestamp::default(), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Demobilize | NtpSourceAction::Reset )); } // We should have received a upgrade response and updated to NTPv5 assert!(matches!( source.protocol_version, ProtocolVersion::UpgradedToV5 )); for _ in 0..2 { let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert_eq!(poll.version(), 5); } let actions = source.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let poll = outgoingbuf.unwrap(); let (poll, _) = NtpPacket::deserialize(&poll, &NoCipher).unwrap(); assert!(matches!(source.protocol_version, ProtocolVersion::V4)); assert_eq!(poll.version(), 4); } #[cfg(feature = "ntpv5")] #[test] fn bloom_filters_will_synchronize_at_some_point() { let mut server_filter = BloomFilter::new(); server_filter.add_id(&ServerId::new(&mut thread_rng())); let mut client = NtpSource::test_ntp_source(NoopController); client.protocol_version = ProtocolVersion::V5; let clock = TestClock::default(); let server_system = SystemSnapshot { bloom_filter: server_filter, ..Default::default() }; let mut tries = 0; while client.bloom_filter.full_filter().is_none() && tries < 100 { let actions = client.handle_timer(); let mut outgoingbuf = None; for action in actions { assert!(!matches!( action, NtpSourceAction::Reset | NtpSourceAction::Demobilize )); if let NtpSourceAction::Send(buf) = action { outgoingbuf = Some(buf); } } let req = outgoingbuf.unwrap(); let (req, _) = NtpPacket::deserialize(&req, &NoCipher).unwrap(); let response = NtpPacket::timestamp_response(&server_system, req, NtpTimestamp::default(), &clock); let resp_bytes = response.serialize_without_encryption_vec(None).unwrap(); let actions = client.handle_incoming( &resp_bytes, NtpInstant::now(), NtpTimestamp::default(), NtpTimestamp::default(), ); for action in actions { assert!(!matches!( action, NtpSourceAction::Demobilize | NtpSourceAction::Reset )); } tries += 1; } assert_eq!(Some(&server_filter), client.bloom_filter.full_filter()); } } ntp-proto-1.4.0/src/system.rs000064400000000000000000000352731046102023000142630ustar 00000000000000use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; use std::{fmt::Debug, hash::Hash}; #[cfg(feature = "ntpv5")] use crate::packet::v5::server_reference_id::{BloomFilter, ServerId}; use crate::source::{NtpSourceUpdate, SourceSnapshot}; use crate::{ algorithm::{StateUpdate, TimeSyncController}, clock::NtpClock, config::{SourceDefaultsConfig, SynchronizationConfig}, identifiers::ReferenceId, packet::NtpLeapIndicator, source::{NtpSource, NtpSourceActionIterator, ProtocolVersion, SourceNtsData}, time_types::NtpDuration, }; use crate::{OneWaySource, OneWaySourceUpdate}; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub struct TimeSnapshot { /// Precision of the local clock pub precision: NtpDuration, /// Current root delay pub root_delay: NtpDuration, /// Current root dispersion pub root_dispersion: NtpDuration, /// Current leap indicator state pub leap_indicator: NtpLeapIndicator, /// Total amount that the clock has stepped pub accumulated_steps: NtpDuration, } impl Default for TimeSnapshot { fn default() -> Self { Self { precision: NtpDuration::from_exponent(-18), root_delay: NtpDuration::ZERO, root_dispersion: NtpDuration::ZERO, leap_indicator: NtpLeapIndicator::Unknown, accumulated_steps: NtpDuration::ZERO, } } } #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct SystemSnapshot { /// Log of the precision of the local clock pub stratum: u8, /// Reference ID of current primary time source pub reference_id: ReferenceId, /// Crossing this amount of stepping will cause a Panic pub accumulated_steps_threshold: Option, /// Timekeeping data #[serde(flatten)] pub time_snapshot: TimeSnapshot, #[cfg(feature = "ntpv5")] /// Bloom filter that contains all currently used time sources #[serde(skip)] pub bloom_filter: BloomFilter, #[cfg(feature = "ntpv5")] /// NTPv5 reference ID for this instance #[serde(skip)] pub server_id: ServerId, } impl SystemSnapshot { pub fn update_timedata(&mut self, timedata: TimeSnapshot, config: &SynchronizationConfig) { self.time_snapshot = timedata; self.accumulated_steps_threshold = config.accumulated_step_panic_threshold; } pub fn update_used_sources(&mut self, used_sources: impl Iterator) { let mut used_sources = used_sources.peekable(); if let Some(system_source_snapshot) = used_sources.peek() { let (stratum, source_id) = match system_source_snapshot { SourceSnapshot::Ntp(snapshot) => (snapshot.stratum, snapshot.source_id), SourceSnapshot::OneWay(snapshot) => (snapshot.stratum, snapshot.source_id), }; self.stratum = stratum.saturating_add(1); self.reference_id = source_id; } #[cfg(feature = "ntpv5")] { self.bloom_filter = BloomFilter::new(); for source in used_sources { if let SourceSnapshot::Ntp(source) = source { if let Some(bf) = &source.bloom_filter { self.bloom_filter.add(bf); } else if let ProtocolVersion::V5 = source.protocol_version { tracing::warn!("Using NTPv5 source without a bloom filter!"); } } } self.bloom_filter.add_id(&self.server_id); } } } impl Default for SystemSnapshot { fn default() -> Self { Self { stratum: 16, reference_id: ReferenceId::NONE, accumulated_steps_threshold: None, time_snapshot: TimeSnapshot::default(), #[cfg(feature = "ntpv5")] bloom_filter: BloomFilter::new(), #[cfg(feature = "ntpv5")] server_id: ServerId::new(&mut rand::thread_rng()), } } } pub struct SystemSourceUpdate { pub message: ControllerMessage, } impl std::fmt::Debug for SystemSourceUpdate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SystemSourceUpdate") .field("message", &self.message) .finish() } } impl Clone for SystemSourceUpdate { fn clone(&self) -> Self { Self { message: self.message.clone(), } } } #[derive(Debug, Clone)] #[allow(clippy::large_enum_variant)] pub enum SystemAction { UpdateSources(SystemSourceUpdate), SetTimer(Duration), } #[derive(Debug)] pub struct SystemActionIterator { iter: > as IntoIterator>::IntoIter, } impl Default for SystemActionIterator { fn default() -> Self { Self { iter: vec![].into_iter(), } } } impl From>> for SystemActionIterator { fn from(value: Vec>) -> Self { Self { iter: value.into_iter(), } } } impl Iterator for SystemActionIterator { type Item = SystemAction; fn next(&mut self) -> Option { self.iter.next() } } macro_rules! actions { [$($action:expr),*] => { { SystemActionIterator::from(vec![$($action),*]) } } } pub struct System { synchronization_config: SynchronizationConfig, source_defaults_config: SourceDefaultsConfig, system: SystemSnapshot, ip_list: Arc<[IpAddr]>, sources: HashMap>, controller: Controller, controller_took_control: bool, } impl> System { pub fn new( clock: Controller::Clock, synchronization_config: SynchronizationConfig, source_defaults_config: SourceDefaultsConfig, algorithm_config: Controller::AlgorithmConfig, ip_list: Arc<[IpAddr]>, ) -> Result::Error> { // Setup system snapshot let mut system = SystemSnapshot { stratum: synchronization_config.local_stratum, ..Default::default() }; if synchronization_config.local_stratum == 1 { // We are a stratum 1 server so mark our selves synchronized. system.time_snapshot.leap_indicator = NtpLeapIndicator::NoWarning; } Ok(System { synchronization_config, source_defaults_config, system, ip_list, sources: Default::default(), controller: Controller::new( clock, synchronization_config, source_defaults_config, algorithm_config, )?, controller_took_control: false, }) } pub fn system_snapshot(&self) -> SystemSnapshot { self.system } pub fn check_clock_access(&mut self) -> Result<(), ::Error> { self.ensure_controller_control() } fn ensure_controller_control(&mut self) -> Result<(), ::Error> { if !self.controller_took_control { self.controller.take_control()?; self.controller_took_control = true; } Ok(()) } pub fn create_sock_source( &mut self, id: SourceId, measurement_noise_estimate: f64, ) -> Result< OneWaySource, ::Error, > { self.ensure_controller_control()?; let controller = self .controller .add_one_way_source(id, measurement_noise_estimate); self.sources.insert(id, None); Ok(OneWaySource::new(controller)) } #[allow(clippy::type_complexity)] pub fn create_ntp_source( &mut self, id: SourceId, source_addr: SocketAddr, protocol_version: ProtocolVersion, nts: Option>, ) -> Result< ( NtpSource, NtpSourceActionIterator, ), ::Error, > { self.ensure_controller_control()?; let controller = self.controller.add_source(id); self.sources.insert(id, None); Ok(NtpSource::new( source_addr, self.source_defaults_config, protocol_version, controller, nts, )) } pub fn handle_source_remove( &mut self, id: SourceId, ) -> Result<(), ::Error> { self.controller.remove_source(id); self.sources.remove(&id); Ok(()) } pub fn handle_source_update( &mut self, id: SourceId, update: NtpSourceUpdate, ) -> Result< SystemActionIterator, ::Error, > { let usable = update .snapshot .accept_synchronization( self.synchronization_config.local_stratum, self.ip_list.as_ref(), &self.system, ) .is_ok(); self.controller.source_update(id, usable); *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::Ntp(update.snapshot)); if let Some(message) = update.message { let update = self.controller.source_message(id, message); Ok(self.handle_algorithm_state_update(update)) } else { Ok(actions!()) } } pub fn handle_one_way_source_update( &mut self, id: SourceId, update: OneWaySourceUpdate, ) -> Result< SystemActionIterator, ::Error, > { self.controller.source_update(id, true); *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::OneWay(update.snapshot)); if let Some(message) = update.message { let update = self.controller.source_message(id, message); Ok(self.handle_algorithm_state_update(update)) } else { Ok(actions!()) } } fn handle_algorithm_state_update( &mut self, update: StateUpdate, ) -> SystemActionIterator { let mut actions = vec![]; if let Some(ref used_sources) = update.used_sources { self.system .update_used_sources(used_sources.iter().map(|v| { self.sources.get(v).and_then(|snapshot| *snapshot).expect( "Critical error: Source used for synchronization that is not known to system", ) })); } if let Some(time_snapshot) = update.time_snapshot { self.system .update_timedata(time_snapshot, &self.synchronization_config); } if let Some(timeout) = update.next_update { actions.push(SystemAction::SetTimer(timeout)); } if let Some(message) = update.source_message { actions.push(SystemAction::UpdateSources(SystemSourceUpdate { message })) } actions.into() } pub fn handle_timer(&mut self) -> SystemActionIterator { tracing::debug!("Timer expired"); let update = self.controller.time_update(); self.handle_algorithm_state_update(update) } pub fn update_ip_list(&mut self, ip_list: Arc<[IpAddr]>) { self.ip_list = ip_list; } } #[cfg(test)] mod tests { use std::net::{Ipv4Addr, SocketAddr}; use crate::{time_types::PollIntervalLimits, NtpSourceSnapshot}; use super::*; #[test] fn test_empty_source_update() { let mut system = SystemSnapshot::default(); // Should do nothing system.update_used_sources(std::iter::empty()); assert_eq!(system.stratum, 16); assert_eq!(system.reference_id, ReferenceId::NONE); } #[test] fn test_source_update() { let mut system = SystemSnapshot::default(); system.update_used_sources( vec![ SourceSnapshot::Ntp(NtpSourceSnapshot { source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), source_id: ReferenceId::KISS_DENY, poll_interval: PollIntervalLimits::default().max, reach: Default::default(), stratum: 2, reference_id: ReferenceId::NONE, protocol_version: Default::default(), #[cfg(feature = "ntpv5")] bloom_filter: None, }), SourceSnapshot::Ntp(NtpSourceSnapshot { source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), source_id: ReferenceId::KISS_RATE, poll_interval: PollIntervalLimits::default().max, reach: Default::default(), stratum: 3, reference_id: ReferenceId::NONE, protocol_version: Default::default(), #[cfg(feature = "ntpv5")] bloom_filter: None, }), ] .into_iter(), ); assert_eq!(system.stratum, 3); assert_eq!(system.reference_id, ReferenceId::KISS_DENY); } #[test] fn test_timedata_update() { let mut system = SystemSnapshot::default(); let new_root_delay = NtpDuration::from_seconds(1.0); let new_accumulated_threshold = NtpDuration::from_seconds(2.0); let snapshot = TimeSnapshot { root_delay: new_root_delay, ..Default::default() }; system.update_timedata( snapshot, &SynchronizationConfig { accumulated_step_panic_threshold: Some(new_accumulated_threshold), ..Default::default() }, ); assert_eq!(system.time_snapshot, snapshot); assert_eq!( system.accumulated_steps_threshold, Some(new_accumulated_threshold), ); } } ntp-proto-1.4.0/src/time_types.rs000064400000000000000000000701471046102023000151200ustar 00000000000000use rand::{ distributions::{Distribution, Standard}, Rng, }; use serde::{de::Unexpected, Deserialize, Serialize}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::time::{Duration, Instant}; /// NtpInstant is a monotonically increasing value modelling the uptime of the NTP service /// /// It is used to validate packets that we send out, and to order internal operations. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)] pub struct NtpInstant { instant: Instant, } impl NtpInstant { pub fn now() -> Self { Self { instant: Instant::now(), } } pub fn abs_diff(self, rhs: Self) -> NtpDuration { // our code should always give the bigger argument first. debug_assert!( self >= rhs, "self >= rhs, this could indicate another program adjusted the clock" ); // NOTE: `std::time::Duration` cannot be negative, so a simple `lhs - rhs` could give an // empty duration. In our logic, we're always interested in the absolute delta between two // points in time. let duration = if self.instant >= rhs.instant { self.instant - rhs.instant } else { rhs.instant - self.instant }; NtpDuration::from_system_duration(duration) } pub fn elapsed(&self) -> std::time::Duration { self.instant.elapsed() } } impl Add for NtpInstant { type Output = NtpInstant; fn add(mut self, rhs: Duration) -> Self::Output { self.instant += rhs; self } } /// NtpTimestamp represents an ntp timestamp without the era number. #[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Default, Serialize, Deserialize)] pub struct NtpTimestamp { timestamp: u64, } impl std::fmt::Debug for NtpTimestamp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("NtpTimestamp") .field(&self.timestamp) .finish() } } impl NtpTimestamp { pub(crate) const fn from_bits(bits: [u8; 8]) -> NtpTimestamp { NtpTimestamp { timestamp: u64::from_be_bytes(bits), } } pub(crate) const fn to_bits(self) -> [u8; 8] { self.timestamp.to_be_bytes() } /// Create an NTP timestamp from the number of seconds and nanoseconds that have /// passed since the last ntp era boundary. pub const fn from_seconds_nanos_since_ntp_era(seconds: u32, nanos: u32) -> Self { // Although having a valid interpretation, providing more // than 1 second worth of nanoseconds as input probably // indicates an error from the caller. debug_assert!(nanos < 1_000_000_000); // NTP uses 1/2^32 sec as its unit of fractional time. // our time is in nanoseconds, so 1/1e9 seconds let fraction = ((nanos as u64) << 32) / 1_000_000_000; // alternatively, abuse FP arithmetic to save an instruction // let fraction = (nanos as f64 * 4.294967296) as u64; let timestamp = ((seconds as u64) << 32) + fraction; NtpTimestamp::from_bits(timestamp.to_be_bytes()) } pub fn is_before(self, other: NtpTimestamp) -> bool { // Around an era change, self can be near the maximum value // for NtpTimestamp and other near the minimum, and that must // be interpreted as self being before other (which it is due // to wrapping in subtraction of NtpTimestamp) self - other < NtpDuration::ZERO } #[cfg(test)] pub(crate) const fn from_fixed_int(timestamp: u64) -> NtpTimestamp { NtpTimestamp { timestamp } } } // In order to provide increased entropy on origin timestamps, // we should generate these randomly. This helps avoid // attacks from attackers guessing our current time. impl Distribution for Standard { fn sample(&self, rng: &mut R) -> NtpTimestamp { NtpTimestamp { timestamp: rng.gen(), } } } impl Add for NtpTimestamp { type Output = NtpTimestamp; fn add(self, rhs: NtpDuration) -> Self::Output { // In order to properly deal with ntp era changes, timestamps // need to roll over. Converting the duration to u64 here // still gives desired effects because of how two's complement // arithmetic works. NtpTimestamp { timestamp: self.timestamp.wrapping_add(rhs.duration as u64), } } } impl AddAssign for NtpTimestamp { fn add_assign(&mut self, rhs: NtpDuration) { // In order to properly deal with ntp era changes, timestamps // need to roll over. Converting the duration to u64 here // still gives desired effects because of how two's complement // arithmetic works. self.timestamp = self.timestamp.wrapping_add(rhs.duration as u64); } } impl Sub for NtpTimestamp { type Output = NtpDuration; fn sub(self, rhs: Self) -> Self::Output { // In order to properly deal with ntp era changes, timestamps // need to roll over. Doing a wrapping subtract to a signed // integer type always gives us the result as if the eras of // the timestamps were chosen to minimize the norm of the // difference, which is the desired behaviour NtpDuration { duration: self.timestamp.wrapping_sub(rhs.timestamp) as i64, } } } impl Sub for NtpTimestamp { type Output = NtpTimestamp; fn sub(self, rhs: NtpDuration) -> Self::Output { // In order to properly deal with ntp era changes, timestamps // need to roll over. Converting the duration to u64 here // still gives desired effects because of how two's complement // arithmetic works. NtpTimestamp { timestamp: self.timestamp.wrapping_sub(rhs.duration as u64), } } } impl SubAssign for NtpTimestamp { fn sub_assign(&mut self, rhs: NtpDuration) { // In order to properly deal with ntp era changes, timestamps // need to roll over. Converting the duration to u64 here // still gives desired effects because of how two's complement // arithmetic works. self.timestamp = self.timestamp.wrapping_sub(rhs.duration as u64); } } /// NtpDuration is used to represent signed intervals between NtpTimestamps. /// A negative duration interval is interpreted to mean that the first /// timestamp used to define the interval represents a point in time after /// the second timestamp. #[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Default)] pub struct NtpDuration { duration: i64, } impl std::fmt::Debug for NtpDuration { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "NtpDuration({} ms)", self.to_seconds() * 1e3) } } impl NtpDuration { pub const ZERO: Self = Self { duration: 0 }; pub const MAX: Self = Self { duration: i64::MAX }; pub(crate) const fn from_bits(bits: [u8; 8]) -> Self { Self { duration: i64::from_be_bytes(bits), } } pub(crate) const fn from_bits_short(bits: [u8; 4]) -> Self { NtpDuration { duration: (u32::from_be_bytes(bits) as i64) << 16, } } pub(crate) const fn to_bits_short(self) -> [u8; 4] { // serializing negative durations should never happen // and indicates a programming error elsewhere. // as for duration that are too large, saturating is // the safe option. assert!(self.duration >= 0); // Although saturating is safe to do, it probably still // should never happen in practice, so ensure we will // see it when running in debug mode. debug_assert!(self.duration <= 0x0000FFFFFFFFFFFF); match self.duration > 0x0000FFFFFFFFFFFF { true => 0xFFFFFFFF_u32, false => ((self.duration & 0x0000FFFFFFFF0000) >> 16) as u32, } .to_be_bytes() } #[cfg(feature = "ntpv5")] pub(crate) const fn from_bits_time32(bits: [u8; 4]) -> Self { NtpDuration { duration: (u32::from_be_bytes(bits) as i64) << 4, } } #[cfg(feature = "ntpv5")] pub(crate) fn to_bits_time32(self) -> [u8; 4] { // serializing negative durations should never happen // and indicates a programming error elsewhere. // as for duration that are too large, saturating is // the safe option. assert!(self.duration >= 0); // On overflow we just saturate to the maximum 16s u32::try_from(self.duration >> 4) .unwrap_or(u32::MAX) .to_be_bytes() } /// Convert to an f64; required for statistical calculations /// (e.g. in clock filtering) pub fn to_seconds(self) -> f64 { // dividing by u32::MAX moves the decimal point to the right position self.duration as f64 / u32::MAX as f64 } pub fn from_seconds(seconds: f64) -> Self { debug_assert!(!(seconds.is_nan() || seconds.is_infinite())); let i = seconds.floor(); let f = seconds - i; // Ensure proper saturating behaviour let duration = match i as i64 { i if i >= i32::MIN as i64 && i <= i32::MAX as i64 => { (i << 32) | (f * u32::MAX as f64) as i64 } i if i < i32::MIN as i64 => i64::MIN, i if i > i32::MAX as i64 => i64::MAX, _ => unreachable!(), }; Self { duration } } /// Interval of same length, but positive direction pub const fn abs(self) -> Self { Self { duration: self.duration.abs(), } } /// Interval of same length, but positive direction pub fn abs_diff(self, other: Self) -> Self { (self - other).abs() } /// Get the number of seconds (first return value) and nanoseconds /// (second return value) representing the length of this duration. /// The number of nanoseconds is guaranteed to be positive and less /// than 10^9 pub const fn as_seconds_nanos(self) -> (i32, u32) { ( (self.duration >> 32) as i32, (((self.duration & 0xFFFFFFFF) * 1_000_000_000) >> 32) as u32, ) } /// Interpret an exponent `k` as `2^k` seconds, expressed as an NtpDuration pub const fn from_exponent(input: i8) -> Self { Self { duration: match input { exp if exp > 30 => i64::MAX, exp if exp > 0 && exp <= 30 => 0x1_0000_0000_i64 << exp, exp if exp >= -32 && exp <= 0 => 0x1_0000_0000_i64 >> -exp, _ => 0, }, } } /// calculate the log2 (floored) of the duration in seconds (i8::MIN if 0) pub fn log2(self) -> i8 { if self == NtpDuration::ZERO { return i8::MIN; } 31 - (self.duration.leading_zeros() as i8) } pub fn from_system_duration(duration: Duration) -> Self { let seconds = duration.as_secs(); let nanos = duration.subsec_nanos(); // Although having a valid interpretation, providing more // than 1 second worth of nanoseconds as input probably // indicates an error from the caller. debug_assert!(nanos < 1_000_000_000); // NTP uses 1/2^32 sec as its unit of fractional time. // our time is in nanoseconds, so 1/1e9 seconds let fraction = ((nanos as u64) << 32) / 1_000_000_000; // alternatively, abuse FP arithmetic to save an instruction // let fraction = (nanos as f64 * 4.294967296) as u64; let timestamp = (seconds << 32) + fraction; NtpDuration::from_bits(timestamp.to_be_bytes()) } #[cfg(test)] pub(crate) const fn from_fixed_int(duration: i64) -> NtpDuration { NtpDuration { duration } } } impl Serialize for NtpDuration { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { let seconds = self.to_seconds(); seconds.serialize(serializer) } } impl<'de> Deserialize<'de> for NtpDuration { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let seconds: f64 = Deserialize::deserialize(deserializer)?; if seconds.is_nan() || seconds.is_infinite() { return Err(serde::de::Error::invalid_value( Unexpected::Float(seconds), &"a valid number", )); } Ok(NtpDuration::from_seconds(seconds)) } } impl Add for NtpDuration { type Output = NtpDuration; fn add(self, rhs: Self) -> Self::Output { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. NtpDuration { duration: self.duration.saturating_add(rhs.duration), } } } impl AddAssign for NtpDuration { fn add_assign(&mut self, rhs: Self) { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. self.duration = self.duration.saturating_add(rhs.duration); } } impl Sub for NtpDuration { type Output = NtpDuration; fn sub(self, rhs: Self) -> Self::Output { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. NtpDuration { duration: self.duration.saturating_sub(rhs.duration), } } } impl SubAssign for NtpDuration { fn sub_assign(&mut self, rhs: Self) { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. self.duration = self.duration.saturating_sub(rhs.duration); } } impl Neg for NtpDuration { type Output = NtpDuration; fn neg(self) -> Self::Output { NtpDuration { duration: -self.duration, } } } macro_rules! ntp_duration_scalar_mul { ($scalar_type:ty) => { impl Mul for $scalar_type { type Output = NtpDuration; fn mul(self, rhs: NtpDuration) -> NtpDuration { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. NtpDuration { duration: rhs.duration.saturating_mul(self as i64), } } } impl Mul<$scalar_type> for NtpDuration { type Output = NtpDuration; fn mul(self, rhs: $scalar_type) -> NtpDuration { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. NtpDuration { duration: self.duration.saturating_mul(rhs as i64), } } } impl MulAssign<$scalar_type> for NtpDuration { fn mul_assign(&mut self, rhs: $scalar_type) { // For duration, saturation is safer as that ensures // addition or subtraction of two big durations never // unintentionally cancel, ensuring that filtering // can properly reject on the result. self.duration = self.duration.saturating_mul(rhs as i64); } } }; } ntp_duration_scalar_mul!(i8); ntp_duration_scalar_mul!(i16); ntp_duration_scalar_mul!(i32); ntp_duration_scalar_mul!(i64); ntp_duration_scalar_mul!(isize); ntp_duration_scalar_mul!(u8); ntp_duration_scalar_mul!(u16); ntp_duration_scalar_mul!(u32); // u64 and usize deliberately excluded as they can result in overflows macro_rules! ntp_duration_scalar_div { ($scalar_type:ty) => { impl Div<$scalar_type> for NtpDuration { type Output = NtpDuration; fn div(self, rhs: $scalar_type) -> NtpDuration { // No overflow risks for division NtpDuration { duration: self.duration / (rhs as i64), } } } impl DivAssign<$scalar_type> for NtpDuration { fn div_assign(&mut self, rhs: $scalar_type) { // No overflow risks for division self.duration /= (rhs as i64); } } }; } ntp_duration_scalar_div!(i8); ntp_duration_scalar_div!(i16); ntp_duration_scalar_div!(i32); ntp_duration_scalar_div!(i64); ntp_duration_scalar_div!(isize); ntp_duration_scalar_div!(u8); ntp_duration_scalar_div!(u16); ntp_duration_scalar_div!(u32); // u64 and usize deliberately excluded as they can result in overflows /// Stores when we will next exchange packages with a remote server. // // The value is in seconds stored in log2 format: // // - a value of 4 means 2^4 = 16 seconds // - a value of 17 is 2^17 = ~36h #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct PollInterval(i8); #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct PollIntervalLimits { pub min: PollInterval, pub max: PollInterval, } // here we follow the spec (the code skeleton and ntpd repository use different values) // with the exception that we have lowered the MAX value, which is needed because // we don't support bursting, and hence using a larger poll interval gives issues // with the responsiveness of the client to environmental changes impl Default for PollIntervalLimits { fn default() -> Self { Self { min: PollInterval(4), max: PollInterval(10), } } } impl std::fmt::Debug for PollInterval { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "PollInterval({} s)", 2.0_f64.powf(self.0 as _)) } } impl PollInterval { pub const NEVER: PollInterval = PollInterval(i8::MAX); #[cfg(test)] pub fn test_new(value: i8) -> Self { Self(value) } pub fn from_byte(value: u8) -> Self { Self(value as i8) } pub fn as_byte(self) -> u8 { self.0 as u8 } #[must_use] pub fn inc(self, limits: PollIntervalLimits) -> Self { Self(self.0 + 1).min(limits.max) } #[must_use] pub fn force_inc(self) -> Self { Self(self.0.saturating_add(1)) } #[must_use] pub fn dec(self, limits: PollIntervalLimits) -> Self { Self(self.0 - 1).max(limits.min) } pub const fn as_log(self) -> i8 { self.0 } pub const fn as_duration(self) -> NtpDuration { NtpDuration { duration: 1 << (self.0 + 32), } } pub const fn as_system_duration(self) -> Duration { Duration::from_secs(1 << self.0) } } impl Default for PollInterval { fn default() -> Self { Self(4) } } /// Frequency tolerance PHI (unit: seconds per second) #[derive(Debug, Clone, Copy)] pub struct FrequencyTolerance { ppm: u32, } impl<'de> Deserialize<'de> for FrequencyTolerance { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let val: u32 = Deserialize::deserialize(deserializer)?; Ok(FrequencyTolerance { ppm: val }) } } impl FrequencyTolerance { pub const fn ppm(ppm: u32) -> Self { Self { ppm } } } impl Mul for NtpDuration { type Output = NtpDuration; fn mul(self, rhs: FrequencyTolerance) -> Self::Output { (self * rhs.ppm) / 1_000_000 } } #[cfg(feature = "__internal-fuzz")] pub fn fuzz_duration_from_seconds(v: f64) { if v.is_finite() { let duration = NtpDuration::from_seconds(v); assert!(v.signum() as i64 * duration.duration.signum() >= 0); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_timestamp_sub() { let a = NtpTimestamp::from_fixed_int(5); let b = NtpTimestamp::from_fixed_int(3); assert_eq!(a - b, NtpDuration::from_fixed_int(2)); assert_eq!(b - a, NtpDuration::from_fixed_int(-2)); } #[test] fn test_timestamp_era_change() { let mut a = NtpTimestamp::from_fixed_int(1); let b = NtpTimestamp::from_fixed_int(0xFFFFFFFFFFFFFFFF); assert_eq!(a - b, NtpDuration::from_fixed_int(2)); assert_eq!(b - a, NtpDuration::from_fixed_int(-2)); let c = NtpDuration::from_fixed_int(2); let d = NtpDuration::from_fixed_int(-2); assert_eq!(b + c, a); assert_eq!(b - d, a); assert_eq!(a - c, b); assert_eq!(a + d, b); a -= c; assert_eq!(a, b); a += c; assert_eq!(a, NtpTimestamp::from_fixed_int(1)); } #[test] fn test_timestamp_from_seconds_nanos() { assert_eq!( NtpTimestamp::from_seconds_nanos_since_ntp_era(0, 500_000_000), NtpTimestamp::from_fixed_int(0x80000000) ); assert_eq!( NtpTimestamp::from_seconds_nanos_since_ntp_era(1, 0), NtpTimestamp::from_fixed_int(1 << 32) ); } #[test] fn test_timestamp_duration_math() { let mut a = NtpTimestamp::from_fixed_int(5); let b = NtpDuration::from_fixed_int(2); assert_eq!(a + b, NtpTimestamp::from_fixed_int(7)); assert_eq!(a - b, NtpTimestamp::from_fixed_int(3)); a += b; assert_eq!(a, NtpTimestamp::from_fixed_int(7)); a -= b; assert_eq!(a, NtpTimestamp::from_fixed_int(5)); } #[test] fn test_duration_as_seconds_nanos() { assert_eq!( NtpDuration::from_fixed_int(0x80000000).as_seconds_nanos(), (0, 500_000_000) ); assert_eq!( NtpDuration::from_fixed_int(1 << 33).as_seconds_nanos(), (2, 0) ); } #[test] fn test_duration_math() { let mut a = NtpDuration::from_fixed_int(5); let b = NtpDuration::from_fixed_int(2); assert_eq!(a + b, NtpDuration::from_fixed_int(7)); assert_eq!(a - b, NtpDuration::from_fixed_int(3)); a += b; assert_eq!(a, NtpDuration::from_fixed_int(7)); a -= b; assert_eq!(a, NtpDuration::from_fixed_int(5)); } macro_rules! ntp_duration_scaling_test { ($name:ident, $scalar_type:ty) => { #[test] fn $name() { let mut a = NtpDuration::from_fixed_int(31); let b: $scalar_type = 2; assert_eq!(a * b, NtpDuration::from_fixed_int(62)); assert_eq!(b * a, NtpDuration::from_fixed_int(62)); assert_eq!(a / b, NtpDuration::from_fixed_int(15)); a /= b; assert_eq!(a, NtpDuration::from_fixed_int(15)); a *= b; assert_eq!(a, NtpDuration::from_fixed_int(30)); } }; } ntp_duration_scaling_test!(ntp_duration_scaling_i8, i8); ntp_duration_scaling_test!(ntp_duration_scaling_i16, i16); ntp_duration_scaling_test!(ntp_duration_scaling_i32, i32); ntp_duration_scaling_test!(ntp_duration_scaling_i64, i64); ntp_duration_scaling_test!(ntp_duration_scaling_isize, isize); ntp_duration_scaling_test!(ntp_duration_scaling_u8, u8); ntp_duration_scaling_test!(ntp_duration_scaling_u16, u16); ntp_duration_scaling_test!(ntp_duration_scaling_u32, u32); macro_rules! assert_eq_epsilon { ($a:expr, $b:expr, $epsilon:expr) => { assert!( ($a - $b).abs() < $epsilon, "Left not nearly equal to right:\nLeft: {}\nRight: {}\n", $a, $b ); }; } #[test] fn duration_seconds_roundtrip() { assert_eq_epsilon!(NtpDuration::from_seconds(0.0).to_seconds(), 0.0, 1e-9); assert_eq_epsilon!(NtpDuration::from_seconds(1.0).to_seconds(), 1.0, 1e-9); assert_eq_epsilon!(NtpDuration::from_seconds(1.5).to_seconds(), 1.5, 1e-9); assert_eq_epsilon!(NtpDuration::from_seconds(2.0).to_seconds(), 2.0, 1e-9); } #[test] fn duration_from_exponent() { assert_eq_epsilon!(NtpDuration::from_exponent(0).to_seconds(), 1.0, 1e-9); assert_eq_epsilon!(NtpDuration::from_exponent(1).to_seconds(), 2.0, 1e-9); assert_eq_epsilon!( NtpDuration::from_exponent(17).to_seconds(), 2.0f64.powi(17), 1e-4 // Less precision due to larger exponent ); assert_eq_epsilon!(NtpDuration::from_exponent(-1).to_seconds(), 0.5, 1e-9); assert_eq_epsilon!( NtpDuration::from_exponent(-5).to_seconds(), 1.0 / 2.0f64.powi(5), 1e-9 ); } #[test] fn duration_from_exponent_reasonable() { for i in -32..=127 { assert!(NtpDuration::from_exponent(i) > NtpDuration::from_fixed_int(0)); } for i in -128..-32 { NtpDuration::from_exponent(i); // should not crash } } #[test] fn duration_from_float_seconds_saturates() { assert_eq!( NtpDuration::from_seconds(1e40), NtpDuration::from_fixed_int(i64::MAX) ); assert_eq!( NtpDuration::from_seconds(-1e40), NtpDuration::from_fixed_int(i64::MIN) ); } #[test] fn poll_interval_clamps() { let mut interval = PollInterval::default(); let limits = PollIntervalLimits::default(); for _ in 0..100 { interval = interval.inc(limits); assert!(interval <= limits.max); } for _ in 0..100 { interval = interval.dec(limits); assert!(interval >= limits.min); } for _ in 0..100 { interval = interval.inc(limits); assert!(interval <= limits.max); } } #[test] fn poll_interval_to_duration() { assert_eq!( PollInterval(4).as_duration(), NtpDuration::from_fixed_int(16 << 32) ); assert_eq!( PollInterval(5).as_duration(), NtpDuration::from_fixed_int(32 << 32) ); let mut interval = PollInterval::default(); for _ in 0..100 { assert_eq!( interval.as_duration().as_seconds_nanos().0, interval.as_system_duration().as_secs() as i32 ); interval = interval.inc(PollIntervalLimits::default()); } for _ in 0..100 { assert_eq!( interval.as_duration().as_seconds_nanos().0, interval.as_system_duration().as_secs() as i32 ); interval = interval.dec(PollIntervalLimits::default()); } } #[test] fn frequency_tolerance() { assert_eq!( NtpDuration::from_seconds(1.0), NtpDuration::from_seconds(1.0) * FrequencyTolerance::ppm(1_000_000), ); } #[test] #[cfg(feature = "ntpv5")] fn time32() { type D = NtpDuration; assert_eq!(D::from_bits_time32([0, 0, 0, 0]), D::ZERO); assert_eq!(D::from_bits_time32([0x10, 0, 0, 0]), D::from_seconds(1.0)); assert_eq!(D::from_bits_time32([0, 0, 0, 1]).as_seconds_nanos(), (0, 3)); assert_eq!( D::from_bits_time32([0, 0, 0, 10]).as_seconds_nanos(), (0, 37) ); assert_eq!(D::from_seconds(16.0).to_bits_time32(), [0xFF; 4]); assert_eq!(D { duration: 0xF }.to_bits_time32(), [0; 4]); assert_eq!(D { duration: 0x1F }.to_bits_time32(), [0, 0, 0, 1]); for i in 0..u8::MAX { let mut bits = [i, i, i, i]; for (idx, b) in bits.iter_mut().enumerate() { *b = b.wrapping_add(idx as u8); } let d = D::from_bits_time32(bits); let out_bits = d.to_bits_time32(); assert_eq!(bits, out_bits); } } } ntp-proto-1.4.0/src/tls_utils.rs000064400000000000000000000257231046102023000147600ustar 00000000000000#[cfg(feature = "rustls23")] mod rustls23_shim { /// The intent of this ClientCertVerifier is that it accepts any connections that are either /// a.) not presenting a client certificate /// b.) are presenting a well-formed, but otherwise not checked (against a trust root) client certificate /// /// This is because RusTLS apparently doesn't accept every kind of self-signed certificate. /// /// The only goal of this ClientCertVerifier is to achieve that, if a client presents a TLS certificate, /// this certificate shows up in the .peer_certificates() for that connection. #[cfg(feature = "nts-pool")] #[derive(Debug)] pub struct AllowAnyAnonymousOrCertificateBearingClient { supported_algs: WebPkiSupportedAlgorithms, } #[cfg(feature = "nts-pool")] use rustls23::{ crypto::{CryptoProvider, WebPkiSupportedAlgorithms}, pki_types::CertificateDer, server::danger::ClientCertVerified, }; #[cfg(feature = "nts-pool")] impl AllowAnyAnonymousOrCertificateBearingClient { pub fn new(provider: &CryptoProvider) -> Self { AllowAnyAnonymousOrCertificateBearingClient { supported_algs: provider.signature_verification_algorithms, } } } #[cfg(feature = "nts-pool")] impl rustls23::server::danger::ClientCertVerifier for AllowAnyAnonymousOrCertificateBearingClient { fn verify_client_cert( &self, _end_entity: &CertificateDer, _intermediates: &[CertificateDer], _now: rustls23::pki_types::UnixTime, ) -> Result { Ok(ClientCertVerified::assertion()) } fn client_auth_mandatory(&self) -> bool { false } fn root_hint_subjects(&self) -> &[rustls23::DistinguishedName] { &[] } fn verify_tls12_signature( &self, message: &[u8], cert: &rustls23::pki_types::CertificateDer<'_>, dss: &rustls23::DigitallySignedStruct, ) -> Result { rustls23::crypto::verify_tls12_signature(message, cert, dss, &self.supported_algs) } fn verify_tls13_signature( &self, message: &[u8], cert: &rustls23::pki_types::CertificateDer<'_>, dss: &rustls23::DigitallySignedStruct, ) -> Result { rustls23::crypto::verify_tls13_signature(message, cert, dss, &self.supported_algs) } fn supported_verify_schemes(&self) -> Vec { self.supported_algs.supported_schemes() } } pub use rustls23::pki_types::InvalidDnsNameError; pub use rustls23::pki_types::ServerName; pub use rustls23::server::NoClientAuth; pub use rustls23::version::TLS13; pub use rustls23::ClientConfig; pub use rustls23::ClientConnection; pub use rustls23::ConnectionCommon; pub use rustls23::Error; pub use rustls23::RootCertStore; pub use rustls23::ServerConfig; pub use rustls23::ServerConnection; pub type Certificate = rustls23::pki_types::CertificateDer<'static>; pub type PrivateKey = rustls23::pki_types::PrivateKeyDer<'static>; pub mod pemfile { pub use rustls_native_certs7::load_native_certs; pub use rustls_pemfile2::certs; pub use rustls_pemfile2::pkcs8_private_keys; pub use rustls_pemfile2::private_key; pub fn rootstore_ref_shim(cert: &super::Certificate) -> super::Certificate { cert.clone() } } pub trait CloneKeyShim {} pub fn client_config_builder( ) -> rustls23::ConfigBuilder { ClientConfig::builder() } pub fn client_config_builder_with_protocol_versions( versions: &[&'static rustls23::SupportedProtocolVersion], ) -> rustls23::ConfigBuilder { ClientConfig::builder_with_protocol_versions(versions) } pub fn server_config_builder( ) -> rustls23::ConfigBuilder { ServerConfig::builder() } pub fn server_config_builder_with_protocol_versions( versions: &[&'static rustls23::SupportedProtocolVersion], ) -> rustls23::ConfigBuilder { ServerConfig::builder_with_protocol_versions(versions) } } #[cfg(feature = "rustls22")] mod rustls22_shim { pub use rustls22::server::NoClientAuth; pub use rustls22::version::TLS13; pub use rustls22::ClientConfig; pub use rustls22::ClientConnection; pub use rustls22::ConnectionCommon; pub use rustls22::Error; pub use rustls22::RootCertStore; pub use rustls22::ServerConfig; pub use rustls22::ServerConnection; pub use rustls_pki_types::InvalidDnsNameError; pub use rustls_pki_types::ServerName; pub type Certificate = rustls_pki_types::CertificateDer<'static>; pub type PrivateKey = rustls_pki_types::PrivateKeyDer<'static>; pub mod pemfile { pub use rustls_native_certs7::load_native_certs; pub use rustls_pemfile2::certs; pub use rustls_pemfile2::pkcs8_private_keys; pub use rustls_pemfile2::private_key; pub fn rootstore_ref_shim(cert: &super::Certificate) -> super::Certificate { cert.clone() } } pub trait CloneKeyShim {} pub fn client_config_builder( ) -> rustls22::ConfigBuilder { ClientConfig::builder() } pub fn client_config_builder_with_protocol_versions( versions: &[&'static rustls22::SupportedProtocolVersion], ) -> rustls22::ConfigBuilder { ClientConfig::builder_with_protocol_versions(versions) } pub fn server_config_builder( ) -> rustls22::ConfigBuilder { ServerConfig::builder() } pub fn server_config_builder_with_protocol_versions( versions: &[&'static rustls22::SupportedProtocolVersion], ) -> rustls22::ConfigBuilder { ServerConfig::builder_with_protocol_versions(versions) } } #[cfg(feature = "rustls21")] mod rustls21_shim { pub use rustls21::client::InvalidDnsNameError; pub use rustls21::client::ServerName; pub use rustls21::server::NoClientAuth; pub use rustls21::version::TLS13; pub use rustls21::Certificate; pub use rustls21::ClientConfig; pub use rustls21::ClientConnection; pub use rustls21::ConnectionCommon; pub use rustls21::Error; pub use rustls21::PrivateKey; pub use rustls21::RootCertStore; pub use rustls21::ServerConfig; pub use rustls21::ServerConnection; pub fn client_config_builder( ) -> rustls21::ConfigBuilder { ClientConfig::builder().with_safe_defaults() } pub fn server_config_builder( ) -> rustls21::ConfigBuilder { ServerConfig::builder().with_safe_defaults() } pub fn client_config_builder_with_protocol_versions( versions: &[&'static rustls21::SupportedProtocolVersion], ) -> rustls21::ConfigBuilder { // Expect is ok here as this should never fail (not user controlled) ClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(versions) .expect("Could not set protocol versions") } pub fn server_config_builder_with_protocol_versions( versions: &[&'static rustls21::SupportedProtocolVersion], ) -> rustls21::ConfigBuilder { // Expect is ok here as this should never fail (not user controlled) ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(versions) .expect("Could not set protocol versions") } pub trait CloneKeyShim { fn clone_key(&self) -> Self; } impl CloneKeyShim for PrivateKey { fn clone_key(&self) -> Self { self.clone() } } pub mod pemfile { enum Either { L(T), R(U), } impl Iterator for Either where T: Iterator, U: Iterator, { type Item = V; fn next(&mut self) -> Option { match self { Self::L(l) => l.next(), Self::R(r) => r.next(), } } } pub fn certs( rd: &mut dyn std::io::BufRead, ) -> impl Iterator> { match rustls_pemfile1::certs(rd) { Ok(v) => Either::L(v.into_iter().map(super::Certificate).map(Ok)), Err(e) => Either::R(core::iter::once(Err(e))), } } pub fn pkcs8_private_keys( rd: &mut dyn std::io::BufRead, ) -> impl Iterator> { match rustls_pemfile1::pkcs8_private_keys(rd) { Ok(v) => Either::L(v.into_iter().map(super::PrivateKey).map(Ok)), Err(e) => Either::R(core::iter::once(Err(e))), } } pub fn private_key( rd: &mut dyn std::io::BufRead, ) -> Result, std::io::Error> { for item in std::iter::from_fn(|| rustls_pemfile1::read_one(rd).transpose()) { match item { Ok(rustls_pemfile1::Item::RSAKey(key)) | Ok(rustls_pemfile1::Item::PKCS8Key(key)) | Ok(rustls_pemfile1::Item::ECKey(key)) => { return Ok(Some(super::PrivateKey(key))) } Err(e) => return Err(e), _ => {} } } Ok(None) } pub fn load_native_certs() -> Result, std::io::Error> { Ok(rustls_native_certs6::load_native_certs()? .into_iter() .map(|v| super::Certificate(v.0)) .collect()) } pub fn rootstore_ref_shim(cert: &super::Certificate) -> &super::Certificate { cert } } } #[cfg(feature = "rustls23")] pub use rustls23_shim::*; #[cfg(all(feature = "rustls22", not(feature = "rustls23")))] pub use rustls22_shim::*; #[cfg(all( feature = "rustls21", not(any(feature = "rustls23", feature = "rustls22")) ))] pub use rustls21_shim::*; ntp-proto-1.4.0/test-keys/ec_key.pem000064400000000000000000000003431046102023000154620ustar 00000000000000-----BEGIN EC PRIVATE KEY----- MHcCAQEEIB+uHkwPd9WSCTR9m1ITVFwL8UPGaKWnreDdtMBsk8c7oAoGCCqGSM49 AwEHoUQDQgAEW9lR99aS5JMx8ZI5FsJPLOhfSggg+vngirYItXGB8F2y8CblgQfw PTYuxatX/a49ea2ENluguEDKcDaL2+6iHw== -----END EC PRIVATE KEY----- ntp-proto-1.4.0/test-keys/end.fullchain.pem000064400000000000000000000051141046102023000167360ustar 00000000000000-----BEGIN CERTIFICATE----- MIIDsDCCApigAwIBAgIUeLa0dWVwCQr2akxP7Zrw3RDLAF8wDQYJKoZIhvcNAQEL BQAwVzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTAgFw0y MzAxMjAwOTQ3MzhaGA80NzYwMTIxNzA5NDczOFowWTELMAkGA1UEBhMCQVUxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAsZmqWOnowHpN+nsLk0gqvsmZWPuwMBrnJrlDihyUmMXmf28CDXJL /aYDC/3a4EKIAz0uUnH6tCTK6jbmJhouGKnRpo9nS3ee3n0AENgPzcCaBgAoNYMM IT7en4a8olRviwKrMCX91fIorbuaUb0VFQ7BgfJhEvXVJinXcxkdTZJ4fztGE5Cy iqDGuJ1+EEABmDBrWCOr/gpF5HpAl9m6vbdhEWg3UvM02PAcBAn3z0Eno7O11vEK WDjZu6XWRLznY+cFEI0LvF8gLfilC15QgJdtb4+bh5jJsLHCCobBgARBdk50yhbj eQBwDOVMm2OJl5/BUl2OYbD/nK9dSUbT6wIDAQABo3AwbjAfBgNVHSMEGDAWgBR3 Va6VsK3920NVj7trkQittchtpTAJBgNVHRMEAjAAMAsGA1UdDwQEAwIE8DAUBgNV HREEDTALgglsb2NhbGhvc3QwHQYDVR0OBBYEFGWx6Z2EPXqL6pb+65eD/Dl4do/l MA0GCSqGSIb3DQEBCwUAA4IBAQCUEyM1M6EfDOkv9MHL3q1U72JvrKFx6lPDMTWd n/tWTILyQejETXWLmCxhle4JwIC+EQfAS6o/EFumgGvKp2xKuM4lS0ccaIBCCkjf bKkB5WxLppHPznxpv33f1DcU4WRNewBDra3FqJSGYGVjuHAPu4dZbPmU2bqhA22g 0tdwFZyDC3b32CY40m8gbR7VvcymMufyOeLWImR6GVCm5N6SUVpYEPbL2PFHkvnq Z6SALFAeH/Um/uPsWemBPfxMXjq5dDKWaaigiC4wxdfpPqAfORrYbRWcCOoYQv2U 9BO4LkL8OYBtG0IFuWU9eKpchFZgXbDjeoHFqBHz40yQ2yhk -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIDkTCCAnmgAwIBAgIUSJ4RLbU532cpXBrIPM0dgLjFoRowDQYJKoZIhvcNAQEL BQAwVzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTAgFw0y MzAxMjAwOTQzMzdaGA80NzYwMTIxNzA5NDMzN1owVzELMAkGA1UEBhMCQVUxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC AQoCggEBALzqkvECUcCFlg4cldjWKD1/X2e+FPrMBesmUCDExAtGYIjJy2YFovFL 20eNFa4K3QK61MfsmnbhC97Q3Nrm2tFiDXdM1XjnnbGk/GKtTH/cS/v5FQt+8kbj YPKkxfwo02Nhgf8r0Ttsg439tuT+qpw3CymVzEZDllhYFL0EDq5JHAx9Sz5RiXm4 1+4E0ahWpWbTagiG/Ldgk/sXCTZvxsCw7gbULKSVEbaN+cW+pXqkD3YSvrnYCPtk /8OK7llBCtDC9puDIntrd5z6tIxCbj3jnfb9Ek/Pb/AmK04NF5OPw+eUgEwteSde lNInFgNnlEPikNrkDAmBydLuEX7yCO8CAwEAAaNTMFEwHQYDVR0OBBYEFHdVrpWw rf3bQ1WPu2uRCK21yG2lMB8GA1UdIwQYMBaAFHdVrpWwrf3bQ1WPu2uRCK21yG2l MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAFHWNTDdy9BbCoX5 RRvP0S4V0g8HcaWohYuI7uNsDwW/xvOsJ7u+1rjv/Hx3lOCtnEHCAS5peJQenf6Y uQSXbt2BVX7U01TzGKC9y47yxgovpdKJDiJodWSGs6sZP/4x3M5AbGmhmdfSBFAZ /fchAzZPWd5FdYBEaT5J1nnXDCe3G5Aa43zvZzN8i/YCJ376yB7Vt6qUW8L70o9X ++snpnom2bvIKwkO4Z9jBY6njrpYjE212N1OY+eYRLknOdJlFuy6kGO2ipEoPKt/ +vur95a6fTo8WiU2kYQc649XiPNW53v1epWNFJCRoOFietIVrKANWuqQB7xVYuIG Yo0A3Sw= -----END CERTIFICATE----- ntp-proto-1.4.0/test-keys/end.key000064400000000000000000000032541046102023000150040ustar 00000000000000-----BEGIN PRIVATE KEY----- MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCxmapY6ejAek36 ewuTSCq+yZlY+7AwGucmuUOKHJSYxeZ/bwINckv9pgML/drgQogDPS5Scfq0JMrq NuYmGi4YqdGmj2dLd57efQAQ2A/NwJoGACg1gwwhPt6fhryiVG+LAqswJf3V8iit u5pRvRUVDsGB8mES9dUmKddzGR1Nknh/O0YTkLKKoMa4nX4QQAGYMGtYI6v+CkXk ekCX2bq9t2ERaDdS8zTY8BwECffPQSejs7XW8QpYONm7pdZEvOdj5wUQjQu8XyAt +KULXlCAl21vj5uHmMmwscIKhsGABEF2TnTKFuN5AHAM5UybY4mXn8FSXY5hsP+c r11JRtPrAgMBAAECggEASD/QKe22bx8SO/T0h40TPpw60xVI3rkDEiDKFhR8aw4P MAZT2m6F9YEkuisicJsAQ/kOsCGIMOLK3a9Jv3RlDkl/bXfnOK9IJRDLBw8ulrBk uE42DVbrh1bRMCqa8JrS6cVDKQo7kl66J7srE1eNjQx8skWNMi5p8OWSrVMpNZXT 1jGXFHpHkZz4i5TSeovSKMPRHEpXB3QQRIV7izysxtyQiBgHyCgmxwva5vjRZGQN BZyZzDcxpRTVR8B8JvUoxKpNSFatPD2+d0JT/g8a26uLymJbwI+K6bDzTY900xR6 ufH0pcw0nCgryAHp1Zt8+1hC43nblnnOxJwAJ6ad0QKBgQDgCCvcrlNO5AxsgkmT CRaSSIoaUF8+QZIhXXUL6wsdGKba5vND8Rjzdow4Ewy4vO/KGH7vvj9sv1mtV2bu kkHk5y6uA2FjUnbGHBxnHzv/mbNa3+q8sYFVF/QqTektqRXdLo98ButOZhTQX6I5 6EIWHvLDeOlgczDcNPBJ5ZdziQKBgQDK8Vt3goNSCqWTsb+4OFS+KsS6Ncakxu9E rdWM2m8TSx5z71Jzp3Lj28rVmPrI5bFhwLFxWcHfipRP+xE1uo2Ga8CqXy6UuvVp 1hbvzltwecAPeEzPD5hs7pWK7DmKLL3iUnl537CREP4UVwKCrHkt4xLZCaiTv5wn m2D36WVK0wKBgQCO0Ua8+TjMmx68cdZbcLi96pZ3rfL5qi1xLbX3MhC0rMl51S8R ifphAprjCGncvz2SNUl+pmaied2+XnCU+BIfzaz5a9hCzAhBxRvqNYQ3LpGjBgoL 3pDXYVzbNy3GWPtCNHNuGq8ZHIR6Te0KQ2EV3wbdzA/i16w3RVxFj6KcGQKBgQCG 7PjW+Bq/DP0QuPiyTiFpXZ31/5LWMr0ZeEmmoAOBXEwe4Fp9MjMccyDj6hWyQ6Qv TaGrrvVK3iPFGTNT+XfmivVJUIbzs2k+uGv/e78nhIrAvkay07ePlQAvoOaQizaj phnFgYcuq5GBjGfK4UifzXzWd6lwsc/sNU2/BZmmqQKBgQCD2PevU6thArwydWc3 ALP+1VvS37n6XOp1ywqECNhUI2rJ8ShbIlF/kbB9ylPnei14cnn4oP04HP9ZFXkR yx+OWiTAepvQUFjIXDst+CIUVmeFhE5RcHpOyHioDA6eQ88fzE6sTOC0SvZFF04u +O3oaJs+4jqY2DpsCuqYZjBEug== -----END PRIVATE KEY----- ntp-proto-1.4.0/test-keys/end.pem000064400000000000000000000024721046102023000147760ustar 00000000000000-----BEGIN CERTIFICATE----- MIIDsDCCApigAwIBAgIUeLa0dWVwCQr2akxP7Zrw3RDLAF8wDQYJKoZIhvcNAQEL BQAwVzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTAgFw0y MzAxMjAwOTQ3MzhaGA80NzYwMTIxNzA5NDczOFowWTELMAkGA1UEBhMCQVUxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAsZmqWOnowHpN+nsLk0gqvsmZWPuwMBrnJrlDihyUmMXmf28CDXJL /aYDC/3a4EKIAz0uUnH6tCTK6jbmJhouGKnRpo9nS3ee3n0AENgPzcCaBgAoNYMM IT7en4a8olRviwKrMCX91fIorbuaUb0VFQ7BgfJhEvXVJinXcxkdTZJ4fztGE5Cy iqDGuJ1+EEABmDBrWCOr/gpF5HpAl9m6vbdhEWg3UvM02PAcBAn3z0Eno7O11vEK WDjZu6XWRLznY+cFEI0LvF8gLfilC15QgJdtb4+bh5jJsLHCCobBgARBdk50yhbj eQBwDOVMm2OJl5/BUl2OYbD/nK9dSUbT6wIDAQABo3AwbjAfBgNVHSMEGDAWgBR3 Va6VsK3920NVj7trkQittchtpTAJBgNVHRMEAjAAMAsGA1UdDwQEAwIE8DAUBgNV HREEDTALgglsb2NhbGhvc3QwHQYDVR0OBBYEFGWx6Z2EPXqL6pb+65eD/Dl4do/l MA0GCSqGSIb3DQEBCwUAA4IBAQCUEyM1M6EfDOkv9MHL3q1U72JvrKFx6lPDMTWd n/tWTILyQejETXWLmCxhle4JwIC+EQfAS6o/EFumgGvKp2xKuM4lS0ccaIBCCkjf bKkB5WxLppHPznxpv33f1DcU4WRNewBDra3FqJSGYGVjuHAPu4dZbPmU2bqhA22g 0tdwFZyDC3b32CY40m8gbR7VvcymMufyOeLWImR6GVCm5N6SUVpYEPbL2PFHkvnq Z6SALFAeH/Um/uPsWemBPfxMXjq5dDKWaaigiC4wxdfpPqAfORrYbRWcCOoYQv2U 9BO4LkL8OYBtG0IFuWU9eKpchFZgXbDjeoHFqBHz40yQ2yhk -----END CERTIFICATE----- ntp-proto-1.4.0/test-keys/gen-cert.sh000075500000000000000000000026771046102023000155770ustar 00000000000000#! /bin/sh # This script generates a private key/certificate for a server, and signs it with the provided CA key # based on https://docs.ntpd-rs.pendulum-project.org/development/ca/ # Because this script generate keys without passwords set, they should only be used in a development setting. if [ -z "$1" ]; then echo "usage: gen-cert.sh name-of-server [ca-name]" echo echo "This will generate a name-of-server.key, name-of-server.pem and name-of-server.chain.pem file" echo "containing the private key, public certificate, and full certificate chain (respectively)" echo echo "The second argument denotes the name of the CA be used (found in the files ca-name.key and ca-name.pem)" echo "If this is omitted, the name 'testca' will be used." exit fi NAME="${1:-ntpd-rs.test}" CA="${2:-testca}" # generate a key openssl genrsa -out "$NAME".key 2048 # generate a certificate signing request openssl req -batch -new -key "$NAME".key -out "$NAME".csr # generate an ext file cat >> "$NAME".ext < "$NAME".chain.pem # cleanup rm "$NAME".csr ntp-proto-1.4.0/test-keys/pkcs8_key.pem000064400000000000000000000003611046102023000161230ustar 00000000000000-----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgDi/ejEuJATtM3Y1u zzdOIYXvP0FoKUDD2b0dJD+A1PChRANCAAQVage65def6DD2jTzZ7hu+sNaw9zeQ SbSlApUWht98YHRhVM/hyN3lJ0or0qVyjcW49uSzHyuDm2BtwlcLQjOh -----END PRIVATE KEY----- ntp-proto-1.4.0/test-keys/rsa_key.pem000064400000000000000000000015671046102023000156710ustar 00000000000000-----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQC1Dt8tFmGS76ciuNXvk/QRrV8wCcArWxvl7Ku0aSQXgcFBAav6 P5RD8b+dC9DihSu/r+6OOfjsAZ6oKCq3OTUfmoUhLpoBomxPczJgLyyLD+nQkp5q B1Q3WB6ACL/HJRRjJEIn7lc5u1FVBGbiCAHKMiaP4BDSym8oqimKC6uiaQIDAQAB AoGAGKmY7sxQqDIqwwkIYyT1Jv9FqwZ4/a7gYvZVATMdLnKHP3KZ2XGVoZepcRvt 7R0Us3ykcw0kgglKcj9eaizJtnSuoDPPwt53mDypPN2sU3hZgyk2tPgr49DB3MIp fjoqw4RL/p60ksgGXbDEqBuXqOtH5i61khWlMj+BWL9VDq0CQQDaELWPQGjgs+7X /QyWMJwOF4FXE4jecH/CcPVDB9K1ukllyC1HqTNe44Sp2bIDuSXXWb8yEixrEWBE ci2CSSjXAkEA1I4W9IzwEmAeLtL6VBip9ks52O0JKu373/Xv1F2GYdhnQaFw7IC6 1lSzcYMKGTmDuM8Cj26caldyv19Q0SPmvwJAdRHjZzS9GWWAJJTF3Rvbq/USix0B renXrRvXkFTy2n1YSjxdkstTuO2Mm2M0HquXlTWpX8hB8HkzpYtmwztjoQJAECKl LXVReCOhxu4vIJkqtc6qGoSL8J1WRH8X8KgU3nKeDAZkWx++jyyo3pIS/y01iZ71 U8wSxaPTyyFCMk4mYwJBALjg7g8yDy1Lg9GFfOZvAVzPjqD28jZh/VJsDz9IhYoG z89iHWHkllOisbOm+SeynVC8CoFXmJPc26U65GcjI18= -----END RSA PRIVATE KEY----- ntp-proto-1.4.0/test-keys/testca.key000064400000000000000000000032501046102023000155150ustar 00000000000000-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC86pLxAlHAhZYO HJXY1ig9f19nvhT6zAXrJlAgxMQLRmCIyctmBaLxS9tHjRWuCt0CutTH7Jp24Qve 0Nza5trRYg13TNV4552xpPxirUx/3Ev7+RULfvJG42DypMX8KNNjYYH/K9E7bION /bbk/qqcNwsplcxGQ5ZYWBS9BA6uSRwMfUs+UYl5uNfuBNGoVqVm02oIhvy3YJP7 Fwk2b8bAsO4G1CyklRG2jfnFvqV6pA92Er652Aj7ZP/Diu5ZQQrQwvabgyJ7a3ec +rSMQm494532/RJPz2/wJitODReTj8PnlIBMLXknXpTSJxYDZ5RD4pDa5AwJgcnS 7hF+8gjvAgMBAAECggEAAZrFvgbSoSHLqN7lSP7ZLtfkTwpuA7RZeIUQNQmgGW0P 3BFQZA0v8kaImiM8gdb2TC7dKJSGBKImQTW4CXmejxSX7l1H7bsYWHBgHKsYifQw q95QccSuZHJ0zYIGtcMA8e2Zk4Qa/GVzbT7+0QMb1IKuh+mRrbN9hLWsXJTTuYvf GppDVqMdDPy5NibudiZPKdpnMyDCJ/Wxl1+1PX18anifzBHw/G8ZPnLU3OKDqL2T OtEivvk9ZFDiRKKEsHksr+aLcUGhXFswk0zEQJwMj6rFwcDEExTQkMar+xaxshpf qo6AC88SDT9qEffSHHGJzTi73NIGgLNPO1aON4/pwQKBgQDUPo+ZJymo9IunaXWi HywqLLVZJSvqo2x9SrlqqYe3Yz0bROGBoHSMaGQzxiDApeOabdyg24wrU1P24jrC jPt94TWdu8bZKAkZAGOUPvdSGA/5yQkxVSMUK5zZwQxyLWfb77+B+WSvzhxI17Bt bX6od5pcdFSC5OczJ64DjLeHlQKBgQDj3NjsbLnxFu88A121kPD4AdpoMAtgrA5R AWwc7mWzKvL1RZlZCn861QMaRoUThQW4+dxTdoOoL68PXK3L8zuU3imKOBOe33lh j7B+M0gjdWnkcTag5q56qk1VA4YZ0R30LhUw44JxFHXhtuTR00CattI1pOQr6OdK By3kj4NdcwKBgQChOxko1eg+0e6Y8XMMAjQxoZ7tpmAzMYxDrZUm4rwXYsrTwUKx jyuaUd70uai90AcTlCuLAtz7OKTLIlZS3nhZytBJD5Fh+5jVpkb/IcoNUfwo20Ah erRYKT1Q6ebDgZypJfpMCSEksCUqbLc4mXojDiBz5WchvDOp15XIWog89QKBgE3c Vxtig58IETNWixzRrCVyrKjRUfH0mOfBLqosJAA2+tIouB+e4J6/ztGZqztiRvRQ HKNAafh8YrtDFfgM4x0ZVORwCPROtHFL4ikdaNcE9ewja2FLse8kZkxYaehEdpHL dV5BP39YWHeKQWIZZ4f2VJoUAAupB+9ZyKrDB0ZVAoGBALJ0KzHlAizbZyXRtfk+ ThnegTgjbTd6drMTsRlyHdL1Zet0tdx2nhn2keMQVSDep5KEwTvm+Wy41s9EmzZx RyehNaq9hMljLGR6mtr4Em5RtxtkPTwoJcOttHXQXnTgplDbePb8zQ8N084fScek 0dIjCbVBt5X7akmgHaaizIDl -----END PRIVATE KEY----- ntp-proto-1.4.0/test-keys/testca.pem000064400000000000000000000024221046102023000155060ustar 00000000000000-----BEGIN CERTIFICATE----- MIIDkTCCAnmgAwIBAgIUSJ4RLbU532cpXBrIPM0dgLjFoRowDQYJKoZIhvcNAQEL BQAwVzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTAgFw0y MzAxMjAwOTQzMzdaGA80NzYwMTIxNzA5NDMzN1owVzELMAkGA1UEBhMCQVUxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDEQMA4GA1UEAwwHVGVzdCBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC AQoCggEBALzqkvECUcCFlg4cldjWKD1/X2e+FPrMBesmUCDExAtGYIjJy2YFovFL 20eNFa4K3QK61MfsmnbhC97Q3Nrm2tFiDXdM1XjnnbGk/GKtTH/cS/v5FQt+8kbj YPKkxfwo02Nhgf8r0Ttsg439tuT+qpw3CymVzEZDllhYFL0EDq5JHAx9Sz5RiXm4 1+4E0ahWpWbTagiG/Ldgk/sXCTZvxsCw7gbULKSVEbaN+cW+pXqkD3YSvrnYCPtk /8OK7llBCtDC9puDIntrd5z6tIxCbj3jnfb9Ek/Pb/AmK04NF5OPw+eUgEwteSde lNInFgNnlEPikNrkDAmBydLuEX7yCO8CAwEAAaNTMFEwHQYDVR0OBBYEFHdVrpWw rf3bQ1WPu2uRCK21yG2lMB8GA1UdIwQYMBaAFHdVrpWwrf3bQ1WPu2uRCK21yG2l MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAFHWNTDdy9BbCoX5 RRvP0S4V0g8HcaWohYuI7uNsDwW/xvOsJ7u+1rjv/Hx3lOCtnEHCAS5peJQenf6Y uQSXbt2BVX7U01TzGKC9y47yxgovpdKJDiJodWSGs6sZP/4x3M5AbGmhmdfSBFAZ /fchAzZPWd5FdYBEaT5J1nnXDCe3G5Aa43zvZzN8i/YCJ376yB7Vt6qUW8L70o9X ++snpnom2bvIKwkO4Z9jBY6njrpYjE212N1OY+eYRLknOdJlFuy6kGO2ipEoPKt/ +vur95a6fTo8WiU2kYQc649XiPNW53v1epWNFJCRoOFietIVrKANWuqQB7xVYuIG Yo0A3Sw= -----END CERTIFICATE----- ntp-proto-1.4.0/test-keys/unsafe.nts.client.toml000064400000000000000000000007461046102023000177650ustar 00000000000000[observability] # Other values include trace, debug, warn and error log-level = "info" observation-path = "/var/run/ntpd-rs/observe" # uses an unsecure certificate! [[source]] mode = "nts" address = "localhost:4460" certificate-authority = "ntp-proto/test-keys/testca.pem" # System parameters used in filtering and steering the clock: [synchronization] minimum-agreeing-sources = 1 single-step-panic-threshold = 10 startup-step-panic-threshold = { forward = "inf", backward = 86400 } ntp-proto-1.4.0/test-keys/unsafe.nts.server.toml000064400000000000000000000014121046102023000200040ustar 00000000000000[observability] # Other values include trace, debug, warn and error log-level = "info" observation-path = "/var/run/ntpd-rs/observe" # the server will get its time from the NTP pool [[source]] mode = "pool" address = "pool.ntp.org" count = 4 [[server]] listen = "0.0.0.0:123" # System parameters used in filtering and steering the clock: [synchronization] minimum-agreeing-sources = 1 single-step-panic-threshold = 10 startup-step-panic-threshold = { forward = 0, backward = 86400 } # to function as an NTS server, we must also provide key exchange # uses an unsecure certificate chain! [[nts-ke-server]] listen = "0.0.0.0:4460" certificate-chain-path = "ntp-proto/test-keys/end.fullchain.pem" private-key-path = "ntp-proto/test-keys/end.key" key-exchange-timeout-ms = 1000