quinn-proto-0.11.9/.cargo_vcs_info.json0000644000000001510000000000100134140ustar { "git": { "sha1": "d23e4e494f7446e21184bf58acd17a861ae73bba" }, "path_in_vcs": "quinn-proto" }quinn-proto-0.11.9/Cargo.toml0000644000000064470000000000100114300ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.70.0" name = "quinn-proto" version = "0.11.9" build = false autobins = false autoexamples = false autotests = false autobenches = false description = "State machine for the QUIC transport protocol" readme = false keywords = ["quic"] categories = [ "network-programming", "asynchronous", ] license = "MIT OR Apache-2.0" repository = "https://github.com/quinn-rs/quinn" [package.metadata.docs.rs] all-features = true [lib] name = "quinn_proto" path = "src/lib.rs" [dependencies.arbitrary] version = "1.0.1" features = ["derive"] optional = true [dependencies.aws-lc-rs] version = "1.9" optional = true default-features = false [dependencies.bytes] version = "1" [dependencies.rand] version = "0.8" [dependencies.ring] version = "0.17" optional = true [dependencies.rustc-hash] version = "2" [dependencies.rustls] version = "0.23.5" features = ["std"] optional = true default-features = false [dependencies.rustls-platform-verifier] version = "0.4" optional = true [dependencies.slab] version = "0.4.6" [dependencies.thiserror] version = "2.0.3" [dependencies.tinyvec] version = "1.1" features = [ "alloc", "alloc", ] [dependencies.tracing] version = "0.1.10" features = ["std"] default-features = false [dev-dependencies.assert_matches] version = "1.1" [dev-dependencies.hex-literal] version = "0.4" [dev-dependencies.lazy_static] version = "1" [dev-dependencies.rcgen] version = "0.13" [dev-dependencies.tracing-subscriber] version = "0.3.0" features = [ "env-filter", "fmt", "ansi", "time", "local-time", ] default-features = false [dev-dependencies.wasm-bindgen-test] version = "0.3.45" [features] aws-lc-rs = [ "dep:aws-lc-rs", "aws-lc-rs?/aws-lc-sys", "aws-lc-rs?/prebuilt-nasm", ] aws-lc-rs-fips = [ "aws-lc-rs", "aws-lc-rs?/fips", ] default = [ "rustls-ring", "log", ] log = ["tracing/log"] platform-verifier = ["dep:rustls-platform-verifier"] ring = ["dep:ring"] rustls = ["rustls-ring"] rustls-aws-lc-rs = [ "dep:rustls", "rustls?/aws-lc-rs", "aws-lc-rs", ] rustls-aws-lc-rs-fips = [ "rustls-aws-lc-rs", "aws-lc-rs-fips", ] rustls-log = ["rustls?/logging"] rustls-ring = [ "dep:rustls", "rustls?/ring", "ring", ] [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.getrandom] version = "0.2" features = ["js"] default-features = false [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.ring] version = "0.17" features = ["wasm32_unknown_unknown_js"] [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.rustls-pki-types] version = "1.7" features = ["web"] [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies.web-time] version = "1" [lints.rust.unexpected_cfgs] level = "warn" priority = 0 check-cfg = ["cfg(fuzzing)"] quinn-proto-0.11.9/Cargo.toml.orig000064400000000000000000000050411046102023000150760ustar 00000000000000[package] name = "quinn-proto" version = "0.11.9" edition.workspace = true rust-version.workspace = true license.workspace = true repository.workspace = true description = "State machine for the QUIC transport protocol" keywords.workspace = true categories.workspace = true workspace = ".." [package.metadata.docs.rs] all-features = true [features] default = ["rustls-ring", "log"] aws-lc-rs = ["dep:aws-lc-rs", "aws-lc-rs?/aws-lc-sys", "aws-lc-rs?/prebuilt-nasm"] aws-lc-rs-fips = ["aws-lc-rs", "aws-lc-rs?/fips"] # For backwards compatibility, `rustls` forwards to `rustls-ring` rustls = ["rustls-ring"] # Enable rustls with the `aws-lc-rs` crypto provider rustls-aws-lc-rs = ["dep:rustls", "rustls?/aws-lc-rs", "aws-lc-rs"] rustls-aws-lc-rs-fips = ["rustls-aws-lc-rs", "aws-lc-rs-fips"] # Enable rustls with the `ring` crypto provider rustls-ring = ["dep:rustls", "rustls?/ring", "ring"] ring = ["dep:ring"] # Enable rustls ring provider and direct ring usage # Provides `ClientConfig::with_platform_verifier()` convenience method platform-verifier = ["dep:rustls-platform-verifier"] # Configure `tracing` to log events via `log` if no `tracing` subscriber exists. log = ["tracing/log"] # Enable rustls logging rustls-log = ["rustls?/logging"] [dependencies] arbitrary = { workspace = true, optional = true } aws-lc-rs = { workspace = true, optional = true } bytes = { workspace = true } rustc-hash = { workspace = true } rand = { workspace = true } ring = { workspace = true, optional = true } rustls = { workspace = true, optional = true } rustls-platform-verifier = { workspace = true, optional = true } slab = { workspace = true } thiserror = { workspace = true } tinyvec = { workspace = true, features = ["alloc"] } tracing = { workspace = true } # Feature flags & dependencies for wasm # wasm-bindgen is assumed for a wasm*-*-unknown target [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies] ring = { workspace = true, features = ["wasm32_unknown_unknown_js"] } getrandom = { workspace = true, features = ["js"] } rustls-pki-types = { workspace = true, features = ["web"] } # only added as dependency to enforce the `web` feature for this target web-time = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } hex-literal = { workspace = true } rcgen = { workspace = true } tracing-subscriber = { workspace = true } lazy_static = "1" wasm-bindgen-test = { workspace = true } [lints.rust] # https://rust-fuzz.github.io/book/cargo-fuzz/guide.html#cfgfuzzing unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } quinn-proto-0.11.9/LICENSE-APACHE000064400000000000000000000261351046102023000141420ustar 00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. quinn-proto-0.11.9/LICENSE-MIT000064400000000000000000000020501046102023000136400ustar 00000000000000Copyright (c) 2018 The quinn Developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. quinn-proto-0.11.9/src/cid_generator.rs000064400000000000000000000127061046102023000161570ustar 00000000000000use std::hash::Hasher; use rand::{Rng, RngCore}; use crate::shared::ConnectionId; use crate::Duration; use crate::MAX_CID_SIZE; /// Generates connection IDs for incoming connections pub trait ConnectionIdGenerator: Send + Sync { /// Generates a new CID /// /// Connection IDs MUST NOT contain any information that can be used by /// an external observer (that is, one that does not cooperate with the /// issuer) to correlate them with other connection IDs for the same /// connection. They MUST have high entropy, e.g. due to encrypted data /// or cryptographic-grade random data. fn generate_cid(&mut self) -> ConnectionId; /// Quickly determine whether `cid` could have been generated by this generator /// /// False positives are permitted, but increase the cost of handling invalid packets. fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> { Ok(()) } /// Returns the length of a CID for connections created by this generator fn cid_len(&self) -> usize; /// Returns the lifetime of generated Connection IDs /// /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant. fn cid_lifetime(&self) -> Option; } /// The connection ID was not recognized by the [`ConnectionIdGenerator`] #[derive(Debug, Copy, Clone)] pub struct InvalidCid; /// Generates purely random connection IDs of a specified length /// /// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be /// usefully [`validate`](ConnectionIdGenerator::validate)d. #[derive(Debug, Clone, Copy)] pub struct RandomConnectionIdGenerator { cid_len: usize, lifetime: Option, } impl Default for RandomConnectionIdGenerator { fn default() -> Self { Self { cid_len: 8, lifetime: None, } } } impl RandomConnectionIdGenerator { /// Initialize Random CID generator with a fixed CID length /// /// The given length must be less than or equal to MAX_CID_SIZE. pub fn new(cid_len: usize) -> Self { debug_assert!(cid_len <= MAX_CID_SIZE); Self { cid_len, ..Self::default() } } /// Set the lifetime of CIDs created by this generator pub fn set_lifetime(&mut self, d: Duration) -> &mut Self { self.lifetime = Some(d); self } } impl ConnectionIdGenerator for RandomConnectionIdGenerator { fn generate_cid(&mut self) -> ConnectionId { let mut bytes_arr = [0; MAX_CID_SIZE]; rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]); ConnectionId::new(&bytes_arr[..self.cid_len]) } /// Provide the length of dst_cid in short header packet fn cid_len(&self) -> usize { self.cid_len } fn cid_lifetime(&self) -> Option { self.lifetime } } /// Generates 8-byte connection IDs that can be efficiently /// [`validate`](ConnectionIdGenerator::validate)d /// /// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless /// helps prevents Quinn from responding to non-QUIC packets at very low cost. pub struct HashedConnectionIdGenerator { key: u64, lifetime: Option, } impl HashedConnectionIdGenerator { /// Create a generator with a random key pub fn new() -> Self { Self::from_key(rand::thread_rng().gen()) } /// Create a generator with a specific key /// /// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of /// connection IDs across restarts pub fn from_key(key: u64) -> Self { Self { key, lifetime: None, } } /// Set the lifetime of CIDs created by this generator pub fn set_lifetime(&mut self, d: Duration) -> &mut Self { self.lifetime = Some(d); self } } impl Default for HashedConnectionIdGenerator { fn default() -> Self { Self::new() } } impl ConnectionIdGenerator for HashedConnectionIdGenerator { fn generate_cid(&mut self) -> ConnectionId { let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN]; rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]); let mut hasher = rustc_hash::FxHasher::default(); hasher.write_u64(self.key); hasher.write(&bytes_arr[..NONCE_LEN]); bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]); ConnectionId::new(&bytes_arr) } fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> { let (nonce, signature) = cid.split_at(NONCE_LEN); let mut hasher = rustc_hash::FxHasher::default(); hasher.write_u64(self.key); hasher.write(nonce); let expected = hasher.finish().to_le_bytes(); match expected[..SIGNATURE_LEN] == signature[..] { true => Ok(()), false => Err(InvalidCid), } } fn cid_len(&self) -> usize { NONCE_LEN + SIGNATURE_LEN } fn cid_lifetime(&self) -> Option { self.lifetime } } const NONCE_LEN: usize = 3; // Good for more than 16 million connections const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length #[cfg(test)] mod tests { use super::*; #[test] fn validate_keyed_cid() { let mut generator = HashedConnectionIdGenerator::new(); let cid = generator.generate_cid(); generator.validate(&cid).unwrap(); } } quinn-proto-0.11.9/src/cid_queue.rs000064400000000000000000000232141046102023000153110ustar 00000000000000use std::ops::Range; use crate::{frame::NewConnectionId, ConnectionId, ResetToken}; /// DataType stored in CidQueue buffer type CidData = (ConnectionId, Option); /// Sliding window of active Connection IDs /// /// May contain gaps due to packet loss or reordering #[derive(Debug)] pub(crate) struct CidQueue { /// Ring buffer indexed by `self.cursor` buffer: [Option; Self::LEN], /// Index at which circular buffer addressing is based cursor: usize, /// Sequence number of `self.buffer[cursor]` /// /// The sequence number of the active CID; must be the smallest among CIDs in `buffer`. offset: u64, } impl CidQueue { pub(crate) fn new(cid: ConnectionId) -> Self { let mut buffer = [None; Self::LEN]; buffer[0] = Some((cid, None)); Self { buffer, cursor: 0, offset: 0, } } /// Handle a `NEW_CONNECTION_ID` frame /// /// Returns a non-empty range of retired sequence numbers and the reset token of the new active /// CID iff any CIDs were retired. pub(crate) fn insert( &mut self, cid: NewConnectionId, ) -> Result, ResetToken)>, InsertError> { // Position of new CID wrt. the current active CID let index = match cid.sequence.checked_sub(self.offset) { None => return Err(InsertError::Retired), Some(x) => x, }; let retired_count = cid.retire_prior_to.saturating_sub(self.offset); if index >= Self::LEN as u64 + retired_count { return Err(InsertError::ExceedsLimit); } // Discard retired CIDs, if any for i in 0..(retired_count.min(Self::LEN as u64) as usize) { self.buffer[(self.cursor + i) % Self::LEN] = None; } // Record the new CID let index = ((self.cursor as u64 + index) % Self::LEN as u64) as usize; self.buffer[index] = Some((cid.id, Some(cid.reset_token))); if retired_count == 0 { return Ok(None); } // The active CID was retired. Find the first known CID with sequence number of at least // retire_prior_to, and inform the caller that all prior CIDs have been retired, and of // the new CID's reset token. self.cursor = ((self.cursor as u64 + retired_count) % Self::LEN as u64) as usize; let (i, (_, token)) = self .iter() .next() .expect("it is impossible to retire a CID without supplying a new one"); self.cursor = (self.cursor + i) % Self::LEN; let orig_offset = self.offset; self.offset = cid.retire_prior_to + i as u64; // We don't immediately retire CIDs in the range (orig_offset + // Self::LEN)..self.offset. These are CIDs that we haven't yet received from a // NEW_CONNECTION_ID frame, since having previously received them would violate the // connection ID limit we specified based on Self::LEN. If we do receive a such a frame // in the future, e.g. due to reordering, we'll retire it then. This ensures we can't be // made to buffer an arbitrarily large number of RETIRE_CONNECTION_ID frames. Ok(Some(( orig_offset..self.offset.min(orig_offset + Self::LEN as u64), token.expect("non-initial CID missing reset token"), ))) } /// Switch to next active CID if possible, return /// 1) the corresponding ResetToken and 2) a non-empty range preceding it to retire pub(crate) fn next(&mut self) -> Option<(ResetToken, Range)> { let (i, cid_data) = self.iter().nth(1)?; self.buffer[self.cursor] = None; let orig_offset = self.offset; self.offset += i as u64; self.cursor = (self.cursor + i) % Self::LEN; Some((cid_data.1.unwrap(), orig_offset..self.offset)) } /// Iterate CIDs in CidQueue that are not `None`, including the active CID fn iter(&self) -> impl Iterator + '_ { (0..Self::LEN).filter_map(move |step| { let index = (self.cursor + step) % Self::LEN; self.buffer[index].map(|cid_data| (step, cid_data)) }) } /// Replace the initial CID pub(crate) fn update_initial_cid(&mut self, cid: ConnectionId) { debug_assert_eq!(self.offset, 0); self.buffer[self.cursor] = Some((cid, None)); } /// Return active remote CID itself pub(crate) fn active(&self) -> ConnectionId { self.buffer[self.cursor].unwrap().0 } /// Return the sequence number of active remote CID pub(crate) fn active_seq(&self) -> u64 { self.offset } pub(crate) const LEN: usize = 5; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum InsertError { /// CID was already retired Retired, /// Sequence number violates the leading edge of the window ExceedsLimit, } #[cfg(test)] mod tests { use super::*; fn cid(sequence: u64, retire_prior_to: u64) -> NewConnectionId { NewConnectionId { sequence, id: ConnectionId::new(&[0xAB; 8]), reset_token: ResetToken::from([0xCD; crate::RESET_TOKEN_SIZE]), retire_prior_to, } } fn initial_cid() -> ConnectionId { ConnectionId::new(&[0xFF; 8]) } #[test] fn next_dense() { let mut q = CidQueue::new(initial_cid()); assert!(q.next().is_none()); assert!(q.next().is_none()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } for i in 1..CidQueue::LEN as u64 { let (_, retire) = q.next().unwrap(); assert_eq!(q.active_seq(), i); assert_eq!(retire.end - retire.start, 1); } assert!(q.next().is_none()); } #[test] fn next_sparse() { let mut q = CidQueue::new(initial_cid()); let seqs = (1..CidQueue::LEN as u64).filter(|x| x % 2 == 0); for i in seqs.clone() { q.insert(cid(i, 0)).unwrap(); } for i in seqs { let (_, retire) = q.next().unwrap(); dbg!(&retire); assert_eq!(q.active_seq(), i); assert_eq!(retire, (q.active_seq().saturating_sub(2))..q.active_seq()); } assert!(q.next().is_none()); } #[test] fn wrap() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } for _ in 1..(CidQueue::LEN as u64 - 1) { q.next().unwrap(); } for i in CidQueue::LEN as u64..(CidQueue::LEN as u64 + 3) { q.insert(cid(i, 0)).unwrap(); } for i in (CidQueue::LEN as u64 - 1)..(CidQueue::LEN as u64 + 3) { q.next().unwrap(); assert_eq!(q.active_seq(), i); } assert!(q.next().is_none()); } #[test] fn retire_dense() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } assert_eq!(q.active_seq(), 0); assert_eq!(q.insert(cid(4, 2)).unwrap().unwrap().0, 0..2); assert_eq!(q.active_seq(), 2); assert_eq!(q.insert(cid(4, 2)), Ok(None)); for i in 2..(CidQueue::LEN as u64 - 1) { let _ = q.next().unwrap(); assert_eq!(q.active_seq(), i + 1); assert_eq!(q.insert(cid(i + 1, i + 1)), Ok(None)); } assert!(q.next().is_none()); } #[test] fn retire_sparse() { // Retiring CID 0 when CID 1 is not known should retire CID 1 as we move to CID 2 let mut q = CidQueue::new(initial_cid()); q.insert(cid(2, 0)).unwrap(); assert_eq!(q.insert(cid(3, 1)).unwrap().unwrap().0, 0..2,); assert_eq!(q.active_seq(), 2); } #[test] fn retire_many() { let mut q = CidQueue::new(initial_cid()); q.insert(cid(2, 0)).unwrap(); assert_eq!( q.insert(cid(1_000_000, 1_000_000)).unwrap().unwrap().0, 0..CidQueue::LEN as u64, ); assert_eq!(q.active_seq(), 1_000_000); } #[test] fn insert_limit() { let mut q = CidQueue::new(initial_cid()); assert_eq!(q.insert(cid(CidQueue::LEN as u64 - 1, 0)), Ok(None)); assert_eq!( q.insert(cid(CidQueue::LEN as u64, 0)), Err(InsertError::ExceedsLimit) ); } #[test] fn insert_duplicate() { let mut q = CidQueue::new(initial_cid()); q.insert(cid(0, 0)).unwrap(); q.insert(cid(0, 0)).unwrap(); } #[test] fn insert_retired() { let mut q = CidQueue::new(initial_cid()); assert_eq!( q.insert(cid(0, 0)), Ok(None), "reinserting active CID succeeds" ); assert!(q.next().is_none(), "active CID isn't requeued"); q.insert(cid(1, 0)).unwrap(); q.next().unwrap(); assert_eq!( q.insert(cid(0, 0)), Err(InsertError::Retired), "previous active CID is already retired" ); } #[test] fn retire_then_insert_next() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } q.next().unwrap(); q.insert(cid(CidQueue::LEN as u64, 0)).unwrap(); assert_eq!( q.insert(cid(CidQueue::LEN as u64 + 1, 0)), Err(InsertError::ExceedsLimit) ); } #[test] fn always_valid() { let mut q = CidQueue::new(initial_cid()); assert!(q.next().is_none()); assert_eq!(q.active(), initial_cid()); assert_eq!(q.active_seq(), 0); } } quinn-proto-0.11.9/src/coding.rs000064400000000000000000000054321046102023000146130ustar 00000000000000use std::net::{Ipv4Addr, Ipv6Addr}; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::VarInt; #[derive(Error, Debug, Copy, Clone, Eq, PartialEq)] #[error("unexpected end of buffer")] pub struct UnexpectedEnd; pub type Result = ::std::result::Result; pub trait Codec: Sized { fn decode(buf: &mut B) -> Result; fn encode(&self, buf: &mut B); } impl Codec for u8 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 1 { return Err(UnexpectedEnd); } Ok(buf.get_u8()) } fn encode(&self, buf: &mut B) { buf.put_u8(*self); } } impl Codec for u16 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 2 { return Err(UnexpectedEnd); } Ok(buf.get_u16()) } fn encode(&self, buf: &mut B) { buf.put_u16(*self); } } impl Codec for u32 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 4 { return Err(UnexpectedEnd); } Ok(buf.get_u32()) } fn encode(&self, buf: &mut B) { buf.put_u32(*self); } } impl Codec for u64 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 8 { return Err(UnexpectedEnd); } Ok(buf.get_u64()) } fn encode(&self, buf: &mut B) { buf.put_u64(*self); } } impl Codec for Ipv4Addr { fn decode(buf: &mut B) -> Result { if buf.remaining() < 4 { return Err(UnexpectedEnd); } let mut octets = [0; 4]; buf.copy_to_slice(&mut octets); Ok(octets.into()) } fn encode(&self, buf: &mut B) { buf.put_slice(&self.octets()); } } impl Codec for Ipv6Addr { fn decode(buf: &mut B) -> Result { if buf.remaining() < 16 { return Err(UnexpectedEnd); } let mut octets = [0; 16]; buf.copy_to_slice(&mut octets); Ok(octets.into()) } fn encode(&self, buf: &mut B) { buf.put_slice(&self.octets()); } } pub trait BufExt { fn get(&mut self) -> Result; fn get_var(&mut self) -> Result; } impl BufExt for T { fn get(&mut self) -> Result { U::decode(self) } fn get_var(&mut self) -> Result { Ok(VarInt::decode(self)?.into_inner()) } } pub trait BufMutExt { fn write(&mut self, x: T); fn write_var(&mut self, x: u64); } impl BufMutExt for T { fn write(&mut self, x: U) { x.encode(self); } fn write_var(&mut self, x: u64) { VarInt::from_u64(x).unwrap().encode(self); } } quinn-proto-0.11.9/src/config.rs000064400000000000000000001364061046102023000146230ustar 00000000000000use std::{ fmt, net::{SocketAddrV4, SocketAddrV6}, num::TryFromIntError, sync::Arc, }; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] use rustls::client::WebPkiServerVerifier; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use thiserror::Error; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] use crate::crypto::rustls::{configured_provider, QuicServerConfig}; use crate::{ cid_generator::{ConnectionIdGenerator, HashedConnectionIdGenerator}, congestion, crypto::{self, HandshakeTokenKey, HmacKey}, shared::ConnectionId, Duration, RandomConnectionIdGenerator, VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, MAX_CID_SIZE, MAX_UDP_PAYLOAD, }; /// Parameters governing the core QUIC state machine /// /// Default values should be suitable for most internet applications. Applications protocols which /// forbid remotely-initiated streams should set `max_concurrent_bidi_streams` and /// `max_concurrent_uni_streams` to zero. /// /// In some cases, performance or resource requirements can be improved by tuning these values to /// suit a particular application and/or network connection. In particular, data window sizes can be /// tuned for a particular expected round trip time, link capacity, and memory availability. Tuning /// for higher bandwidths and latencies increases worst-case memory consumption, but does not impair /// performance at lower bandwidths and latencies. The default configuration is tuned for a 100Mbps /// link with a 100ms round trip time. pub struct TransportConfig { pub(crate) max_concurrent_bidi_streams: VarInt, pub(crate) max_concurrent_uni_streams: VarInt, pub(crate) max_idle_timeout: Option, pub(crate) stream_receive_window: VarInt, pub(crate) receive_window: VarInt, pub(crate) send_window: u64, pub(crate) send_fairness: bool, pub(crate) packet_threshold: u32, pub(crate) time_threshold: f32, pub(crate) initial_rtt: Duration, pub(crate) initial_mtu: u16, pub(crate) min_mtu: u16, pub(crate) mtu_discovery_config: Option, pub(crate) ack_frequency_config: Option, pub(crate) persistent_congestion_threshold: u32, pub(crate) keep_alive_interval: Option, pub(crate) crypto_buffer_size: usize, pub(crate) allow_spin: bool, pub(crate) datagram_receive_buffer_size: Option, pub(crate) datagram_send_buffer_size: usize, #[cfg(test)] pub(crate) deterministic_packet_numbers: bool, pub(crate) congestion_controller_factory: Arc, pub(crate) enable_segmentation_offload: bool, } impl TransportConfig { /// Maximum number of incoming bidirectional streams that may be open concurrently /// /// Must be nonzero for the peer to open any bidirectional streams. /// /// Worst-case memory use is directly proportional to `max_concurrent_bidi_streams * /// stream_receive_window`, with an upper bound proportional to `receive_window`. pub fn max_concurrent_bidi_streams(&mut self, value: VarInt) -> &mut Self { self.max_concurrent_bidi_streams = value; self } /// Variant of `max_concurrent_bidi_streams` affecting unidirectional streams pub fn max_concurrent_uni_streams(&mut self, value: VarInt) -> &mut Self { self.max_concurrent_uni_streams = value; self } /// Maximum duration of inactivity to accept before timing out the connection. /// /// The true idle timeout is the minimum of this and the peer's own max idle timeout. `None` /// represents an infinite timeout. Defaults to 30 seconds. /// /// **WARNING**: If a peer or its network path malfunctions or acts maliciously, an infinite /// idle timeout can result in permanently hung futures! /// /// ``` /// # use std::{convert::TryInto, time::Duration}; /// # use quinn_proto::{TransportConfig, VarInt, VarIntBoundsExceeded}; /// # fn main() -> Result<(), VarIntBoundsExceeded> { /// let mut config = TransportConfig::default(); /// /// // Set the idle timeout as `VarInt`-encoded milliseconds /// config.max_idle_timeout(Some(VarInt::from_u32(10_000).into())); /// /// // Set the idle timeout as a `Duration` /// config.max_idle_timeout(Some(Duration::from_secs(10).try_into()?)); /// # Ok(()) /// # } /// ``` pub fn max_idle_timeout(&mut self, value: Option) -> &mut Self { self.max_idle_timeout = value.map(|t| t.0); self } /// Maximum number of bytes the peer may transmit without acknowledgement on any one stream /// before becoming blocked. /// /// This should be set to at least the expected connection latency multiplied by the maximum /// desired throughput. Setting this smaller than `receive_window` helps ensure that a single /// stream doesn't monopolize receive buffers, which may otherwise occur if the application /// chooses not to read from a large stream for a time while still requiring data on other /// streams. pub fn stream_receive_window(&mut self, value: VarInt) -> &mut Self { self.stream_receive_window = value; self } /// Maximum number of bytes the peer may transmit across all streams of a connection before /// becoming blocked. /// /// This should be set to at least the expected connection latency multiplied by the maximum /// desired throughput. Larger values can be useful to allow maximum throughput within a /// stream while another is blocked. pub fn receive_window(&mut self, value: VarInt) -> &mut Self { self.receive_window = value; self } /// Maximum number of bytes to transmit to a peer without acknowledgment /// /// Provides an upper bound on memory when communicating with peers that issue large amounts of /// flow control credit. Endpoints that wish to handle large numbers of connections robustly /// should take care to set this low enough to guarantee memory exhaustion does not occur if /// every connection uses the entire window. pub fn send_window(&mut self, value: u64) -> &mut Self { self.send_window = value; self } /// Whether to implement fair queuing for send streams having the same priority. /// /// When enabled, connections schedule data from outgoing streams having the same priority in a /// round-robin fashion. When disabled, streams are scheduled in the order they are written to. /// /// Note that this only affects streams with the same priority. Higher priority streams always /// take precedence over lower priority streams. /// /// Disabling fairness can reduce fragmentation and protocol overhead for workloads that use /// many small streams. pub fn send_fairness(&mut self, value: bool) -> &mut Self { self.send_fairness = value; self } /// Maximum reordering in packet number space before FACK style loss detection considers a /// packet lost. Should not be less than 3, per RFC5681. pub fn packet_threshold(&mut self, value: u32) -> &mut Self { self.packet_threshold = value; self } /// Maximum reordering in time space before time based loss detection considers a packet lost, /// as a factor of RTT pub fn time_threshold(&mut self, value: f32) -> &mut Self { self.time_threshold = value; self } /// The RTT used before an RTT sample is taken pub fn initial_rtt(&mut self, value: Duration) -> &mut Self { self.initial_rtt = value; self } /// The initial value to be used as the maximum UDP payload size before running MTU discovery /// (see [`TransportConfig::mtu_discovery_config`]). /// /// Must be at least 1200, which is the default, and known to be safe for typical internet /// applications. Larger values are more efficient, but increase the risk of packet loss due to /// exceeding the network path's IP MTU. If the provided value is higher than what the network /// path actually supports, packet loss will eventually trigger black hole detection and bring /// it down to [`TransportConfig::min_mtu`]. pub fn initial_mtu(&mut self, value: u16) -> &mut Self { self.initial_mtu = value.max(INITIAL_MTU); self } pub(crate) fn get_initial_mtu(&self) -> u16 { self.initial_mtu.max(self.min_mtu) } /// The maximum UDP payload size guaranteed to be supported by the network. /// /// Must be at least 1200, which is the default, and lower than or equal to /// [`TransportConfig::initial_mtu`]. /// /// Real-world MTUs can vary according to ISP, VPN, and properties of intermediate network links /// outside of either endpoint's control. Extreme care should be used when raising this value /// outside of private networks where these factors are fully controlled. If the provided value /// is higher than what the network path actually supports, the result will be unpredictable and /// catastrophic packet loss, without a possibility of repair. Prefer /// [`TransportConfig::initial_mtu`] together with /// [`TransportConfig::mtu_discovery_config`] to set a maximum UDP payload size that robustly /// adapts to the network. pub fn min_mtu(&mut self, value: u16) -> &mut Self { self.min_mtu = value.max(INITIAL_MTU); self } /// Specifies the MTU discovery config (see [`MtuDiscoveryConfig`] for details). /// /// Enabled by default. pub fn mtu_discovery_config(&mut self, value: Option) -> &mut Self { self.mtu_discovery_config = value; self } /// Specifies the ACK frequency config (see [`AckFrequencyConfig`] for details) /// /// The provided configuration will be ignored if the peer does not support the acknowledgement /// frequency QUIC extension. /// /// Defaults to `None`, which disables controlling the peer's acknowledgement frequency. Even /// if set to `None`, the local side still supports the acknowledgement frequency QUIC /// extension and may use it in other ways. pub fn ack_frequency_config(&mut self, value: Option) -> &mut Self { self.ack_frequency_config = value; self } /// Number of consecutive PTOs after which network is considered to be experiencing persistent congestion. pub fn persistent_congestion_threshold(&mut self, value: u32) -> &mut Self { self.persistent_congestion_threshold = value; self } /// Period of inactivity before sending a keep-alive packet /// /// Keep-alive packets prevent an inactive but otherwise healthy connection from timing out. /// /// `None` to disable, which is the default. Only one side of any given connection needs keep-alive /// enabled for the connection to be preserved. Must be set lower than the idle_timeout of both /// peers to be effective. pub fn keep_alive_interval(&mut self, value: Option) -> &mut Self { self.keep_alive_interval = value; self } /// Maximum quantity of out-of-order crypto layer data to buffer pub fn crypto_buffer_size(&mut self, value: usize) -> &mut Self { self.crypto_buffer_size = value; self } /// Whether the implementation is permitted to set the spin bit on this connection /// /// This allows passive observers to easily judge the round trip time of a connection, which can /// be useful for network administration but sacrifices a small amount of privacy. pub fn allow_spin(&mut self, value: bool) -> &mut Self { self.allow_spin = value; self } /// Maximum number of incoming application datagram bytes to buffer, or None to disable /// incoming datagrams /// /// The peer is forbidden to send single datagrams larger than this size. If the aggregate size /// of all datagrams that have been received from the peer but not consumed by the application /// exceeds this value, old datagrams are dropped until it is no longer exceeded. pub fn datagram_receive_buffer_size(&mut self, value: Option) -> &mut Self { self.datagram_receive_buffer_size = value; self } /// Maximum number of outgoing application datagram bytes to buffer /// /// While datagrams are sent ASAP, it is possible for an application to generate data faster /// than the link, or even the underlying hardware, can transmit them. This limits the amount of /// memory that may be consumed in that case. When the send buffer is full and a new datagram is /// sent, older datagrams are dropped until sufficient space is available. pub fn datagram_send_buffer_size(&mut self, value: usize) -> &mut Self { self.datagram_send_buffer_size = value; self } /// Whether to force every packet number to be used /// /// By default, packet numbers are occasionally skipped to ensure peers aren't ACKing packets /// before they see them. #[cfg(test)] pub(crate) fn deterministic_packet_numbers(&mut self, enabled: bool) -> &mut Self { self.deterministic_packet_numbers = enabled; self } /// How to construct new `congestion::Controller`s /// /// Typically the refcounted configuration of a `congestion::Controller`, /// e.g. a `congestion::NewRenoConfig`. /// /// # Example /// ``` /// # use quinn_proto::*; use std::sync::Arc; /// let mut config = TransportConfig::default(); /// config.congestion_controller_factory(Arc::new(congestion::NewRenoConfig::default())); /// ``` pub fn congestion_controller_factory( &mut self, factory: Arc, ) -> &mut Self { self.congestion_controller_factory = factory; self } /// Whether to use "Generic Segmentation Offload" to accelerate transmits, when supported by the /// environment /// /// Defaults to `true`. /// /// GSO dramatically reduces CPU consumption when sending large numbers of packets with the same /// headers, such as when transmitting bulk data on a connection. However, it is not supported /// by all network interface drivers or packet inspection tools. `quinn-udp` will attempt to /// disable GSO automatically when unavailable, but this can lead to spurious packet loss at /// startup, temporarily degrading performance. pub fn enable_segmentation_offload(&mut self, enabled: bool) -> &mut Self { self.enable_segmentation_offload = enabled; self } } impl Default for TransportConfig { fn default() -> Self { const EXPECTED_RTT: u32 = 100; // ms const MAX_STREAM_BANDWIDTH: u32 = 12500 * 1000; // bytes/s // Window size needed to avoid pipeline // stalls const STREAM_RWND: u32 = MAX_STREAM_BANDWIDTH / 1000 * EXPECTED_RTT; Self { max_concurrent_bidi_streams: 100u32.into(), max_concurrent_uni_streams: 100u32.into(), // 30 second default recommended by RFC 9308 ยง 3.2 max_idle_timeout: Some(VarInt(30_000)), stream_receive_window: STREAM_RWND.into(), receive_window: VarInt::MAX, send_window: (8 * STREAM_RWND).into(), send_fairness: true, packet_threshold: 3, time_threshold: 9.0 / 8.0, initial_rtt: Duration::from_millis(333), // per spec, intentionally distinct from EXPECTED_RTT initial_mtu: INITIAL_MTU, min_mtu: INITIAL_MTU, mtu_discovery_config: Some(MtuDiscoveryConfig::default()), ack_frequency_config: None, persistent_congestion_threshold: 3, keep_alive_interval: None, crypto_buffer_size: 16 * 1024, allow_spin: true, datagram_receive_buffer_size: Some(STREAM_RWND as usize), datagram_send_buffer_size: 1024 * 1024, #[cfg(test)] deterministic_packet_numbers: false, congestion_controller_factory: Arc::new(congestion::CubicConfig::default()), enable_segmentation_offload: true, } } } impl fmt::Debug for TransportConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let Self { max_concurrent_bidi_streams, max_concurrent_uni_streams, max_idle_timeout, stream_receive_window, receive_window, send_window, send_fairness, packet_threshold, time_threshold, initial_rtt, initial_mtu, min_mtu, mtu_discovery_config, ack_frequency_config, persistent_congestion_threshold, keep_alive_interval, crypto_buffer_size, allow_spin, datagram_receive_buffer_size, datagram_send_buffer_size, #[cfg(test)] deterministic_packet_numbers: _, congestion_controller_factory: _, enable_segmentation_offload, } = self; fmt.debug_struct("TransportConfig") .field("max_concurrent_bidi_streams", max_concurrent_bidi_streams) .field("max_concurrent_uni_streams", max_concurrent_uni_streams) .field("max_idle_timeout", max_idle_timeout) .field("stream_receive_window", stream_receive_window) .field("receive_window", receive_window) .field("send_window", send_window) .field("send_fairness", send_fairness) .field("packet_threshold", packet_threshold) .field("time_threshold", time_threshold) .field("initial_rtt", initial_rtt) .field("initial_mtu", initial_mtu) .field("min_mtu", min_mtu) .field("mtu_discovery_config", mtu_discovery_config) .field("ack_frequency_config", ack_frequency_config) .field( "persistent_congestion_threshold", persistent_congestion_threshold, ) .field("keep_alive_interval", keep_alive_interval) .field("crypto_buffer_size", crypto_buffer_size) .field("allow_spin", allow_spin) .field("datagram_receive_buffer_size", datagram_receive_buffer_size) .field("datagram_send_buffer_size", datagram_send_buffer_size) .field("congestion_controller_factory", &"[ opaque ]") .field("enable_segmentation_offload", enable_segmentation_offload) .finish() } } /// Parameters for controlling the peer's acknowledgement frequency /// /// The parameters provided in this config will be sent to the peer at the beginning of the /// connection, so it can take them into account when sending acknowledgements (see each parameter's /// description for details on how it influences acknowledgement frequency). /// /// Quinn's implementation follows the fourth draft of the /// [QUIC Acknowledgement Frequency extension](https://datatracker.ietf.org/doc/html/draft-ietf-quic-ack-frequency-04). /// The defaults produce behavior slightly different than the behavior without this extension, /// because they change the way reordered packets are handled (see /// [`AckFrequencyConfig::reordering_threshold`] for details). #[derive(Clone, Debug)] pub struct AckFrequencyConfig { pub(crate) ack_eliciting_threshold: VarInt, pub(crate) max_ack_delay: Option, pub(crate) reordering_threshold: VarInt, } impl AckFrequencyConfig { /// The ack-eliciting threshold we will request the peer to use /// /// This threshold represents the number of ack-eliciting packets an endpoint may receive /// without immediately sending an ACK. /// /// The remote peer should send at least one ACK frame when more than this number of /// ack-eliciting packets have been received. A value of 0 results in a receiver immediately /// acknowledging every ack-eliciting packet. /// /// Defaults to 1, which sends ACK frames for every other ack-eliciting packet. pub fn ack_eliciting_threshold(&mut self, value: VarInt) -> &mut Self { self.ack_eliciting_threshold = value; self } /// The `max_ack_delay` we will request the peer to use /// /// This parameter represents the maximum amount of time that an endpoint waits before sending /// an ACK when the ack-eliciting threshold hasn't been reached. /// /// The effective `max_ack_delay` will be clamped to be at least the peer's `min_ack_delay` /// transport parameter, and at most the greater of the current path RTT or 25ms. /// /// Defaults to `None`, in which case the peer's original `max_ack_delay` will be used, as /// obtained from its transport parameters. pub fn max_ack_delay(&mut self, value: Option) -> &mut Self { self.max_ack_delay = value; self } /// The reordering threshold we will request the peer to use /// /// This threshold represents the amount of out-of-order packets that will trigger an endpoint /// to send an ACK, without waiting for `ack_eliciting_threshold` to be exceeded or for /// `max_ack_delay` to be elapsed. /// /// A value of 0 indicates out-of-order packets do not elicit an immediate ACK. A value of 1 /// immediately acknowledges any packets that are received out of order (this is also the /// behavior when the extension is disabled). /// /// It is recommended to set this value to [`TransportConfig::packet_threshold`] minus one. /// Since the default value for [`TransportConfig::packet_threshold`] is 3, this value defaults /// to 2. pub fn reordering_threshold(&mut self, value: VarInt) -> &mut Self { self.reordering_threshold = value; self } } impl Default for AckFrequencyConfig { fn default() -> Self { Self { ack_eliciting_threshold: VarInt(1), max_ack_delay: None, reordering_threshold: VarInt(2), } } } /// Parameters governing MTU discovery. /// /// # The why of MTU discovery /// /// By design, QUIC ensures during the handshake that the network path between the client and the /// server is able to transmit unfragmented UDP packets with a body of 1200 bytes. In other words, /// once the connection is established, we know that the network path's maximum transmission unit /// (MTU) is of at least 1200 bytes (plus IP and UDP headers). Because of this, a QUIC endpoint can /// split outgoing data in packets of 1200 bytes, with confidence that the network will be able to /// deliver them (if the endpoint were to send bigger packets, they could prove too big and end up /// being dropped). /// /// There is, however, a significant overhead associated to sending a packet. If the same /// information can be sent in fewer packets, that results in higher throughput. The amount of /// packets that need to be sent is inversely proportional to the MTU: the higher the MTU, the /// bigger the packets that can be sent, and the fewer packets that are needed to transmit a given /// amount of bytes. /// /// Most networks have an MTU higher than 1200. Through MTU discovery, endpoints can detect the /// path's MTU and, if it turns out to be higher, start sending bigger packets. /// /// # MTU discovery internals /// /// Quinn implements MTU discovery through DPLPMTUD (Datagram Packetization Layer Path MTU /// Discovery), described in [section 14.3 of RFC /// 9000](https://www.rfc-editor.org/rfc/rfc9000.html#section-14.3). This method consists of sending /// QUIC packets padded to a particular size (called PMTU probes), and waiting to see if the remote /// peer responds with an ACK. If an ACK is received, that means the probe arrived at the remote /// peer, which in turn means that the network path's MTU is of at least the packet's size. If the /// probe is lost, it is sent another 2 times before concluding that the MTU is lower than the /// packet's size. /// /// MTU discovery runs on a schedule (e.g. every 600 seconds) specified through /// [`MtuDiscoveryConfig::interval`]. The first run happens right after the handshake, and /// subsequent discoveries are scheduled to run when the interval has elapsed, starting from the /// last time when MTU discovery completed. /// /// Since the search space for MTUs is quite big (the smallest possible MTU is 1200, and the highest /// is 65527), Quinn performs a binary search to keep the number of probes as low as possible. The /// lower bound of the search is equal to [`TransportConfig::initial_mtu`] in the /// initial MTU discovery run, and is equal to the currently discovered MTU in subsequent runs. The /// upper bound is determined by the minimum of [`MtuDiscoveryConfig::upper_bound`] and the /// `max_udp_payload_size` transport parameter received from the peer during the handshake. /// /// # Black hole detection /// /// If, at some point, the network path no longer accepts packets of the detected size, packet loss /// will eventually trigger black hole detection and reset the detected MTU to 1200. In that case, /// MTU discovery will be triggered after [`MtuDiscoveryConfig::black_hole_cooldown`] (ignoring the /// timer that was set based on [`MtuDiscoveryConfig::interval`]). /// /// # Interaction between peers /// /// There is no guarantee that the MTU on the path between A and B is the same as the MTU of the /// path between B and A. Therefore, each peer in the connection needs to run MTU discovery /// independently in order to discover the path's MTU. #[derive(Clone, Debug)] pub struct MtuDiscoveryConfig { pub(crate) interval: Duration, pub(crate) upper_bound: u16, pub(crate) minimum_change: u16, pub(crate) black_hole_cooldown: Duration, } impl MtuDiscoveryConfig { /// Specifies the time to wait after completing MTU discovery before starting a new MTU /// discovery run. /// /// Defaults to 600 seconds, as recommended by [RFC /// 8899](https://www.rfc-editor.org/rfc/rfc8899). pub fn interval(&mut self, value: Duration) -> &mut Self { self.interval = value; self } /// Specifies the upper bound to the max UDP payload size that MTU discovery will search for. /// /// Defaults to 1452, to stay within Ethernet's MTU when using IPv4 and IPv6. The highest /// allowed value is 65527, which corresponds to the maximum permitted UDP payload on IPv6. /// /// It is safe to use an arbitrarily high upper bound, regardless of the network path's MTU. The /// only drawback is that MTU discovery might take more time to finish. pub fn upper_bound(&mut self, value: u16) -> &mut Self { self.upper_bound = value.min(MAX_UDP_PAYLOAD); self } /// Specifies the amount of time that MTU discovery should wait after a black hole was detected /// before running again. Defaults to one minute. /// /// Black hole detection can be spuriously triggered in case of congestion, so it makes sense to /// try MTU discovery again after a short period of time. pub fn black_hole_cooldown(&mut self, value: Duration) -> &mut Self { self.black_hole_cooldown = value; self } /// Specifies the minimum MTU change to stop the MTU discovery phase. /// Defaults to 20. pub fn minimum_change(&mut self, value: u16) -> &mut Self { self.minimum_change = value; self } } impl Default for MtuDiscoveryConfig { fn default() -> Self { Self { interval: Duration::from_secs(600), upper_bound: 1452, black_hole_cooldown: Duration::from_secs(60), minimum_change: 20, } } } /// Global configuration for the endpoint, affecting all connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, /// CID generator factory /// /// Create a cid generator for local cid in Endpoint struct pub(crate) connection_id_generator_factory: Arc Box + Send + Sync>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, /// Minimum interval between outgoing stateless reset packets pub(crate) min_reset_interval: Duration, /// Optional seed to be used internally for random number generation pub(crate) rng_seed: Option<[u8; 32]>, } impl EndpointConfig { /// Create a default config with a particular `reset_key` pub fn new(reset_key: Arc) -> Self { let cid_factory = || -> Box { Box::::default() }; Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers connection_id_generator_factory: Arc::new(cid_factory), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, min_reset_interval: Duration::from_millis(20), rng_seed: None, } } /// Supply a custom connection ID generator factory /// /// Called once by each `Endpoint` constructed from this configuration to obtain the CID /// generator which will be used to generate the CIDs used for incoming packets on all /// connections involving that `Endpoint`. A custom CID generator allows applications to embed /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. pub fn cid_generator Box + Send + Sync + 'static>( &mut self, factory: F, ) -> &mut Self { self.connection_id_generator_factory = Arc::new(factory); self } /// Private key used to send authenticated connection resets to peers who were /// communicating with a previous instance of this endpoint. pub fn reset_key(&mut self, key: Arc) -> &mut Self { self.reset_key = key; self } /// Maximum UDP payload size accepted from peers (excluding UDP and IP overhead). /// /// Must be greater or equal than 1200. /// /// Defaults to 1472, which is the largest UDP payload that can be transmitted in the typical /// 1500 byte Ethernet MTU. Deployments on links with larger MTUs (e.g. loopback or Ethernet /// with jumbo frames) can raise this to improve performance at the cost of a linear increase in /// datagram receive buffer size. pub fn max_udp_payload_size(&mut self, value: u16) -> Result<&mut Self, ConfigError> { if !(1200..=65_527).contains(&value) { return Err(ConfigError::OutOfBounds); } self.max_udp_payload_size = value.into(); Ok(self) } /// Get the current value of `max_udp_payload_size` /// /// While most parameters don't need to be readable, this must be exposed to allow higher-level /// layers, e.g. the `quinn` crate, to determine how large a receive buffer to allocate to /// support an externally-defined `EndpointConfig`. /// /// While `get_` accessors are typically unidiomatic in Rust, we favor concision for setters, /// which will be used far more heavily. #[doc(hidden)] pub fn get_max_udp_payload_size(&self) -> u64 { self.max_udp_payload_size.into() } /// Override supported QUIC versions pub fn supported_versions(&mut self, supported_versions: Vec) -> &mut Self { self.supported_versions = supported_versions; self } /// Whether to accept QUIC packets containing any value for the fixed bit /// /// Enabled by default. Helps protect against protocol ossification and makes traffic less /// identifiable to observers. Disable if helping observers identify this traffic as QUIC is /// desired. pub fn grease_quic_bit(&mut self, value: bool) -> &mut Self { self.grease_quic_bit = value; self } /// Minimum interval between outgoing stateless reset packets /// /// Defaults to 20ms. Limits the impact of attacks which flood an endpoint with garbage packets, /// e.g. [ISAKMP/IKE amplification]. Larger values provide a stronger defense, but may delay /// detection of some error conditions by clients. Using a [`ConnectionIdGenerator`] with a low /// rate of false positives in [`validate`](ConnectionIdGenerator::validate) reduces the risk /// incurred by a small minimum reset interval. /// /// [ISAKMP/IKE /// amplification]: https://bughunters.google.com/blog/5960150648750080/preventing-cross-service-udp-loops-in-quic#isakmp-ike-amplification-vs-quic pub fn min_reset_interval(&mut self, value: Duration) -> &mut Self { self.min_reset_interval = value; self } /// Optional seed to be used internally for random number generation /// /// By default, quinn will initialize an endpoint's rng using a platform entropy source. /// However, you can seed the rng yourself through this method (e.g. if you need to run quinn /// deterministically or if you are using quinn in an environment that doesn't have a source of /// entropy available). pub fn rng_seed(&mut self, seed: Option<[u8; 32]>) -> &mut Self { self.rng_seed = seed; self } } impl fmt::Debug for EndpointConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("EndpointConfig") .field("reset_key", &"[ elided ]") .field("max_udp_payload_size", &self.max_udp_payload_size) .field("cid_generator_factory", &"[ elided ]") .field("supported_versions", &self.supported_versions) .field("grease_quic_bit", &self.grease_quic_bit) .field("rng_seed", &self.rng_seed) .finish() } } #[cfg(any(feature = "aws-lc-rs", feature = "ring"))] impl Default for EndpointConfig { fn default() -> Self { #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::hmac; use rand::RngCore; #[cfg(feature = "ring")] use ring::hmac; let mut reset_key = [0; 64]; rand::thread_rng().fill_bytes(&mut reset_key); Self::new(Arc::new(hmac::Key::new(hmac::HMAC_SHA256, &reset_key))) } } /// Parameters governing incoming connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] pub struct ServerConfig { /// Transport configuration to use for incoming connections pub transport: Arc, /// TLS configuration used for incoming connections. /// /// Must be set to use TLS 1.3 only. pub crypto: Arc, /// Used to generate one-time AEAD keys to protect handshake tokens pub(crate) token_key: Arc, /// Microseconds after a stateless retry token was issued for which it's considered valid. pub(crate) retry_token_lifetime: Duration, /// Whether to allow clients to migrate to new addresses /// /// Improves behavior for clients that move between different internet connections or suffer NAT /// rebinding. Enabled by default. pub(crate) migration: bool, pub(crate) preferred_address_v4: Option, pub(crate) preferred_address_v6: Option, pub(crate) max_incoming: usize, pub(crate) incoming_buffer_size: u64, pub(crate) incoming_buffer_size_total: u64, } impl ServerConfig { /// Create a default config with a particular handshake token key pub fn new( crypto: Arc, token_key: Arc, ) -> Self { Self { transport: Arc::new(TransportConfig::default()), crypto, token_key, retry_token_lifetime: Duration::from_secs(15), migration: true, preferred_address_v4: None, preferred_address_v6: None, max_incoming: 1 << 16, incoming_buffer_size: 10 << 20, incoming_buffer_size_total: 100 << 20, } } /// Set a custom [`TransportConfig`] pub fn transport_config(&mut self, transport: Arc) -> &mut Self { self.transport = transport; self } /// Private key used to authenticate data included in handshake tokens. pub fn token_key(&mut self, value: Arc) -> &mut Self { self.token_key = value; self } /// Duration after a stateless retry token was issued for which it's considered valid. pub fn retry_token_lifetime(&mut self, value: Duration) -> &mut Self { self.retry_token_lifetime = value; self } /// Whether to allow clients to migrate to new addresses /// /// Improves behavior for clients that move between different internet connections or suffer NAT /// rebinding. Enabled by default. pub fn migration(&mut self, value: bool) -> &mut Self { self.migration = value; self } /// The preferred IPv4 address that will be communicated to clients during handshaking. /// If the client is able to reach this address, it will switch to it. pub fn preferred_address_v4(&mut self, address: Option) -> &mut Self { self.preferred_address_v4 = address; self } /// The preferred IPv6 address that will be communicated to clients during handshaking. /// If the client is able to reach this address, it will switch to it. pub fn preferred_address_v6(&mut self, address: Option) -> &mut Self { self.preferred_address_v6 = address; self } /// Maximum number of [`Incoming`][crate::Incoming] to allow to exist at a time /// /// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt /// is received and stops existing when the application either accepts it or otherwise disposes /// of it. While this limit is reached, new incoming connection attempts are immediately /// refused. Larger values have greater worst-case memory consumption, but accommodate greater /// application latency in handling incoming connection attempts. /// /// The default value is set to 65536. With a typical Ethernet MTU of 1500 bytes, this limits /// memory consumption from this to under 100 MiB--a generous amount that still prevents memory /// exhaustion in most contexts. pub fn max_incoming(&mut self, max_incoming: usize) -> &mut Self { self.max_incoming = max_incoming; self } /// Maximum number of received bytes to buffer for each [`Incoming`][crate::Incoming] /// /// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt /// is received and stops existing when the application either accepts it or otherwise disposes /// of it. This limit governs only packets received within that period, and does not include /// the first packet. Packets received in excess of this limit are dropped, which may cause /// 0-RTT or handshake data to have to be retransmitted. /// /// The default value is set to 10 MiB--an amount such that in most situations a client would /// not transmit that much 0-RTT data faster than the server handles the corresponding /// [`Incoming`][crate::Incoming]. pub fn incoming_buffer_size(&mut self, incoming_buffer_size: u64) -> &mut Self { self.incoming_buffer_size = incoming_buffer_size; self } /// Maximum number of received bytes to buffer for all [`Incoming`][crate::Incoming] /// collectively /// /// An [`Incoming`][crate::Incoming] comes into existence when an incoming connection attempt /// is received and stops existing when the application either accepts it or otherwise disposes /// of it. This limit governs only packets received within that period, and does not include /// the first packet. Packets received in excess of this limit are dropped, which may cause /// 0-RTT or handshake data to have to be retransmitted. /// /// The default value is set to 100 MiB--a generous amount that still prevents memory /// exhaustion in most contexts. pub fn incoming_buffer_size_total(&mut self, incoming_buffer_size_total: u64) -> &mut Self { self.incoming_buffer_size_total = incoming_buffer_size_total; self } } #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] impl ServerConfig { /// Create a server config with the given certificate chain to be presented to clients /// /// Uses a randomized handshake token key. pub fn with_single_cert( cert_chain: Vec>, key: PrivateKeyDer<'static>, ) -> Result { Ok(Self::with_crypto(Arc::new(QuicServerConfig::new( cert_chain, key, )?))) } } #[cfg(any(feature = "aws-lc-rs", feature = "ring"))] impl ServerConfig { /// Create a server config with the given [`crypto::ServerConfig`] /// /// Uses a randomized handshake token key. pub fn with_crypto(crypto: Arc) -> Self { #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::hkdf; use rand::RngCore; #[cfg(feature = "ring")] use ring::hkdf; let rng = &mut rand::thread_rng(); let mut master_key = [0u8; 64]; rng.fill_bytes(&mut master_key); let master_key = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); Self::new(crypto, Arc::new(master_key)) } } impl fmt::Debug for ServerConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ServerConfig") .field("transport", &self.transport) .field("crypto", &"ServerConfig { elided }") .field("token_key", &"[ elided ]") .field("retry_token_lifetime", &self.retry_token_lifetime) .field("migration", &self.migration) .field("preferred_address_v4", &self.preferred_address_v4) .field("preferred_address_v6", &self.preferred_address_v6) .field("max_incoming", &self.max_incoming) .field("incoming_buffer_size", &self.incoming_buffer_size) .field( "incoming_buffer_size_total", &self.incoming_buffer_size_total, ) .finish() } } /// Configuration for outgoing connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] #[non_exhaustive] pub struct ClientConfig { /// Transport configuration to use pub(crate) transport: Arc, /// Cryptographic configuration to use pub(crate) crypto: Arc, /// Provider that populates the destination connection ID of Initial Packets pub(crate) initial_dst_cid_provider: Arc ConnectionId + Send + Sync>, /// QUIC protocol version to use pub(crate) version: u32, } impl ClientConfig { /// Create a default config with a particular cryptographic config pub fn new(crypto: Arc) -> Self { Self { transport: Default::default(), crypto, initial_dst_cid_provider: Arc::new(|| { RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid() }), version: 1, } } /// Configure how to populate the destination CID of the initial packet when attempting to /// establish a new connection. /// /// By default, it's populated with random bytes with reasonable length, so unless you have /// a good reason, you do not need to change it. /// /// When prefer to override the default, please note that the generated connection ID MUST be /// at least 8 bytes long and unpredictable, as per section 7.2 of RFC 9000. pub fn initial_dst_cid_provider( &mut self, initial_dst_cid_provider: Arc ConnectionId + Send + Sync>, ) -> &mut Self { self.initial_dst_cid_provider = initial_dst_cid_provider; self } /// Set a custom [`TransportConfig`] pub fn transport_config(&mut self, transport: Arc) -> &mut Self { self.transport = transport; self } /// Set the QUIC version to use pub fn version(&mut self, version: u32) -> &mut Self { self.version = version; self } } #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] impl ClientConfig { /// Create a client configuration that trusts the platform's native roots #[cfg(feature = "platform-verifier")] pub fn with_platform_verifier() -> Self { Self::new(Arc::new(crypto::rustls::QuicClientConfig::new(Arc::new( rustls_platform_verifier::Verifier::new(), )))) } /// Create a client configuration that trusts specified trust anchors pub fn with_root_certificates( roots: Arc, ) -> Result { Ok(Self::new(Arc::new(crypto::rustls::QuicClientConfig::new( WebPkiServerVerifier::builder_with_provider(roots, configured_provider()).build()?, )))) } } impl fmt::Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ClientConfig") .field("transport", &self.transport) .field("crypto", &"ClientConfig { elided }") .field("version", &self.version) .finish_non_exhaustive() } } /// Errors in the configuration of an endpoint #[derive(Debug, Error, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum ConfigError { /// Value exceeds supported bounds #[error("value exceeds supported bounds")] OutOfBounds, } impl From for ConfigError { fn from(_: TryFromIntError) -> Self { Self::OutOfBounds } } impl From for ConfigError { fn from(_: VarIntBoundsExceeded) -> Self { Self::OutOfBounds } } /// Maximum duration of inactivity to accept before timing out the connection. /// /// This wraps an underlying [`VarInt`], representing the duration in milliseconds. Values can be /// constructed by converting directly from `VarInt`, or using `TryFrom`. /// /// ``` /// # use std::{convert::TryFrom, time::Duration}; /// # use quinn_proto::{IdleTimeout, VarIntBoundsExceeded, VarInt}; /// # fn main() -> Result<(), VarIntBoundsExceeded> { /// // A `VarInt`-encoded value in milliseconds /// let timeout = IdleTimeout::from(VarInt::from_u32(10_000)); /// /// // Try to convert a `Duration` into a `VarInt`-encoded timeout /// let timeout = IdleTimeout::try_from(Duration::from_secs(10))?; /// # Ok(()) /// # } /// ``` #[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct IdleTimeout(VarInt); impl From for IdleTimeout { fn from(inner: VarInt) -> Self { Self(inner) } } impl std::convert::TryFrom for IdleTimeout { type Error = VarIntBoundsExceeded; fn try_from(timeout: Duration) -> Result { let inner = VarInt::try_from(timeout.as_millis())?; Ok(Self(inner)) } } impl fmt::Debug for IdleTimeout { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } quinn-proto-0.11.9/src/congestion/bbr/bw_estimation.rs000064400000000000000000000067621046102023000211400ustar 00000000000000use std::fmt::{Debug, Display, Formatter}; use super::min_max::MinMax; use crate::{Duration, Instant}; #[derive(Clone, Debug)] pub(crate) struct BandwidthEstimation { total_acked: u64, prev_total_acked: u64, acked_time: Option, prev_acked_time: Option, total_sent: u64, prev_total_sent: u64, sent_time: Instant, prev_sent_time: Option, max_filter: MinMax, acked_at_last_window: u64, } impl BandwidthEstimation { pub(crate) fn on_sent(&mut self, now: Instant, bytes: u64) { self.prev_total_sent = self.total_sent; self.total_sent += bytes; self.prev_sent_time = Some(self.sent_time); self.sent_time = now; } pub(crate) fn on_ack( &mut self, now: Instant, _sent: Instant, bytes: u64, round: u64, app_limited: bool, ) { self.prev_total_acked = self.total_acked; self.total_acked += bytes; self.prev_acked_time = self.acked_time; self.acked_time = Some(now); let prev_sent_time = match self.prev_sent_time { Some(prev_sent_time) => prev_sent_time, None => return, }; let send_rate = if self.sent_time > prev_sent_time { Self::bw_from_delta( self.total_sent - self.prev_total_sent, self.sent_time - prev_sent_time, ) .unwrap_or(0) } else { u64::MAX // will take the min of send and ack, so this is just a skip }; let ack_rate = match self.prev_acked_time { Some(prev_acked_time) => Self::bw_from_delta( self.total_acked - self.prev_total_acked, now - prev_acked_time, ) .unwrap_or(0), None => 0, }; let bandwidth = send_rate.min(ack_rate); if !app_limited && self.max_filter.get() < bandwidth { self.max_filter.update_max(round, bandwidth); } } pub(crate) fn bytes_acked_this_window(&self) -> u64 { self.total_acked - self.acked_at_last_window } pub(crate) fn end_acks(&mut self, _current_round: u64, _app_limited: bool) { self.acked_at_last_window = self.total_acked; } pub(crate) fn get_estimate(&self) -> u64 { self.max_filter.get() } pub(crate) const fn bw_from_delta(bytes: u64, delta: Duration) -> Option { let window_duration_ns = delta.as_nanos(); if window_duration_ns == 0 { return None; } let b_ns = bytes * 1_000_000_000; let bytes_per_second = b_ns / (window_duration_ns as u64); Some(bytes_per_second) } } impl Default for BandwidthEstimation { fn default() -> Self { Self { total_acked: 0, prev_total_acked: 0, acked_time: None, prev_acked_time: None, total_sent: 0, prev_total_sent: 0, // The `sent_time` value set here is ignored; it is used in `on_ack()`, but will // have been reset by `on_sent()` before that method is called. sent_time: Instant::now(), prev_sent_time: None, max_filter: MinMax::default(), acked_at_last_window: 0, } } } impl Display for BandwidthEstimation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, "{:.3} MB/s", self.get_estimate() as f32 / (1024 * 1024) as f32 ) } } quinn-proto-0.11.9/src/congestion/bbr/min_max.rs000064400000000000000000000120641046102023000177140ustar 00000000000000/* * Based on Google code released under BSD license here: * https://groups.google.com/forum/#!topic/bbr-dev/3RTgkzi5ZD8 */ /* * Kathleen Nichols' algorithm for tracking the minimum (or maximum) * value of a data stream over some fixed time interval. (E.g., * the minimum RTT over the past five minutes.) It uses constant * space and constant time per update yet almost always delivers * the same minimum as an implementation that has to keep all the * data in the window. * * The algorithm keeps track of the best, 2nd best & 3rd best min * values, maintaining an invariant that the measurement time of * the n'th best >= n-1'th best. It also makes sure that the three * values are widely separated in the time window since that bounds * the worse case error when that data is monotonically increasing * over the window. * * Upon getting a new min, we can forget everything earlier because * it has no value - the new min is <= everything else in the window * by definition and it samples the most recent. So we restart fresh on * every new min and overwrites 2nd & 3rd choices. The same property * holds for 2nd & 3rd best. */ use std::fmt::Debug; #[derive(Copy, Clone, Debug)] pub(super) struct MinMax { /// round count, not a timestamp window: u64, samples: [MinMaxSample; 3], } impl MinMax { pub(super) fn get(&self) -> u64 { self.samples[0].value } fn fill(&mut self, sample: MinMaxSample) { self.samples.fill(sample); } pub(super) fn reset(&mut self) { self.fill(Default::default()) } /// update_min is also defined in the original source, but removed here since it is not used. pub(super) fn update_max(&mut self, current_round: u64, measurement: u64) { let sample = MinMaxSample { time: current_round, value: measurement, }; if self.samples[0].value == 0 /* uninitialised */ || /* found new max? */ sample.value >= self.samples[0].value || /* nothing left in window? */ sample.time - self.samples[2].time > self.window { self.fill(sample); /* forget earlier samples */ return; } if sample.value >= self.samples[1].value { self.samples[2] = sample; self.samples[1] = sample; } else if sample.value >= self.samples[2].value { self.samples[2] = sample; } self.subwin_update(sample); } /* As time advances, update the 1st, 2nd, and 3rd choices. */ fn subwin_update(&mut self, sample: MinMaxSample) { let dt = sample.time - self.samples[0].time; if dt > self.window { /* * Passed entire window without a new sample so make 2nd * choice the new sample & 3rd choice the new 2nd choice. * we may have to iterate this since our 2nd choice * may also be outside the window (we checked on entry * that the third choice was in the window). */ self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; if sample.time - self.samples[0].time > self.window { self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; } } else if self.samples[1].time == self.samples[0].time && dt > self.window / 4 { /* * We've passed a quarter of the window without a new sample * so take a 2nd choice from the 2nd quarter of the window. */ self.samples[2] = sample; self.samples[1] = sample; } else if self.samples[2].time == self.samples[1].time && dt > self.window / 2 { /* * We've passed half the window without finding a new sample * so take a 3rd choice from the last half of the window */ self.samples[2] = sample; } } } impl Default for MinMax { fn default() -> Self { Self { window: 10, samples: [Default::default(); 3], } } } #[derive(Debug, Copy, Clone, Default)] struct MinMaxSample { /// round number, not a timestamp time: u64, value: u64, } #[cfg(test)] mod test { use super::*; #[test] fn test() { let round = 25; let mut min_max = MinMax::default(); min_max.update_max(round + 1, 100); assert_eq!(100, min_max.get()); min_max.update_max(round + 3, 120); assert_eq!(120, min_max.get()); min_max.update_max(round + 5, 160); assert_eq!(160, min_max.get()); min_max.update_max(round + 7, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 10, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 14, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 16, 100); assert_eq!(100, min_max.get()); min_max.update_max(round + 18, 130); assert_eq!(130, min_max.get()); } } quinn-proto-0.11.9/src/congestion/bbr/mod.rs000064400000000000000000000550621046102023000170500ustar 00000000000000use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use rand::{Rng, SeedableRng}; use crate::congestion::bbr::bw_estimation::BandwidthEstimation; use crate::congestion::bbr::min_max::MinMax; use crate::connection::RttEstimator; use crate::{Duration, Instant}; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; mod bw_estimation; mod min_max; /// Experimental! Use at your own risk. /// /// Aims for reduced buffer bloat and improved performance over high bandwidth-delay product networks. /// Based on google's quiche implementation /// of BBR . /// More discussion and links at . #[derive(Debug, Clone)] pub struct Bbr { config: Arc, current_mtu: u64, max_bandwidth: BandwidthEstimation, acked_bytes: u64, mode: Mode, loss_state: LossState, recovery_state: RecoveryState, recovery_window: u64, is_at_full_bandwidth: bool, pacing_gain: f32, high_gain: f32, drain_gain: f32, cwnd_gain: f32, high_cwnd_gain: f32, last_cycle_start: Option, current_cycle_offset: u8, init_cwnd: u64, min_cwnd: u64, prev_in_flight_count: u64, exit_probe_rtt_at: Option, probe_rtt_last_started_at: Option, min_rtt: Duration, exiting_quiescence: bool, pacing_rate: u64, max_acked_packet_number: u64, max_sent_packet_number: u64, end_recovery_at_packet_number: u64, cwnd: u64, current_round_trip_end_packet_number: u64, round_count: u64, bw_at_last_round: u64, round_wo_bw_gain: u64, ack_aggregation: AckAggregationState, random_number_generator: rand::rngs::StdRng, } impl Bbr { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, current_mtu: u16) -> Self { let initial_window = config.initial_window; Self { config, current_mtu: current_mtu as u64, max_bandwidth: BandwidthEstimation::default(), acked_bytes: 0, mode: Mode::Startup, loss_state: Default::default(), recovery_state: RecoveryState::NotInRecovery, recovery_window: 0, is_at_full_bandwidth: false, pacing_gain: K_DEFAULT_HIGH_GAIN, high_gain: K_DEFAULT_HIGH_GAIN, drain_gain: 1.0 / K_DEFAULT_HIGH_GAIN, cwnd_gain: K_DEFAULT_HIGH_GAIN, high_cwnd_gain: K_DEFAULT_HIGH_GAIN, last_cycle_start: None, current_cycle_offset: 0, init_cwnd: initial_window, min_cwnd: calculate_min_window(current_mtu as u64), prev_in_flight_count: 0, exit_probe_rtt_at: None, probe_rtt_last_started_at: None, min_rtt: Default::default(), exiting_quiescence: false, pacing_rate: 0, max_acked_packet_number: 0, max_sent_packet_number: 0, end_recovery_at_packet_number: 0, cwnd: initial_window, current_round_trip_end_packet_number: 0, round_count: 0, bw_at_last_round: 0, round_wo_bw_gain: 0, ack_aggregation: AckAggregationState::default(), random_number_generator: rand::rngs::StdRng::from_entropy(), } } fn enter_startup_mode(&mut self) { self.mode = Mode::Startup; self.pacing_gain = self.high_gain; self.cwnd_gain = self.high_cwnd_gain; } fn enter_probe_bandwidth_mode(&mut self, now: Instant) { self.mode = Mode::ProbeBw; self.cwnd_gain = K_DERIVED_HIGH_CWNDGAIN; self.last_cycle_start = Some(now); // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is // excluded because in that case increased gain and decreased gain would not // follow each other. let mut rand_index = self .random_number_generator .gen_range(0..K_PACING_GAIN.len() as u8 - 1); if rand_index >= 1 { rand_index += 1; } self.current_cycle_offset = rand_index; self.pacing_gain = K_PACING_GAIN[rand_index as usize]; } fn update_recovery_state(&mut self, is_round_start: bool) { // Exit recovery when there are no losses for a round. if self.loss_state.has_losses() { self.end_recovery_at_packet_number = self.max_sent_packet_number; } match self.recovery_state { // Enter conservation on the first loss. RecoveryState::NotInRecovery if self.loss_state.has_losses() => { self.recovery_state = RecoveryState::Conservation; // This will cause the |recovery_window| to be set to the // correct value in CalculateRecoveryWindow(). self.recovery_window = 0; // Since the conservation phase is meant to be lasting for a whole // round, extend the current round as if it were started right now. self.current_round_trip_end_packet_number = self.max_sent_packet_number; } RecoveryState::Growth | RecoveryState::Conservation => { if self.recovery_state == RecoveryState::Conservation && is_round_start { self.recovery_state = RecoveryState::Growth; } // Exit recovery if appropriate. if !self.loss_state.has_losses() && self.max_acked_packet_number > self.end_recovery_at_packet_number { self.recovery_state = RecoveryState::NotInRecovery; } } _ => {} } } fn update_gain_cycle_phase(&mut self, now: Instant, in_flight: u64) { // In most cases, the cycle is advanced after an RTT passes. let mut should_advance_gain_cycling = self .last_cycle_start .map(|last_cycle_start| now.duration_since(last_cycle_start) > self.min_rtt) .unwrap_or(false); // If the pacing gain is above 1.0, the connection is trying to probe the // bandwidth by increasing the number of bytes in flight to at least // pacing_gain * BDP. Make sure that it actually reaches the target, as // long as there are no losses suggesting that the buffers are not able to // hold that much. if self.pacing_gain > 1.0 && !self.loss_state.has_losses() && self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain) { should_advance_gain_cycling = false; } // If pacing gain is below 1.0, the connection is trying to drain the extra // queue which could have been incurred by probing prior to it. If the // number of bytes in flight falls down to the estimated BDP value earlier, // conclude that the queue has been successfully drained and exit this cycle // early. if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) { should_advance_gain_cycling = true; } if should_advance_gain_cycling { self.current_cycle_offset = (self.current_cycle_offset + 1) % K_PACING_GAIN.len() as u8; self.last_cycle_start = Some(now); // Stay in low gain mode until the target BDP is hit. Low gain mode // will be exited immediately when the target BDP is achieved. if DRAIN_TO_TARGET && self.pacing_gain < 1.0 && (K_PACING_GAIN[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON && in_flight > self.get_target_cwnd(1.0) { return; } self.pacing_gain = K_PACING_GAIN[self.current_cycle_offset as usize]; } } fn maybe_exit_startup_or_drain(&mut self, now: Instant, in_flight: u64) { if self.mode == Mode::Startup && self.is_at_full_bandwidth { self.mode = Mode::Drain; self.pacing_gain = self.drain_gain; self.cwnd_gain = self.high_cwnd_gain; } if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) { self.enter_probe_bandwidth_mode(now); } } fn is_min_rtt_expired(&self, now: Instant, app_limited: bool) -> bool { !app_limited && self .probe_rtt_last_started_at .map(|last| now.saturating_duration_since(last) > Duration::from_secs(10)) .unwrap_or(true) } fn maybe_enter_or_exit_probe_rtt( &mut self, now: Instant, is_round_start: bool, bytes_in_flight: u64, app_limited: bool, ) { let min_rtt_expired = self.is_min_rtt_expired(now, app_limited); if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt { self.mode = Mode::ProbeRtt; self.pacing_gain = 1.0; // Do not decide on the time to exit ProbeRtt until the // |bytes_in_flight| is at the target small value. self.exit_probe_rtt_at = None; self.probe_rtt_last_started_at = Some(now); } if self.mode == Mode::ProbeRtt { if self.exit_probe_rtt_at.is_none() { // If the window has reached the appropriate size, schedule exiting // ProbeRtt. The CWND during ProbeRtt is // kMinimumCongestionWindow, but we allow an extra packet since QUIC // checks CWND before sending a packet. if bytes_in_flight < self.get_probe_rtt_cwnd() + self.current_mtu { const K_PROBE_RTT_TIME: Duration = Duration::from_millis(200); self.exit_probe_rtt_at = Some(now + K_PROBE_RTT_TIME); } } else if is_round_start && now >= self.exit_probe_rtt_at.unwrap() { if !self.is_at_full_bandwidth { self.enter_startup_mode(); } else { self.enter_probe_bandwidth_mode(now); } } } self.exiting_quiescence = false; } fn get_target_cwnd(&self, gain: f32) -> u64 { let bw = self.max_bandwidth.get_estimate(); let bdp = self.min_rtt.as_micros() as u64 * bw; let bdpf = bdp as f64; let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64; // BDP estimate will be zero if no bandwidth samples are available yet. if cwnd == 0 { return self.init_cwnd; } cwnd.max(self.min_cwnd) } fn get_probe_rtt_cwnd(&self) -> u64 { const K_MODERATE_PROBE_RTT_MULTIPLIER: f32 = 0.75; if PROBE_RTT_BASED_ON_BDP { return self.get_target_cwnd(K_MODERATE_PROBE_RTT_MULTIPLIER); } self.min_cwnd } fn calculate_pacing_rate(&mut self) { let bw = self.max_bandwidth.get_estimate(); if bw == 0 { return; } let target_rate = (bw as f64 * self.pacing_gain as f64) as u64; if self.is_at_full_bandwidth { self.pacing_rate = target_rate; return; } // Pace at the rate of initial_window / RTT as soon as RTT measurements are // available. if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 { self.pacing_rate = BandwidthEstimation::bw_from_delta(self.init_cwnd, self.min_rtt).unwrap(); return; } // Do not decrease the pacing rate during startup. if self.pacing_rate < target_rate { self.pacing_rate = target_rate; } } fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) { if self.mode == Mode::ProbeRtt { return; } let mut target_window = self.get_target_cwnd(self.cwnd_gain); if self.is_at_full_bandwidth { // Add the max recently measured ack aggregation to CWND. target_window += self.ack_aggregation.max_ack_height.get(); } else { // Add the most recent excess acked. Because CWND never decreases in // STARTUP, this will automatically create a very localized max filter. target_window += excess_acked; } // Instead of immediately setting the target CWND as the new one, BBR grows // the CWND towards |target_window| by only increasing it |bytes_acked| at a // time. if self.is_at_full_bandwidth { self.cwnd = target_window.min(self.cwnd + bytes_acked); } else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.init_cwnd) { // If the connection is not yet out of startup phase, do not decrease // the window. self.cwnd += bytes_acked; } // Enforce the limits on the congestion window. if self.cwnd < self.min_cwnd { self.cwnd = self.min_cwnd; } } fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) { if !self.recovery_state.in_recovery() { return; } // Set up the initial recovery window. if self.recovery_window == 0 { self.recovery_window = self.min_cwnd.max(in_flight + bytes_acked); return; } // Remove losses from the recovery window, while accounting for a potential // integer underflow. if self.recovery_window >= bytes_lost { self.recovery_window -= bytes_lost; } else { // k_max_segment_size = current_mtu self.recovery_window = self.current_mtu; } // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, // release additional |bytes_acked| to achieve a slow-start-like behavior. if self.recovery_state == RecoveryState::Growth { self.recovery_window += bytes_acked; } // Sanity checks. Ensure that we always allow to send at least an MSS or // |bytes_acked| in response, whichever is larger. self.recovery_window = self .recovery_window .max(in_flight + bytes_acked) .max(self.min_cwnd); } /// fn check_if_full_bw_reached(&mut self, app_limited: bool) { if app_limited { return; } let target = (self.bw_at_last_round as f64 * K_STARTUP_GROWTH_TARGET as f64) as u64; let bw = self.max_bandwidth.get_estimate(); if bw >= target { self.bw_at_last_round = bw; self.round_wo_bw_gain = 0; self.ack_aggregation.max_ack_height.reset(); return; } self.round_wo_bw_gain += 1; if self.round_wo_bw_gain >= K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP as u64 || (self.recovery_state.in_recovery()) { self.is_at_full_bandwidth = true; } } } impl Controller for Bbr { fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) { self.max_sent_packet_number = last_packet_number; self.max_bandwidth.on_sent(now, bytes); } fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { self.max_bandwidth .on_ack(now, sent, bytes, self.round_count, app_limited); self.acked_bytes += bytes; if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt.min() { self.min_rtt = rtt.min(); } } fn on_end_acks( &mut self, now: Instant, in_flight: u64, app_limited: bool, largest_packet_num_acked: Option, ) { let bytes_acked = self.max_bandwidth.bytes_acked_this_window(); let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes( bytes_acked, now, self.round_count, self.max_bandwidth.get_estimate(), ); self.max_bandwidth.end_acks(self.round_count, app_limited); if let Some(largest_acked_packet) = largest_packet_num_acked { self.max_acked_packet_number = largest_acked_packet; } let mut is_round_start = false; if bytes_acked > 0 { is_round_start = self.max_acked_packet_number > self.current_round_trip_end_packet_number; if is_round_start { self.current_round_trip_end_packet_number = self.max_sent_packet_number; self.round_count += 1; } } self.update_recovery_state(is_round_start); if self.mode == Mode::ProbeBw { self.update_gain_cycle_phase(now, in_flight); } if is_round_start && !self.is_at_full_bandwidth { self.check_if_full_bw_reached(app_limited); } self.maybe_exit_startup_or_drain(now, in_flight); self.maybe_enter_or_exit_probe_rtt(now, is_round_start, in_flight, app_limited); // After the model is updated, recalculate the pacing rate and congestion window. self.calculate_pacing_rate(); self.calculate_cwnd(bytes_acked, excess_acked); self.calculate_recovery_window(bytes_acked, self.loss_state.lost_bytes, in_flight); self.prev_in_flight_count = in_flight; self.loss_state.reset(); } fn on_congestion_event( &mut self, _now: Instant, _sent: Instant, _is_persistent_congestion: bool, lost_bytes: u64, ) { self.loss_state.lost_bytes += lost_bytes; } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.min_cwnd = calculate_min_window(self.current_mtu); self.init_cwnd = self.config.initial_window.max(self.min_cwnd); self.cwnd = self.cwnd.max(self.min_cwnd); } fn window(&self) -> u64 { if self.mode == Mode::ProbeRtt { return self.get_probe_rtt_cwnd(); } else if self.recovery_state.in_recovery() && self.mode != Mode::Startup { return self.cwnd.min(self.recovery_window); } self.cwnd } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the [`Bbr`] congestion controller #[derive(Debug, Clone)] pub struct BbrConfig { initial_window: u64, } impl BbrConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } } impl Default for BbrConfig { fn default() -> Self { Self { initial_window: K_MAX_INITIAL_CONGESTION_WINDOW * BASE_DATAGRAM_SIZE, } } } impl ControllerFactory for BbrConfig { fn build(self: Arc, _now: Instant, current_mtu: u16) -> Box { Box::new(Bbr::new(self, current_mtu)) } } #[derive(Debug, Default, Copy, Clone)] struct AckAggregationState { max_ack_height: MinMax, aggregation_epoch_start_time: Option, aggregation_epoch_bytes: u64, } impl AckAggregationState { fn update_ack_aggregation_bytes( &mut self, newly_acked_bytes: u64, now: Instant, round: u64, max_bandwidth: u64, ) -> u64 { // Compute how many bytes are expected to be delivered, assuming max // bandwidth is correct. let expected_bytes_acked = max_bandwidth * now .saturating_duration_since(self.aggregation_epoch_start_time.unwrap_or(now)) .as_micros() as u64 / 1_000_000; // Reset the current aggregation epoch as soon as the ack arrival rate is // less than or equal to the max bandwidth. if self.aggregation_epoch_bytes <= expected_bytes_acked { // Reset to start measuring a new aggregation epoch. self.aggregation_epoch_bytes = newly_acked_bytes; self.aggregation_epoch_start_time = Some(now); return 0; } // Compute how many extra bytes were delivered vs max bandwidth. // Include the bytes most recently acknowledged to account for stretch acks. self.aggregation_epoch_bytes += newly_acked_bytes; let diff = self.aggregation_epoch_bytes - expected_bytes_acked; self.max_ack_height.update_max(round, diff); diff } } #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum Mode { // Startup phase of the connection. Startup, // After achieving the highest possible bandwidth during the startup, lower // the pacing rate in order to drain the queue. Drain, // Cruising mode. ProbeBw, // Temporarily slow down sending in order to empty the buffer and measure // the real minimum RTT. ProbeRtt, } // Indicates how the congestion control limits the amount of bytes in flight. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum RecoveryState { // Do not limit. NotInRecovery, // Allow an extra outstanding byte for each byte acknowledged. Conservation, // Allow two extra outstanding bytes for each byte acknowledged (slow // start). Growth, } impl RecoveryState { pub(super) fn in_recovery(&self) -> bool { !matches!(self, Self::NotInRecovery) } } #[derive(Debug, Clone, Default)] struct LossState { lost_bytes: u64, } impl LossState { pub(super) fn reset(&mut self) { self.lost_bytes = 0; } pub(super) fn has_losses(&self) -> bool { self.lost_bytes != 0 } } fn calculate_min_window(current_mtu: u64) -> u64 { 4 * current_mtu } // The gain used for the STARTUP, equal to 2/ln(2). const K_DEFAULT_HIGH_GAIN: f32 = 2.885; // The newly derived CWND gain for STARTUP, 2. const K_DERIVED_HIGH_CWNDGAIN: f32 = 2.0; // The cycle of gains used during the ProbeBw stage. const K_PACING_GAIN: [f32; 8] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; const K_STARTUP_GROWTH_TARGET: f32 = 1.25; const K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP: u8 = 3; // Do not allow initial congestion window to be greater than 200 packets. const K_MAX_INITIAL_CONGESTION_WINDOW: u64 = 200; const PROBE_RTT_BASED_ON_BDP: bool = true; const DRAIN_TO_TARGET: bool = true; quinn-proto-0.11.9/src/congestion/cubic.rs000064400000000000000000000177311046102023000166120ustar 00000000000000use std::any::Any; use std::cmp; use std::sync::Arc; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; use crate::connection::RttEstimator; use crate::{Duration, Instant}; /// CUBIC Constants. /// /// These are recommended value in RFC8312. const BETA_CUBIC: f64 = 0.7; const C: f64 = 0.4; /// CUBIC State Variables. /// /// We need to keep those variables across the connection. /// k, w_max are described in the RFC. #[derive(Debug, Default, Clone)] pub(super) struct State { k: f64, w_max: f64, // Store cwnd increment during congestion avoidance. cwnd_inc: u64, } /// CUBIC Functions. /// /// Note that these calculations are based on a count of cwnd as bytes, /// not packets. /// Unit of t (duration) and RTT are based on seconds (f64). impl State { // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2) fn cubic_k(&self, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (w_max * (1.0 - BETA_CUBIC) / C).cbrt() } // W_cubic(t) = C * (t - K)^3 - w_max (Eq. 1) fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64 } // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) * // (t / RTT) (Eq. 4) fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (w_max * BETA_CUBIC + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64()) * max_datagram_size as f64 } } /// The RFC8312 congestion controller, as widely used for TCP #[derive(Debug, Clone)] pub struct Cubic { config: Arc, /// Maximum number of bytes in flight that may be sent. window: u64, /// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is /// slow start and the window grows by the number of bytes acknowledged. ssthresh: u64, /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent /// after this time is acknowledged, QUIC exits recovery. recovery_start_time: Option, cubic_state: State, current_mtu: u64, } impl Cubic { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, _now: Instant, current_mtu: u16) -> Self { Self { window: config.initial_window, ssthresh: u64::MAX, recovery_start_time: None, config, cubic_state: Default::default(), current_mtu: current_mtu as u64, } } fn minimum_window(&self) -> u64 { 2 * self.current_mtu } } impl Controller for Cubic { fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { if app_limited || self .recovery_start_time .map(|recovery_start_time| sent <= recovery_start_time) .unwrap_or(false) { return; } if self.window < self.ssthresh { // Slow start self.window += bytes; } else { // Congestion avoidance. let ca_start_time; match self.recovery_start_time { Some(t) => ca_start_time = t, None => { // When we come here without congestion_event() triggered, // initialize congestion_recovery_start_time, w_max and k. ca_start_time = now; self.recovery_start_time = Some(now); self.cubic_state.w_max = self.window as f64; self.cubic_state.k = 0.0; } } let t = now - ca_start_time; // w_cubic(t + rtt) let w_cubic = self.cubic_state.w_cubic(t + rtt.get(), self.current_mtu); // w_est(t) let w_est = self.cubic_state.w_est(t, rtt.get(), self.current_mtu); let mut cubic_cwnd = self.window; if w_cubic < w_est { // TCP friendly region. cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64); } else if cubic_cwnd < w_cubic as u64 { // Concave region or convex region use same increment. let cubic_inc = (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64; cubic_cwnd += cubic_inc as u64; } // Update the increment and increase cwnd by MSS. self.cubic_state.cwnd_inc += cubic_cwnd - self.window; // cwnd_inc can be more than 1 MSS in the late stage of max probing. // however RFC9002 ยง7.3.3 (Congestion Avoidance) limits // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged. if self.cubic_state.cwnd_inc >= self.current_mtu { self.window += self.current_mtu; self.cubic_state.cwnd_inc = 0; } } } fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, _lost_bytes: u64, ) { if self .recovery_start_time .map(|recovery_start_time| sent <= recovery_start_time) .unwrap_or(false) { return; } self.recovery_start_time = Some(now); // Fast convergence #[allow(clippy::branches_sharing_code)] // https://github.com/rust-lang/rust-clippy/issues/7198 if (self.window as f64) < self.cubic_state.w_max { self.cubic_state.w_max = self.window as f64 * (1.0 + BETA_CUBIC) / 2.0; } else { self.cubic_state.w_max = self.window as f64; } self.ssthresh = cmp::max( (self.cubic_state.w_max * BETA_CUBIC) as u64, self.minimum_window(), ); self.window = self.ssthresh; self.cubic_state.k = self.cubic_state.cubic_k(self.current_mtu); self.cubic_state.cwnd_inc = (self.cubic_state.cwnd_inc as f64 * BETA_CUBIC) as u64; if is_persistent_congestion { self.recovery_start_time = None; self.cubic_state.w_max = self.window as f64; // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC self.ssthresh = cmp::max( (self.window as f64 * BETA_CUBIC) as u64, self.minimum_window(), ); self.cubic_state.cwnd_inc = 0; self.window = self.minimum_window(); } } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.window = self.window.max(self.minimum_window()); } fn window(&self) -> u64 { self.window } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the `Cubic` congestion controller #[derive(Debug, Clone)] pub struct CubicConfig { initial_window: u64, } impl CubicConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } } impl Default for CubicConfig { fn default() -> Self { Self { initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE), } } } impl ControllerFactory for CubicConfig { fn build(self: Arc, now: Instant, current_mtu: u16) -> Box { Box::new(Cubic::new(self, now, current_mtu)) } } quinn-proto-0.11.9/src/congestion/new_reno.rs000064400000000000000000000116151046102023000173340ustar 00000000000000use std::any::Any; use std::sync::Arc; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; use crate::connection::RttEstimator; use crate::Instant; /// A simple, standard congestion controller #[derive(Debug, Clone)] pub struct NewReno { config: Arc, current_mtu: u64, /// Maximum number of bytes in flight that may be sent. window: u64, /// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is /// slow start and the window grows by the number of bytes acknowledged. ssthresh: u64, /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent /// after this time is acknowledged, QUIC exits recovery. recovery_start_time: Instant, /// Bytes which had been acked by the peer since leaving slow start bytes_acked: u64, } impl NewReno { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, now: Instant, current_mtu: u16) -> Self { Self { window: config.initial_window, ssthresh: u64::MAX, recovery_start_time: now, current_mtu: current_mtu as u64, config, bytes_acked: 0, } } fn minimum_window(&self) -> u64 { 2 * self.current_mtu } } impl Controller for NewReno { fn on_ack( &mut self, _now: Instant, sent: Instant, bytes: u64, app_limited: bool, _rtt: &RttEstimator, ) { if app_limited || sent <= self.recovery_start_time { return; } if self.window < self.ssthresh { // Slow start self.window += bytes; if self.window >= self.ssthresh { // Exiting slow start // Initialize `bytes_acked` for congestion avoidance. The idea // here is that any bytes over `sshthresh` will already be counted // towards the congestion avoidance phase - independent of when // how close to `sshthresh` the `window` was when switching states, // and independent of datagram sizes. self.bytes_acked = self.window - self.ssthresh; } } else { // Congestion avoidance // This implementation uses the method which does not require // floating point math, which also increases the window by 1 datagram // for every round trip. // This mechanism is called Appropriate Byte Counting in // https://tools.ietf.org/html/rfc3465 self.bytes_acked += bytes; if self.bytes_acked >= self.window { self.bytes_acked -= self.window; self.window += self.current_mtu; } } } fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, _lost_bytes: u64, ) { if sent <= self.recovery_start_time { return; } self.recovery_start_time = now; self.window = (self.window as f32 * self.config.loss_reduction_factor) as u64; self.window = self.window.max(self.minimum_window()); self.ssthresh = self.window; if is_persistent_congestion { self.window = self.minimum_window(); } } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.window = self.window.max(self.minimum_window()); } fn window(&self) -> u64 { self.window } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the `NewReno` congestion controller #[derive(Debug, Clone)] pub struct NewRenoConfig { initial_window: u64, loss_reduction_factor: f32, } impl NewRenoConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } /// Reduction in congestion window when a new loss event is detected. pub fn loss_reduction_factor(&mut self, value: f32) -> &mut Self { self.loss_reduction_factor = value; self } } impl Default for NewRenoConfig { fn default() -> Self { Self { initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE), loss_reduction_factor: 0.5, } } } impl ControllerFactory for NewRenoConfig { fn build(self: Arc, now: Instant, current_mtu: u16) -> Box { Box::new(NewReno::new(self, now, current_mtu)) } } quinn-proto-0.11.9/src/congestion.rs000064400000000000000000000047501046102023000155220ustar 00000000000000//! Logic for controlling the rate at which data is sent use crate::connection::RttEstimator; use crate::Instant; use std::any::Any; use std::sync::Arc; mod bbr; mod cubic; mod new_reno; pub use bbr::{Bbr, BbrConfig}; pub use cubic::{Cubic, CubicConfig}; pub use new_reno::{NewReno, NewRenoConfig}; /// Common interface for different congestion controllers pub trait Controller: Send + Sync { /// One or more packets were just sent #[allow(unused_variables)] fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {} /// Packet deliveries were confirmed /// /// `app_limited` indicates whether the connection was blocked on outgoing /// application data prior to receiving these acknowledgements. #[allow(unused_variables)] fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { } /// Packets are acked in batches, all with the same `now` argument. This indicates one of those batches has completed. #[allow(unused_variables)] fn on_end_acks( &mut self, now: Instant, in_flight: u64, app_limited: bool, largest_packet_num_acked: Option, ) { } /// Packets were deemed lost or marked congested /// /// `in_persistent_congestion` indicates whether all packets sent within the persistent /// congestion threshold period ending when the most recent packet in this batch was sent were /// lost. /// `lost_bytes` indicates how many bytes were lost. This value will be 0 for ECN triggers. fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, lost_bytes: u64, ); /// The known MTU for the current network path has been updated fn on_mtu_update(&mut self, new_mtu: u16); /// Number of ack-eliciting bytes that may be in flight fn window(&self) -> u64; /// Duplicate the controller's state fn clone_box(&self) -> Box; /// Initial congestion window fn initial_window(&self) -> u64; /// Returns Self for use in down-casting to extract implementation details fn into_any(self: Box) -> Box; } /// Constructs controllers on demand pub trait ControllerFactory { /// Construct a fresh `Controller` fn build(self: Arc, now: Instant, current_mtu: u16) -> Box; } const BASE_DATAGRAM_SIZE: u64 = 1200; quinn-proto-0.11.9/src/connection/ack_frequency.rs000064400000000000000000000136431046102023000203310ustar 00000000000000use crate::connection::spaces::PendingAcks; use crate::frame::AckFrequency; use crate::transport_parameters::TransportParameters; use crate::Duration; use crate::{AckFrequencyConfig, TransportError, VarInt, TIMER_GRANULARITY}; /// State associated to ACK frequency pub(super) struct AckFrequencyState { // // Sending ACK_FREQUENCY frames // in_flight_ack_frequency_frame: Option<(u64, Duration)>, next_outgoing_sequence_number: VarInt, pub(super) peer_max_ack_delay: Duration, // // Receiving ACK_FREQUENCY frames // last_ack_frequency_frame: Option, pub(super) max_ack_delay: Duration, } impl AckFrequencyState { pub(super) fn new(default_max_ack_delay: Duration) -> Self { Self { in_flight_ack_frequency_frame: None, next_outgoing_sequence_number: VarInt(0), peer_max_ack_delay: default_max_ack_delay, last_ack_frequency_frame: None, max_ack_delay: default_max_ack_delay, } } /// Returns the `max_ack_delay` that should be requested of the peer when sending an /// ACK_FREQUENCY frame pub(super) fn candidate_max_ack_delay( &self, rtt: Duration, config: &AckFrequencyConfig, peer_params: &TransportParameters, ) -> Duration { // Use the peer's max_ack_delay if no custom max_ack_delay was provided in the config let min_ack_delay = Duration::from_micros(peer_params.min_ack_delay.map_or(0, |x| x.into())); config .max_ack_delay .unwrap_or(self.peer_max_ack_delay) .clamp(min_ack_delay, rtt.max(MIN_AUTOMATIC_ACK_DELAY)) } /// Returns the `max_ack_delay` for the purposes of calculating the PTO /// /// This `max_ack_delay` is defined as the maximum of the peer's current `max_ack_delay` and all /// in-flight `max_ack_delay`s (i.e. proposed values that haven't been acknowledged yet, but /// might be already in use by the peer). pub(super) fn max_ack_delay_for_pto(&self) -> Duration { // Note: we have at most one in-flight ACK_FREQUENCY frame if let Some((_, max_ack_delay)) = self.in_flight_ack_frequency_frame { self.peer_max_ack_delay.max(max_ack_delay) } else { self.peer_max_ack_delay } } /// Returns the next sequence number for an ACK_FREQUENCY frame pub(super) fn next_sequence_number(&mut self) -> VarInt { assert!(self.next_outgoing_sequence_number <= VarInt::MAX); let seq = self.next_outgoing_sequence_number; self.next_outgoing_sequence_number.0 += 1; seq } /// Returns true if we should send an ACK_FREQUENCY frame pub(super) fn should_send_ack_frequency( &self, rtt: Duration, config: &AckFrequencyConfig, peer_params: &TransportParameters, ) -> bool { if self.next_outgoing_sequence_number.0 == 0 { // Always send at startup return true; } let current = self .in_flight_ack_frequency_frame .map_or(self.peer_max_ack_delay, |(_, pending)| pending); let desired = self.candidate_max_ack_delay(rtt, config, peer_params); let error = (desired.as_secs_f32() / current.as_secs_f32()) - 1.0; error.abs() > MAX_RTT_ERROR } /// Notifies the [`AckFrequencyState`] that a packet containing an ACK_FREQUENCY frame was sent pub(super) fn ack_frequency_sent(&mut self, pn: u64, requested_max_ack_delay: Duration) { self.in_flight_ack_frequency_frame = Some((pn, requested_max_ack_delay)); } /// Notifies the [`AckFrequencyState`] that a packet has been ACKed pub(super) fn on_acked(&mut self, pn: u64) { match self.in_flight_ack_frequency_frame { Some((number, requested_max_ack_delay)) if number == pn => { self.in_flight_ack_frequency_frame = None; self.peer_max_ack_delay = requested_max_ack_delay; } _ => {} } } /// Notifies the [`AckFrequencyState`] that an ACK_FREQUENCY frame was received /// /// Updates the endpoint's params according to the payload of the ACK_FREQUENCY frame, or /// returns an error in case the requested `max_ack_delay` is invalid. /// /// Returns `true` if the frame was processed and `false` if it was ignored because of being /// stale. pub(super) fn ack_frequency_received( &mut self, frame: &AckFrequency, pending_acks: &mut PendingAcks, ) -> Result { if self .last_ack_frequency_frame .map_or(false, |highest_sequence_nr| { frame.sequence.into_inner() <= highest_sequence_nr }) { return Ok(false); } self.last_ack_frequency_frame = Some(frame.sequence.into_inner()); // Update max_ack_delay let max_ack_delay = Duration::from_micros(frame.request_max_ack_delay.into_inner()); if max_ack_delay < TIMER_GRANULARITY { return Err(TransportError::PROTOCOL_VIOLATION( "Requested Max Ack Delay in ACK_FREQUENCY frame is less than min_ack_delay", )); } self.max_ack_delay = max_ack_delay; // Update the rest of the params pending_acks.set_ack_frequency_params(frame); Ok(true) } } /// Maximum proportion difference between the most recently requested max ACK delay and the /// currently desired one before a new request is sent, when the peer supports the ACK frequency /// extension and an explicit max ACK delay is not configured. const MAX_RTT_ERROR: f32 = 0.2; /// Minimum value to request the peer set max ACK delay to when the peer supports the ACK frequency /// extension and an explicit max ACK delay is not configured. // Keep in sync with `AckFrequencyConfig::max_ack_delay` documentation const MIN_AUTOMATIC_ACK_DELAY: Duration = Duration::from_millis(25); quinn-proto-0.11.9/src/connection/assembler.rs000064400000000000000000000560701046102023000174700ustar 00000000000000use std::{ cmp::Ordering, collections::{binary_heap::PeekMut, BinaryHeap}, mem, }; use bytes::{Buf, Bytes, BytesMut}; use crate::range_set::RangeSet; /// Helper to assemble unordered stream frames into an ordered stream #[derive(Debug, Default)] pub(super) struct Assembler { state: State, data: BinaryHeap, /// Total number of buffered bytes, including duplicates in ordered mode. buffered: usize, /// Estimated number of allocated bytes, will never be less than `buffered`. allocated: usize, /// Number of bytes read by the application. When only ordered reads have been used, this is the /// length of the contiguous prefix of the stream which has been consumed by the application, /// aka the stream offset. bytes_read: u64, end: u64, } impl Assembler { pub(super) fn new() -> Self { Self::default() } /// Reset to the initial state pub(super) fn reinit(&mut self) { let old_data = mem::take(&mut self.data); *self = Self::default(); self.data = old_data; self.data.clear(); } pub(super) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> { if ordered && !self.state.is_ordered() { return Err(IllegalOrderedRead); } else if !ordered && self.state.is_ordered() { // Enter unordered mode if !self.data.is_empty() { // Get rid of possible duplicates self.defragment(); } let mut recvd = RangeSet::new(); recvd.insert(0..self.bytes_read); for chunk in &self.data { recvd.insert(chunk.offset..chunk.offset + chunk.bytes.len() as u64); } self.state = State::Unordered { recvd }; } Ok(()) } /// Get the the next chunk pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> Option { loop { let mut chunk = self.data.peek_mut()?; if ordered { if chunk.offset > self.bytes_read { // Next chunk is after current read index return None; } else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read { // Next chunk is useless as the read index is beyond its end self.buffered -= chunk.bytes.len(); self.allocated -= chunk.allocation_size; PeekMut::pop(chunk); continue; } // Determine `start` and `len` of the slice of useful data in chunk let start = (self.bytes_read - chunk.offset) as usize; if start > 0 { chunk.bytes.advance(start); chunk.offset += start as u64; self.buffered -= start; } } return Some(if max_length < chunk.bytes.len() { self.bytes_read += max_length as u64; let offset = chunk.offset; chunk.offset += max_length as u64; self.buffered -= max_length; Chunk::new(offset, chunk.bytes.split_to(max_length)) } else { self.bytes_read += chunk.bytes.len() as u64; self.buffered -= chunk.bytes.len(); self.allocated -= chunk.allocation_size; let chunk = PeekMut::pop(chunk); Chunk::new(chunk.offset, chunk.bytes) }); } } /// Copy fragmented chunk data to new chunks backed by a single buffer /// /// This makes sure we're not unnecessarily holding on to many larger allocations. /// We merge contiguous chunks in the process of doing so. fn defragment(&mut self) { let new = BinaryHeap::with_capacity(self.data.len()); let old = mem::replace(&mut self.data, new); let mut buffers = old.into_sorted_vec(); self.buffered = 0; let mut fragmented_buffered = 0; let mut offset = 0; for chunk in buffers.iter_mut().rev() { chunk.try_mark_defragment(offset); let size = chunk.bytes.len(); offset = chunk.offset + size as u64; self.buffered += size; if !chunk.defragmented { fragmented_buffered += size; } } self.allocated = self.buffered; let mut buffer = BytesMut::with_capacity(fragmented_buffered); let mut offset = 0; for chunk in buffers.into_iter().rev() { if chunk.defragmented { // bytes might be empty after try_mark_defragment if !chunk.bytes.is_empty() { self.data.push(chunk); } continue; } // Overlap is resolved by try_mark_defragment if chunk.offset != offset + (buffer.len() as u64) { if !buffer.is_empty() { self.data .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } offset = chunk.offset; } buffer.extend_from_slice(&chunk.bytes); } if !buffer.is_empty() { self.data .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } } // Note: If a packet contains many frames from the same stream, the estimated over-allocation // will be much higher because we are counting the same allocation multiple times. pub(super) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) { debug_assert!( bytes.len() <= allocation_size, "allocation_size less than bytes.len(): {:?} < {:?}", allocation_size, bytes.len() ); self.end = self.end.max(offset + bytes.len() as u64); if let State::Unordered { ref mut recvd } = self.state { // Discard duplicate data for duplicate in recvd.replace(offset..offset + bytes.len() as u64) { if duplicate.start > offset { let buffer = Buffer::new( offset, bytes.split_to((duplicate.start - offset) as usize), allocation_size, ); self.buffered += buffer.bytes.len(); self.allocated += buffer.allocation_size; self.data.push(buffer); offset = duplicate.start; } bytes.advance((duplicate.end - offset) as usize); offset = duplicate.end; } } else if offset < self.bytes_read { if (offset + bytes.len() as u64) <= self.bytes_read { return; } else { let diff = self.bytes_read - offset; offset += diff; bytes.advance(diff as usize); } } if bytes.is_empty() { return; } let buffer = Buffer::new(offset, bytes, allocation_size); self.buffered += buffer.bytes.len(); self.allocated += buffer.allocation_size; self.data.push(buffer); // `self.buffered` also counts duplicate bytes, therefore we use // `self.end - self.bytes_read` as an upper bound of buffered unique // bytes. This will cause a defragmentation if the amount of duplicate // bytes exceedes a proportion of the receive window size. let buffered = self.buffered.min((self.end - self.bytes_read) as usize); let over_allocation = self.allocated - buffered; // Rationale: on the one hand, we want to defragment rarely, ideally never // in non-pathological scenarios. However, a pathological or malicious // peer could send us one-byte frames, and since we use reference-counted // buffers in order to prevent copying, this could result in keeping a lot // of memory allocated. This limits over-allocation in proportion to the // buffered data. The constants are chosen somewhat arbitrarily and try to // balance between defragmentation overhead and over-allocation. let threshold = 32768.max(buffered * 3 / 2); if over_allocation > threshold { self.defragment() } } /// Number of bytes consumed by the application pub(super) fn bytes_read(&self) -> u64 { self.bytes_read } /// Discard all buffered data pub(super) fn clear(&mut self) { self.data.clear(); self.buffered = 0; self.allocated = 0; } } /// A chunk of data from the receive stream #[derive(Debug, PartialEq, Eq)] pub struct Chunk { /// The offset in the stream pub offset: u64, /// The contents of the chunk pub bytes: Bytes, } impl Chunk { fn new(offset: u64, bytes: Bytes) -> Self { Self { offset, bytes } } } #[derive(Debug, Eq)] struct Buffer { offset: u64, bytes: Bytes, /// Size of the allocation behind `bytes`, if `defragmented == false`. /// Otherwise this will be set to `bytes.len()` by `try_mark_defragment`. /// Will never be less than `bytes.len()`. allocation_size: usize, defragmented: bool, } impl Buffer { /// Constructs a new fragmented Buffer fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self { Self { offset, bytes, allocation_size, defragmented: false, } } /// Constructs a new defragmented Buffer fn new_defragmented(offset: u64, bytes: Bytes) -> Self { let allocation_size = bytes.len(); Self { offset, bytes, allocation_size, defragmented: true, } } /// Discards data before `offset` and flags `self` as defragmented if it has good utilization fn try_mark_defragment(&mut self, offset: u64) { let duplicate = offset.saturating_sub(self.offset) as usize; self.offset = self.offset.max(offset); if duplicate >= self.bytes.len() { // All bytes are duplicate self.bytes = Bytes::new(); self.defragmented = true; self.allocation_size = 0; return; } self.bytes.advance(duplicate); // Make sure that fragmented buffers with high utilization become defragmented and // defragmented buffers remain defragmented self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size; if self.defragmented { // Make sure that defragmented buffers do not contribute to over-allocation self.allocation_size = self.bytes.len(); } } } impl Ord for Buffer { // Invert ordering based on offset (max-heap, min offset first), // prioritize longer chunks at the same offset. fn cmp(&self, other: &Self) -> Ordering { self.offset .cmp(&other.offset) .reverse() .then(self.bytes.len().cmp(&other.bytes.len())) } } impl PartialOrd for Buffer { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for Buffer { fn eq(&self, other: &Self) -> bool { (self.offset, self.bytes.len()) == (other.offset, other.bytes.len()) } } #[derive(Debug)] enum State { Ordered, Unordered { /// The set of offsets that have been received from the peer, including portions not yet /// read by the application. recvd: RangeSet, }, } impl State { fn is_ordered(&self) -> bool { matches!(self, Self::Ordered) } } impl Default for State { fn default() -> Self { Self::Ordered } } /// Error indicating that an ordered read was performed on a stream after an unordered read #[derive(Debug)] pub struct IllegalOrderedRead; #[cfg(test)] mod test { use super::*; use assert_matches::assert_matches; #[test] fn assemble_ordered() { let mut x = Assembler::new(); assert_matches!(next(&mut x, 32), None); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1"); assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23"); x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); x.insert(6, Bytes::from_static(b"789"), 3); x.insert(9, Bytes::from_static(b"10"), 2); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_unordered() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), None); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_duplicate() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_duplicate_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(0, Bytes::from_static(b"123"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contained() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"12345"), 5); x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contained_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"12345"), 5); x.insert(1, Bytes::from_static(b"234"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contains() { let mut x = Assembler::new(); x.insert(1, Bytes::from_static(b"234"), 3); x.insert(0, Bytes::from_static(b"12345"), 5); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contains_compact() { let mut x = Assembler::new(); x.insert(1, Bytes::from_static(b"234"), 3); x.insert(0, Bytes::from_static(b"12345"), 5); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_overlapping() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_overlapping_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 4); x.insert(1, Bytes::from_static(b"234"), 4); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_complex() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1"), 1); x.insert(2, Bytes::from_static(b"3"), 1); x.insert(4, Bytes::from_static(b"5"), 1); x.insert(0, Bytes::from_static(b"123456"), 6); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_complex_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1"), 1); x.insert(2, Bytes::from_static(b"3"), 1); x.insert(4, Bytes::from_static(b"5"), 1); x.insert(0, Bytes::from_static(b"123456"), 6); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_old() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), None); } #[test] fn compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 4); x.insert(3, Bytes::from_static(b"def"), 4); x.insert(9, Bytes::from_static(b"jkl"), 4); x.insert(12, Bytes::from_static(b"mno"), 4); x.defragment(); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abcdef")) ); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jklmno")) ); } #[test] fn defrag_with_missing_prefix() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 3); x.defragment(); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); } #[test] fn defrag_read_chunk() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 4); x.insert(0, Bytes::from_static(b"abc"), 4); x.insert(7, Bytes::from_static(b"hij"), 4); x.insert(11, Bytes::from_static(b"lmn"), 4); x.defragment(); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef"); x.insert(5, Bytes::from_static(b"fghijklmn"), 9); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn"); x.insert(13, Bytes::from_static(b"nopq"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq"); x.insert(15, Bytes::from_static(b"pqrs"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs"); assert_matches!(x.read(usize::MAX, true), None); } #[test] fn unordered_happy_path() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abc")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn unordered_dedup() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(0, Bytes::from_static(b"a"), 1); x.insert(0, Bytes::from_static(b"abcdefghi"), 9); x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"a")) ); assert_eq!( next_unordered(&mut x), Chunk::new(1, Bytes::from_static(b"bc")) ); assert_eq!( next_unordered(&mut x), Chunk::new(6, Bytes::from_static(b"ghi")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jkl")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(12, Bytes::from_static(b"mno")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn chunks_dedup() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!(x.read(usize::MAX, true), None); x.insert(0, Bytes::from_static(b"a"), 1); x.insert(1, Bytes::from_static(b"bcdefghi"), 9); x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abcd"))) ); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(4, Bytes::from_static(b"efghi"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(9, Bytes::from_static(b"jkl"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(12, Bytes::from_static(b"mno"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, true), None); } #[test] fn ordered_eager_discard() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!(x.data.len(), 1); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); x.insert(0, Bytes::from_static(b"ab"), 2); assert_eq!(x.data.len(), 0); x.insert(2, Bytes::from_static(b"cd"), 2); assert_eq!( x.data.peek(), Some(&Buffer::new(3, Bytes::from_static(b"d"), 2)) ); } #[test] fn ordered_insert_unordered_read() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 3); x.insert(0, Bytes::from_static(b"abc"), 3); x.ensure_ordering(false).unwrap(); assert_eq!( x.read(3, false), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); assert_eq!(x.read(3, false), None); } fn next_unordered(x: &mut Assembler) -> Chunk { x.read(usize::MAX, false).unwrap() } fn next(x: &mut Assembler, size: usize) -> Option { x.read(size, true).map(|chunk| chunk.bytes) } } quinn-proto-0.11.9/src/connection/cid_state.rs000064400000000000000000000205051046102023000174440ustar 00000000000000//! Maintain the state of local connection IDs use std::collections::VecDeque; use rustc_hash::FxHashSet; use tracing::{debug, trace}; use crate::{shared::IssuedCid, Duration, Instant, TransportError}; /// Local connection ID management pub(super) struct CidState { /// Timestamp when issued cids should be retired retire_timestamp: VecDeque, /// Number of local connection IDs that have been issued in NEW_CONNECTION_ID frames. issued: u64, /// Sequence numbers of local connection IDs not yet retired by the peer active_seq: FxHashSet, /// Sequence number the peer has already retired all CIDs below at our request via `retire_prior_to` prev_retire_seq: u64, /// Sequence number to set in retire_prior_to field in NEW_CONNECTION_ID frame retire_seq: u64, /// cid length used to decode short packet cid_len: usize, //// cid lifetime cid_lifetime: Option, } impl CidState { pub(crate) fn new( cid_len: usize, cid_lifetime: Option, now: Instant, issued: u64, ) -> Self { let mut active_seq = FxHashSet::default(); // Add sequence number of CIDs used in handshaking into tracking set for seq in 0..issued { active_seq.insert(seq); } let mut this = Self { retire_timestamp: VecDeque::new(), issued, active_seq, prev_retire_seq: 0, retire_seq: 0, cid_len, cid_lifetime, }; // Track lifetime of CIDs used in handshaking for seq in 0..issued { this.track_lifetime(seq, now); } this } /// Find the next timestamp when previously issued CID should be retired pub(crate) fn next_timeout(&mut self) -> Option { self.retire_timestamp.front().map(|nc| { trace!("CID {} will expire at {:?}", nc.sequence, nc.timestamp); nc.timestamp }) } /// Track the lifetime of issued cids in `retire_timestamp` fn track_lifetime(&mut self, new_cid_seq: u64, now: Instant) { let lifetime = match self.cid_lifetime { Some(lifetime) => lifetime, None => return, }; let expire_timestamp = now.checked_add(lifetime); let expire_at = match expire_timestamp { Some(expire_at) => expire_at, None => return, }; let last_record = self.retire_timestamp.back_mut(); if let Some(last) = last_record { // Compare the timestamp with the last inserted record // Combine into a single batch if timestamp of current cid is same as the last record if expire_at == last.timestamp { debug_assert!(new_cid_seq > last.sequence); last.sequence = new_cid_seq; return; } } self.retire_timestamp.push_back(CidTimestamp { sequence: new_cid_seq, timestamp: expire_at, }); } /// Update local CID state when previously issued CID is retired /// /// Return whether a new CID needs to be pushed that notifies remote peer to respond `RETIRE_CONNECTION_ID` pub(crate) fn on_cid_timeout(&mut self) -> bool { // Whether the peer hasn't retired all the CIDs we asked it to yet let unretired_ids_found = (self.prev_retire_seq..self.retire_seq).any(|seq| self.active_seq.contains(&seq)); let current_retire_prior_to = self.retire_seq; let next_retire_sequence = self .retire_timestamp .pop_front() .map(|seq| seq.sequence + 1); // According to RFC: // Endpoints SHOULD NOT issue updates of the Retire Prior To field // before receiving RETIRE_CONNECTION_ID frames that retire all // connection IDs indicated by the previous Retire Prior To value. // https://tools.ietf.org/html/draft-ietf-quic-transport-29#section-5.1.2 if !unretired_ids_found { // All Cids are retired, `prev_retire_cid_seq` can be assigned to `retire_cid_seq` self.prev_retire_seq = self.retire_seq; // Advance `retire_seq` if next cid that needs to be retired exists if let Some(next_retire_prior_to) = next_retire_sequence { self.retire_seq = next_retire_prior_to; } } // Check if retirement of all CIDs that reach their lifetime is still needed // According to RFC: // An endpoint MUST NOT // provide more connection IDs than the peer's limit. An endpoint MAY // send connection IDs that temporarily exceed a peer's limit if the // NEW_CONNECTION_ID frame also requires the retirement of any excess, // by including a sufficiently large value in the Retire Prior To field. // // If yes (return true), a new CID must be pushed with updated `retire_prior_to` field to remote peer. // If no (return false), it means CIDs that reach the end of lifetime have been retired already. Do not push a new CID in order to avoid violating above RFC. (current_retire_prior_to..self.retire_seq).any(|seq| self.active_seq.contains(&seq)) } /// Update cid state when `NewIdentifiers` event is received pub(crate) fn new_cids(&mut self, ids: &[IssuedCid], now: Instant) { // `ids` could be `None` once active_connection_id_limit is set to 1 by peer let last_cid = match ids.last() { Some(cid) => cid, None => return, }; self.issued += ids.len() as u64; // Record the timestamp of CID with the largest seq number let sequence = last_cid.sequence; ids.iter().for_each(|frame| { self.active_seq.insert(frame.sequence); }); self.track_lifetime(sequence, now); } /// Update CidState for receipt of a `RETIRE_CONNECTION_ID` frame /// /// Returns whether a new CID can be issued, or an error if the frame was illegal. pub(crate) fn on_cid_retirement( &mut self, sequence: u64, limit: u64, ) -> Result { if self.cid_len == 0 { return Err(TransportError::PROTOCOL_VIOLATION( "RETIRE_CONNECTION_ID when CIDs aren't in use", )); } if sequence > self.issued { debug!( sequence, "got RETIRE_CONNECTION_ID for unissued sequence number" ); return Err(TransportError::PROTOCOL_VIOLATION( "RETIRE_CONNECTION_ID for unissued sequence number", )); } self.active_seq.remove(&sequence); // Consider a scenario where peer A has active remote cid 0,1,2. // Peer B first send a NEW_CONNECTION_ID with cid 3 and retire_prior_to set to 1. // Peer A processes this NEW_CONNECTION_ID frame; update remote cid to 1,2,3 // and meanwhile send a RETIRE_CONNECTION_ID to retire cid 0 to peer B. // If peer B doesn't check the cid limit here and send a new cid again, peer A will then face CONNECTION_ID_LIMIT_ERROR Ok(limit > self.active_seq.len() as u64) } /// Length of local Connection IDs pub(crate) fn cid_len(&self) -> usize { self.cid_len } /// The value for `retire_prior_to` field in `NEW_CONNECTION_ID` frame pub(crate) fn retire_prior_to(&self) -> u64 { self.retire_seq } #[cfg(test)] pub(crate) fn active_seq(&self) -> (u64, u64) { let mut min = u64::MAX; let mut max = u64::MIN; for n in self.active_seq.iter() { if n < &min { min = *n; } if n > &max { max = *n; } } (min, max) } #[cfg(test)] pub(crate) fn assign_retire_seq(&mut self, v: u64) -> u64 { // Cannot retire more CIDs than what have been issued debug_assert!(v <= *self.active_seq.iter().max().unwrap() + 1); let n = v.checked_sub(self.retire_seq).unwrap(); self.retire_seq = v; n } } /// Data structure that records when issued cids should be retired #[derive(Copy, Clone, Eq, PartialEq)] struct CidTimestamp { /// Highest cid sequence number created in a batch sequence: u64, /// Timestamp when cid needs to be retired timestamp: Instant, } quinn-proto-0.11.9/src/connection/datagrams.rs000064400000000000000000000170561046102023000174570ustar 00000000000000use std::collections::VecDeque; use bytes::Bytes; use thiserror::Error; use tracing::{debug, trace}; use super::Connection; use crate::{ frame::{Datagram, FrameStruct}, TransportError, }; /// API to control datagram traffic pub struct Datagrams<'a> { pub(super) conn: &'a mut Connection, } impl Datagrams<'_> { /// Queue an unreliable, unordered datagram for immediate transmission /// /// If `drop` is true, previously queued datagrams which are still unsent may be discarded to /// make space for this datagram, in order of oldest to newest. If `drop` is false, and there /// isn't enough space due to previously queued datagrams, this function will return /// `SendDatagramError::Blocked`. `Event::DatagramsUnblocked` will be emitted once datagrams /// have been sent. /// /// Returns `Err` iff a `len`-byte datagram cannot currently be sent. pub fn send(&mut self, data: Bytes, drop: bool) -> Result<(), SendDatagramError> { if self.conn.config.datagram_receive_buffer_size.is_none() { return Err(SendDatagramError::Disabled); } let max = self .max_size() .ok_or(SendDatagramError::UnsupportedByPeer)?; if data.len() > max { return Err(SendDatagramError::TooLarge); } if drop { while self.conn.datagrams.outgoing_total > self.conn.config.datagram_send_buffer_size { let prev = self .conn .datagrams .outgoing .pop_front() .expect("datagrams.outgoing_total desynchronized"); trace!(len = prev.data.len(), "dropping outgoing datagram"); self.conn.datagrams.outgoing_total -= prev.data.len(); } } else if self.conn.datagrams.outgoing_total + data.len() > self.conn.config.datagram_send_buffer_size { self.conn.datagrams.send_blocked = true; return Err(SendDatagramError::Blocked(data)); } self.conn.datagrams.outgoing_total += data.len(); self.conn.datagrams.outgoing.push_back(Datagram { data }); Ok(()) } /// Compute the maximum size of datagrams that may passed to `send_datagram` /// /// Returns `None` if datagrams are unsupported by the peer or disabled locally. /// /// This may change over the lifetime of a connection according to variation in the path MTU /// estimate. The peer can also enforce an arbitrarily small fixed limit, but if the peer's /// limit is large this is guaranteed to be a little over a kilobyte at minimum. /// /// Not necessarily the maximum size of received datagrams. pub fn max_size(&self) -> Option { // We use the conservative overhead bound for any packet number, reducing the budget by at // most 3 bytes, so that PN size fluctuations don't cause users sending maximum-size // datagrams to suffer avoidable packet loss. let max_size = self.conn.path.current_mtu() as usize - self.conn.predict_1rtt_overhead(None) - Datagram::SIZE_BOUND; let limit = self .conn .peer_params .max_datagram_frame_size? .into_inner() .saturating_sub(Datagram::SIZE_BOUND as u64); Some(limit.min(max_size as u64) as usize) } /// Receive an unreliable, unordered datagram pub fn recv(&mut self) -> Option { self.conn.datagrams.recv() } /// Bytes available in the outgoing datagram buffer /// /// When greater than zero, [`send`](Self::send)ing a datagram of at most this size is /// guaranteed not to cause older datagrams to be dropped. pub fn send_buffer_space(&self) -> usize { self.conn .config .datagram_send_buffer_size .saturating_sub(self.conn.datagrams.outgoing_total) } } #[derive(Default)] pub(super) struct DatagramState { /// Number of bytes of datagrams that have been received by the local transport but not /// delivered to the application pub(super) recv_buffered: usize, pub(super) incoming: VecDeque, pub(super) outgoing: VecDeque, pub(super) outgoing_total: usize, pub(super) send_blocked: bool, } impl DatagramState { pub(super) fn received( &mut self, datagram: Datagram, window: &Option, ) -> Result { let window = match window { None => { return Err(TransportError::PROTOCOL_VIOLATION( "unexpected DATAGRAM frame", )); } Some(x) => *x, }; if datagram.data.len() > window { return Err(TransportError::PROTOCOL_VIOLATION("oversized datagram")); } let was_empty = self.recv_buffered == 0; while datagram.data.len() + self.recv_buffered > window { debug!("dropping stale datagram"); self.recv(); } self.recv_buffered += datagram.data.len(); self.incoming.push_back(datagram); Ok(was_empty) } /// Discard outgoing datagrams with a payload larger than `max_payload` bytes /// /// Used to ensure that reductions in MTU don't get us stuck in a state where we have a datagram /// queued but can't send it. pub(super) fn drop_oversized(&mut self, max_payload: usize) { self.outgoing.retain(|datagram| { let result = datagram.data.len() < max_payload; if !result { trace!( "dropping {} byte datagram violating {} byte limit", datagram.data.len(), max_payload ); self.outgoing_total -= datagram.data.len(); } result }); } /// Attempt to write a datagram frame into `buf`, consuming it from `self.outgoing` /// /// Returns whether a frame was written. At most `max_size` bytes will be written, including /// framing. pub(super) fn write(&mut self, buf: &mut Vec, max_size: usize) -> bool { let datagram = match self.outgoing.pop_front() { Some(x) => x, None => return false, }; if buf.len() + datagram.size(true) > max_size { // Future work: we could be more clever about cramming small datagrams into // mostly-full packets when a larger one is queued first self.outgoing.push_front(datagram); return false; } trace!(len = datagram.data.len(), "DATAGRAM"); self.outgoing_total -= datagram.data.len(); datagram.encode(true, buf); true } pub(super) fn recv(&mut self) -> Option { let x = self.incoming.pop_front()?.data; self.recv_buffered -= x.len(); Some(x) } } /// Errors that can arise when sending a datagram #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum SendDatagramError { /// The peer does not support receiving datagram frames #[error("datagrams not supported by peer")] UnsupportedByPeer, /// Datagram support is disabled locally #[error("datagram support disabled")] Disabled, /// The datagram is larger than the connection can currently accommodate /// /// Indicates that the path MTU minus overhead or the limit advertised by the peer has been /// exceeded. #[error("datagram too large")] TooLarge, /// Send would block #[error("datagram send blocked")] Blocked(Bytes), } quinn-proto-0.11.9/src/connection/mod.rs000064400000000000000000004550431046102023000162750ustar 00000000000000use std::{ cmp, collections::VecDeque, convert::TryFrom, fmt, io, mem, net::{IpAddr, SocketAddr}, sync::Arc, }; use bytes::{Bytes, BytesMut}; use frame::StreamMetaVec; use rand::{rngs::StdRng, Rng, SeedableRng}; use thiserror::Error; use tracing::{debug, error, trace, trace_span, warn}; use crate::{ cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, frame::{self, Close, Datagram, FrameStruct}, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, }, range_set::ArrayRangeSet, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, EndpointEvent, EndpointEventInner, }, token::ResetToken, transport_parameters::TransportParameters, Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, Transmit, TransportError, TransportErrorCode, VarInt, MAX_CID_SIZE, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; use ack_frequency::AckFrequencyState; mod assembler; pub use assembler::Chunk; mod cid_state; use cid_state::CidState; mod datagrams; use datagrams::DatagramState; pub use datagrams::{Datagrams, SendDatagramError}; mod mtud; mod pacing; mod packet_builder; use packet_builder::PacketBuilder; mod packet_crypto; use packet_crypto::{PrevCrypto, ZeroRttCrypto}; mod paths; pub use paths::RttEstimator; use paths::{PathData, PathResponses}; mod send_buffer; mod spaces; #[cfg(fuzzing)] pub use spaces::Retransmits; #[cfg(not(fuzzing))] use spaces::Retransmits; use spaces::{PacketNumberFilter, PacketSpace, SendableFrames, SentPacket, ThinRetransmits}; mod stats; pub use stats::{ConnectionStats, FrameStats, PathStats, UdpStats}; mod streams; #[cfg(fuzzing)] pub use streams::StreamsState; #[cfg(not(fuzzing))] use streams::StreamsState; pub use streams::{ BytesSource, Chunks, ClosedStream, FinishError, ReadError, ReadableError, RecvStream, SendStream, ShouldTransmit, StreamEvent, Streams, WriteError, Written, }; mod timer; use crate::congestion::Controller; use timer::{Timer, TimerTable}; /// Protocol state and logic for a single QUIC connection /// /// Objects of this type receive [`ConnectionEvent`]s and emit [`EndpointEvent`]s and application /// [`Event`]s to make progress. To handle timeouts, a `Connection` returns timer updates and /// expects timeouts through various methods. A number of simple getter methods are exposed /// to allow callers to inspect some of the connection state. /// /// `Connection` has roughly 4 types of methods: /// /// - A. Simple getters, taking `&self` /// - B. Handlers for incoming events from the network or system, named `handle_*`. /// - C. State machine mutators, for incoming commands from the application. For convenience we /// refer to this as "performing I/O" below, however as per the design of this library none of the /// functions actually perform system-level I/O. For example, [`read`](RecvStream::read) and /// [`write`](SendStream::write), but also things like [`reset`](SendStream::reset). /// - D. Polling functions for outgoing events or actions for the caller to /// take, named `poll_*`. /// /// The simplest way to use this API correctly is to call (B) and (C) whenever /// appropriate, then after each of those calls, as soon as feasible call all /// polling methods (D) and deal with their outputs appropriately, e.g. by /// passing it to the application or by making a system-level I/O call. You /// should call the polling functions in this order: /// /// 1. [`poll_transmit`](Self::poll_transmit) /// 2. [`poll_timeout`](Self::poll_timeout) /// 3. [`poll_endpoint_events`](Self::poll_endpoint_events) /// 4. [`poll`](Self::poll) /// /// Currently the only actual dependency is from (2) to (1), however additional /// dependencies may be added in future, so the above order is recommended. /// /// (A) may be called whenever desired. /// /// Care should be made to ensure that the input events represent monotonically /// increasing time. Specifically, calling [`handle_timeout`](Self::handle_timeout) /// with events of the same [`Instant`] may be interleaved in any order with a /// call to [`handle_event`](Self::handle_event) at that same instant; however /// events or timeouts with different instants must not be interleaved. pub struct Connection { endpoint_config: Arc, server_config: Option>, config: Arc, rng: StdRng, crypto: Box, /// The CID we initially chose, for use during the handshake handshake_cid: ConnectionId, /// The CID the peer initially chose, for use during the handshake rem_handshake_cid: ConnectionId, /// The "real" local IP address which was was used to receive the initial packet. /// This is only populated for the server case, and if known local_ip: Option, path: PathData, /// Whether MTU detection is supported in this environment allow_mtud: bool, prev_path: Option<(ConnectionId, PathData)>, state: State, side: Side, /// Whether or not 0-RTT was enabled during the handshake. Does not imply acceptance. zero_rtt_enabled: bool, /// Set if 0-RTT is supported, then cleared when no longer needed. zero_rtt_crypto: Option, key_phase: bool, /// How many packets are in the current key phase. Used only for `Data` space. key_phase_size: u64, /// Transport parameters set by the peer peer_params: TransportParameters, /// Source ConnectionId of the first packet received from the peer orig_rem_cid: ConnectionId, /// Destination ConnectionId sent by the client on the first Initial initial_dst_cid: ConnectionId, /// The value that the server included in the Source Connection ID field of a Retry packet, if /// one was received retry_src_cid: Option, /// Total number of outgoing packets that have been deemed lost lost_packets: u64, events: VecDeque, endpoint_events: VecDeque, /// Whether the spin bit is in use for this connection spin_enabled: bool, /// Outgoing spin bit state spin: bool, /// Packet number spaces: initial, handshake, 1-RTT spaces: [PacketSpace; 3], /// Highest usable packet number space highest_space: SpaceId, /// 1-RTT keys used prior to a key update prev_crypto: Option, /// 1-RTT keys to be used for the next key update /// /// These are generated in advance to prevent timing attacks and/or DoS by third-party attackers /// spoofing key updates. next_crypto: Option>>, accepted_0rtt: bool, /// Whether the idle timer should be reset the next time an ack-eliciting packet is transmitted. permit_idle_reset: bool, /// Negotiated idle timeout idle_timeout: Option, timers: TimerTable, /// Number of packets received which could not be authenticated authentication_failures: u64, /// Why the connection was lost, if it has been error: Option, /// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are /// discarded. retry_token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, // // Queued non-retransmittable 1-RTT data // /// Responses to PATH_CHALLENGE frames path_responses: PathResponses, close: bool, // // ACK frequency // ack_frequency: AckFrequencyState, // // Loss Detection // /// The number of times a PTO has been sent without receiving an ack. pto_count: u32, // // Congestion Control // /// Whether the most recently received packet had an ECN codepoint set receiving_ecn: bool, /// Number of packets authenticated total_authed_packets: u64, /// Whether the last `poll_transmit` call yielded no data because there was /// no outgoing application data. app_limited: bool, streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, // Attributes of CIDs generated by local peer local_cid_state: CidState, /// State of the unreliable datagram extension datagrams: DatagramState, /// Connection level statistics stats: ConnectionStats, /// QUIC version used for the connection. version: u32, } impl Connection { pub(crate) fn new( endpoint_config: Arc, server_config: Option>, config: Arc, init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, pref_addr_cid: Option, remote: SocketAddr, local_ip: Option, crypto: Box, cid_gen: &dyn ConnectionIdGenerator, now: Instant, version: u32, allow_mtud: bool, rng_seed: [u8; 32], path_validated: bool, ) -> Self { let side = if server_config.is_some() { Side::Server } else { Side::Client }; let initial_space = PacketSpace { crypto: Some(crypto.initial_keys(&init_cid, side)), ..PacketSpace::new(now) }; let state = State::Handshake(state::Handshake { rem_cid_set: side.is_server(), expected_token: Bytes::new(), client_hello: None, }); let mut rng = StdRng::from_seed(rng_seed); let mut this = Self { endpoint_config, server_config, crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, local_cid_state: CidState::new( cid_gen.cid_len(), cid_gen.cid_lifetime(), now, if pref_addr_cid.is_some() { 2 } else { 1 }, ), path: PathData::new(remote, allow_mtud, None, now, path_validated, &config), allow_mtud, local_ip, prev_path: None, side, state, zero_rtt_enabled: false, zero_rtt_crypto: None, key_phase: false, // A small initial key phase size ensures peers that don't handle key updates correctly // fail sooner rather than later. It's okay for both peers to do this, as the first one // to perform an update will reset the other's key phase size in `update_keys`, and a // simultaneous key update by both is just like a regular key update with a really fast // response. Inspired by quic-go's similar behavior of performing the first key update // at the 100th short-header packet. key_phase_size: rng.gen_range(10..1000), peer_params: TransportParameters::default(), orig_rem_cid: rem_cid, initial_dst_cid: init_cid, retry_src_cid: None, lost_packets: 0, events: VecDeque::new(), endpoint_events: VecDeque::new(), spin_enabled: config.allow_spin && rng.gen_ratio(7, 8), spin: false, spaces: [initial_space, PacketSpace::new(now), PacketSpace::new(now)], highest_space: SpaceId::Initial, prev_crypto: None, next_crypto: None, accepted_0rtt: false, permit_idle_reset: true, idle_timeout: match config.max_idle_timeout { None | Some(VarInt(0)) => None, Some(dur) => Some(Duration::from_millis(dur.0)), }, timers: TimerTable::default(), authentication_failures: 0, error: None, retry_token: Bytes::new(), #[cfg(test)] packet_number_filter: match config.deterministic_packet_numbers { false => PacketNumberFilter::new(&mut rng), true => PacketNumberFilter::disabled(), }, #[cfg(not(test))] packet_number_filter: PacketNumberFilter::new(&mut rng), path_responses: PathResponses::default(), close: false, ack_frequency: AckFrequencyState::new(get_max_ack_delay( &TransportParameters::default(), )), pto_count: 0, app_limited: false, receiving_ecn: false, total_authed_packets: 0, streams: StreamsState::new( side, config.max_concurrent_uni_streams, config.max_concurrent_bidi_streams, config.send_window, config.receive_window, config.stream_receive_window, ), datagrams: DatagramState::default(), config, rem_cids: CidQueue::new(rem_cid), rng, stats: ConnectionStats::default(), version, }; if side.is_client() { // Kick off the connection this.write_crypto(); this.init_0rtt(); } this } /// Returns the next time at which `handle_timeout` should be called /// /// The value returned may change after: /// - the application performed some I/O on the connection /// - a call was made to `handle_event` /// - a call to `poll_transmit` returned `Some` /// - a call was made to `handle_timeout` #[must_use] pub fn poll_timeout(&mut self) -> Option { self.timers.next_timeout() } /// Returns application-facing events /// /// Connections should be polled for events after: /// - a call was made to `handle_event` /// - a call was made to `handle_timeout` #[must_use] pub fn poll(&mut self) -> Option { if let Some(x) = self.events.pop_front() { return Some(x); } if let Some(event) = self.streams.poll() { return Some(Event::Stream(event)); } if let Some(err) = self.error.take() { return Some(Event::ConnectionLost { reason: err }); } None } /// Return endpoint-facing events #[must_use] pub fn poll_endpoint_events(&mut self) -> Option { self.endpoint_events.pop_front().map(EndpointEvent) } /// Provide control over streams #[must_use] pub fn streams(&mut self) -> Streams<'_> { Streams { state: &mut self.streams, conn_state: &self.state, } } /// Provide control over streams #[must_use] pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> { assert!(id.dir() == Dir::Bi || id.initiator() != self.side); RecvStream { id, state: &mut self.streams, pending: &mut self.spaces[SpaceId::Data].pending, } } /// Provide control over streams #[must_use] pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> { assert!(id.dir() == Dir::Bi || id.initiator() == self.side); SendStream { id, state: &mut self.streams, pending: &mut self.spaces[SpaceId::Data].pending, conn_state: &self.state, } } /// Returns packets to transmit /// /// Connections should be polled for transmit after: /// - the application performed some I/O on the connection /// - a call was made to `handle_event` /// - a call was made to `handle_timeout` /// /// `max_datagrams` specifies how many datagrams can be returned inside a /// single Transmit using GSO. This must be at least 1. #[must_use] pub fn poll_transmit( &mut self, now: Instant, max_datagrams: usize, buf: &mut Vec, ) -> Option { assert!(max_datagrams != 0); let max_datagrams = match self.config.enable_segmentation_offload { false => 1, true => max_datagrams.min(MAX_TRANSMIT_SEGMENTS), }; let mut num_datagrams = 0; // Position in `buf` of the first byte of the current UDP datagram. When coalescing QUIC // packets, this can be earlier than the start of the current QUIC packet. let mut datagram_start = 0; let mut segment_size = usize::from(self.path.current_mtu()); // Send PATH_CHALLENGE for a previous path if necessary if let Some((prev_cid, ref mut prev_path)) = self.prev_path { if prev_path.challenge_pending { prev_path.challenge_pending = false; let token = prev_path .challenge .expect("previous path challenge pending without token"); let destination = prev_path.remote; debug_assert_eq!( self.highest_space, SpaceId::Data, "PATH_CHALLENGE queued without 1-RTT keys" ); buf.reserve(MIN_INITIAL_SIZE as usize); let buf_capacity = buf.capacity(); // Use the previous CID to avoid linking the new path with the previous path. We // don't bother accounting for possible retirement of that prev_cid because this is // sent once, immediately after migration, when the CID is known to be valid. Even // if a post-migration packet caused the CID to be retired, it's fair to pretend // this is sent first. let mut builder = PacketBuilder::new( now, SpaceId::Data, prev_cid, buf, buf_capacity, 0, false, self, )?; trace!("validating previous path with PATH_CHALLENGE {:08x}", token); buf.write(frame::FrameType::PATH_CHALLENGE); buf.write(token); self.stats.frame_tx.path_challenge += 1; // An endpoint MUST expand datagrams that contain a PATH_CHALLENGE frame // to at least the smallest allowed maximum datagram size of 1200 bytes, // unless the anti-amplification limit for the path does not permit // sending a datagram of this size builder.pad_to(MIN_INITIAL_SIZE); builder.finish(self, buf); self.stats.udp_tx.on_sent(1, buf.len()); return Some(Transmit { destination, size: buf.len(), ecn: None, segment_size: None, src_ip: self.local_ip, }); } } // If we need to send a probe, make sure we have something to send. for space in SpaceId::iter() { let request_immediate_ack = space == SpaceId::Data && self.peer_supports_ack_frequency(); self.spaces[space].maybe_queue_probe(request_immediate_ack, &self.streams); } // Check whether we need to send a close message let close = match self.state { State::Drained => { self.app_limited = true; return None; } State::Draining | State::Closed(_) => { // self.close is only reset once the associated packet had been // encoded successfully if !self.close { self.app_limited = true; return None; } true } _ => false, }; // Check whether we need to send an ACK_FREQUENCY frame if let Some(config) = &self.config.ack_frequency_config { self.spaces[SpaceId::Data].pending.ack_frequency = self .ack_frequency .should_send_ack_frequency(self.path.rtt.get(), config, &self.peer_params) && self.highest_space == SpaceId::Data && self.peer_supports_ack_frequency(); } // Reserving capacity can provide more capacity than we asked for. However, we are not // allowed to write more than `segment_size`. Therefore the maximum capacity is tracked // separately. let mut buf_capacity = 0; let mut coalesce = true; let mut builder_storage: Option = None; let mut sent_frames = None; let mut pad_datagram = false; let mut congestion_blocked = false; // Iterate over all spaces and find data to send let mut space_idx = 0; let spaces = [SpaceId::Initial, SpaceId::Handshake, SpaceId::Data]; // This loop will potentially spend multiple iterations in the same `SpaceId`, // so we cannot trivially rewrite it to take advantage of `SpaceId::iter()`. while space_idx < spaces.len() { let space_id = spaces[space_idx]; // Number of bytes available for frames if this is a 1-RTT packet. We're guaranteed to // be able to send an individual frame at least this large in the next 1-RTT // packet. This could be generalized to support every space, but it's only needed to // handle large fixed-size frames, which only exist in 1-RTT (application datagrams). We // don't account for coalesced packets potentially occupying space because frames can // always spill into the next datagram. let pn = self.packet_number_filter.peek(&self.spaces[SpaceId::Data]); let frame_space_1rtt = segment_size.saturating_sub(self.predict_1rtt_overhead(Some(pn))); // Is there data or a close message to send in this space? let can_send = self.space_can_send(space_id, frame_space_1rtt); if can_send.is_empty() && (!close || self.spaces[space_id].crypto.is_none()) { space_idx += 1; continue; } let mut ack_eliciting = !self.spaces[space_id].pending.is_empty(&self.streams) || self.spaces[space_id].ping_pending || self.spaces[space_id].immediate_ack_pending; if space_id == SpaceId::Data { ack_eliciting |= self.can_send_1rtt(frame_space_1rtt); } // Can we append more data into the current buffer? // It is not safe to assume that `buf.len()` is the end of the data, // since the last packet might not have been finished. let buf_end = if let Some(builder) = &builder_storage { buf.len().max(builder.min_size) + builder.tag_len } else { buf.len() }; let tag_len = if let Some(ref crypto) = self.spaces[space_id].crypto { crypto.packet.local.tag_len() } else if space_id == SpaceId::Data { self.zero_rtt_crypto.as_ref().expect( "sending packets in the application data space requires known 0-RTT or 1-RTT keys", ).packet.tag_len() } else { unreachable!("tried to send {:?} packet without keys", space_id) }; if !coalesce || buf_capacity - buf_end < MIN_PACKET_SPACE + tag_len { // We need to send 1 more datagram and extend the buffer for that. // Is 1 more datagram allowed? if buf_capacity >= segment_size * max_datagrams { // No more datagrams allowed break; } // Anti-amplification is only based on `total_sent`, which gets // updated at the end of this method. Therefore we pass the amount // of bytes for datagrams that are already created, as well as 1 byte // for starting another datagram. If there is any anti-amplification // budget left, we always allow a full MTU to be sent // (see https://github.com/quinn-rs/quinn/issues/1082) if self .path .anti_amplification_blocked(segment_size as u64 * num_datagrams + 1) { trace!("blocked by anti-amplification"); break; } // Congestion control and pacing checks // Tail loss probes must not be blocked by congestion, or a deadlock could arise if ack_eliciting && self.spaces[space_id].loss_probes == 0 { // Assume the current packet will get padded to fill the segment let untracked_bytes = if let Some(builder) = &builder_storage { buf_capacity - builder.partial_encode.start } else { 0 } as u64; debug_assert!(untracked_bytes <= segment_size as u64); let bytes_to_send = segment_size as u64 + untracked_bytes; if self.path.in_flight.bytes + bytes_to_send >= self.path.congestion.window() { space_idx += 1; congestion_blocked = true; // We continue instead of breaking here in order to avoid // blocking loss probes queued for higher spaces. trace!("blocked by congestion control"); continue; } // Check whether the next datagram is blocked by pacing let smoothed_rtt = self.path.rtt.get(); if let Some(delay) = self.path.pacing.delay( smoothed_rtt, bytes_to_send, self.path.current_mtu(), self.path.congestion.window(), now, ) { self.timers.set(Timer::Pacing, delay); congestion_blocked = true; // Loss probes should be subject to pacing, even though // they are not congestion controlled. trace!("blocked by pacing"); break; } } // Finish current packet if let Some(mut builder) = builder_storage.take() { if pad_datagram { builder.pad_to(MIN_INITIAL_SIZE); } if num_datagrams > 1 { // If too many padding bytes would be required to continue the GSO batch // after this packet, end the GSO batch here. Ensures that fixed-size frames // with heterogeneous sizes (e.g. application datagrams) won't inadvertently // waste large amounts of bandwidth. The exact threshold is a bit arbitrary // and might benefit from further tuning, though there's no universally // optimal value. const MAX_PADDING: usize = 16; let packet_len_unpadded = cmp::max(builder.min_size, buf.len()) - datagram_start + builder.tag_len; if packet_len_unpadded + MAX_PADDING < segment_size { trace!( "GSO truncated by demand for {} padding bytes", segment_size - packet_len_unpadded ); builder_storage = Some(builder); break; } // Pad the current datagram to GSO segment size so it can be included in the // GSO batch. builder.pad_to(segment_size as u16); } builder.finish_and_track(now, self, sent_frames.take(), buf); if num_datagrams == 1 { // Set the segment size for this GSO batch to the size of the first UDP // datagram in the batch. Larger data that cannot be fragmented // (e.g. application datagrams) will be included in a future batch. When // sending large enough volumes of data for GSO to be useful, we expect // packet sizes to usually be consistent, e.g. populated by max-size STREAM // frames or uniformly sized datagrams. segment_size = buf.len(); // Clip the unused capacity out of the buffer so future packets don't // overrun buf_capacity = buf.len(); // Check whether the data we planned to send will fit in the reduced segment // size. If not, bail out and leave it for the next GSO batch so we don't // end up trying to send an empty packet. We can't easily compute the right // segment size before the original call to `space_can_send`, because at // that time we haven't determined whether we're going to coalesce with the // first datagram or potentially pad it to `MIN_INITIAL_SIZE`. if space_id == SpaceId::Data { let frame_space_1rtt = segment_size.saturating_sub(self.predict_1rtt_overhead(Some(pn))); if self.space_can_send(space_id, frame_space_1rtt).is_empty() { break; } } } } // Allocate space for another datagram buf_capacity += segment_size; if buf.capacity() < buf_capacity { // We reserve the maximum space for sending `max_datagrams` upfront // to avoid any reallocations if more datagrams have to be appended later on. // Benchmarks have shown shown a 5-10% throughput improvement // compared to continuously resizing the datagram buffer. // While this will lead to over-allocation for small transmits // (e.g. purely containing ACKs), modern memory allocators // (e.g. mimalloc and jemalloc) will pool certain allocation sizes // and therefore this is still rather efficient. buf.reserve(max_datagrams * segment_size); } num_datagrams += 1; coalesce = true; pad_datagram = false; datagram_start = buf.len(); debug_assert_eq!( datagram_start % segment_size, 0, "datagrams in a GSO batch must be aligned to the segment size" ); } else { // We can append/coalesce the next packet into the current // datagram. // Finish current packet without adding extra padding if let Some(builder) = builder_storage.take() { builder.finish_and_track(now, self, sent_frames.take(), buf); } } debug_assert!(buf_capacity - buf.len() >= MIN_PACKET_SPACE); // // From here on, we've determined that a packet will definitely be sent. // if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake && self.side.is_client() { // A client stops both sending and processing Initial packets when it // sends its first Handshake packet. self.discard_space(now, SpaceId::Initial); } if let Some(ref mut prev) = self.prev_crypto { prev.update_unacked = false; } debug_assert!( builder_storage.is_none() && sent_frames.is_none(), "Previous packet must have been finished" ); let builder = builder_storage.insert(PacketBuilder::new( now, space_id, self.rem_cids.active(), buf, buf_capacity, datagram_start, ack_eliciting, self, )?); coalesce = coalesce && !builder.short_header; // https://tools.ietf.org/html/draft-ietf-quic-transport-34#section-14.1 pad_datagram |= space_id == SpaceId::Initial && (self.side.is_client() || ack_eliciting); if close { trace!("sending CONNECTION_CLOSE"); // Encode ACKs before the ConnectionClose message, to give the receiver // a better approximate on what data has been processed. This is // especially important with ack delay, since the peer might not // have gotten any other ACK for the data earlier on. if !self.spaces[space_id].pending_acks.ranges().is_empty() { Self::populate_acks( now, self.receiving_ecn, &mut SentFrames::default(), &mut self.spaces[space_id], buf, &mut self.stats, ); } // Since there only 64 ACK frames there will always be enough space // to encode the ConnectionClose frame too. However we still have the // check here to prevent crashes if something changes. debug_assert!( buf.len() + frame::ConnectionClose::SIZE_BOUND < builder.max_size, "ACKs should leave space for ConnectionClose" ); if buf.len() + frame::ConnectionClose::SIZE_BOUND < builder.max_size { let max_frame_size = builder.max_size - buf.len(); match self.state { State::Closed(state::Closed { ref reason }) => { if space_id == SpaceId::Data || reason.is_transport_layer() { reason.encode(buf, max_frame_size) } else { frame::ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, frame_type: None, reason: Bytes::new(), } .encode(buf, max_frame_size) } } State::Draining => frame::ConnectionClose { error_code: TransportErrorCode::NO_ERROR, frame_type: None, reason: Bytes::new(), } .encode(buf, max_frame_size), _ => unreachable!( "tried to make a close packet when the connection wasn't closed" ), } } if space_id == self.highest_space { // Don't send another close packet self.close = false; // `CONNECTION_CLOSE` is the final packet break; } else { // Send a close frame in every possible space for robustness, per RFC9000 // "Immediate Close during the Handshake". Don't bother trying to send anything // else. space_idx += 1; continue; } } // Send an off-path PATH_RESPONSE. Prioritized over on-path data to ensure that path // validation can occur while the link is saturated. if space_id == SpaceId::Data && num_datagrams == 1 { if let Some((token, remote)) = self.path_responses.pop_off_path(&self.path.remote) { // `unwrap` guaranteed to succeed because `builder_storage` was populated just // above. let mut builder = builder_storage.take().unwrap(); trace!("PATH_RESPONSE {:08x} (off-path)", token); buf.write(frame::FrameType::PATH_RESPONSE); buf.write(token); self.stats.frame_tx.path_response += 1; builder.pad_to(MIN_INITIAL_SIZE); builder.finish_and_track( now, self, Some(SentFrames { non_retransmits: true, ..SentFrames::default() }), buf, ); self.stats.udp_tx.on_sent(1, buf.len()); return Some(Transmit { destination: remote, size: buf.len(), ecn: None, segment_size: None, src_ip: self.local_ip, }); } } let sent = self.populate_packet(now, space_id, buf, builder.max_size, builder.exact_number); // ACK-only packets should only be sent when explicitly allowed. If we write them due to // any other reason, there is a bug which leads to one component announcing write // readiness while not writing any data. This degrades performance. The condition is // only checked if the full MTU is available and when potentially large fixed-size // frames aren't queued, so that lack of space in the datagram isn't the reason for just // writing ACKs. debug_assert!( !(sent.is_ack_only(&self.streams) && !can_send.acks && can_send.other && (buf_capacity - builder.datagram_start) == self.path.current_mtu() as usize && self.datagrams.outgoing.is_empty()), "SendableFrames was {can_send:?}, but only ACKs have been written" ); pad_datagram |= sent.requires_padding; if sent.largest_acked.is_some() { self.spaces[space_id].pending_acks.acks_sent(); self.timers.stop(Timer::MaxAckDelay); } // Keep information about the packet around until it gets finalized sent_frames = Some(sent); // Don't increment space_idx. // We stay in the current space and check if there is more data to send. } // Finish the last packet if let Some(mut builder) = builder_storage { if pad_datagram { builder.pad_to(MIN_INITIAL_SIZE); } let last_packet_number = builder.exact_number; builder.finish_and_track(now, self, sent_frames, buf); self.path .congestion .on_sent(now, buf.len() as u64, last_packet_number); } self.app_limited = buf.is_empty() && !congestion_blocked; // Send MTU probe if necessary if buf.is_empty() && self.state.is_established() { let space_id = SpaceId::Data; let probe_size = match self .path .mtud .poll_transmit(now, self.packet_number_filter.peek(&self.spaces[space_id])) { Some(next_probe_size) => next_probe_size, None => return None, }; let buf_capacity = probe_size as usize; buf.reserve(buf_capacity); let mut builder = PacketBuilder::new( now, space_id, self.rem_cids.active(), buf, buf_capacity, 0, true, self, )?; // We implement MTU probes as ping packets padded up to the probe size buf.write(frame::FrameType::PING); self.stats.frame_tx.ping += 1; // If supported by the peer, we want no delays to the probe's ACK if self.peer_supports_ack_frequency() { buf.write(frame::FrameType::IMMEDIATE_ACK); self.stats.frame_tx.immediate_ack += 1; } builder.pad_to(probe_size); let sent_frames = SentFrames { non_retransmits: true, ..Default::default() }; builder.finish_and_track(now, self, Some(sent_frames), buf); self.stats.path.sent_plpmtud_probes += 1; num_datagrams = 1; trace!(?probe_size, "writing MTUD probe"); } if buf.is_empty() { return None; } trace!("sending {} bytes in {} datagrams", buf.len(), num_datagrams); self.path.total_sent = self.path.total_sent.saturating_add(buf.len() as u64); self.stats.udp_tx.on_sent(num_datagrams, buf.len()); Some(Transmit { destination: self.path.remote, size: buf.len(), ecn: if self.path.sending_ecn { Some(EcnCodepoint::Ect0) } else { None }, segment_size: match num_datagrams { 1 => None, _ => Some(segment_size), }, src_ip: self.local_ip, }) } /// Indicate what types of frames are ready to send for the given space fn space_can_send(&self, space_id: SpaceId, frame_space_1rtt: usize) -> SendableFrames { if self.spaces[space_id].crypto.is_none() && (space_id != SpaceId::Data || self.zero_rtt_crypto.is_none() || self.side.is_server()) { // No keys available for this space return SendableFrames::empty(); } let mut can_send = self.spaces[space_id].can_send(&self.streams); if space_id == SpaceId::Data { can_send.other |= self.can_send_1rtt(frame_space_1rtt); } can_send } /// Process `ConnectionEvent`s generated by the associated `Endpoint` /// /// Will execute protocol logic upon receipt of a connection event, in turn preparing signals /// (including application `Event`s, `EndpointEvent`s and outgoing datagrams) that should be /// extracted through the relevant methods. pub fn handle_event(&mut self, event: ConnectionEvent) { use self::ConnectionEventInner::*; match event.0 { Datagram(DatagramConnectionEvent { now, remote, ecn, first_decode, remaining, }) => { // If this packet could initiate a migration and we're a client or a server that // forbids migration, drop the datagram. This could be relaxed to heuristically // permit NAT-rebinding-like migration. if remote != self.path.remote && self.server_config.as_ref().map_or(true, |x| !x.migration) { trace!("discarding packet from unrecognized peer {}", remote); return; } let was_anti_amplification_blocked = self.path.anti_amplification_blocked(1); self.stats.udp_rx.datagrams += 1; self.stats.udp_rx.bytes += first_decode.len() as u64; let data_len = first_decode.len(); self.handle_decode(now, remote, ecn, first_decode); // The current `path` might have changed inside `handle_decode`, // since the packet could have triggered a migration. Make sure // the data received is accounted for the most recent path by accessing // `path` after `handle_decode`. self.path.total_recvd = self.path.total_recvd.saturating_add(data_len as u64); if let Some(data) = remaining { self.stats.udp_rx.bytes += data.len() as u64; self.handle_coalesced(now, remote, ecn, data); } if was_anti_amplification_blocked { // A prior attempt to set the loss detection timer may have failed due to // anti-amplification, so ensure it's set now. Prevents a handshake deadlock if // the server's first flight is lost. self.set_loss_detection_timer(now); } } NewIdentifiers(ids, now) => { self.local_cid_state.new_cids(&ids, now); ids.into_iter().rev().for_each(|frame| { self.spaces[SpaceId::Data].pending.new_cids.push(frame); }); // Update Timer::PushNewCid if self .timers .get(Timer::PushNewCid) .map_or(true, |x| x <= now) { self.reset_cid_retirement(); } } } } /// Process timer expirations /// /// Executes protocol logic, potentially preparing signals (including application `Event`s, /// `EndpointEvent`s and outgoing datagrams) that should be extracted through the relevant /// methods. /// /// It is most efficient to call this immediately after the system clock reaches the latest /// `Instant` that was output by `poll_timeout`; however spurious extra calls will simply /// no-op and therefore are safe. pub fn handle_timeout(&mut self, now: Instant) { for &timer in &Timer::VALUES { if !self.timers.is_expired(timer, now) { continue; } self.timers.stop(timer); trace!(timer = ?timer, "timeout"); match timer { Timer::Close => { self.state = State::Drained; self.endpoint_events.push_back(EndpointEventInner::Drained); } Timer::Idle => { self.kill(ConnectionError::TimedOut); } Timer::KeepAlive => { trace!("sending keep-alive"); self.ping(); } Timer::LossDetection => { self.on_loss_detection_timeout(now); } Timer::KeyDiscard => { self.zero_rtt_crypto = None; self.prev_crypto = None; } Timer::PathValidation => { debug!("path validation failed"); if let Some((_, prev)) = self.prev_path.take() { self.path = prev; } self.path.challenge = None; self.path.challenge_pending = false; } Timer::Pacing => trace!("pacing timer expired"), Timer::PushNewCid => { // Update `retire_prior_to` field in NEW_CONNECTION_ID frame let num_new_cid = self.local_cid_state.on_cid_timeout().into(); if !self.state.is_closed() { trace!( "push a new cid to peer RETIRE_PRIOR_TO field {}", self.local_cid_state.retire_prior_to() ); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid)); } } Timer::MaxAckDelay => { trace!("max ack delay reached"); // This timer is only armed in the Data space self.spaces[SpaceId::Data] .pending_acks .on_max_ack_delay_timeout() } } } } /// Close a connection immediately /// /// This does not ensure delivery of outstanding data. It is the application's responsibility to /// call this only when all important communications have been completed, e.g. by calling /// [`SendStream::finish`] on outstanding streams and waiting for the corresponding /// [`StreamEvent::Finished`] event. /// /// If [`Streams::send_streams`] returns 0, all outstanding stream data has been /// delivered. There may still be data from the peer that has not been received. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished pub fn close(&mut self, now: Instant, error_code: VarInt, reason: Bytes) { self.close_inner( now, Close::Application(frame::ApplicationClose { error_code, reason }), ) } fn close_inner(&mut self, now: Instant, reason: Close) { let was_closed = self.state.is_closed(); if !was_closed { self.close_common(); self.set_close_timer(now); self.close = true; self.state = State::Closed(state::Closed { reason }); } } /// Control datagrams pub fn datagrams(&mut self) -> Datagrams<'_> { Datagrams { conn: self } } /// Returns connection statistics pub fn stats(&self) -> ConnectionStats { let mut stats = self.stats; stats.path.rtt = self.path.rtt.get(); stats.path.cwnd = self.path.congestion.window(); stats.path.current_mtu = self.path.mtud.current_mtu(); stats } /// Ping the remote endpoint /// /// Causes an ACK-eliciting packet to be transmitted. pub fn ping(&mut self) { self.spaces[self.highest_space].ping_pending = true; } #[doc(hidden)] pub fn initiate_key_update(&mut self) { self.update_keys(None, false); } /// Get a session reference pub fn crypto_session(&self) -> &dyn crypto::Session { &*self.crypto } /// Whether the connection is in the process of being established /// /// If this returns `false`, the connection may be either established or closed, signaled by the /// emission of a `Connected` or `ConnectionLost` message respectively. pub fn is_handshaking(&self) -> bool { self.state.is_handshake() } /// Whether the connection is closed /// /// Closed connections cannot transport any further data. A connection becomes closed when /// either peer application intentionally closes it, or when either transport layer detects an /// error such as a time-out or certificate validation failure. /// /// A `ConnectionLost` event is emitted with details when the connection becomes closed. pub fn is_closed(&self) -> bool { self.state.is_closed() } /// Whether there is no longer any need to keep the connection around /// /// Closed connections become drained after a brief timeout to absorb any remaining in-flight /// packets from the peer. All drained connections have been closed. pub fn is_drained(&self) -> bool { self.state.is_drained() } /// For clients, if the peer accepted the 0-RTT data packets /// /// The value is meaningless until after the handshake completes. pub fn accepted_0rtt(&self) -> bool { self.accepted_0rtt } /// Whether 0-RTT is/was possible during the handshake pub fn has_0rtt(&self) -> bool { self.zero_rtt_enabled } /// Whether there are any pending retransmits pub fn has_pending_retransmits(&self) -> bool { !self.spaces[SpaceId::Data].pending.is_empty(&self.streams) } /// Look up whether we're the client or server of this Connection pub fn side(&self) -> Side { self.side } /// The latest socket address for this connection's peer pub fn remote_address(&self) -> SocketAddr { self.path.remote } /// The local IP address which was used when the peer established /// the connection /// /// This can be different from the address the endpoint is bound to, in case /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`. /// /// This will return `None` for clients, or when no `local_ip` was passed to /// [`Endpoint::handle()`](crate::Endpoint::handle) for the datagrams establishing this /// connection. pub fn local_ip(&self) -> Option { self.local_ip } /// Current best estimate of this connection's latency (round-trip-time) pub fn rtt(&self) -> Duration { self.path.rtt.get() } /// Current state of this connection's congestion controller, for debugging purposes pub fn congestion_state(&self) -> &dyn Controller { self.path.congestion.as_ref() } /// Resets path-specific settings. /// /// This will force-reset several subsystems related to a specific network path. /// Currently this is the congestion controller, round-trip estimator, and the MTU /// discovery. /// /// This is useful when it is known the underlying network path has changed and the old /// state of these subsystems is no longer valid or optimal. In this case it might be /// faster or reduce loss to settle on optimal values by restarting from the initial /// configuration in the [`TransportConfig`]. pub fn path_changed(&mut self, now: Instant) { self.path.reset(now, &self.config); } /// Modify the number of remotely initiated streams that may be concurrently open /// /// No streams may be opened by the peer unless fewer than `count` are already open. Large /// `count`s increase both minimum and worst-case memory consumption. pub fn set_max_concurrent_streams(&mut self, dir: Dir, count: VarInt) { self.streams.set_max_concurrent(dir, count); // If the limit was reduced, then a flow control update previously deemed insignificant may // now be significant. let pending = &mut self.spaces[SpaceId::Data].pending; self.streams.queue_max_stream_id(pending); } /// Current number of remotely initiated streams that may be concurrently open /// /// If the target for this limit is reduced using [`set_max_concurrent_streams`](Self::set_max_concurrent_streams), /// it will not change immediately, even if fewer streams are open. Instead, it will /// decrement by one for each time a remotely initiated stream of matching directionality is closed. pub fn max_concurrent_streams(&self, dir: Dir) -> u64 { self.streams.max_concurrent(dir) } /// See [`TransportConfig::receive_window()`] pub fn set_receive_window(&mut self, receive_window: VarInt) { if self.streams.set_receive_window(receive_window) { self.spaces[SpaceId::Data].pending.max_data = true; } } fn on_ack_received( &mut self, now: Instant, space: SpaceId, ack: frame::Ack, ) -> Result<(), TransportError> { if ack.largest >= self.spaces[space].next_packet_number { return Err(TransportError::PROTOCOL_VIOLATION("unsent packet acked")); } let new_largest = { let space = &mut self.spaces[space]; if space .largest_acked_packet .map_or(true, |pn| ack.largest > pn) { space.largest_acked_packet = Some(ack.largest); if let Some(info) = space.sent_packets.get(&ack.largest) { // This should always succeed, but a misbehaving peer might ACK a packet we // haven't sent. At worst, that will result in us spuriously reducing the // congestion window. space.largest_acked_packet_sent = info.time_sent; } true } else { false } }; // Avoid DoS from unreasonably huge ack ranges by filtering out just the new acks. let mut newly_acked = ArrayRangeSet::new(); for range in ack.iter() { self.packet_number_filter.check_ack(space, range.clone())?; for (&pn, _) in self.spaces[space].sent_packets.range(range) { newly_acked.insert_one(pn); } } if newly_acked.is_empty() { return Ok(()); } let mut ack_eliciting_acked = false; for packet in newly_acked.elts() { if let Some(info) = self.spaces[space].take(packet) { if let Some(acked) = info.largest_acked { // Assume ACKs for all packets below the largest acknowledged in `packet` have // been received. This can cause the peer to spuriously retransmit if some of // our earlier ACKs were lost, but allows for simpler state tracking. See // discussion at // https://www.rfc-editor.org/rfc/rfc9000.html#name-limiting-ranges-by-tracking self.spaces[space].pending_acks.subtract_below(acked); } ack_eliciting_acked |= info.ack_eliciting; // Notify MTU discovery that a packet was acked, because it might be an MTU probe let mtu_updated = self.path.mtud.on_acked(space, packet, info.size); if mtu_updated { self.path .congestion .on_mtu_update(self.path.mtud.current_mtu()); } // Notify ack frequency that a packet was acked, because it might contain an ACK_FREQUENCY frame self.ack_frequency.on_acked(packet); self.on_packet_acked(now, packet, info); } } self.path.congestion.on_end_acks( now, self.path.in_flight.bytes, self.app_limited, self.spaces[space].largest_acked_packet, ); if new_largest && ack_eliciting_acked { let ack_delay = if space != SpaceId::Data { Duration::from_micros(0) } else { cmp::min( self.ack_frequency.peer_max_ack_delay, Duration::from_micros(ack.delay << self.peer_params.ack_delay_exponent.0), ) }; let rtt = instant_saturating_sub(now, self.spaces[space].largest_acked_packet_sent); self.path.rtt.update(ack_delay, rtt); if self.path.first_packet_after_rtt_sample.is_none() { self.path.first_packet_after_rtt_sample = Some((space, self.spaces[space].next_packet_number)); } } // Must be called before crypto/pto_count are clobbered self.detect_lost_packets(now, space, true); if self.peer_completed_address_validation() { self.pto_count = 0; } // Explicit congestion notification if self.path.sending_ecn { if let Some(ecn) = ack.ecn { // We only examine ECN counters from ACKs that we are certain we received in transmit // order, allowing us to compute an increase in ECN counts to compare against the number // of newly acked packets that remains well-defined in the presence of arbitrary packet // reordering. if new_largest { let sent = self.spaces[space].largest_acked_packet_sent; self.process_ecn(now, space, newly_acked.len() as u64, ecn, sent); } } else { // We always start out sending ECN, so any ack that doesn't acknowledge it disables it. debug!("ECN not acknowledged by peer"); self.path.sending_ecn = false; } } self.set_loss_detection_timer(now); Ok(()) } /// Process a new ECN block from an in-order ACK fn process_ecn( &mut self, now: Instant, space: SpaceId, newly_acked: u64, ecn: frame::EcnCounts, largest_sent_time: Instant, ) { match self.spaces[space].detect_ecn(newly_acked, ecn) { Err(e) => { debug!("halting ECN due to verification failure: {}", e); self.path.sending_ecn = false; // Wipe out the existing value because it might be garbage and could interfere with // future attempts to use ECN on new paths. self.spaces[space].ecn_feedback = frame::EcnCounts::ZERO; } Ok(false) => {} Ok(true) => { self.stats.path.congestion_events += 1; self.path .congestion .on_congestion_event(now, largest_sent_time, false, 0); } } } // Not timing-aware, so it's safe to call this for inferred acks, such as arise from // high-latency handshakes fn on_packet_acked(&mut self, now: Instant, pn: u64, info: SentPacket) { self.remove_in_flight(pn, &info); if info.ack_eliciting && self.path.challenge.is_none() { // Only pass ACKs to the congestion controller if we are not validating the current // path, so as to ignore any ACKs from older paths still coming in. self.path.congestion.on_ack( now, info.time_sent, info.size.into(), self.app_limited, &self.path.rtt, ); } // Update state for confirmed delivery of frames if let Some(retransmits) = info.retransmits.get() { for (id, _) in retransmits.reset_stream.iter() { self.streams.reset_acked(*id); } } for frame in info.stream_frames { self.streams.received_ack_of(frame); } } fn set_key_discard_timer(&mut self, now: Instant, space: SpaceId) { let start = if self.zero_rtt_crypto.is_some() { now } else { self.prev_crypto .as_ref() .expect("no previous keys") .end_packet .as_ref() .expect("update not acknowledged yet") .1 }; self.timers .set(Timer::KeyDiscard, start + self.pto(space) * 3); } fn on_loss_detection_timeout(&mut self, now: Instant) { if let Some((_, pn_space)) = self.loss_time_and_space() { // Time threshold loss Detection self.detect_lost_packets(now, pn_space, false); self.set_loss_detection_timer(now); return; } let (_, space) = match self.pto_time_and_space(now) { Some(x) => x, None => { error!("PTO expired while unset"); return; } }; trace!( in_flight = self.path.in_flight.bytes, count = self.pto_count, ?space, "PTO fired" ); let count = match self.path.in_flight.ack_eliciting { // A PTO when we're not expecting any ACKs must be due to handshake anti-amplification // deadlock preventions 0 => { debug_assert!(!self.peer_completed_address_validation()); 1 } // Conventional loss probe _ => 2, }; self.spaces[space].loss_probes = self.spaces[space].loss_probes.saturating_add(count); self.pto_count = self.pto_count.saturating_add(1); self.set_loss_detection_timer(now); } fn detect_lost_packets(&mut self, now: Instant, pn_space: SpaceId, due_to_ack: bool) { let mut lost_packets = Vec::::new(); let mut lost_mtu_probe = None; let in_flight_mtu_probe = self.path.mtud.in_flight_mtu_probe(); let rtt = self.path.rtt.conservative(); let loss_delay = cmp::max(rtt.mul_f32(self.config.time_threshold), TIMER_GRANULARITY); // Packets sent before this time are deemed lost. let lost_send_time = now.checked_sub(loss_delay).unwrap(); let largest_acked_packet = self.spaces[pn_space].largest_acked_packet.unwrap(); let packet_threshold = self.config.packet_threshold as u64; let mut size_of_lost_packets = 0u64; // InPersistentCongestion: Determine if all packets in the time period before the newest // lost packet, including the edges, are marked lost. PTO computation must always // include max ACK delay, i.e. operate as if in Data space (see RFC9001 ยง7.6.1). let congestion_period = self.pto(SpaceId::Data) * self.config.persistent_congestion_threshold; let mut persistent_congestion_start: Option = None; let mut prev_packet = None; let mut in_persistent_congestion = false; let space = &mut self.spaces[pn_space]; space.loss_time = None; for (&packet, info) in space.sent_packets.range(0..largest_acked_packet) { if prev_packet != Some(packet.wrapping_sub(1)) { // An intervening packet was acknowledged persistent_congestion_start = None; } if info.time_sent <= lost_send_time || largest_acked_packet >= packet + packet_threshold { if Some(packet) == in_flight_mtu_probe { // Lost MTU probes are not included in `lost_packets`, because they should not // trigger a congestion control response lost_mtu_probe = in_flight_mtu_probe; } else { lost_packets.push(packet); size_of_lost_packets += info.size as u64; if info.ack_eliciting && due_to_ack { match persistent_congestion_start { // Two ACK-eliciting packets lost more than congestion_period apart, with no // ACKed packets in between Some(start) if info.time_sent - start > congestion_period => { in_persistent_congestion = true; } // Persistent congestion must start after the first RTT sample None if self .path .first_packet_after_rtt_sample .map_or(false, |x| x < (pn_space, packet)) => { persistent_congestion_start = Some(info.time_sent); } _ => {} } } } } else { let next_loss_time = info.time_sent + loss_delay; space.loss_time = Some( space .loss_time .map_or(next_loss_time, |x| cmp::min(x, next_loss_time)), ); persistent_congestion_start = None; } prev_packet = Some(packet); } // OnPacketsLost if let Some(largest_lost) = lost_packets.last().cloned() { let old_bytes_in_flight = self.path.in_flight.bytes; let largest_lost_sent = self.spaces[pn_space].sent_packets[&largest_lost].time_sent; self.lost_packets += lost_packets.len() as u64; self.stats.path.lost_packets += lost_packets.len() as u64; self.stats.path.lost_bytes += size_of_lost_packets; trace!( "packets lost: {:?}, bytes lost: {}", lost_packets, size_of_lost_packets ); for &packet in &lost_packets { let info = self.spaces[pn_space].take(packet).unwrap(); // safe: lost_packets is populated just above self.remove_in_flight(packet, &info); for frame in info.stream_frames { self.streams.retransmit(frame); } self.spaces[pn_space].pending |= info.retransmits; self.path.mtud.on_non_probe_lost(packet, info.size); } if self.path.mtud.black_hole_detected(now) { self.stats.path.black_holes_detected += 1; self.path .congestion .on_mtu_update(self.path.mtud.current_mtu()); if let Some(max_datagram_size) = self.datagrams().max_size() { self.datagrams.drop_oversized(max_datagram_size); } } // Don't apply congestion penalty for lost ack-only packets let lost_ack_eliciting = old_bytes_in_flight != self.path.in_flight.bytes; if lost_ack_eliciting { self.stats.path.congestion_events += 1; self.path.congestion.on_congestion_event( now, largest_lost_sent, in_persistent_congestion, size_of_lost_packets, ); } } // Handle a lost MTU probe if let Some(packet) = lost_mtu_probe { let info = self.spaces[SpaceId::Data].take(packet).unwrap(); // safe: lost_mtu_probe is omitted from lost_packets, and therefore must not have been removed yet self.remove_in_flight(packet, &info); self.path.mtud.on_probe_lost(); self.stats.path.lost_plpmtud_probes += 1; } } fn loss_time_and_space(&self) -> Option<(Instant, SpaceId)> { SpaceId::iter() .filter_map(|id| Some((self.spaces[id].loss_time?, id))) .min_by_key(|&(time, _)| time) } fn pto_time_and_space(&self, now: Instant) -> Option<(Instant, SpaceId)> { let backoff = 2u32.pow(self.pto_count.min(MAX_BACKOFF_EXPONENT)); let mut duration = self.path.rtt.pto_base() * backoff; if self.path.in_flight.ack_eliciting == 0 { debug_assert!(!self.peer_completed_address_validation()); let space = match self.highest_space { SpaceId::Handshake => SpaceId::Handshake, _ => SpaceId::Initial, }; return Some((now + duration, space)); } let mut result = None; for space in SpaceId::iter() { if self.spaces[space].in_flight == 0 { continue; } if space == SpaceId::Data { // Skip ApplicationData until handshake completes. if self.is_handshaking() { return result; } // Include max_ack_delay and backoff for ApplicationData. duration += self.ack_frequency.max_ack_delay_for_pto() * backoff; } let last_ack_eliciting = match self.spaces[space].time_of_last_ack_eliciting_packet { Some(time) => time, None => continue, }; let pto = last_ack_eliciting + duration; if result.map_or(true, |(earliest_pto, _)| pto < earliest_pto) { result = Some((pto, space)); } } result } #[allow(clippy::suspicious_operation_groupings)] fn peer_completed_address_validation(&self) -> bool { if self.side.is_server() || self.state.is_closed() { return true; } // The server is guaranteed to have validated our address if any of our handshake or 1-RTT // packets are acknowledged or we've seen HANDSHAKE_DONE and discarded handshake keys. self.spaces[SpaceId::Handshake] .largest_acked_packet .is_some() || self.spaces[SpaceId::Data].largest_acked_packet.is_some() || (self.spaces[SpaceId::Data].crypto.is_some() && self.spaces[SpaceId::Handshake].crypto.is_none()) } fn set_loss_detection_timer(&mut self, now: Instant) { if self.state.is_closed() { // No loss detection takes place on closed connections, and `close_common` already // stopped time timer. Ensure we don't restart it inadvertently, e.g. in response to a // reordered packet being handled by state-insensitive code. return; } if let Some((loss_time, _)) = self.loss_time_and_space() { // Time threshold loss detection. self.timers.set(Timer::LossDetection, loss_time); return; } if self.path.anti_amplification_blocked(1) { // We wouldn't be able to send anything, so don't bother. self.timers.stop(Timer::LossDetection); return; } if self.path.in_flight.ack_eliciting == 0 && self.peer_completed_address_validation() { // There is nothing to detect lost, so no timer is set. However, the client needs to arm // the timer if the server might be blocked by the anti-amplification limit. self.timers.stop(Timer::LossDetection); return; } // Determine which PN space to arm PTO for. // Calculate PTO duration if let Some((timeout, _)) = self.pto_time_and_space(now) { self.timers.set(Timer::LossDetection, timeout); } else { self.timers.stop(Timer::LossDetection); } } /// Probe Timeout fn pto(&self, space: SpaceId) -> Duration { let max_ack_delay = match space { SpaceId::Initial | SpaceId::Handshake => Duration::new(0, 0), SpaceId::Data => self.ack_frequency.max_ack_delay_for_pto(), }; self.path.rtt.pto_base() + max_ack_delay } fn on_packet_authenticated( &mut self, now: Instant, space_id: SpaceId, ecn: Option, packet: Option, spin: bool, is_1rtt: bool, ) { self.total_authed_packets += 1; self.reset_keep_alive(now); self.reset_idle_timeout(now, space_id); self.permit_idle_reset = true; self.receiving_ecn |= ecn.is_some(); if let Some(x) = ecn { let space = &mut self.spaces[space_id]; space.ecn_counters += x; if x.is_ce() { space.pending_acks.set_immediate_ack_required(); } } let packet = match packet { Some(x) => x, None => return, }; if self.side.is_server() { if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake { // A server stops sending and processing Initial packets when it receives its first Handshake packet. self.discard_space(now, SpaceId::Initial); } if self.zero_rtt_crypto.is_some() && is_1rtt { // Discard 0-RTT keys soon after receiving a 1-RTT packet self.set_key_discard_timer(now, space_id) } } let space = &mut self.spaces[space_id]; space.pending_acks.insert_one(packet, now); if packet >= space.rx_packet { space.rx_packet = packet; // Update outgoing spin bit, inverting iff we're the client self.spin = self.side.is_client() ^ spin; } } fn reset_idle_timeout(&mut self, now: Instant, space: SpaceId) { let timeout = match self.idle_timeout { None => return, Some(dur) => dur, }; if self.state.is_closed() { self.timers.stop(Timer::Idle); return; } let dt = cmp::max(timeout, 3 * self.pto(space)); self.timers.set(Timer::Idle, now + dt); } fn reset_keep_alive(&mut self, now: Instant) { let interval = match self.config.keep_alive_interval { Some(x) if self.state.is_established() => x, _ => return, }; self.timers.set(Timer::KeepAlive, now + interval); } fn reset_cid_retirement(&mut self) { if let Some(t) = self.local_cid_state.next_timeout() { self.timers.set(Timer::PushNewCid, t); } } /// Handle the already-decrypted first packet from the client /// /// Decrypting the first packet in the `Endpoint` allows stateless packet handling to be more /// efficient. pub(crate) fn handle_first_packet( &mut self, now: Instant, remote: SocketAddr, ecn: Option, packet_number: u64, packet: InitialPacket, remaining: Option, ) -> Result<(), ConnectionError> { let span = trace_span!("first recv"); let _guard = span.enter(); debug_assert!(self.side.is_server()); let len = packet.header_data.len() + packet.payload.len(); self.path.total_recvd = len as u64; match self.state { State::Handshake(ref mut state) => { state.expected_token = packet.header.token.clone(); } _ => unreachable!("first packet must be delivered in Handshake state"), } self.on_packet_authenticated( now, SpaceId::Initial, ecn, Some(packet_number), false, false, ); self.process_decrypted_packet(now, remote, Some(packet_number), packet.into())?; if let Some(data) = remaining { self.handle_coalesced(now, remote, ecn, data); } Ok(()) } fn init_0rtt(&mut self) { let (header, packet) = match self.crypto.early_crypto() { Some(x) => x, None => return, }; if self.side.is_client() { match self.crypto.transport_parameters() { Ok(params) => { let params = params .expect("crypto layer didn't supply transport parameters with ticket"); // Certain values must not be cached let params = TransportParameters { initial_src_cid: None, original_dst_cid: None, preferred_address: None, retry_src_cid: None, stateless_reset_token: None, min_ack_delay: None, ack_delay_exponent: TransportParameters::default().ack_delay_exponent, max_ack_delay: TransportParameters::default().max_ack_delay, ..params }; self.set_peer_params(params); } Err(e) => { error!("session ticket has malformed transport parameters: {}", e); return; } } } trace!("0-RTT enabled"); self.zero_rtt_enabled = true; self.zero_rtt_crypto = Some(ZeroRttCrypto { header, packet }); } fn read_crypto( &mut self, space: SpaceId, crypto: &frame::Crypto, payload_len: usize, ) -> Result<(), TransportError> { let expected = if !self.state.is_handshake() { SpaceId::Data } else if self.highest_space == SpaceId::Initial { SpaceId::Initial } else { // On the server, self.highest_space can be Data after receiving the client's first // flight, but we expect Handshake CRYPTO until the handshake is complete. SpaceId::Handshake }; // We can't decrypt Handshake packets when highest_space is Initial, CRYPTO frames in 0-RTT // packets are illegal, and we don't process 1-RTT packets until the handshake is // complete. Therefore, we will never see CRYPTO data from a later-than-expected space. debug_assert!(space <= expected, "received out-of-order CRYPTO data"); let end = crypto.offset + crypto.data.len() as u64; if space < expected && end > self.spaces[space].crypto_stream.bytes_read() { warn!( "received new {:?} CRYPTO data when expecting {:?}", space, expected ); return Err(TransportError::PROTOCOL_VIOLATION( "new data at unexpected encryption level", )); } let space = &mut self.spaces[space]; let max = end.saturating_sub(space.crypto_stream.bytes_read()); if max > self.config.crypto_buffer_size as u64 { return Err(TransportError::CRYPTO_BUFFER_EXCEEDED("")); } space .crypto_stream .insert(crypto.offset, crypto.data.clone(), payload_len); while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) { trace!("consumed {} CRYPTO bytes", chunk.bytes.len()); if self.crypto.read_handshake(&chunk.bytes)? { self.events.push_back(Event::HandshakeDataReady); } } Ok(()) } fn write_crypto(&mut self) { loop { let space = self.highest_space; let mut outgoing = Vec::new(); if let Some(crypto) = self.crypto.write_handshake(&mut outgoing) { match space { SpaceId::Initial => { self.upgrade_crypto(SpaceId::Handshake, crypto); } SpaceId::Handshake => { self.upgrade_crypto(SpaceId::Data, crypto); } _ => unreachable!("got updated secrets during 1-RTT"), } } if outgoing.is_empty() { if space == self.highest_space { break; } else { // Keys updated, check for more data to send continue; } } let offset = self.spaces[space].crypto_offset; let outgoing = Bytes::from(outgoing); if let State::Handshake(ref mut state) = self.state { if space == SpaceId::Initial && offset == 0 && self.side.is_client() { state.client_hello = Some(outgoing.clone()); } } self.spaces[space].crypto_offset += outgoing.len() as u64; trace!("wrote {} {:?} CRYPTO bytes", outgoing.len(), space); self.spaces[space].pending.crypto.push_back(frame::Crypto { offset, data: outgoing, }); } } /// Switch to stronger cryptography during handshake fn upgrade_crypto(&mut self, space: SpaceId, crypto: Keys) { debug_assert!( self.spaces[space].crypto.is_none(), "already reached packet space {space:?}" ); trace!("{:?} keys ready", space); if space == SpaceId::Data { // Precompute the first key update self.next_crypto = Some( self.crypto .next_1rtt_keys() .expect("handshake should be complete"), ); } self.spaces[space].crypto = Some(crypto); debug_assert!(space as usize > self.highest_space as usize); self.highest_space = space; if space == SpaceId::Data && self.side.is_client() { // Discard 0-RTT keys because 1-RTT keys are available. self.zero_rtt_crypto = None; } } fn discard_space(&mut self, now: Instant, space_id: SpaceId) { debug_assert!(space_id != SpaceId::Data); trace!("discarding {:?} keys", space_id); if space_id == SpaceId::Initial { // No longer needed self.retry_token = Bytes::new(); } let space = &mut self.spaces[space_id]; space.crypto = None; space.time_of_last_ack_eliciting_packet = None; space.loss_time = None; space.in_flight = 0; let sent_packets = mem::take(&mut space.sent_packets); for (pn, packet) in sent_packets.into_iter() { self.remove_in_flight(pn, &packet); } self.set_loss_detection_timer(now) } fn handle_coalesced( &mut self, now: Instant, remote: SocketAddr, ecn: Option, data: BytesMut, ) { self.path.total_recvd = self.path.total_recvd.saturating_add(data.len() as u64); let mut remaining = Some(data); while let Some(data) = remaining { match PartialDecode::new( data, &FixedLengthConnectionIdParser::new(self.local_cid_state.cid_len()), &[self.version], self.endpoint_config.grease_quic_bit, ) { Ok((partial_decode, rest)) => { remaining = rest; self.handle_decode(now, remote, ecn, partial_decode); } Err(e) => { trace!("malformed header: {}", e); return; } } } } fn handle_decode( &mut self, now: Instant, remote: SocketAddr, ecn: Option, partial_decode: PartialDecode, ) { if let Some(decoded) = packet_crypto::unprotect_header( partial_decode, &self.spaces, self.zero_rtt_crypto.as_ref(), self.peer_params.stateless_reset_token, ) { self.handle_packet(now, remote, ecn, decoded.packet, decoded.stateless_reset); } } fn handle_packet( &mut self, now: Instant, remote: SocketAddr, ecn: Option, packet: Option, stateless_reset: bool, ) { self.stats.udp_rx.ios += 1; if let Some(ref packet) = packet { trace!( "got {:?} packet ({} bytes) from {} using id {}", packet.header.space(), packet.payload.len() + packet.header_data.len(), remote, packet.header.dst_cid(), ); } if self.is_handshaking() && remote != self.path.remote { debug!("discarding packet with unexpected remote during handshake"); return; } let was_closed = self.state.is_closed(); let was_drained = self.state.is_drained(); let decrypted = match packet { None => Err(None), Some(mut packet) => self .decrypt_packet(now, &mut packet) .map(move |number| (packet, number)), }; let result = match decrypted { _ if stateless_reset => { debug!("got stateless reset"); Err(ConnectionError::Reset) } Err(Some(e)) => { warn!("illegal packet: {}", e); Err(e.into()) } Err(None) => { debug!("failed to authenticate packet"); self.authentication_failures += 1; let integrity_limit = self.spaces[self.highest_space] .crypto .as_ref() .unwrap() .packet .local .integrity_limit(); if self.authentication_failures > integrity_limit { Err(TransportError::AEAD_LIMIT_REACHED("integrity limit violated").into()) } else { return; } } Ok((packet, number)) => { let span = match number { Some(pn) => trace_span!("recv", space = ?packet.header.space(), pn), None => trace_span!("recv", space = ?packet.header.space()), }; let _guard = span.enter(); let is_duplicate = |n| self.spaces[packet.header.space()].dedup.insert(n); if number.map_or(false, is_duplicate) { debug!("discarding possible duplicate packet"); return; } else if self.state.is_handshake() && packet.header.is_short() { // TODO: SHOULD buffer these to improve reordering tolerance. trace!("dropping short packet during handshake"); return; } else { if let Header::Initial(InitialHeader { ref token, .. }) = packet.header { if let State::Handshake(ref hs) = self.state { if self.side.is_server() && token != &hs.expected_token { // Clients must send the same retry token in every Initial. Initial // packets can be spoofed, so we discard rather than killing the // connection. warn!("discarding Initial with invalid retry token"); return; } } } if !self.state.is_closed() { let spin = match packet.header { Header::Short { spin, .. } => spin, _ => false, }; self.on_packet_authenticated( now, packet.header.space(), ecn, number, spin, packet.header.is_1rtt(), ); } self.process_decrypted_packet(now, remote, number, packet) } } }; // State transitions for error cases if let Err(conn_err) = result { self.error = Some(conn_err.clone()); self.state = match conn_err { ConnectionError::ApplicationClosed(reason) => State::closed(reason), ConnectionError::ConnectionClosed(reason) => State::closed(reason), ConnectionError::Reset | ConnectionError::TransportError(TransportError { code: TransportErrorCode::AEAD_LIMIT_REACHED, .. }) => State::Drained, ConnectionError::TimedOut => { unreachable!("timeouts aren't generated by packet processing"); } ConnectionError::TransportError(err) => { debug!("closing connection due to transport error: {}", err); State::closed(err) } ConnectionError::VersionMismatch => State::Draining, ConnectionError::LocallyClosed => { unreachable!("LocallyClosed isn't generated by packet processing"); } ConnectionError::CidsExhausted => { unreachable!("CidsExhausted isn't generated by packet processing"); } }; } if !was_closed && self.state.is_closed() { self.close_common(); if !self.state.is_drained() { self.set_close_timer(now); } } if !was_drained && self.state.is_drained() { self.endpoint_events.push_back(EndpointEventInner::Drained); // Close timer may have been started previously, e.g. if we sent a close and got a // stateless reset in response self.timers.stop(Timer::Close); } // Transmit CONNECTION_CLOSE if necessary if let State::Closed(_) = self.state { self.close = remote == self.path.remote; } } fn process_decrypted_packet( &mut self, now: Instant, remote: SocketAddr, number: Option, packet: Packet, ) -> Result<(), ConnectionError> { let state = match self.state { State::Established => { match packet.header.space() { SpaceId::Data => self.process_payload(now, remote, number.unwrap(), packet)?, _ if packet.header.has_frames() => self.process_early_payload(now, packet)?, _ => { trace!("discarding unexpected pre-handshake packet"); } } return Ok(()); } State::Closed(_) => { for result in frame::Iter::new(packet.payload.freeze())? { let frame = match result { Ok(frame) => frame, Err(err) => { debug!("frame decoding error: {err:?}"); continue; } }; if let Frame::Padding = frame { continue; }; self.stats.frame_rx.record(&frame); if let Frame::Close(_) = frame { trace!("draining"); self.state = State::Draining; break; } } return Ok(()); } State::Draining | State::Drained => return Ok(()), State::Handshake(ref mut state) => state, }; match packet.header { Header::Retry { src_cid: rem_cid, .. } => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent Retry").into()); } if self.total_authed_packets > 1 || packet.payload.len() <= 16 // token + 16 byte tag || !self.crypto.is_valid_retry( &self.rem_cids.active(), &packet.header_data, &packet.payload, ) { trace!("discarding invalid Retry"); // - After the client has received and processed an Initial or Retry // packet from the server, it MUST discard any subsequent Retry // packets that it receives. // - A client MUST discard a Retry packet with a zero-length Retry Token // field. // - Clients MUST discard Retry packets that have a Retry Integrity Tag // that cannot be validated return Ok(()); } trace!("retrying with CID {}", rem_cid); let client_hello = state.client_hello.take().unwrap(); self.retry_src_cid = Some(rem_cid); self.rem_cids.update_initial_cid(rem_cid); self.rem_handshake_cid = rem_cid; let space = &mut self.spaces[SpaceId::Initial]; if let Some(info) = space.take(0) { self.on_packet_acked(now, 0, info); }; self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials self.spaces[SpaceId::Initial] = PacketSpace { crypto: Some(self.crypto.initial_keys(&rem_cid, self.side)), next_packet_number: self.spaces[SpaceId::Initial].next_packet_number, crypto_offset: client_hello.len() as u64, ..PacketSpace::new(now) }; self.spaces[SpaceId::Initial] .pending .crypto .push_back(frame::Crypto { offset: 0, data: client_hello, }); // Retransmit all 0-RTT data let zero_rtt = mem::take(&mut self.spaces[SpaceId::Data].sent_packets); for (pn, info) in zero_rtt { self.remove_in_flight(pn, &info); self.spaces[SpaceId::Data].pending |= info.retransmits; } self.streams.retransmit_all_for_0rtt(); let token_len = packet.payload.len() - 16; self.retry_token = packet.payload.freeze().split_to(token_len); self.state = State::Handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, client_hello: None, }); Ok(()) } Header::Long { ty: LongType::Handshake, src_cid: rem_cid, .. } => { if rem_cid != self.rem_handshake_cid { debug!( "discarding packet with mismatched remote CID: {} != {}", self.rem_handshake_cid, rem_cid ); return Ok(()); } self.path.validated = true; self.process_early_payload(now, packet)?; if self.state.is_closed() { return Ok(()); } if self.crypto.is_handshaking() { trace!("handshake ongoing"); return Ok(()); } if self.side.is_client() { // Client-only because server params were set from the client's Initial let params = self.crypto .transport_parameters()? .ok_or_else(|| TransportError { code: TransportErrorCode::crypto(0x6d), frame: None, reason: "transport parameters missing".into(), })?; if self.has_0rtt() { if !self.crypto.early_data_accepted().unwrap() { debug_assert!(self.side.is_client()); debug!("0-RTT rejected"); self.accepted_0rtt = false; self.streams.zero_rtt_rejected(); // Discard already-queued frames self.spaces[SpaceId::Data].pending = Retransmits::default(); // Discard 0-RTT packets let sent_packets = mem::take(&mut self.spaces[SpaceId::Data].sent_packets); for (pn, packet) in sent_packets { self.remove_in_flight(pn, &packet); } } else { self.accepted_0rtt = true; params.validate_resumption_from(&self.peer_params)?; } } if let Some(token) = params.stateless_reset_token { self.endpoint_events .push_back(EndpointEventInner::ResetToken(self.path.remote, token)); } self.handle_peer_params(params)?; self.issue_first_cids(now); } else { // Server-only self.spaces[SpaceId::Data].pending.handshake_done = true; self.discard_space(now, SpaceId::Handshake); } self.events.push_back(Event::Connected); self.state = State::Established; trace!("established"); Ok(()) } Header::Initial(InitialHeader { src_cid: rem_cid, .. }) => { if !state.rem_cid_set { trace!("switching remote CID to {}", rem_cid); let mut state = state.clone(); self.rem_cids.update_initial_cid(rem_cid); self.rem_handshake_cid = rem_cid; self.orig_rem_cid = rem_cid; state.rem_cid_set = true; self.state = State::Handshake(state); } else if rem_cid != self.rem_handshake_cid { debug!( "discarding packet with mismatched remote CID: {} != {}", self.rem_handshake_cid, rem_cid ); return Ok(()); } let starting_space = self.highest_space; self.process_early_payload(now, packet)?; if self.side.is_server() && starting_space == SpaceId::Initial && self.highest_space != SpaceId::Initial { let params = self.crypto .transport_parameters()? .ok_or_else(|| TransportError { code: TransportErrorCode::crypto(0x6d), frame: None, reason: "transport parameters missing".into(), })?; self.handle_peer_params(params)?; self.issue_first_cids(now); self.init_0rtt(); } Ok(()) } Header::Long { ty: LongType::ZeroRtt, .. } => { self.process_payload(now, remote, number.unwrap(), packet)?; Ok(()) } Header::VersionNegotiate { .. } => { if self.total_authed_packets > 1 { return Ok(()); } let supported = packet .payload .chunks(4) .any(|x| match <[u8; 4]>::try_from(x) { Ok(version) => self.version == u32::from_be_bytes(version), Err(_) => false, }); if supported { return Ok(()); } debug!("remote doesn't support our version"); Err(ConnectionError::VersionMismatch) } Header::Short { .. } => unreachable!( "short packets received during handshake are discarded in handle_packet" ), } } /// Process an Initial or Handshake packet payload fn process_early_payload( &mut self, now: Instant, packet: Packet, ) -> Result<(), TransportError> { debug_assert_ne!(packet.header.space(), SpaceId::Data); let payload_len = packet.payload.len(); let mut ack_eliciting = false; for result in frame::Iter::new(packet.payload.freeze())? { let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), }; self.stats.frame_rx.record(&frame); let _guard = span.as_ref().map(|x| x.enter()); ack_eliciting |= frame.is_ack_eliciting(); // Process frames match frame { Frame::Padding | Frame::Ping => {} Frame::Crypto(frame) => { self.read_crypto(packet.header.space(), &frame, payload_len)?; } Frame::Ack(ack) => { self.on_ack_received(now, packet.header.space(), ack)?; } Frame::Close(reason) => { self.error = Some(reason.into()); self.state = State::Draining; return Ok(()); } _ => { let mut err = TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake"); err.frame = Some(frame.ty()); return Err(err); } } } if ack_eliciting { // In the initial and handshake spaces, ACKs must be sent immediately self.spaces[packet.header.space()] .pending_acks .set_immediate_ack_required(); } self.write_crypto(); Ok(()) } fn process_payload( &mut self, now: Instant, remote: SocketAddr, number: u64, packet: Packet, ) -> Result<(), TransportError> { let payload = packet.payload.freeze(); let mut is_probing_packet = true; let mut close = None; let payload_len = payload.len(); let mut ack_eliciting = false; for result in frame::Iter::new(payload)? { let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), }; self.stats.frame_rx.record(&frame); // Crypto, Stream and Datagram frames are special cased in order no pollute // the log with payload data match &frame { Frame::Crypto(f) => { trace!(offset = f.offset, len = f.data.len(), "got crypto frame"); } Frame::Stream(f) => { trace!(id = %f.id, offset = f.offset, len = f.data.len(), fin = f.fin, "got stream frame"); } Frame::Datagram(f) => { trace!(len = f.data.len(), "got datagram frame"); } f => { trace!("got frame {:?}", f); } } let _guard = span.as_ref().map(|x| x.enter()); if packet.header.is_0rtt() { match frame { Frame::Crypto(_) | Frame::Close(Close::Application(_)) => { return Err(TransportError::PROTOCOL_VIOLATION( "illegal frame type in 0-RTT", )); } _ => {} } } ack_eliciting |= frame.is_ack_eliciting(); // Check whether this could be a probing packet match frame { Frame::Padding | Frame::PathChallenge(_) | Frame::PathResponse(_) | Frame::NewConnectionId(_) => {} _ => { is_probing_packet = false; } } match frame { Frame::Crypto(frame) => { self.read_crypto(SpaceId::Data, &frame, payload_len)?; } Frame::Stream(frame) => { if self.streams.received(frame, payload_len)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } Frame::Ack(ack) => { self.on_ack_received(now, SpaceId::Data, ack)?; } Frame::Padding | Frame::Ping => {} Frame::Close(reason) => { close = Some(reason); } Frame::PathChallenge(token) => { self.path_responses.push(number, token, remote); if remote == self.path.remote { // PATH_CHALLENGE on active path, possible off-path packet forwarding // attack. Send a non-probing packet to recover the active path. match self.peer_supports_ack_frequency() { true => self.immediate_ack(), false => self.ping(), } } } Frame::PathResponse(token) => { if self.path.challenge == Some(token) && remote == self.path.remote { trace!("new path validated"); self.timers.stop(Timer::PathValidation); self.path.challenge = None; self.path.validated = true; if let Some((_, ref mut prev_path)) = self.prev_path { prev_path.challenge = None; prev_path.challenge_pending = false; } } else { debug!(token, "ignoring invalid PATH_RESPONSE"); } } Frame::MaxData(bytes) => { self.streams.received_max_data(bytes); } Frame::MaxStreamData { id, offset } => { self.streams.received_max_stream_data(id, offset)?; } Frame::MaxStreams { dir, count } => { self.streams.received_max_streams(dir, count)?; } Frame::ResetStream(frame) => { if self.streams.received_reset(frame)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } Frame::DataBlocked { offset } => { debug!(offset, "peer claims to be blocked at connection level"); } Frame::StreamDataBlocked { id, offset } => { if id.initiator() == self.side && id.dir() == Dir::Uni { debug!("got STREAM_DATA_BLOCKED on send-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "STREAM_DATA_BLOCKED on send-only stream", )); } debug!( stream = %id, offset, "peer claims to be blocked at stream level" ); } Frame::StreamsBlocked { dir, limit } => { if limit > MAX_STREAM_COUNT { return Err(TransportError::FRAME_ENCODING_ERROR( "unrepresentable stream limit", )); } debug!( "peer claims to be blocked opening more than {} {} streams", limit, dir ); } Frame::StopSending(frame::StopSending { id, error_code }) => { if id.initiator() != self.side { if id.dir() == Dir::Uni { debug!("got STOP_SENDING on recv-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "STOP_SENDING on recv-only stream", )); } } else if self.streams.is_local_unopened(id) { return Err(TransportError::STREAM_STATE_ERROR( "STOP_SENDING on unopened stream", )); } self.streams.received_stop_sending(id, error_code); } Frame::RetireConnectionId { sequence } => { let allow_more_cids = self .local_cid_state .on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?; self.endpoint_events .push_back(EndpointEventInner::RetireConnectionId( now, sequence, allow_more_cids, )); } Frame::NewConnectionId(frame) => { trace!( sequence = frame.sequence, id = %frame.id, retire_prior_to = frame.retire_prior_to, ); if self.rem_cids.active().is_empty() { return Err(TransportError::PROTOCOL_VIOLATION( "NEW_CONNECTION_ID when CIDs aren't in use", )); } if frame.retire_prior_to > frame.sequence { return Err(TransportError::PROTOCOL_VIOLATION( "NEW_CONNECTION_ID retiring unissued CIDs", )); } use crate::cid_queue::InsertError; match self.rem_cids.insert(frame) { Ok(None) => {} Ok(Some((retired, reset_token))) => { let pending_retired = &mut self.spaces[SpaceId::Data].pending.retire_cids; /// Ensure `pending_retired` cannot grow without bound. Limit is /// somewhat arbitrary but very permissive. const MAX_PENDING_RETIRED_CIDS: u64 = CidQueue::LEN as u64 * 10; // We don't bother counting in-flight frames because those are bounded // by congestion control. if (pending_retired.len() as u64) .saturating_add(retired.end.saturating_sub(retired.start)) > MAX_PENDING_RETIRED_CIDS { return Err(TransportError::CONNECTION_ID_LIMIT_ERROR( "queued too many retired CIDs", )); } pending_retired.extend(retired); self.set_reset_token(reset_token); } Err(InsertError::ExceedsLimit) => { return Err(TransportError::CONNECTION_ID_LIMIT_ERROR("")); } Err(InsertError::Retired) => { trace!("discarding already-retired"); // RETIRE_CONNECTION_ID might not have been previously sent if e.g. a // range of connection IDs larger than the active connection ID limit // was retired all at once via retire_prior_to. self.spaces[SpaceId::Data] .pending .retire_cids .push(frame.sequence); continue; } }; if self.side.is_server() && self.rem_cids.active_seq() == 0 { // We're a server still using the initial remote CID for the client, so // let's switch immediately to enable clientside stateless resets. self.update_rem_cid(); } } Frame::NewToken { token } => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent NEW_TOKEN")); } if token.is_empty() { return Err(TransportError::FRAME_ENCODING_ERROR("empty token")); } trace!("got new token"); // TODO: Cache, or perhaps forward to user? } Frame::Datagram(datagram) => { if self .datagrams .received(datagram, &self.config.datagram_receive_buffer_size)? { self.events.push_back(Event::DatagramReceived); } } Frame::AckFrequency(ack_frequency) => { // This frame can only be sent in the Data space let space = &mut self.spaces[SpaceId::Data]; if !self .ack_frequency .ack_frequency_received(&ack_frequency, &mut space.pending_acks)? { // The AckFrequency frame is stale (we have already received a more recent one) continue; } // Our `max_ack_delay` has been updated, so we may need to adjust its associated // timeout if let Some(timeout) = space .pending_acks .max_ack_delay_timeout(self.ack_frequency.max_ack_delay) { self.timers.set(Timer::MaxAckDelay, timeout); } } Frame::ImmediateAck => { // This frame can only be sent in the Data space self.spaces[SpaceId::Data] .pending_acks .set_immediate_ack_required(); } Frame::HandshakeDone => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION( "client sent HANDSHAKE_DONE", )); } if self.spaces[SpaceId::Handshake].crypto.is_some() { self.discard_space(now, SpaceId::Handshake); } } } } let space = &mut self.spaces[SpaceId::Data]; if space .pending_acks .packet_received(now, number, ack_eliciting, &space.dedup) { self.timers .set(Timer::MaxAckDelay, now + self.ack_frequency.max_ack_delay); } // Issue stream ID credit due to ACKs of outgoing finish/resets and incoming finish/resets // on stopped streams. Incoming finishes/resets on open streams are not handled here as they // are only freed, and hence only issue credit, once the application has been notified // during a read on the stream. let pending = &mut self.spaces[SpaceId::Data].pending; self.streams.queue_max_stream_id(pending); if let Some(reason) = close { self.error = Some(reason.into()); self.state = State::Draining; self.close = true; } if remote != self.path.remote && !is_probing_packet && number == self.spaces[SpaceId::Data].rx_packet { debug_assert!( self.server_config .as_ref() .expect("packets from unknown remote should be dropped by clients") .migration, "migration-initiating packets should have been dropped immediately" ); self.migrate(now, remote); // Break linkability, if possible self.update_rem_cid(); self.spin = false; } Ok(()) } fn migrate(&mut self, now: Instant, remote: SocketAddr) { trace!(%remote, "migration initiated"); // Reset rtt/congestion state for new path unless it looks like a NAT rebinding. // Note that the congestion window will not grow until validation terminates. Helps mitigate // amplification attacks performed by spoofing source addresses. let mut new_path = if remote.is_ipv4() && remote.ip() == self.path.remote.ip() { PathData::from_previous(remote, &self.path, now) } else { let peer_max_udp_payload_size = u16::try_from(self.peer_params.max_udp_payload_size.into_inner()) .unwrap_or(u16::MAX); PathData::new( remote, self.allow_mtud, Some(peer_max_udp_payload_size), now, false, &self.config, ) }; new_path.challenge = Some(self.rng.gen()); new_path.challenge_pending = true; let prev_pto = self.pto(SpaceId::Data); let mut prev = mem::replace(&mut self.path, new_path); // Don't clobber the original path if the previous one hasn't been validated yet if prev.challenge.is_none() { prev.challenge = Some(self.rng.gen()); prev.challenge_pending = true; // We haven't updated the remote CID yet, this captures the remote CID we were using on // the previous path. self.prev_path = Some((self.rem_cids.active(), prev)); } self.timers.set( Timer::PathValidation, now + 3 * cmp::max(self.pto(SpaceId::Data), prev_pto), ); } /// Handle a change in the local address, i.e. an active migration pub fn local_address_changed(&mut self) { self.update_rem_cid(); self.ping(); } /// Switch to a previously unused remote connection ID, if possible fn update_rem_cid(&mut self) { let (reset_token, retired) = match self.rem_cids.next() { Some(x) => x, None => return, }; // Retire the current remote CID and any CIDs we had to skip. self.spaces[SpaceId::Data] .pending .retire_cids .extend(retired); self.set_reset_token(reset_token); } fn set_reset_token(&mut self, reset_token: ResetToken) { self.endpoint_events .push_back(EndpointEventInner::ResetToken( self.path.remote, reset_token, )); self.peer_params.stateless_reset_token = Some(reset_token); } /// Issue an initial set of connection IDs to the peer upon connection fn issue_first_cids(&mut self, now: Instant) { if self.local_cid_state.cid_len() == 0 { return; } // Subtract 1 to account for the CID we supplied while handshaking let n = self.peer_params.issue_cids_limit() - 1; self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } fn populate_packet( &mut self, now: Instant, space_id: SpaceId, buf: &mut Vec, max_size: usize, pn: u64, ) -> SentFrames { let mut sent = SentFrames::default(); let space = &mut self.spaces[space_id]; let is_0rtt = space_id == SpaceId::Data && space.crypto.is_none(); space.pending_acks.maybe_ack_non_eliciting(); // HANDSHAKE_DONE if !is_0rtt && mem::replace(&mut space.pending.handshake_done, false) { buf.write(frame::FrameType::HANDSHAKE_DONE); sent.retransmits.get_or_create().handshake_done = true; // This is just a u8 counter and the frame is typically just sent once self.stats.frame_tx.handshake_done = self.stats.frame_tx.handshake_done.saturating_add(1); } // PING if mem::replace(&mut space.ping_pending, false) { trace!("PING"); buf.write(frame::FrameType::PING); sent.non_retransmits = true; self.stats.frame_tx.ping += 1; } // IMMEDIATE_ACK if mem::replace(&mut space.immediate_ack_pending, false) { trace!("IMMEDIATE_ACK"); buf.write(frame::FrameType::IMMEDIATE_ACK); sent.non_retransmits = true; self.stats.frame_tx.immediate_ack += 1; } // ACK if space.pending_acks.can_send() { Self::populate_acks( now, self.receiving_ecn, &mut sent, space, buf, &mut self.stats, ); } // ACK_FREQUENCY if mem::replace(&mut space.pending.ack_frequency, false) { let sequence_number = self.ack_frequency.next_sequence_number(); // Safe to unwrap because this is always provided when ACK frequency is enabled let config = self.config.ack_frequency_config.as_ref().unwrap(); // Ensure the delay is within bounds to avoid a PROTOCOL_VIOLATION error let max_ack_delay = self.ack_frequency.candidate_max_ack_delay( self.path.rtt.get(), config, &self.peer_params, ); trace!(?max_ack_delay, "ACK_FREQUENCY"); frame::AckFrequency { sequence: sequence_number, ack_eliciting_threshold: config.ack_eliciting_threshold, request_max_ack_delay: max_ack_delay.as_micros().try_into().unwrap_or(VarInt::MAX), reordering_threshold: config.reordering_threshold, } .encode(buf); sent.retransmits.get_or_create().ack_frequency = true; self.ack_frequency.ack_frequency_sent(pn, max_ack_delay); self.stats.frame_tx.ack_frequency += 1; } // PATH_CHALLENGE if buf.len() + 9 < max_size && space_id == SpaceId::Data { // Transmit challenges with every outgoing frame on an unvalidated path if let Some(token) = self.path.challenge { // But only send a packet solely for that purpose at most once self.path.challenge_pending = false; sent.non_retransmits = true; sent.requires_padding = true; trace!("PATH_CHALLENGE {:08x}", token); buf.write(frame::FrameType::PATH_CHALLENGE); buf.write(token); self.stats.frame_tx.path_challenge += 1; } } // PATH_RESPONSE if buf.len() + 9 < max_size && space_id == SpaceId::Data { if let Some(token) = self.path_responses.pop_on_path(&self.path.remote) { sent.non_retransmits = true; sent.requires_padding = true; trace!("PATH_RESPONSE {:08x}", token); buf.write(frame::FrameType::PATH_RESPONSE); buf.write(token); self.stats.frame_tx.path_response += 1; } } // CRYPTO while buf.len() + frame::Crypto::SIZE_BOUND < max_size && !is_0rtt { let mut frame = match space.pending.crypto.pop_front() { Some(x) => x, None => break, }; // Calculate the maximum amount of crypto data we can store in the buffer. // Since the offset is known, we can reserve the exact size required to encode it. // For length we reserve 2bytes which allows to encode up to 2^14, // which is more than what fits into normally sized QUIC frames. let max_crypto_data_size = max_size - buf.len() - 1 // Frame Type - VarInt::size(unsafe { VarInt::from_u64_unchecked(frame.offset) }) - 2; // Maximum encoded length for frame size, given we send less than 2^14 bytes let len = frame .data .len() .min(2usize.pow(14) - 1) .min(max_crypto_data_size); let data = frame.data.split_to(len); let truncated = frame::Crypto { offset: frame.offset, data, }; trace!( "CRYPTO: off {} len {}", truncated.offset, truncated.data.len() ); truncated.encode(buf); self.stats.frame_tx.crypto += 1; sent.retransmits.get_or_create().crypto.push_back(truncated); if !frame.data.is_empty() { frame.offset += len as u64; space.pending.crypto.push_front(frame); } } if space_id == SpaceId::Data { self.streams.write_control_frames( buf, &mut space.pending, &mut sent.retransmits, &mut self.stats.frame_tx, max_size, ); } // NEW_CONNECTION_ID while buf.len() + 44 < max_size { let issued = match space.pending.new_cids.pop() { Some(x) => x, None => break, }; trace!( sequence = issued.sequence, id = %issued.id, "NEW_CONNECTION_ID" ); frame::NewConnectionId { sequence: issued.sequence, retire_prior_to: self.local_cid_state.retire_prior_to(), id: issued.id, reset_token: issued.reset_token, } .encode(buf); sent.retransmits.get_or_create().new_cids.push(issued); self.stats.frame_tx.new_connection_id += 1; } // RETIRE_CONNECTION_ID while buf.len() + frame::RETIRE_CONNECTION_ID_SIZE_BOUND < max_size { let seq = match space.pending.retire_cids.pop() { Some(x) => x, None => break, }; trace!(sequence = seq, "RETIRE_CONNECTION_ID"); buf.write(frame::FrameType::RETIRE_CONNECTION_ID); buf.write_var(seq); sent.retransmits.get_or_create().retire_cids.push(seq); self.stats.frame_tx.retire_connection_id += 1; } // DATAGRAM let mut sent_datagrams = false; while buf.len() + Datagram::SIZE_BOUND < max_size && space_id == SpaceId::Data { match self.datagrams.write(buf, max_size) { true => { sent_datagrams = true; sent.non_retransmits = true; self.stats.frame_tx.datagram += 1; } false => break, } } if self.datagrams.send_blocked && sent_datagrams { self.events.push_back(Event::DatagramsUnblocked); self.datagrams.send_blocked = false; } // STREAM if space_id == SpaceId::Data { sent.stream_frames = self.streams .write_stream_frames(buf, max_size, self.config.send_fairness); self.stats.frame_tx.stream += sent.stream_frames.len() as u64; } sent } /// Write pending ACKs into a buffer /// /// This method assumes ACKs are pending, and should only be called if /// `!PendingAcks::ranges().is_empty()` returns `true`. fn populate_acks( now: Instant, receiving_ecn: bool, sent: &mut SentFrames, space: &mut PacketSpace, buf: &mut Vec, stats: &mut ConnectionStats, ) { debug_assert!(!space.pending_acks.ranges().is_empty()); // 0-RTT packets must never carry acks (which would have to be of handshake packets) debug_assert!(space.crypto.is_some(), "tried to send ACK in 0-RTT"); let ecn = if receiving_ecn { Some(&space.ecn_counters) } else { None }; sent.largest_acked = space.pending_acks.ranges().max(); let delay_micros = space.pending_acks.ack_delay(now).as_micros() as u64; // TODO: This should come from `TransportConfig` if that gets configurable. let ack_delay_exp = TransportParameters::default().ack_delay_exponent; let delay = delay_micros >> ack_delay_exp.into_inner(); trace!( "ACK {:?}, Delay = {}us", space.pending_acks.ranges(), delay_micros ); frame::Ack::encode(delay as _, space.pending_acks.ranges(), ecn, buf); stats.frame_tx.acks += 1; } fn close_common(&mut self) { trace!("connection closed"); for &timer in &Timer::VALUES { self.timers.stop(timer); } } fn set_close_timer(&mut self, now: Instant) { self.timers .set(Timer::Close, now + 3 * self.pto(self.highest_space)); } /// Handle transport parameters received from the peer fn handle_peer_params(&mut self, params: TransportParameters) -> Result<(), TransportError> { if Some(self.orig_rem_cid) != params.initial_src_cid || (self.side.is_client() && (Some(self.initial_dst_cid) != params.original_dst_cid || self.retry_src_cid != params.retry_src_cid)) { return Err(TransportError::TRANSPORT_PARAMETER_ERROR( "CID authentication failure", )); } self.set_peer_params(params); Ok(()) } fn set_peer_params(&mut self, params: TransportParameters) { self.streams.set_params(¶ms); self.idle_timeout = negotiate_max_idle_timeout(self.config.max_idle_timeout, Some(params.max_idle_timeout)); trace!("negotiated max idle timeout {:?}", self.idle_timeout); if let Some(ref info) = params.preferred_address { self.rem_cids.insert(frame::NewConnectionId { sequence: 1, id: info.connection_id, reset_token: info.stateless_reset_token, retire_prior_to: 0, }).expect("preferred address CID is the first received, and hence is guaranteed to be legal"); } self.ack_frequency.peer_max_ack_delay = get_max_ack_delay(¶ms); self.peer_params = params; self.path.mtud.on_peer_max_udp_payload_size_received( u16::try_from(self.peer_params.max_udp_payload_size.into_inner()).unwrap_or(u16::MAX), ); } fn decrypt_packet( &mut self, now: Instant, packet: &mut Packet, ) -> Result, Option> { let result = packet_crypto::decrypt_packet_body( packet, &self.spaces, self.zero_rtt_crypto.as_ref(), self.key_phase, self.prev_crypto.as_ref(), self.next_crypto.as_ref(), )?; let result = match result { Some(r) => r, None => return Ok(None), }; if result.outgoing_key_update_acked { if let Some(prev) = self.prev_crypto.as_mut() { prev.end_packet = Some((result.number, now)); self.set_key_discard_timer(now, packet.header.space()); } } if result.incoming_key_update { trace!("key update authenticated"); self.update_keys(Some((result.number, now)), true); self.set_key_discard_timer(now, packet.header.space()); } Ok(Some(result.number)) } fn update_keys(&mut self, end_packet: Option<(u64, Instant)>, remote: bool) { trace!("executing key update"); // Generate keys for the key phase after the one we're switching to, store them in // `next_crypto`, make the contents of `next_crypto` current, and move the current keys into // `prev_crypto`. let new = self .crypto .next_1rtt_keys() .expect("only called for `Data` packets"); self.key_phase_size = new .local .confidentiality_limit() .saturating_sub(KEY_UPDATE_MARGIN); let old = mem::replace( &mut self.spaces[SpaceId::Data] .crypto .as_mut() .unwrap() // safe because update_keys() can only be triggered by short packets .packet, mem::replace(self.next_crypto.as_mut().unwrap(), new), ); self.spaces[SpaceId::Data].sent_with_keys = 0; self.prev_crypto = Some(PrevCrypto { crypto: old, end_packet, update_unacked: remote, }); self.key_phase = !self.key_phase; } fn peer_supports_ack_frequency(&self) -> bool { self.peer_params.min_ack_delay.is_some() } /// Send an IMMEDIATE_ACK frame to the remote endpoint /// /// According to the spec, this will result in an error if the remote endpoint does not support /// the Acknowledgement Frequency extension pub(crate) fn immediate_ack(&mut self) { self.spaces[self.highest_space].immediate_ack_pending = true; } /// Decodes a packet, returning its decrypted payload, so it can be inspected in tests #[cfg(test)] pub(crate) fn decode_packet(&self, event: &ConnectionEvent) -> Option> { let (first_decode, remaining) = match &event.0 { ConnectionEventInner::Datagram(DatagramConnectionEvent { first_decode, remaining, .. }) => (first_decode, remaining), _ => return None, }; if remaining.is_some() { panic!("Packets should never be coalesced in tests"); } let decrypted_header = packet_crypto::unprotect_header( first_decode.clone(), &self.spaces, self.zero_rtt_crypto.as_ref(), self.peer_params.stateless_reset_token, )?; let mut packet = decrypted_header.packet?; packet_crypto::decrypt_packet_body( &mut packet, &self.spaces, self.zero_rtt_crypto.as_ref(), self.key_phase, self.prev_crypto.as_ref(), self.next_crypto.as_ref(), ) .ok()?; Some(packet.payload.to_vec()) } /// The number of bytes of packets containing retransmittable frames that have not been /// acknowledged or declared lost. #[cfg(test)] pub(crate) fn bytes_in_flight(&self) -> u64 { self.path.in_flight.bytes } /// Number of bytes worth of non-ack-only packets that may be sent #[cfg(test)] pub(crate) fn congestion_window(&self) -> u64 { self.path .congestion .window() .saturating_sub(self.path.in_flight.bytes) } /// Whether no timers but keepalive, idle, rtt and pushnewcid are running #[cfg(test)] pub(crate) fn is_idle(&self) -> bool { Timer::VALUES .iter() .filter(|&&t| t != Timer::KeepAlive && t != Timer::PushNewCid) .filter_map(|&t| Some((t, self.timers.get(t)?))) .min_by_key(|&(_, time)| time) .map_or(true, |(timer, _)| timer == Timer::Idle) } /// Total number of outgoing packets that have been deemed lost #[cfg(test)] pub(crate) fn lost_packets(&self) -> u64 { self.lost_packets } /// Whether explicit congestion notification is in use on outgoing packets. #[cfg(test)] pub(crate) fn using_ecn(&self) -> bool { self.path.sending_ecn } /// The number of received bytes in the current path #[cfg(test)] pub(crate) fn total_recvd(&self) -> u64 { self.path.total_recvd } #[cfg(test)] pub(crate) fn active_local_cid_seq(&self) -> (u64, u64) { self.local_cid_state.active_seq() } /// Instruct the peer to replace previously issued CIDs by sending a NEW_CONNECTION_ID frame /// with updated `retire_prior_to` field set to `v` #[cfg(test)] pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) { let n = self.local_cid_state.assign_retire_seq(v); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } /// Check the current active remote CID sequence #[cfg(test)] pub(crate) fn active_rem_cid_seq(&self) -> u64 { self.rem_cids.active_seq() } /// Returns the detected maximum udp payload size for the current path #[cfg(test)] pub(crate) fn path_mtu(&self) -> u16 { self.path.current_mtu() } /// Whether we have 1-RTT data to send /// /// See also `self.space(SpaceId::Data).can_send()` fn can_send_1rtt(&self, max_size: usize) -> bool { self.streams.can_send_stream_data() || self.path.challenge_pending || self .prev_path .as_ref() .map_or(false, |(_, x)| x.challenge_pending) || !self.path_responses.is_empty() || self .datagrams .outgoing .front() .map_or(false, |x| x.size(true) <= max_size) } /// Update counters to account for a packet becoming acknowledged, lost, or abandoned fn remove_in_flight(&mut self, pn: u64, packet: &SentPacket) { // Visit known paths from newest to oldest to find the one `pn` was sent on for path in [&mut self.path] .into_iter() .chain(self.prev_path.as_mut().map(|(_, data)| data)) { if path.remove_in_flight(pn, packet) { return; } } } /// Terminate the connection instantly, without sending a close packet fn kill(&mut self, reason: ConnectionError) { self.close_common(); self.error = Some(reason); self.state = State::Drained; self.endpoint_events.push_back(EndpointEventInner::Drained); } /// Storage size required for the largest packet known to be supported by the current path /// /// Buffers passed to [`Connection::poll_transmit`] should be at least this large. pub fn current_mtu(&self) -> u16 { self.path.current_mtu() } /// Size of non-frame data for a 1-RTT packet /// /// Quantifies space consumed by the QUIC header and AEAD tag. All other bytes in a packet are /// frames. Changes if the length of the remote connection ID changes, which is expected to be /// rare. If `pn` is specified, may additionally change unpredictably due to variations in /// latency and packet loss. fn predict_1rtt_overhead(&self, pn: Option) -> usize { let pn_len = match pn { Some(pn) => PacketNumber::new( pn, self.spaces[SpaceId::Data].largest_acked_packet.unwrap_or(0), ) .len(), // Upper bound None => 4, }; // 1 byte for flags 1 + self.rem_cids.active().len() + pn_len + self.tag_len_1rtt() } fn tag_len_1rtt(&self) -> usize { let key = match self.spaces[SpaceId::Data].crypto.as_ref() { Some(crypto) => Some(&*crypto.packet.local), None => self.zero_rtt_crypto.as_ref().map(|x| &*x.packet), }; // If neither Data nor 0-RTT keys are available, make a reasonable tag length guess. As of // this writing, all QUIC cipher suites use 16-byte tags. We could return `None` instead, // but that would needlessly prevent sending datagrams during 0-RTT. key.map_or(16, |x| x.tag_len()) } } impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Connection") .field("handshake_cid", &self.handshake_cid) .finish() } } /// Reasons why a connection might be lost #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum ConnectionError { /// The peer doesn't implement any supported version #[error("peer doesn't implement any supported version")] VersionMismatch, /// The peer violated the QUIC specification as understood by this implementation #[error(transparent)] TransportError(#[from] TransportError), /// The peer's QUIC stack aborted the connection automatically #[error("aborted by peer: {0}")] ConnectionClosed(frame::ConnectionClose), /// The peer closed the connection #[error("closed by peer: {0}")] ApplicationClosed(frame::ApplicationClose), /// The peer is unable to continue processing this connection, usually due to having restarted #[error("reset by peer")] Reset, /// Communication with the peer has lapsed for longer than the negotiated idle timeout /// /// If neither side is sending keep-alives, a connection will time out after a long enough idle /// period even if the peer is still reachable. See also [`TransportConfig::max_idle_timeout()`] /// and [`TransportConfig::keep_alive_interval()`]. #[error("timed out")] TimedOut, /// The local application closed the connection #[error("closed")] LocallyClosed, /// The connection could not be created because not enough of the CID space is available /// /// Try using longer connection IDs. #[error("CIDs exhausted")] CidsExhausted, } impl From for ConnectionError { fn from(x: Close) -> Self { match x { Close::Connection(reason) => Self::ConnectionClosed(reason), Close::Application(reason) => Self::ApplicationClosed(reason), } } } // For compatibility with API consumers impl From for io::Error { fn from(x: ConnectionError) -> Self { use self::ConnectionError::*; let kind = match x { TimedOut => io::ErrorKind::TimedOut, Reset => io::ErrorKind::ConnectionReset, ApplicationClosed(_) | ConnectionClosed(_) => io::ErrorKind::ConnectionAborted, TransportError(_) | VersionMismatch | LocallyClosed | CidsExhausted => { io::ErrorKind::Other } }; Self::new(kind, x) } } #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub enum State { Handshake(state::Handshake), Established, Closed(state::Closed), Draining, /// Waiting for application to call close so we can dispose of the resources Drained, } impl State { fn closed>(reason: R) -> Self { Self::Closed(state::Closed { reason: reason.into(), }) } fn is_handshake(&self) -> bool { matches!(*self, Self::Handshake(_)) } fn is_established(&self) -> bool { matches!(*self, Self::Established) } fn is_closed(&self) -> bool { matches!(*self, Self::Closed(_) | Self::Draining | Self::Drained) } fn is_drained(&self) -> bool { matches!(*self, Self::Drained) } } mod state { use super::*; #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub struct Handshake { /// Whether the remote CID has been set by the peer yet /// /// Always set for servers pub(super) rem_cid_set: bool, /// Stateless retry token received in the first Initial by a server. /// /// Must be present in every Initial. Always empty for clients. pub(super) expected_token: Bytes, /// First cryptographic message /// /// Only set for clients pub(super) client_hello: Option, } #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub struct Closed { pub(super) reason: Close, } } /// Events of interest to the application #[derive(Debug)] pub enum Event { /// The connection's handshake data is ready HandshakeDataReady, /// The connection was successfully established Connected, /// The connection was lost /// /// Emitted if the peer closes the connection or an error is encountered. ConnectionLost { /// Reason that the connection was closed reason: ConnectionError, }, /// Stream events Stream(StreamEvent), /// One or more application datagrams have been received DatagramReceived, /// One or more application datagrams have been sent after blocking DatagramsUnblocked, } fn instant_saturating_sub(x: Instant, y: Instant) -> Duration { if x > y { x - y } else { Duration::new(0, 0) } } fn get_max_ack_delay(params: &TransportParameters) -> Duration { Duration::from_micros(params.max_ack_delay.0 * 1000) } // Prevents overflow and improves behavior in extreme circumstances const MAX_BACKOFF_EXPONENT: u32 = 16; /// Minimal remaining size to allow packet coalescing, excluding cryptographic tag /// /// This must be at least as large as the header for a well-formed empty packet to be coalesced, /// plus some space for frames. We only care about handshake headers because short header packets /// necessarily have smaller headers, and initial packets are only ever the first packet in a /// datagram (because we coalesce in ascending packet space order and the only reason to split a /// packet is when packet space changes). const MIN_PACKET_SPACE: usize = MAX_HANDSHAKE_OR_0RTT_HEADER_SIZE + 32; /// Largest amount of space that could be occupied by a Handshake or 0-RTT packet's header /// /// Excludes packet-type-specific fields such as packet number or Initial token // https://www.rfc-editor.org/rfc/rfc9000.html#name-0-rtt: flags + version + dcid len + dcid + // scid len + scid + length + pn const MAX_HANDSHAKE_OR_0RTT_HEADER_SIZE: usize = 1 + 4 + 1 + MAX_CID_SIZE + 1 + MAX_CID_SIZE + VarInt::from_u32(u16::MAX as u32).size() + 4; /// The maximum amount of datagrams that are sent in a single transmit /// /// This can be lower than the maximum platform capabilities, to avoid excessive /// memory allocations when calling `poll_transmit()`. Benchmarks have shown /// that numbers around 10 are a good compromise. const MAX_TRANSMIT_SEGMENTS: usize = 10; /// Perform key updates this many packets before the AEAD confidentiality limit. /// /// Chosen arbitrarily, intended to be large enough to prevent spurious connection loss. const KEY_UPDATE_MARGIN: u64 = 10_000; #[derive(Default)] struct SentFrames { retransmits: ThinRetransmits, largest_acked: Option, stream_frames: StreamMetaVec, /// Whether the packet contains non-retransmittable frames (like datagrams) non_retransmits: bool, requires_padding: bool, } impl SentFrames { /// Returns whether the packet contains only ACKs fn is_ack_only(&self, streams: &StreamsState) -> bool { self.largest_acked.is_some() && !self.non_retransmits && self.stream_frames.is_empty() && self.retransmits.is_empty(streams) } } /// Compute the negotiated idle timeout based on local and remote max_idle_timeout transport parameters. /// /// According to the definition of max_idle_timeout, a value of `0` means the timeout is disabled; see /// /// According to the negotiation procedure, either the minimum of the timeouts or one specified is used as the negotiated value; see /// /// Returns the negotiated idle timeout as a `Duration`, or `None` when both endpoints have opted out of idle timeout. fn negotiate_max_idle_timeout(x: Option, y: Option) -> Option { match (x, y) { (Some(VarInt(0)) | None, Some(VarInt(0)) | None) => None, (Some(VarInt(0)) | None, Some(y)) => Some(Duration::from_millis(y.0)), (Some(x), Some(VarInt(0)) | None) => Some(Duration::from_millis(x.0)), (Some(x), Some(y)) => Some(Duration::from_millis(cmp::min(x, y).0)), } } #[cfg(test)] mod tests { use super::*; #[test] fn negotiate_max_idle_timeout_commutative() { let test_params = [ (None, None, None), (None, Some(VarInt(0)), None), (None, Some(VarInt(2)), Some(Duration::from_millis(2))), (Some(VarInt(0)), Some(VarInt(0)), None), ( Some(VarInt(2)), Some(VarInt(0)), Some(Duration::from_millis(2)), ), ( Some(VarInt(1)), Some(VarInt(4)), Some(Duration::from_millis(1)), ), ]; for (left, right, result) in test_params { assert_eq!(negotiate_max_idle_timeout(left, right), result); assert_eq!(negotiate_max_idle_timeout(right, left), result); } } } quinn-proto-0.11.9/src/connection/mtud.rs000064400000000000000000001046241046102023000164630ustar 00000000000000use crate::{packet::SpaceId, Instant, MtuDiscoveryConfig, MAX_UDP_PAYLOAD}; use std::cmp; use tracing::trace; /// Implements Datagram Packetization Layer Path Maximum Transmission Unit Discovery /// /// See [`MtuDiscoveryConfig`] for details #[derive(Clone)] pub(crate) struct MtuDiscovery { /// Detected MTU for the path current_mtu: u16, /// The state of the MTU discovery, if enabled state: Option, /// The state of the black hole detector black_hole_detector: BlackHoleDetector, } impl MtuDiscovery { pub(crate) fn new( initial_plpmtu: u16, min_mtu: u16, peer_max_udp_payload_size: Option, config: MtuDiscoveryConfig, ) -> Self { debug_assert!( initial_plpmtu >= min_mtu, "initial_max_udp_payload_size must be at least {min_mtu}" ); let mut mtud = Self::with_state( initial_plpmtu, min_mtu, Some(EnabledMtuDiscovery::new(config)), ); // We might be migrating an existing connection to a new path, in which case the transport // parameters have already been transmitted, and we already know the value of // `peer_max_udp_payload_size` if let Some(peer_max_udp_payload_size) = peer_max_udp_payload_size { mtud.on_peer_max_udp_payload_size_received(peer_max_udp_payload_size); } mtud } /// MTU discovery will be disabled and the current MTU will be fixed to the provided value pub(crate) fn disabled(plpmtu: u16, min_mtu: u16) -> Self { Self::with_state(plpmtu, min_mtu, None) } fn with_state(current_mtu: u16, min_mtu: u16, state: Option) -> Self { Self { current_mtu, state, black_hole_detector: BlackHoleDetector::new(min_mtu), } } pub(super) fn reset(&mut self, current_mtu: u16, min_mtu: u16) { self.current_mtu = current_mtu; if let Some(state) = self.state.take() { self.state = Some(EnabledMtuDiscovery::new(state.config)); self.on_peer_max_udp_payload_size_received(state.peer_max_udp_payload_size); } self.black_hole_detector = BlackHoleDetector::new(min_mtu); } /// Returns the current MTU pub(crate) fn current_mtu(&self) -> u16 { self.current_mtu } /// Returns the amount of bytes that should be sent as an MTU probe, if any pub(crate) fn poll_transmit(&mut self, now: Instant, next_pn: u64) -> Option { self.state .as_mut() .and_then(|state| state.poll_transmit(now, self.current_mtu, next_pn)) } /// Notifies the [`MtuDiscovery`] that the peer's `max_udp_payload_size` transport parameter has /// been received pub(crate) fn on_peer_max_udp_payload_size_received(&mut self, peer_max_udp_payload_size: u16) { self.current_mtu = self.current_mtu.min(peer_max_udp_payload_size); if let Some(state) = self.state.as_mut() { // MTUD is only active after the connection has been fully established, so it is // guaranteed we will receive the peer's transport parameters before we start probing debug_assert!(matches!(state.phase, Phase::Initial)); state.peer_max_udp_payload_size = peer_max_udp_payload_size; } } /// Notifies the [`MtuDiscovery`] that a packet has been ACKed /// /// Returns true if the packet was an MTU probe pub(crate) fn on_acked(&mut self, space: SpaceId, pn: u64, len: u16) -> bool { // MTU probes are only sent in application data space if space != SpaceId::Data { return false; } // Update the state of the MTU search if let Some(new_mtu) = self .state .as_mut() .and_then(|state| state.on_probe_acked(pn)) { self.current_mtu = new_mtu; trace!(current_mtu = self.current_mtu, "new MTU detected"); self.black_hole_detector.on_probe_acked(pn, len); true } else { self.black_hole_detector.on_non_probe_acked(pn, len); false } } /// Returns the packet number of the in-flight MTU probe, if any pub(crate) fn in_flight_mtu_probe(&self) -> Option { match &self.state { Some(EnabledMtuDiscovery { phase: Phase::Searching(search_state), .. }) => search_state.in_flight_probe, _ => None, } } /// Notifies the [`MtuDiscovery`] that the in-flight MTU probe was lost pub(crate) fn on_probe_lost(&mut self) { if let Some(state) = &mut self.state { state.on_probe_lost(); } } /// Notifies the [`MtuDiscovery`] that a non-probe packet was lost /// /// When done notifying of lost packets, [`MtuDiscovery::black_hole_detected`] must be called, to /// ensure the last loss burst is properly processed and to trigger black hole recovery logic if /// necessary. pub(crate) fn on_non_probe_lost(&mut self, pn: u64, len: u16) { self.black_hole_detector.on_non_probe_lost(pn, len); } /// Returns true if a black hole was detected /// /// Calling this function will close the previous loss burst. If a black hole is detected, the /// current MTU will be reset to `min_mtu`. pub(crate) fn black_hole_detected(&mut self, now: Instant) -> bool { if !self.black_hole_detector.black_hole_detected() { return false; } self.current_mtu = self.black_hole_detector.min_mtu; if let Some(state) = &mut self.state { state.on_black_hole_detected(now); } true } } /// Additional state for enabled MTU discovery #[derive(Debug, Clone)] struct EnabledMtuDiscovery { phase: Phase, peer_max_udp_payload_size: u16, config: MtuDiscoveryConfig, } impl EnabledMtuDiscovery { fn new(config: MtuDiscoveryConfig) -> Self { Self { phase: Phase::Initial, peer_max_udp_payload_size: MAX_UDP_PAYLOAD, config, } } /// Returns the amount of bytes that should be sent as an MTU probe, if any fn poll_transmit(&mut self, now: Instant, current_mtu: u16, next_pn: u64) -> Option { if let Phase::Initial = &self.phase { // Start the first search self.phase = Phase::Searching(SearchState::new( current_mtu, self.peer_max_udp_payload_size, &self.config, )); } else if let Phase::Complete(next_mtud_activation) = &self.phase { if now < *next_mtud_activation { return None; } // Start a new search (we have reached the next activation time) self.phase = Phase::Searching(SearchState::new( current_mtu, self.peer_max_udp_payload_size, &self.config, )); } if let Phase::Searching(state) = &mut self.phase { // Nothing to do while there is a probe in flight if state.in_flight_probe.is_some() { return None; } // Retransmit lost probes, if any if 0 < state.lost_probe_count && state.lost_probe_count < MAX_PROBE_RETRANSMITS { state.in_flight_probe = Some(next_pn); return Some(state.last_probed_mtu); } let last_probe_succeeded = state.lost_probe_count == 0; // The probe is definitely lost (we reached the MAX_PROBE_RETRANSMITS threshold) if !last_probe_succeeded { state.lost_probe_count = 0; state.in_flight_probe = None; } if let Some(probe_udp_payload_size) = state.next_mtu_to_probe(last_probe_succeeded) { state.in_flight_probe = Some(next_pn); state.last_probed_mtu = probe_udp_payload_size; return Some(probe_udp_payload_size); } else { let next_mtud_activation = now + self.config.interval; self.phase = Phase::Complete(next_mtud_activation); return None; } } None } /// Called when a packet is acknowledged in [`SpaceId::Data`] /// /// Returns the new `current_mtu` if the packet number corresponds to the in-flight MTU probe fn on_probe_acked(&mut self, pn: u64) -> Option { match &mut self.phase { Phase::Searching(state) if state.in_flight_probe == Some(pn) => { state.in_flight_probe = None; state.lost_probe_count = 0; Some(state.last_probed_mtu) } _ => None, } } /// Called when the in-flight MTU probe was lost fn on_probe_lost(&mut self) { // We might no longer be searching, e.g. if a black hole was detected if let Phase::Searching(state) = &mut self.phase { state.in_flight_probe = None; state.lost_probe_count += 1; } } /// Called when a black hole is detected fn on_black_hole_detected(&mut self, now: Instant) { // Stop searching, if applicable, and reset the timer let next_mtud_activation = now + self.config.black_hole_cooldown; self.phase = Phase::Complete(next_mtud_activation); } } #[derive(Debug, Clone, Copy)] enum Phase { /// We haven't started polling yet Initial, /// We are currently searching for a higher PMTU Searching(SearchState), /// Searching has completed and will be triggered again at the provided instant Complete(Instant), } #[derive(Debug, Clone, Copy)] struct SearchState { /// The lower bound for the current binary search lower_bound: u16, /// The upper bound for the current binary search upper_bound: u16, /// The minimum change to stop the current binary search minimum_change: u16, /// The UDP payload size we last sent a probe for last_probed_mtu: u16, /// Packet number of an in-flight probe (if any) in_flight_probe: Option, /// Lost probes at the current probe size lost_probe_count: usize, } impl SearchState { /// Creates a new search state, with the specified lower bound (the upper bound is derived from /// the config and the peer's `max_udp_payload_size` transport parameter) fn new( mut lower_bound: u16, peer_max_udp_payload_size: u16, config: &MtuDiscoveryConfig, ) -> Self { lower_bound = lower_bound.min(peer_max_udp_payload_size); let upper_bound = config .upper_bound .clamp(lower_bound, peer_max_udp_payload_size); Self { in_flight_probe: None, lost_probe_count: 0, lower_bound, upper_bound, minimum_change: config.minimum_change, // During initialization, we consider the lower bound to have already been // successfully probed last_probed_mtu: lower_bound, } } /// Determines the next MTU to probe using binary search fn next_mtu_to_probe(&mut self, last_probe_succeeded: bool) -> Option { debug_assert_eq!(self.in_flight_probe, None); if last_probe_succeeded { self.lower_bound = self.last_probed_mtu; } else { self.upper_bound = self.last_probed_mtu - 1; } let next_mtu = (self.lower_bound as i32 + self.upper_bound as i32) / 2; // Binary search stopping condition if ((next_mtu - self.last_probed_mtu as i32).unsigned_abs() as u16) < self.minimum_change { // Special case: if the upper bound is far enough, we want to probe it as a last // step (otherwise we will never achieve the upper bound) if self.upper_bound.saturating_sub(self.last_probed_mtu) >= self.minimum_change { return Some(self.upper_bound); } return None; } Some(next_mtu as u16) } } /// Judges whether packet loss might indicate a drop in MTU /// /// Our MTU black hole detection scheme is a heuristic based on the order in which packets were sent /// (the packet number order), their sizes, and which are deemed lost. /// /// First, contiguous groups of lost packets ("loss bursts") are aggregated, because a group of /// packets all lost together were probably lost for the same reason. /// /// A loss burst is deemed "suspicious" if it contains no packets that are (a) smaller than the /// minimum MTU or (b) smaller than a more recent acknowledged packet, because such a burst could be /// fully explained by a reduction in MTU. /// /// When the number of suspicious loss bursts exceeds [`BLACK_HOLE_THRESHOLD`], we judge the /// evidence for an MTU black hole to be sufficient. #[derive(Clone)] struct BlackHoleDetector { /// Packet loss bursts currently considered suspicious suspicious_loss_bursts: Vec, /// Loss burst currently being aggregated, if any current_loss_burst: Option, /// Packet number of the biggest packet larger than `min_mtu` which we've received /// acknowledgment of more recently than any suspicious loss burst, if any largest_post_loss_packet: u64, /// The maximum of `min_mtu` and the size of `largest_post_loss_packet`, or exactly `min_mtu` if /// no larger packets have been received since the most recent loss burst. acked_mtu: u16, /// The UDP payload size guaranteed to be supported by the network min_mtu: u16, } impl BlackHoleDetector { fn new(min_mtu: u16) -> Self { Self { suspicious_loss_bursts: Vec::with_capacity(BLACK_HOLE_THRESHOLD + 1), current_loss_burst: None, largest_post_loss_packet: 0, acked_mtu: min_mtu, min_mtu, } } fn on_probe_acked(&mut self, pn: u64, len: u16) { // MTU probes are always larger than the previous MTU, so no previous loss bursts are // suspicious. At most one MTU probe is in flight at a time, so we don't need to worry about // reordering between them. self.suspicious_loss_bursts.clear(); self.acked_mtu = len; // This might go backwards, but that's okay: a successful ACK means we haven't yet judged a // more recently sent packet lost, and we just want to track the largest packet that's been // successfully delivered more recently than a loss. self.largest_post_loss_packet = pn; } fn on_non_probe_acked(&mut self, pn: u64, len: u16) { if len <= self.acked_mtu { // We've already seen a larger packet since the most recent suspicious loss burst; // nothing to do. return; } self.acked_mtu = len; // This might go backwards, but that's okay as described in `on_probe_acked`. self.largest_post_loss_packet = pn; // Loss bursts packets smaller than this are retroactively deemed non-suspicious. self.suspicious_loss_bursts .retain(|burst| burst.smallest_packet_size > len); } fn on_non_probe_lost(&mut self, pn: u64, len: u16) { // A loss burst is a group of consecutive packets that are declared lost, so a distance // greater than 1 indicates a new burst let end_last_burst = self .current_loss_burst .as_ref() .map_or(false, |current| pn - current.latest_non_probe != 1); if end_last_burst { self.finish_loss_burst(); } self.current_loss_burst = Some(CurrentLossBurst { latest_non_probe: pn, smallest_packet_size: self .current_loss_burst .map_or(len, |prev| cmp::min(prev.smallest_packet_size, len)), }); } fn black_hole_detected(&mut self) -> bool { self.finish_loss_burst(); if self.suspicious_loss_bursts.len() <= BLACK_HOLE_THRESHOLD { return false; } self.suspicious_loss_bursts.clear(); true } /// Marks the end of the current loss burst, checking whether it was suspicious fn finish_loss_burst(&mut self) { let Some(burst) = self.current_loss_burst.take() else { return; }; // If a loss burst contains a packet smaller than the minimum MTU or a more recently // transmitted packet, it is not suspicious. if burst.smallest_packet_size < self.min_mtu || (burst.latest_non_probe < self.largest_post_loss_packet && burst.smallest_packet_size < self.acked_mtu) { return; } // The loss burst is now deemed suspicious. // A suspicious loss burst more recent than `largest_post_loss_packet` invalidates it. This // makes `acked_mtu` a conservative approximation. Ideally we'd update `safe_mtu` and // `largest_post_loss_packet` to describe the largest acknowledged packet sent later than // this burst, but that would require tracking the size of an unpredictable number of // recently acknowledged packets, and erring on the side of false positives is safe. if burst.latest_non_probe > self.largest_post_loss_packet { self.acked_mtu = self.min_mtu; } let burst = LossBurst { smallest_packet_size: burst.smallest_packet_size, }; if self.suspicious_loss_bursts.len() <= BLACK_HOLE_THRESHOLD { self.suspicious_loss_bursts.push(burst); return; } // To limit memory use, only track the most suspicious loss bursts. let smallest = self .suspicious_loss_bursts .iter_mut() .min_by_key(|prev| prev.smallest_packet_size) .filter(|prev| prev.smallest_packet_size < burst.smallest_packet_size); if let Some(smallest) = smallest { *smallest = burst; } } #[cfg(test)] fn suspicious_loss_burst_count(&self) -> usize { self.suspicious_loss_bursts.len() } #[cfg(test)] fn largest_non_probe_lost(&self) -> Option { self.current_loss_burst.as_ref().map(|x| x.latest_non_probe) } } #[derive(Copy, Clone)] struct LossBurst { smallest_packet_size: u16, } #[derive(Copy, Clone)] struct CurrentLossBurst { smallest_packet_size: u16, latest_non_probe: u64, } // Corresponds to the RFC's `MAX_PROBES` constant (see // https://www.rfc-editor.org/rfc/rfc8899#section-5.1.2) const MAX_PROBE_RETRANSMITS: usize = 3; /// Maximum number of suspicious loss bursts that will not trigger black hole detection const BLACK_HOLE_THRESHOLD: usize = 3; #[cfg(test)] mod tests { use super::*; use crate::packet::SpaceId; use crate::Duration; use crate::MAX_UDP_PAYLOAD; use assert_matches::assert_matches; fn default_mtud() -> MtuDiscovery { let config = MtuDiscoveryConfig::default(); MtuDiscovery::new(1_200, 1_200, None, config) } fn completed(mtud: &MtuDiscovery) -> bool { matches!(mtud.state.as_ref().unwrap().phase, Phase::Complete(_)) } /// Drives mtud until it reaches `Phase::Completed` fn drive_to_completion( mtud: &mut MtuDiscovery, now: Instant, link_payload_size_limit: u16, ) -> Vec { let mut probed_sizes = Vec::new(); for probe_pn in 1..100 { let result = mtud.poll_transmit(now, probe_pn); if completed(mtud) { break; } // "Send" next probe assert!(result.is_some()); let probe_size = result.unwrap(); probed_sizes.push(probe_size); if probe_size <= link_payload_size_limit { mtud.on_acked(SpaceId::Data, probe_pn, probe_size); } else { mtud.on_probe_lost(); } } probed_sizes } #[test] fn black_hole_detector_ignores_burst_containing_non_suspicious_packet() { let mut mtud = default_mtud(); mtud.on_non_probe_lost(2, 1300); mtud.on_non_probe_lost(3, 1300); assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), Some(3)); assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0); mtud.on_non_probe_lost(4, 800); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), None); assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0); } #[test] fn black_hole_detector_counts_burst_containing_only_suspicious_packets() { let mut mtud = default_mtud(); mtud.on_non_probe_lost(2, 1300); mtud.on_non_probe_lost(3, 1300); assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), Some(3)); assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!(mtud.black_hole_detector.largest_non_probe_lost(), None); assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 1); } #[test] fn black_hole_detector_ignores_empty_burst() { let mut mtud = default_mtud(); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!(mtud.black_hole_detector.suspicious_loss_burst_count(), 0); } #[test] fn mtu_discovery_disabled_does_nothing() { let mut mtud = MtuDiscovery::disabled(1_200, 1_200); let probe_size = mtud.poll_transmit(Instant::now(), 0); assert_eq!(probe_size, None); } #[test] fn mtu_discovery_disabled_lost_four_packet_bursts_triggers_black_hole_detection() { let mut mtud = MtuDiscovery::disabled(1_400, 1_250); let now = Instant::now(); for i in 0..4 { // The packets are never contiguous, so each one has its own burst mtud.on_non_probe_lost(i * 2, 1300); } assert!(mtud.black_hole_detected(now)); assert_eq!(mtud.current_mtu, 1250); assert_matches!(mtud.state, None); } #[test] fn mtu_discovery_lost_two_packet_bursts_does_not_trigger_black_hole_detection() { let mut mtud = default_mtud(); let now = Instant::now(); for i in 0..2 { mtud.on_non_probe_lost(i, 1300); assert!(!mtud.black_hole_detected(now)); } } #[test] fn mtu_discovery_lost_four_packet_bursts_triggers_black_hole_detection_and_resets_timer() { let mut mtud = default_mtud(); let now = Instant::now(); for i in 0..4 { // The packets are never contiguous, so each one has its own burst mtud.on_non_probe_lost(i * 2, 1300); } assert!(mtud.black_hole_detected(now)); assert_eq!(mtud.current_mtu, 1200); if let Phase::Complete(next_mtud_activation) = mtud.state.unwrap().phase { assert_eq!(next_mtud_activation, now + Duration::from_secs(60)); } else { panic!("Unexpected MTUD phase!"); } } #[test] fn mtu_discovery_after_complete_reactivates_when_interval_elapsed() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9_000); let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config); let now = Instant::now(); drive_to_completion(&mut mtud, now, 1_500); // Polling right after completion does not cause new packets to be sent assert_eq!(mtud.poll_transmit(now, 42), None); assert!(completed(&mtud)); assert_eq!(mtud.current_mtu, 1_471); // Polling after the interval has passed does (taking the current mtu as lower bound) assert_eq!( mtud.poll_transmit(now + Duration::from_secs(600), 43), Some(5235) ); match mtud.state.unwrap().phase { Phase::Searching(state) => { assert_eq!(state.lower_bound, 1_471); assert_eq!(state.upper_bound, 9_000); } _ => { panic!("Unexpected MTUD phase!") } } } #[test] fn mtu_discovery_lost_three_probes_lowers_probe_size() { let mut mtud = default_mtud(); let mut probe_sizes = (0..4).map(|i| { let probe_size = mtud.poll_transmit(Instant::now(), i); assert!(probe_size.is_some(), "no probe returned for packet {i}"); mtud.on_probe_lost(); probe_size.unwrap() }); // After the first probe is lost, it gets retransmitted twice let first_probe_size = probe_sizes.next().unwrap(); for _ in 0..2 { assert_eq!(probe_sizes.next().unwrap(), first_probe_size) } // After the third probe is lost, we decrement our probe size let fourth_probe_size = probe_sizes.next().unwrap(); assert!(fourth_probe_size < first_probe_size); assert_eq!( fourth_probe_size, first_probe_size - (first_probe_size - 1_200) / 2 - 1 ); } #[test] fn mtu_discovery_with_peer_max_udp_payload_size_clamps_upper_bound() { let mut mtud = default_mtud(); mtud.on_peer_max_udp_payload_size_received(1300); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1300); assert_eq!(mtud.current_mtu, 1300); let expected_probed_sizes = &[1250, 1275, 1300]; assert_eq!(probed_sizes, expected_probed_sizes); assert!(completed(&mtud)); } #[test] fn mtu_discovery_with_previous_peer_max_udp_payload_size_clamps_upper_bound() { let mut mtud = MtuDiscovery::new(1500, 1_200, Some(1400), MtuDiscoveryConfig::default()); assert_eq!(mtud.current_mtu, 1400); assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1400); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); assert_eq!(mtud.current_mtu, 1400); assert!(probed_sizes.is_empty()); assert!(completed(&mtud)); } #[test] #[should_panic] fn mtu_discovery_with_peer_max_udp_payload_size_after_search_panics() { let mut mtud = default_mtud(); drive_to_completion(&mut mtud, Instant::now(), 1500); mtud.on_peer_max_udp_payload_size_received(1300); } #[test] fn mtu_discovery_with_1500_limit() { let mut mtud = default_mtud(); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); let expected_probed_sizes = &[1326, 1389, 1420, 1452]; assert_eq!(probed_sizes, expected_probed_sizes); assert_eq!(mtud.current_mtu, 1452); assert!(completed(&mtud)); } #[test] fn mtu_discovery_with_1500_limit_and_10000_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(10_000); let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); let expected_probed_sizes = &[ 5600, 5600, 5600, 3399, 3399, 3399, 2299, 2299, 2299, 1749, 1749, 1749, 1474, 1611, 1611, 1611, 1542, 1542, 1542, 1507, 1507, 1507, ]; assert_eq!(probed_sizes, expected_probed_sizes); assert_eq!(mtud.current_mtu, 1474); assert!(completed(&mtud)); } #[test] fn mtu_discovery_no_lost_probes_finds_maximum_udp_payload() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(MAX_UDP_PAYLOAD); let mut mtud = MtuDiscovery::new(1200, 1200, None, config); drive_to_completion(&mut mtud, Instant::now(), u16::MAX); assert_eq!(mtud.current_mtu, 65527); assert!(completed(&mtud)); } #[test] fn mtu_discovery_lost_half_of_probes_finds_maximum_udp_payload() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(MAX_UDP_PAYLOAD); let mut mtud = MtuDiscovery::new(1200, 1200, None, config); let now = Instant::now(); let mut iterations = 0; for i in 1..100 { iterations += 1; let probe_pn = i * 2 - 1; let other_pn = i * 2; let result = mtud.poll_transmit(Instant::now(), probe_pn); if completed(&mtud) { break; } // "Send" next probe assert!(result.is_some()); assert!(mtud.in_flight_mtu_probe().is_some()); // Nothing else to send while the probe is in-flight assert_matches!(mtud.poll_transmit(now, other_pn), None); if i % 2 == 0 { // ACK probe and ensure it results in an increase of current_mtu let previous_max_size = mtud.current_mtu; mtud.on_acked(SpaceId::Data, probe_pn, result.unwrap()); println!( "ACK packet {}. Previous MTU = {previous_max_size}. New MTU = {}", result.unwrap(), mtud.current_mtu ); // assert!(mtud.current_mtu > previous_max_size); } else { mtud.on_probe_lost(); } } assert_eq!(iterations, 25); assert_eq!(mtud.current_mtu, 65527); assert!(completed(&mtud)); } #[test] fn search_state_lower_bound_higher_than_upper_bound_clamps_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(1400); let state = SearchState::new(1500, u16::MAX, &config); assert_eq!(state.lower_bound, 1500); assert_eq!(state.upper_bound, 1500); } #[test] fn search_state_lower_bound_higher_than_peer_max_udp_payload_size_clamps_lower_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9000); let state = SearchState::new(1500, 1300, &config); assert_eq!(state.lower_bound, 1300); assert_eq!(state.upper_bound, 1300); } #[test] fn search_state_upper_bound_higher_than_peer_max_udp_payload_size_clamps_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9000); let state = SearchState::new(1200, 1450, &config); assert_eq!(state.lower_bound, 1200); assert_eq!(state.upper_bound, 1450); } // Loss of packets larger than have been acknowledged should indicate a black hole #[test] fn simple_black_hole_detection() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1300); for i in 0..BLACK_HOLE_THRESHOLD { bhd.on_non_probe_lost(i as u64 * 2, 1400); } // But not before `BLACK_HOLE_THRESHOLD + 1` bursts assert!(!bhd.black_hole_detected()); bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 2, 1400); assert!(bhd.black_hole_detected()); } // Loss of packets followed in transmission order by confirmation of a larger packet should not // indicate a black hole #[test] fn non_suspicious_bursts() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1500); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost(i as u64 * 2, 1400); } assert!(!bhd.black_hole_detected()); } // Loss of packets smaller than have been acknowledged previously should still indicate a black // hole #[test] fn dynamic_mtu_reduction() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked(0, 1500); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost(i as u64 * 2, 1400); } assert!(bhd.black_hole_detected()); } // Bursts containing heterogeneous packets are judged based on the smallest #[test] fn mixed_non_suspicious_bursts() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 3, 1400); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost(i as u64 * 3, 1500); bhd.on_non_probe_lost(i as u64 * 3 + 1, 1300); } assert!(!bhd.black_hole_detected()); } // Multi-packet bursts are only counted once #[test] fn bursts_count_once() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 3, 1400); for i in 0..(BLACK_HOLE_THRESHOLD) { bhd.on_non_probe_lost(i as u64 * 3, 1500); bhd.on_non_probe_lost(i as u64 * 3 + 1, 1500); } assert!(!bhd.black_hole_detected()); bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 3, 1500); assert!(bhd.black_hole_detected()); } // Non-suspicious bursts don't interfere with detection of suspicious bursts #[test] fn interleaved_bursts() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 4, 1400); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost(i as u64 * 4, 1500); bhd.on_non_probe_lost(i as u64 * 4 + 2, 1300); } assert!(bhd.black_hole_detected()); } // Bursts that are non-suspicious before a delivered packet become suspicious past it #[test] fn suspicious_after_acked() { let mut bhd = BlackHoleDetector::new(1200); bhd.on_non_probe_acked((BLACK_HOLE_THRESHOLD + 1) as u64 * 2, 1400); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost(i as u64 * 2, 1300); } assert!( !bhd.black_hole_detected(), "1300 byte losses preceding a 1400 byte delivery are not suspicious" ); for i in 0..(BLACK_HOLE_THRESHOLD + 1) { bhd.on_non_probe_lost((BLACK_HOLE_THRESHOLD as u64 + 1 + i as u64) * 2, 1300); } assert!( bhd.black_hole_detected(), "1300 byte losses following a 1400 byte delivery are suspicious" ); } // Acknowledgment of a packet marks prior loss bursts with the same packet size as // non-suspicious #[test] fn retroactively_non_suspicious() { let mut bhd = BlackHoleDetector::new(1200); for i in 0..BLACK_HOLE_THRESHOLD { bhd.on_non_probe_lost(i as u64 * 2, 1400); } bhd.on_non_probe_acked(BLACK_HOLE_THRESHOLD as u64 * 2, 1400); bhd.on_non_probe_lost(BLACK_HOLE_THRESHOLD as u64 * 2 + 1, 1400); assert!(!bhd.black_hole_detected()); } } quinn-proto-0.11.9/src/connection/pacing.rs000064400000000000000000000234541046102023000167540ustar 00000000000000//! Pacing of packet transmissions. use crate::{Duration, Instant}; use tracing::warn; /// A simple token-bucket pacer /// /// The pacer's capacity is derived on a fraction of the congestion window /// which can be sent in regular intervals /// Once the bucket is empty, further transmission is blocked. /// The bucket refills at a rate slightly faster /// than one congestion window per RTT, as recommended in /// pub(super) struct Pacer { capacity: u64, last_window: u64, last_mtu: u16, tokens: u64, prev: Instant, } impl Pacer { /// Obtains a new [`Pacer`]. pub(super) fn new(smoothed_rtt: Duration, window: u64, mtu: u16, now: Instant) -> Self { let capacity = optimal_capacity(smoothed_rtt, window, mtu); Self { capacity, last_window: window, last_mtu: mtu, tokens: capacity, prev: now, } } /// Record that a packet has been transmitted. pub(super) fn on_transmit(&mut self, packet_length: u16) { self.tokens = self.tokens.saturating_sub(packet_length.into()) } /// Return how long we need to wait before sending `bytes_to_send` /// /// If we can send a packet right away, this returns `None`. Otherwise, returns `Some(d)`, /// where `d` is the time before this function should be called again. /// /// The 5/4 ratio used here comes from the suggestion that N = 1.25 in the draft IETF RFC for /// QUIC. pub(super) fn delay( &mut self, smoothed_rtt: Duration, bytes_to_send: u64, mtu: u16, window: u64, now: Instant, ) -> Option { debug_assert_ne!( window, 0, "zero-sized congestion control window is nonsense" ); if window != self.last_window || mtu != self.last_mtu { self.capacity = optimal_capacity(smoothed_rtt, window, mtu); // Clamp the tokens self.tokens = self.capacity.min(self.tokens); self.last_window = window; self.last_mtu = mtu; } // if we can already send a packet, there is no need for delay if self.tokens >= bytes_to_send { return None; } // we disable pacing for extremely large windows if window > u32::MAX.into() { return None; } let window = window as u32; let time_elapsed = now.checked_duration_since(self.prev).unwrap_or_else(|| { warn!("received a timestamp early than a previous recorded time, ignoring"); Default::default() }); if smoothed_rtt.as_nanos() == 0 { return None; } let elapsed_rtts = time_elapsed.as_secs_f64() / smoothed_rtt.as_secs_f64(); let new_tokens = window as f64 * 1.25 * elapsed_rtts; self.tokens = self .tokens .saturating_add(new_tokens as _) .min(self.capacity); self.prev = now; // if we can already send a packet, there is no need for delay if self.tokens >= bytes_to_send { return None; } let unscaled_delay = smoothed_rtt .checked_mul((bytes_to_send.max(self.capacity) - self.tokens) as _) .unwrap_or_else(|| Duration::new(u64::MAX, 999_999_999)) / window; // divisions come before multiplications to prevent overflow // this is the time at which the pacing window becomes empty Some(self.prev + (unscaled_delay / 5) * 4) } } /// Calculates a pacer capacity for a certain window and RTT /// /// The goal is to emit a burst (of size `capacity`) in timer intervals /// which compromise between /// - ideally distributing datagrams over time /// - constantly waking up the connection to produce additional datagrams /// /// Too short burst intervals means we will never meet them since the timer /// accuracy in user-space is not high enough. If we miss the interval by more /// than 25%, we will lose that part of the congestion window since no additional /// tokens for the extra-elapsed time can be stored. /// /// Too long burst intervals make pacing less effective. fn optimal_capacity(smoothed_rtt: Duration, window: u64, mtu: u16) -> u64 { let rtt = smoothed_rtt.as_nanos().max(1); let capacity = ((window as u128 * BURST_INTERVAL_NANOS) / rtt) as u64; // Small bursts are less efficient (no GSO), could increase latency and don't effectively // use the channel's buffer capacity. Large bursts might block the connection on sending. capacity.clamp(MIN_BURST_SIZE * mtu as u64, MAX_BURST_SIZE * mtu as u64) } /// The burst interval /// /// The capacity will we refilled in 4/5 of that time. /// 2ms is chosen here since framework timers might have 1ms precision. /// If kernel-level pacing is supported later a higher time here might be /// more applicable. const BURST_INTERVAL_NANOS: u128 = 2_000_000; // 2ms /// Allows some usage of GSO, and doesn't slow down the handshake. const MIN_BURST_SIZE: u64 = 10; /// Creating 256 packets took 1ms in a benchmark, so larger bursts don't make sense. const MAX_BURST_SIZE: u64 = 256; #[cfg(test)] mod tests { use super::*; #[test] fn does_not_panic_on_bad_instant() { let old_instant = Instant::now(); let new_instant = old_instant + Duration::from_micros(15); let rtt = Duration::from_micros(400); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 0, 1500, 1, old_instant) .is_none()); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 1600, 1500, 1, old_instant) .is_none()); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 1500, 1500, 3000, old_instant) .is_none()); } #[test] fn derives_initial_capacity() { let window = 2_000_000; let mtu = 1500; let rtt = Duration::from_millis(50); let now = Instant::now(); let pacer = Pacer::new(rtt, window, mtu, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, pacer.capacity); let pacer = Pacer::new(Duration::from_millis(0), window, mtu, now); assert_eq!(pacer.capacity, MAX_BURST_SIZE * mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); let pacer = Pacer::new(rtt, 1, mtu, now); assert_eq!(pacer.capacity, MIN_BURST_SIZE * mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); } #[test] fn adjusts_capacity() { let window = 2_000_000; let mtu = 1500; let rtt = Duration::from_millis(50); let now = Instant::now(); let mut pacer = Pacer::new(rtt, window, mtu, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, pacer.capacity); let initial_tokens = pacer.tokens; pacer.delay(rtt, mtu as u64, mtu, window * 2, now); assert_eq!( pacer.capacity, (2 * window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, initial_tokens); pacer.delay(rtt, mtu as u64, mtu, window / 2, now); assert_eq!( pacer.capacity, (window as u128 / 2 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, initial_tokens / 2); pacer.delay(rtt, mtu as u64, mtu * 2, window, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); pacer.delay(rtt, mtu as u64, 20_000, window, now); assert_eq!(pacer.capacity, 20_000_u64 * MIN_BURST_SIZE); } #[test] fn computes_pause_correctly() { let window = 2_000_000u64; let mtu = 1000; let rtt = Duration::from_millis(50); let old_instant = Instant::now(); let mut pacer = Pacer::new(rtt, window, mtu, old_instant); let packet_capacity = pacer.capacity / mtu as u64; for _ in 0..packet_capacity { assert_eq!( pacer.delay(rtt, mtu as u64, mtu, window, old_instant), None, "When capacity is available packets should be sent immediately" ); pacer.on_transmit(mtu); } let pace_duration = Duration::from_nanos((BURST_INTERVAL_NANOS * 4 / 5) as u64); assert_eq!( pacer .delay(rtt, mtu as u64, mtu, window, old_instant) .expect("Send must be delayed") .duration_since(old_instant), pace_duration ); // Refill half of the tokens assert_eq!( pacer.delay( rtt, mtu as u64, mtu, window, old_instant + pace_duration / 2 ), None ); assert_eq!(pacer.tokens, pacer.capacity / 2); for _ in 0..packet_capacity / 2 { assert_eq!( pacer.delay(rtt, mtu as u64, mtu, window, old_instant), None, "When capacity is available packets should be sent immediately" ); pacer.on_transmit(mtu); } // Refill all capacity by waiting more than the expected duration assert_eq!( pacer.delay( rtt, mtu as u64, mtu, window, old_instant + pace_duration * 3 / 2 ), None ); assert_eq!(pacer.tokens, pacer.capacity); } } quinn-proto-0.11.9/src/connection/packet_builder.rs000064400000000000000000000235201046102023000204620ustar 00000000000000use std::cmp; use bytes::Bytes; use rand::Rng; use tracing::{trace, trace_span}; use super::{spaces::SentPacket, Connection, SentFrames}; use crate::{ frame::{self, Close}, packet::{Header, InitialHeader, LongType, PacketNumber, PartialEncode, SpaceId, FIXED_BIT}, ConnectionId, Instant, TransportError, TransportErrorCode, INITIAL_MTU, }; pub(super) struct PacketBuilder { pub(super) datagram_start: usize, pub(super) space: SpaceId, pub(super) partial_encode: PartialEncode, pub(super) ack_eliciting: bool, pub(super) exact_number: u64, pub(super) short_header: bool, /// Smallest absolute position in the associated buffer that must be occupied by this packet's /// frames pub(super) min_size: usize, /// Largest absolute position in the associated buffer that may be occupied by this packet's /// frames pub(super) max_size: usize, pub(super) tag_len: usize, pub(super) _span: tracing::span::EnteredSpan, } impl PacketBuilder { /// Write a new packet header to `buffer` and determine the packet's properties /// /// Marks the connection drained and returns `None` if the confidentiality limit would be /// violated. pub(super) fn new( now: Instant, space_id: SpaceId, dst_cid: ConnectionId, buffer: &mut Vec, mut buffer_capacity: usize, datagram_start: usize, ack_eliciting: bool, conn: &mut Connection, ) -> Option { let version = conn.version; // Initiate key update if we're approaching the confidentiality limit let sent_with_keys = conn.spaces[space_id].sent_with_keys; if space_id == SpaceId::Data { if sent_with_keys >= conn.key_phase_size { conn.initiate_key_update(); } } else { let confidentiality_limit = conn.spaces[space_id] .crypto .as_ref() .map_or_else( || &conn.zero_rtt_crypto.as_ref().unwrap().packet, |keys| &keys.packet.local, ) .confidentiality_limit(); if sent_with_keys.saturating_add(1) == confidentiality_limit { // We still have time to attempt a graceful close conn.close_inner( now, Close::Connection(frame::ConnectionClose { error_code: TransportErrorCode::AEAD_LIMIT_REACHED, frame_type: None, reason: Bytes::from_static(b"confidentiality limit reached"), }), ) } else if sent_with_keys > confidentiality_limit { // Confidentiality limited violated and there's nothing we can do conn.kill( TransportError::AEAD_LIMIT_REACHED("confidentiality limit reached").into(), ); return None; } } let space = &mut conn.spaces[space_id]; if space.loss_probes != 0 { space.loss_probes -= 1; // Clamp the packet size to at most the minimum MTU to ensure that loss probes can get // through and enable recovery even if the path MTU has shrank unexpectedly. buffer_capacity = cmp::min(buffer_capacity, datagram_start + usize::from(INITIAL_MTU)); } let exact_number = match space_id { SpaceId::Data => conn.packet_number_filter.allocate(&mut conn.rng, space), _ => space.get_tx_number(), }; let span = trace_span!("send", space = ?space_id, pn = exact_number).entered(); let number = PacketNumber::new(exact_number, space.largest_acked_packet.unwrap_or(0)); let header = match space_id { SpaceId::Data if space.crypto.is_some() => Header::Short { dst_cid, number, spin: if conn.spin_enabled { conn.spin } else { conn.rng.gen() }, key_phase: conn.key_phase, }, SpaceId::Data => Header::Long { ty: LongType::ZeroRtt, src_cid: conn.handshake_cid, dst_cid, number, version, }, SpaceId::Handshake => Header::Long { ty: LongType::Handshake, src_cid: conn.handshake_cid, dst_cid, number, version, }, SpaceId::Initial => Header::Initial(InitialHeader { src_cid: conn.handshake_cid, dst_cid, token: conn.retry_token.clone(), number, version, }), }; let partial_encode = header.encode(buffer); if conn.peer_params.grease_quic_bit && conn.rng.gen() { buffer[partial_encode.start] ^= FIXED_BIT; } let (sample_size, tag_len) = if let Some(ref crypto) = space.crypto { ( crypto.header.local.sample_size(), crypto.packet.local.tag_len(), ) } else if space_id == SpaceId::Data { let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap(); (zero_rtt.header.sample_size(), zero_rtt.packet.tag_len()) } else { unreachable!(); }; // Each packet must be large enough for header protection sampling, i.e. the combined // lengths of the encoded packet number and protected payload must be at least 4 bytes // longer than the sample required for header protection. Further, each packet should be at // least tag_len + 6 bytes larger than the destination CID on incoming packets so that the // peer may send stateless resets that are indistinguishable from regular traffic. // pn_len + payload_len + tag_len >= sample_size + 4 // payload_len >= sample_size + 4 - pn_len - tag_len let min_size = Ord::max( buffer.len() + (sample_size + 4).saturating_sub(number.len() + tag_len), partial_encode.start + dst_cid.len() + 6, ); let max_size = buffer_capacity - tag_len; debug_assert!(max_size >= min_size); Some(Self { datagram_start, space: space_id, partial_encode, exact_number, short_header: header.is_short(), min_size, max_size, tag_len, ack_eliciting, _span: span, }) } /// Append the minimum amount of padding to the packet such that, after encryption, the /// enclosing datagram will occupy at least `min_size` bytes pub(super) fn pad_to(&mut self, min_size: u16) { // The datagram might already have a larger minimum size than the caller is requesting, if // e.g. we're coalescing packets and have populated more than `min_size` bytes with packets // already. self.min_size = Ord::max( self.min_size, self.datagram_start + (min_size as usize) - self.tag_len, ); } pub(super) fn finish_and_track( self, now: Instant, conn: &mut Connection, sent: Option, buffer: &mut Vec, ) { let ack_eliciting = self.ack_eliciting; let exact_number = self.exact_number; let space_id = self.space; let (size, padded) = self.finish(conn, buffer); let sent = match sent { Some(sent) => sent, None => return, }; let size = match padded || ack_eliciting { true => size as u16, false => 0, }; let packet = SentPacket { largest_acked: sent.largest_acked, time_sent: now, size, ack_eliciting, retransmits: sent.retransmits, stream_frames: sent.stream_frames, }; conn.path .sent(exact_number, packet, &mut conn.spaces[space_id]); conn.stats.path.sent_packets += 1; conn.reset_keep_alive(now); if size != 0 { if ack_eliciting { conn.spaces[space_id].time_of_last_ack_eliciting_packet = Some(now); if conn.permit_idle_reset { conn.reset_idle_timeout(now, space_id); } conn.permit_idle_reset = false; } conn.set_loss_detection_timer(now); conn.path.pacing.on_transmit(size); } } /// Encrypt packet, returning the length of the packet and whether padding was added pub(super) fn finish(self, conn: &mut Connection, buffer: &mut Vec) -> (usize, bool) { let pad = buffer.len() < self.min_size; if pad { trace!("PADDING * {}", self.min_size - buffer.len()); buffer.resize(self.min_size, 0); } let space = &conn.spaces[self.space]; let (header_crypto, packet_crypto) = if let Some(ref crypto) = space.crypto { (&*crypto.header.local, &*crypto.packet.local) } else if self.space == SpaceId::Data { let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap(); (&*zero_rtt.header, &*zero_rtt.packet) } else { unreachable!("tried to send {:?} packet without keys", self.space); }; debug_assert_eq!( packet_crypto.tag_len(), self.tag_len, "Mismatching crypto tag len" ); buffer.resize(buffer.len() + packet_crypto.tag_len(), 0); let encode_start = self.partial_encode.start; let packet_buf = &mut buffer[encode_start..]; self.partial_encode.finish( packet_buf, header_crypto, Some((self.exact_number, packet_crypto)), ); (buffer.len() - encode_start, pad) } } quinn-proto-0.11.9/src/connection/packet_crypto.rs000064400000000000000000000141661046102023000203620ustar 00000000000000use tracing::{debug, trace}; use crate::connection::spaces::PacketSpace; use crate::crypto::{HeaderKey, KeyPair, PacketKey}; use crate::packet::{Packet, PartialDecode, SpaceId}; use crate::token::ResetToken; use crate::Instant; use crate::{TransportError, RESET_TOKEN_SIZE}; /// Removes header protection of a packet, or returns `None` if the packet was dropped pub(super) fn unprotect_header( partial_decode: PartialDecode, spaces: &[PacketSpace; 3], zero_rtt_crypto: Option<&ZeroRttCrypto>, stateless_reset_token: Option, ) -> Option { let header_crypto = if partial_decode.is_0rtt() { if let Some(crypto) = zero_rtt_crypto { Some(&*crypto.header) } else { debug!("dropping unexpected 0-RTT packet"); return None; } } else if let Some(space) = partial_decode.space() { if let Some(ref crypto) = spaces[space].crypto { Some(&*crypto.header.remote) } else { debug!( "discarding unexpected {:?} packet ({} bytes)", space, partial_decode.len(), ); return None; } } else { // Unprotected packet None }; let packet = partial_decode.data(); let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5 && stateless_reset_token.as_deref() == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]); match partial_decode.finish(header_crypto) { Ok(packet) => Some(UnprotectHeaderResult { packet: Some(packet), stateless_reset, }), Err(_) if stateless_reset => Some(UnprotectHeaderResult { packet: None, stateless_reset: true, }), Err(e) => { trace!("unable to complete packet decoding: {}", e); None } } } pub(super) struct UnprotectHeaderResult { /// The packet with the now unprotected header (`None` in the case of stateless reset packets /// that fail to be decoded) pub(super) packet: Option, /// Whether the packet was a stateless reset packet pub(super) stateless_reset: bool, } /// Decrypts a packet's body in-place pub(super) fn decrypt_packet_body( packet: &mut Packet, spaces: &[PacketSpace; 3], zero_rtt_crypto: Option<&ZeroRttCrypto>, conn_key_phase: bool, prev_crypto: Option<&PrevCrypto>, next_crypto: Option<&KeyPair>>, ) -> Result, Option> { if !packet.header.is_protected() { // Unprotected packets also don't have packet numbers return Ok(None); } let space = packet.header.space(); let rx_packet = spaces[space].rx_packet; let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); let packet_key_phase = packet.header.key_phase(); let mut crypto_update = false; let crypto = if packet.header.is_0rtt() { &zero_rtt_crypto.unwrap().packet } else if packet_key_phase == conn_key_phase || space != SpaceId::Data { &spaces[space].crypto.as_ref().unwrap().packet.remote } else if let Some(prev) = prev_crypto.and_then(|crypto| { // If this packet comes prior to acknowledgment of the key update by the peer, if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { // use the previous keys. Some(crypto) } else { // Otherwise, this must be a remotely-initiated key update, so fall through to the // final case. None } }) { &prev.crypto.remote } else { // We're in the Data space with a key phase mismatch and either there is no locally // initiated key update or the locally initiated key update was acknowledged by a // lower-numbered packet. The key phase mismatch must therefore represent a new // remotely-initiated key update. crypto_update = true; &next_crypto.unwrap().remote }; crypto .decrypt(number, &packet.header_data, &mut packet.payload) .map_err(|_| { trace!("decryption failed with packet number {}", number); None })?; if !packet.reserved_bits_valid() { return Err(Some(TransportError::PROTOCOL_VIOLATION( "reserved bits set", ))); } let mut outgoing_key_update_acked = false; if let Some(prev) = prev_crypto { if prev.end_packet.is_none() && packet_key_phase == conn_key_phase { outgoing_key_update_acked = true; } } if crypto_update { // Validate incoming key update if number <= rx_packet || prev_crypto.map_or(false, |x| x.update_unacked) { return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); } } Ok(Some(DecryptPacketResult { number, outgoing_key_update_acked, incoming_key_update: crypto_update, })) } pub(super) struct DecryptPacketResult { /// The packet number pub(super) number: u64, /// Whether a locally initiated key update has been acknowledged by the peer pub(super) outgoing_key_update_acked: bool, /// Whether the peer has initiated a key update pub(super) incoming_key_update: bool, } pub(super) struct PrevCrypto { /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by /// the peer prior to its own key update. pub(super) crypto: KeyPair>, /// The incoming packet that ends the interval for which these keys are applicable, and the time /// of its receipt. /// /// Incoming packets should be decrypted using these keys iff this is `None` or their packet /// number is lower. `None` indicates that we have not yet received a packet using newer keys, /// which implies that the update was locally initiated. pub(super) end_packet: Option<(u64, Instant)>, /// Whether the following key phase is from a remotely initiated update that we haven't acked pub(super) update_unacked: bool, } pub(super) struct ZeroRttCrypto { pub(super) header: Box, pub(super) packet: Box, } quinn-proto-0.11.9/src/connection/paths.rs000064400000000000000000000267171046102023000166370ustar 00000000000000use std::{cmp, net::SocketAddr}; use tracing::trace; use super::{ mtud::MtuDiscovery, pacing::Pacer, spaces::{PacketSpace, SentPacket}, }; use crate::{congestion, packet::SpaceId, Duration, Instant, TransportConfig, TIMER_GRANULARITY}; /// Description of a particular network path pub(super) struct PathData { pub(super) remote: SocketAddr, pub(super) rtt: RttEstimator, /// Whether we're enabling ECN on outgoing packets pub(super) sending_ecn: bool, /// Congestion controller state pub(super) congestion: Box, /// Pacing state pub(super) pacing: Pacer, pub(super) challenge: Option, pub(super) challenge_pending: bool, /// Whether we're certain the peer can both send and receive on this address /// /// Initially equal to `use_stateless_retry` for servers, and becomes false again on every /// migration. Always true for clients. pub(super) validated: bool, /// Total size of all UDP datagrams sent on this path pub(super) total_sent: u64, /// Total size of all UDP datagrams received on this path pub(super) total_recvd: u64, /// The state of the MTU discovery process pub(super) mtud: MtuDiscovery, /// Packet number of the first packet sent after an RTT sample was collected on this path /// /// Used in persistent congestion determination. pub(super) first_packet_after_rtt_sample: Option<(SpaceId, u64)>, pub(super) in_flight: InFlight, /// Number of the first packet sent on this path /// /// Used to determine whether a packet was sent on an earlier path. Insufficient to determine if /// a packet was sent on a later path. first_packet: Option, } impl PathData { pub(super) fn new( remote: SocketAddr, allow_mtud: bool, peer_max_udp_payload_size: Option, now: Instant, validated: bool, config: &TransportConfig, ) -> Self { let congestion = config .congestion_controller_factory .clone() .build(now, config.get_initial_mtu()); Self { remote, rtt: RttEstimator::new(config.initial_rtt), sending_ecn: true, pacing: Pacer::new( config.initial_rtt, congestion.initial_window(), config.get_initial_mtu(), now, ), congestion, challenge: None, challenge_pending: false, validated, total_sent: 0, total_recvd: 0, mtud: config .mtu_discovery_config .as_ref() .filter(|_| allow_mtud) .map_or( MtuDiscovery::disabled(config.get_initial_mtu(), config.min_mtu), |mtud_config| { MtuDiscovery::new( config.get_initial_mtu(), config.min_mtu, peer_max_udp_payload_size, mtud_config.clone(), ) }, ), first_packet_after_rtt_sample: None, in_flight: InFlight::new(), first_packet: None, } } pub(super) fn from_previous(remote: SocketAddr, prev: &Self, now: Instant) -> Self { let congestion = prev.congestion.clone_box(); let smoothed_rtt = prev.rtt.get(); Self { remote, rtt: prev.rtt, pacing: Pacer::new(smoothed_rtt, congestion.window(), prev.current_mtu(), now), sending_ecn: true, congestion, challenge: None, challenge_pending: false, validated: false, total_sent: 0, total_recvd: 0, mtud: prev.mtud.clone(), first_packet_after_rtt_sample: prev.first_packet_after_rtt_sample, in_flight: InFlight::new(), first_packet: None, } } /// Resets RTT, congestion control and MTU states. /// /// This is useful when it is known the underlying path has changed. pub(super) fn reset(&mut self, now: Instant, config: &TransportConfig) { self.rtt = RttEstimator::new(config.initial_rtt); self.congestion = config .congestion_controller_factory .clone() .build(now, config.get_initial_mtu()); self.mtud.reset(config.get_initial_mtu(), config.min_mtu); } /// Indicates whether we're a server that hasn't validated the peer's address and hasn't /// received enough data from the peer to permit sending `bytes_to_send` additional bytes pub(super) fn anti_amplification_blocked(&self, bytes_to_send: u64) -> bool { !self.validated && self.total_recvd * 3 < self.total_sent + bytes_to_send } /// Returns the path's current MTU pub(super) fn current_mtu(&self) -> u16 { self.mtud.current_mtu() } /// Account for transmission of `packet` with number `pn` in `space` pub(super) fn sent(&mut self, pn: u64, packet: SentPacket, space: &mut PacketSpace) { self.in_flight.insert(&packet); if self.first_packet.is_none() { self.first_packet = Some(pn); } self.in_flight.bytes -= space.sent(pn, packet); } /// Remove `packet` with number `pn` from this path's congestion control counters, or return /// `false` if `pn` was sent before this path was established. pub(super) fn remove_in_flight(&mut self, pn: u64, packet: &SentPacket) -> bool { if self.first_packet.map_or(true, |first| first > pn) { return false; } self.in_flight.remove(packet); true } } /// RTT estimation for a particular network path #[derive(Copy, Clone)] pub struct RttEstimator { /// The most recent RTT measurement made when receiving an ack for a previously unacked packet latest: Duration, /// The smoothed RTT of the connection, computed as described in RFC6298 smoothed: Option, /// The RTT variance, computed as described in RFC6298 var: Duration, /// The minimum RTT seen in the connection, ignoring ack delay. min: Duration, } impl RttEstimator { fn new(initial_rtt: Duration) -> Self { Self { latest: initial_rtt, smoothed: None, var: initial_rtt / 2, min: initial_rtt, } } /// The current best RTT estimation. pub fn get(&self) -> Duration { self.smoothed.unwrap_or(self.latest) } /// Conservative estimate of RTT /// /// Takes the maximum of smoothed and latest RTT, as recommended /// in 6.1.2 of the recovery spec (draft 29). pub fn conservative(&self) -> Duration { self.get().max(self.latest) } /// Minimum RTT registered so far for this estimator. pub fn min(&self) -> Duration { self.min } // PTO computed as described in RFC9002#6.2.1 pub(crate) fn pto_base(&self) -> Duration { self.get() + cmp::max(4 * self.var, TIMER_GRANULARITY) } pub(crate) fn update(&mut self, ack_delay: Duration, rtt: Duration) { self.latest = rtt; // min_rtt ignores ack delay. self.min = cmp::min(self.min, self.latest); // Based on RFC6298. if let Some(smoothed) = self.smoothed { let adjusted_rtt = if self.min + ack_delay <= self.latest { self.latest - ack_delay } else { self.latest }; let var_sample = if smoothed > adjusted_rtt { smoothed - adjusted_rtt } else { adjusted_rtt - smoothed }; self.var = (3 * self.var + var_sample) / 4; self.smoothed = Some((7 * smoothed + adjusted_rtt) / 8); } else { self.smoothed = Some(self.latest); self.var = self.latest / 2; self.min = self.latest; } } } #[derive(Default)] pub(crate) struct PathResponses { pending: Vec, } impl PathResponses { pub(crate) fn push(&mut self, packet: u64, token: u64, remote: SocketAddr) { /// Arbitrary permissive limit to prevent abuse const MAX_PATH_RESPONSES: usize = 16; let response = PathResponse { packet, token, remote, }; let existing = self.pending.iter_mut().find(|x| x.remote == remote); if let Some(existing) = existing { // Update a queued response if existing.packet <= packet { *existing = response; } return; } if self.pending.len() < MAX_PATH_RESPONSES { self.pending.push(response); } else { // We don't expect to ever hit this with well-behaved peers, so we don't bother dropping // older challenges. trace!("ignoring excessive PATH_CHALLENGE"); } } pub(crate) fn pop_off_path(&mut self, remote: &SocketAddr) -> Option<(u64, SocketAddr)> { let response = *self.pending.last()?; if response.remote == *remote { // We don't bother searching further because we expect that the on-path response will // get drained in the immediate future by a call to `pop_on_path` return None; } self.pending.pop(); Some((response.token, response.remote)) } pub(crate) fn pop_on_path(&mut self, remote: &SocketAddr) -> Option { let response = *self.pending.last()?; if response.remote != *remote { // We don't bother searching further because we expect that the off-path response will // get drained in the immediate future by a call to `pop_off_path` return None; } self.pending.pop(); Some(response.token) } pub(crate) fn is_empty(&self) -> bool { self.pending.is_empty() } } #[derive(Copy, Clone)] struct PathResponse { /// The packet number the corresponding PATH_CHALLENGE was received in packet: u64, token: u64, /// The address the corresponding PATH_CHALLENGE was received from remote: SocketAddr, } /// Summary statistics of packets that have been sent on a particular path, but which have not yet /// been acked or deemed lost pub(super) struct InFlight { /// Sum of the sizes of all sent packets considered "in flight" by congestion control /// /// The size does not include IP or UDP overhead. Packets only containing ACK frames do not /// count towards this to ensure congestion control does not impede congestion feedback. pub(super) bytes: u64, /// Number of packets in flight containing frames other than ACK and PADDING /// /// This can be 0 even when bytes is not 0 because PADDING frames cause a packet to be /// considered "in flight" by congestion control. However, if this is nonzero, bytes will always /// also be nonzero. pub(super) ack_eliciting: u64, } impl InFlight { fn new() -> Self { Self { bytes: 0, ack_eliciting: 0, } } fn insert(&mut self, packet: &SentPacket) { self.bytes += u64::from(packet.size); self.ack_eliciting += u64::from(packet.ack_eliciting); } /// Update counters to account for a packet becoming acknowledged, lost, or abandoned fn remove(&mut self, packet: &SentPacket) { self.bytes -= u64::from(packet.size); self.ack_eliciting -= u64::from(packet.ack_eliciting); } } quinn-proto-0.11.9/src/connection/send_buffer.rs000064400000000000000000000327331046102023000177750ustar 00000000000000use std::{collections::VecDeque, ops::Range}; use bytes::{Buf, Bytes}; use crate::{range_set::RangeSet, VarInt}; /// Buffer of outgoing retransmittable stream data #[derive(Default, Debug)] pub(super) struct SendBuffer { /// Data queued by the application but not yet acknowledged. May or may not have been sent. unacked_segments: VecDeque, /// Total size of `unacked_segments` unacked_len: usize, /// The first offset that hasn't been written by the application, i.e. the offset past the end of `unacked` offset: u64, /// The first offset that hasn't been sent /// /// Always lies in (offset - unacked.len())..offset unsent: u64, /// Acknowledged ranges which couldn't be discarded yet as they don't include the earliest /// offset in `unacked` // TODO: Recover storage from these by compacting (#700) acks: RangeSet, /// Previously transmitted ranges deemed lost retransmits: RangeSet, } impl SendBuffer { /// Construct an empty buffer at the initial offset pub(super) fn new() -> Self { Self::default() } /// Append application data to the end of the stream pub(super) fn write(&mut self, data: Bytes) { self.unacked_len += data.len(); self.offset += data.len() as u64; self.unacked_segments.push_back(data); } /// Discard a range of acknowledged stream data pub(super) fn ack(&mut self, mut range: Range) { // Clamp the range to data which is still tracked let base_offset = self.offset - self.unacked_len as u64; range.start = base_offset.max(range.start); range.end = base_offset.max(range.end); self.acks.insert(range); while self.acks.min() == Some(self.offset - self.unacked_len as u64) { let prefix = self.acks.pop_min().unwrap(); let mut to_advance = (prefix.end - prefix.start) as usize; self.unacked_len -= to_advance; while to_advance > 0 { let front = self .unacked_segments .front_mut() .expect("Expected buffered data"); if front.len() <= to_advance { to_advance -= front.len(); self.unacked_segments.pop_front(); if self.unacked_segments.len() * 4 < self.unacked_segments.capacity() { self.unacked_segments.shrink_to_fit(); } } else { front.advance(to_advance); to_advance = 0; } } } } /// Compute the next range to transmit on this stream and update state to account for that /// transmission. /// /// `max_len` here includes the space which is available to transmit the /// offset and length of the data to send. The caller has to guarantee that /// there is at least enough space available to write maximum-sized metadata /// (8 byte offset + 8 byte length). /// /// The method returns a tuple: /// - The first return value indicates the range of data to send /// - The second return value indicates whether the length needs to be encoded /// in the STREAM frames metadata (`true`), or whether it can be omitted /// since the selected range will fill the whole packet. pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range, bool) { debug_assert!(max_len >= 8 + 8); let mut encode_length = false; if let Some(range) = self.retransmits.pop_min() { // Retransmit sent data // When the offset is known, we know how many bytes are required to encode it. // Offset 0 requires no space if range.start != 0 { max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) }); } if range.end - range.start < max_len as u64 { encode_length = true; max_len -= 8; } let end = range.end.min((max_len as u64).saturating_add(range.start)); if end != range.end { self.retransmits.insert(end..range.end); } return (range.start..end, encode_length); } // Transmit new data // When the offset is known, we know how many bytes are required to encode it. // Offset 0 requires no space if self.unsent != 0 { max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) }); } if self.offset - self.unsent < max_len as u64 { encode_length = true; max_len -= 8; } let end = self .offset .min((max_len as u64).saturating_add(self.unsent)); let result = self.unsent..end; self.unsent = end; (result, encode_length) } /// Returns data which is associated with a range /// /// This function can return a subset of the range, if the data is stored /// in noncontiguous fashion in the send buffer. In this case callers /// should call the function again with an incremented start offset to /// retrieve more data. pub(super) fn get(&self, offsets: Range) -> &[u8] { let base_offset = self.offset - self.unacked_len as u64; let mut segment_offset = base_offset; for segment in self.unacked_segments.iter() { if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() as u64 { let start = (offsets.start - segment_offset) as usize; let end = (offsets.end - segment_offset) as usize; return &segment[start..end.min(segment.len())]; } segment_offset += segment.len() as u64; } &[] } /// Queue a range of sent but unacknowledged data to be retransmitted pub(super) fn retransmit(&mut self, range: Range) { debug_assert!(range.end <= self.unsent, "unsent data can't be lost"); self.retransmits.insert(range); } pub(super) fn retransmit_all_for_0rtt(&mut self) { debug_assert_eq!(self.offset, self.unacked_len as u64); self.unsent = 0; } /// First stream offset unwritten by the application, i.e. the offset that the next write will /// begin at pub(super) fn offset(&self) -> u64 { self.offset } /// Whether all sent data has been acknowledged pub(super) fn is_fully_acked(&self) -> bool { self.unacked_len == 0 } /// Whether there's data to send /// /// There may be sent unacknowledged data even when this is false. pub(super) fn has_unsent_data(&self) -> bool { self.unsent != self.offset || !self.retransmits.is_empty() } /// Compute the amount of data that hasn't been acknowledged pub(super) fn unacked(&self) -> u64 { self.unacked_len as u64 - self.acks.iter().map(|x| x.end - x.start).sum::() } } #[cfg(test)] mod tests { use super::*; #[test] fn fragment_with_length() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; buf.write(MSG.into()); // 0 byte offset => 19 bytes left => 13 byte data isn't enough // with 8 bytes reserved for length 11 payload bytes will fit assert_eq!(buf.poll_transmit(19), (0..11, true)); assert_eq!( buf.poll_transmit(MSG.len() + 16 - 11), (11..MSG.len() as u64, true) ); assert_eq!( buf.poll_transmit(58), (MSG.len() as u64..MSG.len() as u64, true) ); } #[test] fn fragment_without_length() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with some extra data!"; buf.write(MSG.into()); // 0 byte offset => 19 bytes left => can be filled by 34 bytes payload assert_eq!(buf.poll_transmit(19), (0..19, false)); assert_eq!( buf.poll_transmit(MSG.len() - 19 + 1), (19..MSG.len() as u64, false) ); assert_eq!( buf.poll_transmit(58), (MSG.len() as u64..MSG.len() as u64, true) ); } #[test] fn reserves_encoded_offset() { let mut buf = SendBuffer::new(); // Pretend we have more than 1 GB of data in the buffer let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]); for _ in 0..1025 { buf.write(chunk.clone()); } const SIZE1: u64 = 64; const SIZE2: u64 = 16 * 1024; const SIZE3: u64 = 1024 * 1024 * 1024; // Offset 0 requires no space assert_eq!(buf.poll_transmit(16), (0..16, false)); buf.retransmit(0..16); assert_eq!(buf.poll_transmit(16), (0..16, false)); let mut transmitted = 16u64; // Offset 16 requires 1 byte assert_eq!( buf.poll_transmit((SIZE1 - transmitted + 1) as usize), (transmitted..SIZE1, false) ); buf.retransmit(transmitted..SIZE1); assert_eq!( buf.poll_transmit((SIZE1 - transmitted + 1) as usize), (transmitted..SIZE1, false) ); transmitted = SIZE1; // Offset 64 requires 2 bytes assert_eq!( buf.poll_transmit((SIZE2 - transmitted + 2) as usize), (transmitted..SIZE2, false) ); buf.retransmit(transmitted..SIZE2); assert_eq!( buf.poll_transmit((SIZE2 - transmitted + 2) as usize), (transmitted..SIZE2, false) ); transmitted = SIZE2; // Offset 16384 requires requires 4 bytes assert_eq!( buf.poll_transmit((SIZE3 - transmitted + 4) as usize), (transmitted..SIZE3, false) ); buf.retransmit(transmitted..SIZE3); assert_eq!( buf.poll_transmit((SIZE3 - transmitted + 4) as usize), (transmitted..SIZE3, false) ); transmitted = SIZE3; // Offset 1GB requires 8 bytes assert_eq!( buf.poll_transmit(chunk.len() + 8), (transmitted..transmitted + chunk.len() as u64, false) ); buf.retransmit(transmitted..transmitted + chunk.len() as u64); assert_eq!( buf.poll_transmit(chunk.len() + 8), (transmitted..transmitted + chunk.len() as u64, false) ); } #[test] fn multiple_segments() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; const MSG_LEN: u64 = MSG.len() as u64; const SEG1: &[u8] = b"He"; buf.write(SEG1.into()); const SEG2: &[u8] = b"llo,"; buf.write(SEG2.into()); const SEG3: &[u8] = b" w"; buf.write(SEG3.into()); const SEG4: &[u8] = b"o"; buf.write(SEG4.into()); const SEG5: &[u8] = b"rld!"; buf.write(SEG5.into()); assert_eq!(aggregate_unacked(&buf), MSG); assert_eq!(buf.poll_transmit(16), (0..8, true)); assert_eq!(buf.get(0..5), SEG1); assert_eq!(buf.get(2..8), SEG2); assert_eq!(buf.get(6..8), SEG3); assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true)); assert_eq!(buf.get(8..MSG_LEN), SEG4); assert_eq!(buf.get(9..MSG_LEN), SEG5); assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true)); // Now drain the segments buf.ack(0..1); assert_eq!(aggregate_unacked(&buf), &MSG[1..]); buf.ack(0..3); assert_eq!(aggregate_unacked(&buf), &MSG[3..]); buf.ack(3..5); assert_eq!(aggregate_unacked(&buf), &MSG[5..]); buf.ack(7..9); assert_eq!(aggregate_unacked(&buf), &MSG[5..]); buf.ack(4..7); assert_eq!(aggregate_unacked(&buf), &MSG[9..]); buf.ack(0..MSG_LEN); assert_eq!(aggregate_unacked(&buf), &[]); } #[test] fn retransmit() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with extra data!"; buf.write(MSG.into()); // Transmit two frames assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (16..23, true)); // Lose the first, but not the second buf.retransmit(0..16); // Ensure we only retransmit the lost frame, then continue sending fresh data assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true)); // Lose the second frame buf.retransmit(16..23); assert_eq!(buf.poll_transmit(16), (16..23, true)); } #[test] fn ack() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; buf.write(MSG.into()); assert_eq!(buf.poll_transmit(16), (0..8, true)); buf.ack(0..8); assert_eq!(aggregate_unacked(&buf), &MSG[8..]); } #[test] fn reordered_ack() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with extra data!"; buf.write(MSG.into()); assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (16..23, true)); buf.ack(16..23); assert_eq!(aggregate_unacked(&buf), MSG); buf.ack(0..16); assert_eq!(aggregate_unacked(&buf), &MSG[23..]); assert!(buf.acks.is_empty()); } fn aggregate_unacked(buf: &SendBuffer) -> Vec { let mut result = Vec::new(); for segment in buf.unacked_segments.iter() { result.extend_from_slice(&segment[..]); } result } } quinn-proto-0.11.9/src/connection/spaces.rs000064400000000000000000001165151046102023000167720ustar 00000000000000use std::{ cmp, collections::{BTreeMap, VecDeque}, mem, ops::{Bound, Index, IndexMut}, }; use rand::Rng; use rustc_hash::FxHashSet; use tracing::trace; use super::assembler::Assembler; use crate::{ connection::StreamsState, crypto::Keys, frame, packet::SpaceId, range_set::ArrayRangeSet, shared::IssuedCid, Dir, Duration, Instant, StreamId, TransportError, VarInt, }; pub(super) struct PacketSpace { pub(super) crypto: Option, pub(super) dedup: Dedup, /// Highest received packet number pub(super) rx_packet: u64, /// Data to send pub(super) pending: Retransmits, /// Packet numbers to acknowledge pub(super) pending_acks: PendingAcks, /// The packet number of the next packet that will be sent, if any. In the Data space, the /// packet number stored here is sometimes skipped by [`PacketNumberFilter`] logic. pub(super) next_packet_number: u64, /// The largest packet number the remote peer acknowledged in an ACK frame. pub(super) largest_acked_packet: Option, pub(super) largest_acked_packet_sent: Instant, /// The highest-numbered ACK-eliciting packet we've sent pub(super) largest_ack_eliciting_sent: u64, /// Number of packets in `sent_packets` with numbers above `largest_ack_eliciting_sent` pub(super) unacked_non_ack_eliciting_tail: u64, /// Transmitted but not acked // We use a BTreeMap here so we can efficiently query by range on ACK and for loss detection pub(super) sent_packets: BTreeMap, /// Number of explicit congestion notification codepoints seen on incoming packets pub(super) ecn_counters: frame::EcnCounts, /// Recent ECN counters sent by the peer in ACK frames /// /// Updated (and inspected) whenever we receive an ACK with a new highest acked packet /// number. Stored per-space to simplify verification, which would otherwise have difficulty /// distinguishing between ECN bleaching and counts having been updated by a near-simultaneous /// ACK already processed in another space. pub(super) ecn_feedback: frame::EcnCounts, /// Incoming cryptographic handshake stream pub(super) crypto_stream: Assembler, /// Current offset of outgoing cryptographic handshake stream pub(super) crypto_offset: u64, /// The time the most recently sent retransmittable packet was sent. pub(super) time_of_last_ack_eliciting_packet: Option, /// The time at which the earliest sent packet in this space will be considered lost based on /// exceeding the reordering window in time. Only set for packets numbered prior to a packet /// that has been acknowledged. pub(super) loss_time: Option, /// Number of tail loss probes to send pub(super) loss_probes: u32, pub(super) ping_pending: bool, pub(super) immediate_ack_pending: bool, /// Number of congestion control "in flight" bytes pub(super) in_flight: u64, /// Number of packets sent in the current key phase pub(super) sent_with_keys: u64, } impl PacketSpace { pub(super) fn new(now: Instant) -> Self { Self { crypto: None, dedup: Dedup::new(), rx_packet: 0, pending: Retransmits::default(), pending_acks: PendingAcks::new(), next_packet_number: 0, largest_acked_packet: None, largest_acked_packet_sent: now, largest_ack_eliciting_sent: 0, unacked_non_ack_eliciting_tail: 0, sent_packets: BTreeMap::new(), ecn_counters: frame::EcnCounts::ZERO, ecn_feedback: frame::EcnCounts::ZERO, crypto_stream: Assembler::new(), crypto_offset: 0, time_of_last_ack_eliciting_packet: None, loss_time: None, loss_probes: 0, ping_pending: false, immediate_ack_pending: false, in_flight: 0, sent_with_keys: 0, } } /// Queue data for a tail loss probe (or anti-amplification deadlock prevention) packet /// /// Probes are sent similarly to normal packets when an expect ACK has not arrived. We never /// deem a packet lost until we receive an ACK that should have included it, but if a trailing /// run of packets (or their ACKs) are lost, this might not happen in a timely fashion. We send /// probe packets to force an ACK, and exempt them from congestion control to prevent a deadlock /// when the congestion window is filled with lost tail packets. /// /// We prefer to send new data, to make the most efficient use of bandwidth. If there's no data /// waiting to be sent, then we retransmit in-flight data to reduce odds of loss. If there's no /// in-flight data either, we're probably a client guarding against a handshake /// anti-amplification deadlock and we just make something up. pub(super) fn maybe_queue_probe( &mut self, request_immediate_ack: bool, streams: &StreamsState, ) { if self.loss_probes == 0 { return; } if request_immediate_ack { // The probe should be ACKed without delay (should only be used in the Data space and // when the peer supports the acknowledgement frequency extension) self.immediate_ack_pending = true; } // Retransmit the data of the oldest in-flight packet if !self.pending.is_empty(streams) { // There's real data to send here, no need to make something up return; } for packet in self.sent_packets.values_mut() { if !packet.retransmits.is_empty(streams) { // Remove retransmitted data from the old packet so we don't end up retransmitting // it *again* even if the copy we're sending now gets acknowledged. self.pending |= mem::take(&mut packet.retransmits); return; } } // Nothing new to send and nothing to retransmit, so fall back on a ping. This should only // happen in rare cases during the handshake when the server becomes blocked by // anti-amplification. self.ping_pending = true; } /// Get the next outgoing packet number in this space /// /// In the Data space, the connection's [`PacketNumberFilter`] must be used rather than calling /// this directly. pub(super) fn get_tx_number(&mut self) -> u64 { // TODO: Handle packet number overflow gracefully assert!(self.next_packet_number < 2u64.pow(62)); let x = self.next_packet_number; self.next_packet_number += 1; self.sent_with_keys += 1; x } pub(super) fn can_send(&self, streams: &StreamsState) -> SendableFrames { let acks = self.pending_acks.can_send(); let other = !self.pending.is_empty(streams) || self.ping_pending || self.immediate_ack_pending; SendableFrames { acks, other } } /// Verifies sanity of an ECN block and returns whether congestion was encountered. pub(super) fn detect_ecn( &mut self, newly_acked: u64, ecn: frame::EcnCounts, ) -> Result { let ect0_increase = ecn .ect0 .checked_sub(self.ecn_feedback.ect0) .ok_or("peer ECT(0) count regression")?; let ect1_increase = ecn .ect1 .checked_sub(self.ecn_feedback.ect1) .ok_or("peer ECT(1) count regression")?; let ce_increase = ecn .ce .checked_sub(self.ecn_feedback.ce) .ok_or("peer CE count regression")?; let total_increase = ect0_increase + ect1_increase + ce_increase; if total_increase < newly_acked { return Err("ECN bleaching"); } if (ect0_increase + ce_increase) < newly_acked || ect1_increase != 0 { return Err("ECN corruption"); } // If total_increase > newly_acked (which happens when ACKs are lost), this is required by // the draft so that long-term drift does not occur. If =, then the only question is whether // to count CE packets as CE or ECT0. Recording them as CE is more consistent and keeps the // congestion check obvious. self.ecn_feedback = ecn; Ok(ce_increase != 0) } /// Stop tracking sent packet `number`, and return what we knew about it pub(super) fn take(&mut self, number: u64) -> Option { let packet = self.sent_packets.remove(&number)?; self.in_flight -= u64::from(packet.size); if !packet.ack_eliciting && number > self.largest_ack_eliciting_sent { self.unacked_non_ack_eliciting_tail = self.unacked_non_ack_eliciting_tail.checked_sub(1).unwrap(); } Some(packet) } /// Returns the number of bytes to *remove* from the connection's in-flight count pub(super) fn sent(&mut self, number: u64, packet: SentPacket) -> u64 { // Retain state for at most this many non-ACK-eliciting packets sent after the most recently // sent ACK-eliciting packet. We're never guaranteed to receive an ACK for those, and we // can't judge them as lost without an ACK, so to limit memory in applications which receive // packets but don't send ACK-eliciting data for long periods use we must eventually start // forgetting about them, although it might also be reasonable to just kill the connection // due to weird peer behavior. const MAX_UNACKED_NON_ACK_ELICTING_TAIL: u64 = 1_000; let mut forgotten_bytes = 0; if packet.ack_eliciting { self.unacked_non_ack_eliciting_tail = 0; self.largest_ack_eliciting_sent = number; } else if self.unacked_non_ack_eliciting_tail > MAX_UNACKED_NON_ACK_ELICTING_TAIL { let oldest_after_ack_eliciting = *self .sent_packets .range(( Bound::Excluded(self.largest_ack_eliciting_sent), Bound::Unbounded, )) .next() .unwrap() .0; // Per https://www.rfc-editor.org/rfc/rfc9000.html#name-frames-and-frame-types, // non-ACK-eliciting packets must only contain PADDING, ACK, and CONNECTION_CLOSE // frames, which require no special handling on ACK or loss beyond removal from // in-flight counters if padded. let packet = self .sent_packets .remove(&oldest_after_ack_eliciting) .unwrap(); forgotten_bytes = u64::from(packet.size); self.in_flight -= forgotten_bytes; } else { self.unacked_non_ack_eliciting_tail += 1; } self.in_flight += u64::from(packet.size); self.sent_packets.insert(number, packet); forgotten_bytes } } impl Index for [PacketSpace; 3] { type Output = PacketSpace; fn index(&self, space: SpaceId) -> &PacketSpace { &self.as_ref()[space as usize] } } impl IndexMut for [PacketSpace; 3] { fn index_mut(&mut self, space: SpaceId) -> &mut PacketSpace { &mut self.as_mut()[space as usize] } } /// Represents one or more packets subject to retransmission #[derive(Debug, Clone)] pub(super) struct SentPacket { /// The time the packet was sent. pub(super) time_sent: Instant, /// The number of bytes sent in the packet, not including UDP or IP overhead, but including QUIC /// framing overhead. Zero if this packet is not counted towards congestion control, i.e. not an /// "in flight" packet. pub(super) size: u16, /// Whether an acknowledgement is expected directly in response to this packet. pub(super) ack_eliciting: bool, /// The largest packet number acknowledged by this packet pub(super) largest_acked: Option, /// Data which needs to be retransmitted in case the packet is lost. /// The data is boxed to minimize `SentPacket` size for the typical case of /// packets only containing ACKs and STREAM frames. pub(super) retransmits: ThinRetransmits, /// Metadata for stream frames in a packet /// /// The actual application data is stored with the stream state. pub(super) stream_frames: frame::StreamMetaVec, } /// Retransmittable data queue #[allow(unreachable_pub)] // fuzzing only #[derive(Debug, Default, Clone)] pub struct Retransmits { pub(super) max_data: bool, pub(super) max_stream_id: [bool; 2], pub(super) reset_stream: Vec<(StreamId, VarInt)>, pub(super) stop_sending: Vec, pub(super) max_stream_data: FxHashSet, pub(super) crypto: VecDeque, pub(super) new_cids: Vec, pub(super) retire_cids: Vec, pub(super) ack_frequency: bool, pub(super) handshake_done: bool, } impl Retransmits { pub(super) fn is_empty(&self, streams: &StreamsState) -> bool { !self.max_data && !self.max_stream_id.into_iter().any(|x| x) && self.reset_stream.is_empty() && self.stop_sending.is_empty() && self .max_stream_data .iter() .all(|&id| !streams.can_send_flow_control(id)) && self.crypto.is_empty() && self.new_cids.is_empty() && self.retire_cids.is_empty() && !self.ack_frequency && !self.handshake_done } } impl ::std::ops::BitOrAssign for Retransmits { fn bitor_assign(&mut self, rhs: Self) { // We reduce in-stream head-of-line blocking by queueing retransmits before other data for // STREAM and CRYPTO frames. self.max_data |= rhs.max_data; for dir in Dir::iter() { self.max_stream_id[dir as usize] |= rhs.max_stream_id[dir as usize]; } self.reset_stream.extend_from_slice(&rhs.reset_stream); self.stop_sending.extend_from_slice(&rhs.stop_sending); self.max_stream_data.extend(&rhs.max_stream_data); for crypto in rhs.crypto.into_iter().rev() { self.crypto.push_front(crypto); } self.new_cids.extend(&rhs.new_cids); self.retire_cids.extend(rhs.retire_cids); self.ack_frequency |= rhs.ack_frequency; self.handshake_done |= rhs.handshake_done; } } impl ::std::ops::BitOrAssign for Retransmits { fn bitor_assign(&mut self, rhs: ThinRetransmits) { if let Some(retransmits) = rhs.retransmits { self.bitor_assign(*retransmits) } } } impl ::std::iter::FromIterator for Retransmits { fn from_iter(iter: T) -> Self where T: IntoIterator, { let mut result = Self::default(); for packet in iter { result |= packet; } result } } /// A variant of `Retransmits` which only allocates storage when required #[derive(Debug, Default, Clone)] pub(super) struct ThinRetransmits { retransmits: Option>, } impl ThinRetransmits { /// Returns `true` if no retransmits are necessary pub(super) fn is_empty(&self, streams: &StreamsState) -> bool { match &self.retransmits { Some(retransmits) => retransmits.is_empty(streams), None => true, } } /// Returns a reference to the retransmits stored in this box pub(super) fn get(&self) -> Option<&Retransmits> { self.retransmits.as_deref() } /// Returns a mutable reference to the stored retransmits /// /// This function will allocate a backing storage if required. pub(super) fn get_or_create(&mut self) -> &mut Retransmits { if self.retransmits.is_none() { self.retransmits = Some(Box::default()); } self.retransmits.as_deref_mut().unwrap() } } /// RFC4303-style sliding window packet number deduplicator. /// /// A contiguous bitfield, where each bit corresponds to a packet number and the rightmost bit is /// always set. A set bit represents a packet that has been successfully authenticated. Bits left of /// the window are assumed to be set. /// /// ```text /// ...xxxxxxxxx 1 0 /// ^ ^ ^ /// window highest next /// ``` pub(super) struct Dedup { window: Window, /// Lowest packet number higher than all yet authenticated. next: u64, } /// Inner bitfield type. /// /// Because QUIC never reuses packet numbers, this only needs to be large enough to deal with /// packets that are reordered but still delivered in a timely manner. type Window = u128; /// Number of packets tracked by `Dedup`. const WINDOW_SIZE: u64 = 1 + mem::size_of::() as u64 * 8; impl Dedup { /// Construct an empty window positioned at the start. pub(super) fn new() -> Self { Self { window: 0, next: 0 } } /// Highest packet number authenticated. fn highest(&self) -> u64 { self.next - 1 } /// Record a newly authenticated packet number. /// /// Returns whether the packet might be a duplicate. pub(super) fn insert(&mut self, packet: u64) -> bool { if let Some(diff) = packet.checked_sub(self.next) { // Right of window self.window = (self.window << 1 | 1) .checked_shl(cmp::min(diff, u64::from(u32::MAX)) as u32) .unwrap_or(0); self.next = packet + 1; false } else if self.highest() - packet < WINDOW_SIZE { // Within window if let Some(bit) = (self.highest() - packet).checked_sub(1) { // < highest let mask = 1 << bit; let duplicate = self.window & mask != 0; self.window |= mask; duplicate } else { // == highest true } } else { // Left of window true } } /// Returns the packet number of the smallest packet missing between the provided interval /// /// If there are no missing packets, returns `None` fn smallest_missing_in_interval(&self, lower_bound: u64, upper_bound: u64) -> Option { debug_assert!(lower_bound <= upper_bound); debug_assert!(upper_bound <= self.highest()); const BITFIELD_SIZE: u64 = (mem::size_of::() * 8) as u64; // Since we already know the packets at the boundaries have been received, we only need to // check those in between them (this removes the necessity of extra logic to deal with the // highest packet, which is stored outside the bitfield) let lower_bound = lower_bound + 1; let upper_bound = upper_bound.saturating_sub(1); // Note: the offsets are counted from the right // The highest packet is not included in the bitfield, so we subtract 1 to account for that let start_offset = (self.highest() - upper_bound).max(1) - 1; if start_offset >= BITFIELD_SIZE { // The start offset is outside of the window. All packets outside of the window are // considered to be received. return None; } let end_offset_exclusive = self.highest().saturating_sub(lower_bound); // The range is clamped at the edge of the window, because any earlier packets are // considered to be received let range_len = end_offset_exclusive .saturating_sub(start_offset) .min(BITFIELD_SIZE); if range_len == 0 { return None; } // Ensure the shift is within bounds (we already know start_offset < BITFIELD_SIZE, // because of the early return) let mask = if range_len == BITFIELD_SIZE { u128::MAX } else { ((1u128 << range_len) - 1) << start_offset }; let gaps = !self.window & mask; let smallest_missing_offset = 128 - gaps.leading_zeros() as u64; let smallest_missing_packet = self.highest() - smallest_missing_offset; if smallest_missing_packet <= upper_bound { Some(smallest_missing_packet) } else { None } } /// Returns true if there are any missing packets between the provided interval /// /// The provided packet numbers must have been received before calling this function fn missing_in_interval(&self, lower_bound: u64, upper_bound: u64) -> bool { self.smallest_missing_in_interval(lower_bound, upper_bound) .is_some() } } /// Indicates which data is available for sending #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub(super) struct SendableFrames { pub(super) acks: bool, pub(super) other: bool, } impl SendableFrames { /// Returns that no data is available for sending pub(super) fn empty() -> Self { Self { acks: false, other: false, } } /// Whether no data is sendable pub(super) fn is_empty(&self) -> bool { !self.acks && !self.other } } #[derive(Debug)] pub(super) struct PendingAcks { /// Whether we should send an ACK immediately, even if that means sending an ACK-only packet /// /// When `immediate_ack_required` is false, the normal behavior is to send ACK frames only when /// there is other data to send, or when the `MaxAckDelay` timer expires. immediate_ack_required: bool, /// The number of ack-eliciting packets received since the last ACK frame was sent /// /// Once the count _exceeds_ `ack_eliciting_threshold`, an immediate ACK is required ack_eliciting_since_last_ack_sent: u64, non_ack_eliciting_since_last_ack_sent: u64, ack_eliciting_threshold: u64, /// The reordering threshold, controlling how we respond to out-of-order ack-eliciting packets /// /// Different values enable different behavior: /// /// * `0`: no special action is taken /// * `1`: an ACK is immediately sent if it is out-of-order according to RFC 9000 /// * `>1`: an ACK is immediately sent if it is out-of-order according to the ACK frequency draft reordering_threshold: u64, /// The earliest ack-eliciting packet since the last ACK was sent, used to calculate the moment /// upon which `max_ack_delay` elapses earliest_ack_eliciting_since_last_ack_sent: Option, /// The packet number ranges of ack-eliciting packets the peer hasn't confirmed receipt of ACKs /// for ranges: ArrayRangeSet, /// The packet with the largest packet number, and the time upon which it was received (used to /// calculate ACK delay in [`PendingAcks::ack_delay`]) largest_packet: Option<(u64, Instant)>, /// The ack-eliciting packet we have received with the largest packet number largest_ack_eliciting_packet: Option, /// The largest acknowledged packet number sent in an ACK frame largest_acked: Option, } impl PendingAcks { fn new() -> Self { Self { immediate_ack_required: false, ack_eliciting_since_last_ack_sent: 0, non_ack_eliciting_since_last_ack_sent: 0, ack_eliciting_threshold: 1, reordering_threshold: 1, earliest_ack_eliciting_since_last_ack_sent: None, ranges: ArrayRangeSet::default(), largest_packet: None, largest_ack_eliciting_packet: None, largest_acked: None, } } pub(super) fn set_ack_frequency_params(&mut self, frame: &frame::AckFrequency) { self.ack_eliciting_threshold = frame.ack_eliciting_threshold.into_inner(); self.reordering_threshold = frame.reordering_threshold.into_inner(); } pub(super) fn set_immediate_ack_required(&mut self) { self.immediate_ack_required = true; } pub(super) fn on_max_ack_delay_timeout(&mut self) { self.immediate_ack_required = self.ack_eliciting_since_last_ack_sent > 0; } pub(super) fn max_ack_delay_timeout(&self, max_ack_delay: Duration) -> Option { self.earliest_ack_eliciting_since_last_ack_sent .map(|earliest_unacked| earliest_unacked + max_ack_delay) } /// Whether any ACK frames can be sent pub(super) fn can_send(&self) -> bool { self.immediate_ack_required && !self.ranges.is_empty() } /// Returns the delay since the packet with the largest packet number was received pub(super) fn ack_delay(&self, now: Instant) -> Duration { self.largest_packet .map_or(Duration::default(), |(_, received)| now - received) } /// Handle receipt of a new packet /// /// Returns true if the max ack delay timer should be armed pub(super) fn packet_received( &mut self, now: Instant, packet_number: u64, ack_eliciting: bool, dedup: &Dedup, ) -> bool { if !ack_eliciting { self.non_ack_eliciting_since_last_ack_sent += 1; return false; } let prev_largest_ack_eliciting = self.largest_ack_eliciting_packet.unwrap_or(0); // Track largest ack-eliciting packet self.largest_ack_eliciting_packet = self .largest_ack_eliciting_packet .map(|pn| pn.max(packet_number)) .or(Some(packet_number)); // Handle ack_eliciting_threshold self.ack_eliciting_since_last_ack_sent += 1; self.immediate_ack_required |= self.ack_eliciting_since_last_ack_sent > self.ack_eliciting_threshold; // Handle out-of-order packets self.immediate_ack_required |= self.is_out_of_order(packet_number, prev_largest_ack_eliciting, dedup); // Arm max_ack_delay timer if necessary if self.earliest_ack_eliciting_since_last_ack_sent.is_none() && !self.can_send() { self.earliest_ack_eliciting_since_last_ack_sent = Some(now); return true; } false } fn is_out_of_order( &self, packet_number: u64, prev_largest_ack_eliciting: u64, dedup: &Dedup, ) -> bool { match self.reordering_threshold { 0 => false, 1 => { // From https://www.rfc-editor.org/rfc/rfc9000#section-13.2.1-7 packet_number < prev_largest_ack_eliciting || dedup.missing_in_interval(prev_largest_ack_eliciting, packet_number) } _ => { // From acknowledgement frequency draft, section 6.1: send an ACK immediately if // doing so would cause the sender to detect a new packet loss let Some((largest_acked, largest_unacked)) = self.largest_acked.zip(self.largest_ack_eliciting_packet) else { return false; }; if self.reordering_threshold > largest_acked { return false; } // The largest packet number that could be declared lost without a new ACK being // sent let largest_reported = largest_acked - self.reordering_threshold + 1; let Some(smallest_missing_unreported) = dedup.smallest_missing_in_interval(largest_reported, largest_unacked) else { return false; }; largest_unacked - smallest_missing_unreported >= self.reordering_threshold } } } /// Should be called whenever ACKs have been sent /// /// This will suppress sending further ACKs until additional ACK eliciting frames arrive pub(super) fn acks_sent(&mut self) { // It is possible (though unlikely) that the ACKs we just sent do not cover all the // ACK-eliciting packets we have received (e.g. if there is not enough room in the packet to // fit all the ranges). To keep things simple, however, we assume they do. If there are // indeed some ACKs that weren't covered, the packets might be ACKed later anyway, because // they are still contained in `self.ranges`. If we somehow fail to send the ACKs at a later // moment, the peer will assume the packets got lost and will retransmit their frames in a // new packet, which is suboptimal, because we already received them. Our assumption here is // that simplicity results in code that is more performant, even in the presence of // occasional redundant retransmits. self.immediate_ack_required = false; self.ack_eliciting_since_last_ack_sent = 0; self.non_ack_eliciting_since_last_ack_sent = 0; self.earliest_ack_eliciting_since_last_ack_sent = None; self.largest_acked = self.largest_ack_eliciting_packet; } /// Insert one packet that needs to be acknowledged pub(super) fn insert_one(&mut self, packet: u64, now: Instant) { self.ranges.insert_one(packet); if self.largest_packet.map_or(true, |(pn, _)| packet > pn) { self.largest_packet = Some((packet, now)); } if self.ranges.len() > MAX_ACK_BLOCKS { self.ranges.pop_min(); } } /// Remove ACKs of packets numbered at or below `max` from the set of pending ACKs pub(super) fn subtract_below(&mut self, max: u64) { self.ranges.remove(0..(max + 1)); } /// Returns the set of currently pending ACK ranges pub(super) fn ranges(&self) -> &ArrayRangeSet { &self.ranges } /// Queue an ACK if a significant number of non-ACK-eliciting packets have not yet been /// acknowledged /// /// Should be called immediately before a non-probing packet is composed, when we've already /// committed to sending a packet regardless. pub(super) fn maybe_ack_non_eliciting(&mut self) { // If we're going to send a packet anyway, and we've received a significant number of // non-ACK-eliciting packets, then include an ACK to help the peer perform timely loss // detection even if they're not sending any ACK-eliciting packets themselves. Exact // threshold chosen somewhat arbitrarily. const LAZY_ACK_THRESHOLD: u64 = 10; if self.non_ack_eliciting_since_last_ack_sent > LAZY_ACK_THRESHOLD { self.immediate_ack_required = true; } } } /// Helper for mitigating [optimistic ACK attacks] /// /// A malicious peer could prompt the local application to begin a large data transfer, and then /// send ACKs without first waiting for data to be received. This could defeat congestion control, /// allowing the connection to consume disproportionate resources. We therefore occasionally skip /// packet numbers, and classify any ACK referencing a skipped packet number as a transport error. /// /// Skipped packet numbers occur only in the application data space (where costly transfers might /// take place) and are distributed exponentially to reflect the reduced likelihood and impact of /// bad behavior from a peer that has been well-behaved for an extended period. /// /// ACKs for packet numbers that have not yet been allocated are also a transport error, but an /// attacker with knowledge of the congestion control algorithm in use could time falsified ACKs to /// arrive after the packets they reference are sent. /// /// [optimistic ACK attacks]: https://www.rfc-editor.org/rfc/rfc9000.html#name-optimistic-ack-attack pub(super) struct PacketNumberFilter { /// Next outgoing packet number to skip next_skipped_packet_number: u64, /// Most recently skipped packet number prev_skipped_packet_number: Option, /// Next packet number to skip is randomly selected from 2^n..2^n+1 exponent: u32, } impl PacketNumberFilter { pub(super) fn new(rng: &mut (impl Rng + ?Sized)) -> Self { // First skipped PN is in 0..64 let exponent = 6; Self { next_skipped_packet_number: rng.gen_range(0..2u64.saturating_pow(exponent)), prev_skipped_packet_number: None, exponent, } } #[cfg(test)] pub(super) fn disabled() -> Self { Self { next_skipped_packet_number: u64::MAX, prev_skipped_packet_number: None, exponent: u32::MAX, } } pub(super) fn peek(&self, space: &PacketSpace) -> u64 { let n = space.next_packet_number; if n != self.next_skipped_packet_number { return n; } n + 1 } pub(super) fn allocate( &mut self, rng: &mut (impl Rng + ?Sized), space: &mut PacketSpace, ) -> u64 { let n = space.get_tx_number(); if n != self.next_skipped_packet_number { return n; } trace!("skipping pn {n}"); // Skip this packet number, and choose the next one to skip self.prev_skipped_packet_number = Some(self.next_skipped_packet_number); let next_exponent = self.exponent.saturating_add(1); self.next_skipped_packet_number = rng.gen_range(2u64.saturating_pow(self.exponent)..2u64.saturating_pow(next_exponent)); self.exponent = next_exponent; space.get_tx_number() } pub(super) fn check_ack( &self, space_id: SpaceId, range: std::ops::RangeInclusive, ) -> Result<(), TransportError> { if space_id == SpaceId::Data && self .prev_skipped_packet_number .map_or(false, |x| range.contains(&x)) { return Err(TransportError::PROTOCOL_VIOLATION("unsent packet acked")); } Ok(()) } } /// Ensures we can always fit all our ACKs in a single minimum-MTU packet with room to spare const MAX_ACK_BLOCKS: usize = 64; #[cfg(test)] mod test { use super::*; #[test] fn sanity() { let mut dedup = Dedup::new(); assert!(!dedup.insert(0)); assert_eq!(dedup.next, 1); assert_eq!(dedup.window, 0b1); assert!(dedup.insert(0)); assert_eq!(dedup.next, 1); assert_eq!(dedup.window, 0b1); assert!(!dedup.insert(1)); assert_eq!(dedup.next, 2); assert_eq!(dedup.window, 0b11); assert!(!dedup.insert(2)); assert_eq!(dedup.next, 3); assert_eq!(dedup.window, 0b111); assert!(!dedup.insert(4)); assert_eq!(dedup.next, 5); assert_eq!(dedup.window, 0b11110); assert!(!dedup.insert(7)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_0100); assert!(dedup.insert(4)); assert!(!dedup.insert(3)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1100); assert!(!dedup.insert(6)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1101); assert!(!dedup.insert(5)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1111); } #[test] fn happypath() { let mut dedup = Dedup::new(); for i in 0..(2 * WINDOW_SIZE) { assert!(!dedup.insert(i)); for j in 0..=i { assert!(dedup.insert(j)); } } } #[test] fn jump() { let mut dedup = Dedup::new(); dedup.insert(2 * WINDOW_SIZE); assert!(dedup.insert(WINDOW_SIZE)); assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1); assert_eq!(dedup.window, 0); assert!(!dedup.insert(WINDOW_SIZE + 1)); assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1); assert_eq!(dedup.window, 1 << (WINDOW_SIZE - 2)); } #[test] fn dedup_has_missing() { let mut dedup = Dedup::new(); dedup.insert(0); assert!(!dedup.missing_in_interval(0, 0)); dedup.insert(1); assert!(!dedup.missing_in_interval(0, 1)); dedup.insert(3); assert!(dedup.missing_in_interval(1, 3)); dedup.insert(4); assert!(!dedup.missing_in_interval(3, 4)); assert!(dedup.missing_in_interval(0, 4)); dedup.insert(2); assert!(!dedup.missing_in_interval(0, 4)); } #[test] fn dedup_outside_of_window_has_missing() { let mut dedup = Dedup::new(); for i in 0..140 { dedup.insert(i); } // 0 and 4 are outside of the window assert!(!dedup.missing_in_interval(0, 4)); dedup.insert(160); assert!(!dedup.missing_in_interval(0, 4)); assert!(!dedup.missing_in_interval(0, 140)); assert!(dedup.missing_in_interval(0, 160)); } #[test] fn dedup_smallest_missing() { let mut dedup = Dedup::new(); dedup.insert(0); assert_eq!(dedup.smallest_missing_in_interval(0, 0), None); dedup.insert(1); assert_eq!(dedup.smallest_missing_in_interval(0, 1), None); dedup.insert(5); dedup.insert(7); assert_eq!(dedup.smallest_missing_in_interval(0, 7), Some(2)); assert_eq!(dedup.smallest_missing_in_interval(5, 7), Some(6)); dedup.insert(2); assert_eq!(dedup.smallest_missing_in_interval(1, 7), Some(3)); dedup.insert(170); dedup.insert(172); dedup.insert(300); assert_eq!(dedup.smallest_missing_in_interval(170, 172), None); dedup.insert(500); assert_eq!(dedup.smallest_missing_in_interval(0, 500), Some(372)); assert_eq!(dedup.smallest_missing_in_interval(0, 373), Some(372)); assert_eq!(dedup.smallest_missing_in_interval(0, 372), None); } #[test] fn pending_acks_first_packet_is_not_considered_reordered() { let mut acks = PendingAcks::new(); let mut dedup = Dedup::new(); dedup.insert(0); acks.packet_received(Instant::now(), 0, true, &dedup); assert!(!acks.immediate_ack_required); } #[test] fn pending_acks_after_immediate_ack_set() { let mut acks = PendingAcks::new(); let mut dedup = Dedup::new(); // Receive ack-eliciting packet dedup.insert(0); let now = Instant::now(); acks.insert_one(0, now); acks.packet_received(now, 0, true, &dedup); // Sanity check assert!(!acks.ranges.is_empty()); assert!(!acks.can_send()); // Can send ACK after max_ack_delay exceeded acks.set_immediate_ack_required(); assert!(acks.can_send()); } #[test] fn pending_acks_ack_delay() { let mut acks = PendingAcks::new(); let mut dedup = Dedup::new(); let t1 = Instant::now(); let t2 = t1 + Duration::from_millis(2); let t3 = t2 + Duration::from_millis(5); assert_eq!(acks.ack_delay(t1), Duration::from_millis(0)); assert_eq!(acks.ack_delay(t2), Duration::from_millis(0)); assert_eq!(acks.ack_delay(t3), Duration::from_millis(0)); // In-order packet dedup.insert(0); acks.insert_one(0, t1); acks.packet_received(t1, 0, true, &dedup); assert_eq!(acks.ack_delay(t1), Duration::from_millis(0)); assert_eq!(acks.ack_delay(t2), Duration::from_millis(2)); assert_eq!(acks.ack_delay(t3), Duration::from_millis(7)); // Out of order (higher than expected) dedup.insert(3); acks.insert_one(3, t2); acks.packet_received(t2, 3, true, &dedup); assert_eq!(acks.ack_delay(t2), Duration::from_millis(0)); assert_eq!(acks.ack_delay(t3), Duration::from_millis(5)); // Out of order (lower than expected, so previous instant is kept) dedup.insert(2); acks.insert_one(2, t3); acks.packet_received(t3, 2, true, &dedup); assert_eq!(acks.ack_delay(t3), Duration::from_millis(5)); } #[test] fn sent_packet_size() { // The tracking state of sent packets should be minimal, and not grow // over time. assert!(std::mem::size_of::() <= 128); } } quinn-proto-0.11.9/src/connection/stats.rs000064400000000000000000000147061046102023000166510ustar 00000000000000//! Connection statistics use crate::{frame::Frame, Dir, Duration}; /// Statistics about UDP datagrams transmitted or received on a connection #[derive(Default, Debug, Copy, Clone)] #[non_exhaustive] pub struct UdpStats { /// The amount of UDP datagrams observed pub datagrams: u64, /// The total amount of bytes which have been transferred inside UDP datagrams pub bytes: u64, /// The amount of I/O operations executed /// /// Can be less than `datagrams` when GSO, GRO, and/or batched system calls are in use. pub ios: u64, } impl UdpStats { pub(crate) fn on_sent(&mut self, datagrams: u64, bytes: usize) { self.datagrams += datagrams; self.bytes += bytes as u64; self.ios += 1; } } /// Number of frames transmitted of each frame type #[derive(Default, Copy, Clone)] #[non_exhaustive] #[allow(missing_docs)] pub struct FrameStats { pub acks: u64, pub ack_frequency: u64, pub crypto: u64, pub connection_close: u64, pub data_blocked: u64, pub datagram: u64, pub handshake_done: u8, pub immediate_ack: u64, pub max_data: u64, pub max_stream_data: u64, pub max_streams_bidi: u64, pub max_streams_uni: u64, pub new_connection_id: u64, pub new_token: u64, pub path_challenge: u64, pub path_response: u64, pub ping: u64, pub reset_stream: u64, pub retire_connection_id: u64, pub stream_data_blocked: u64, pub streams_blocked_bidi: u64, pub streams_blocked_uni: u64, pub stop_sending: u64, pub stream: u64, } impl FrameStats { pub(crate) fn record(&mut self, frame: &Frame) { match frame { Frame::Padding => {} Frame::Ping => self.ping += 1, Frame::Ack(_) => self.acks += 1, Frame::ResetStream(_) => self.reset_stream += 1, Frame::StopSending(_) => self.stop_sending += 1, Frame::Crypto(_) => self.crypto += 1, Frame::Datagram(_) => self.datagram += 1, Frame::NewToken { .. } => self.new_token += 1, Frame::MaxData(_) => self.max_data += 1, Frame::MaxStreamData { .. } => self.max_stream_data += 1, Frame::MaxStreams { dir, .. } => { if *dir == Dir::Bi { self.max_streams_bidi += 1; } else { self.max_streams_uni += 1; } } Frame::DataBlocked { .. } => self.data_blocked += 1, Frame::Stream(_) => self.stream += 1, Frame::StreamDataBlocked { .. } => self.stream_data_blocked += 1, Frame::StreamsBlocked { dir, .. } => { if *dir == Dir::Bi { self.streams_blocked_bidi += 1; } else { self.streams_blocked_uni += 1; } } Frame::NewConnectionId(_) => self.new_connection_id += 1, Frame::RetireConnectionId { .. } => self.retire_connection_id += 1, Frame::PathChallenge(_) => self.path_challenge += 1, Frame::PathResponse(_) => self.path_response += 1, Frame::Close(_) => self.connection_close += 1, Frame::AckFrequency(_) => self.ack_frequency += 1, Frame::ImmediateAck => self.immediate_ack += 1, Frame::HandshakeDone => self.handshake_done = self.handshake_done.saturating_add(1), } } } impl std::fmt::Debug for FrameStats { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("FrameStats") .field("ACK", &self.acks) .field("ACK_FREQUENCY", &self.ack_frequency) .field("CONNECTION_CLOSE", &self.connection_close) .field("CRYPTO", &self.crypto) .field("DATA_BLOCKED", &self.data_blocked) .field("DATAGRAM", &self.datagram) .field("HANDSHAKE_DONE", &self.handshake_done) .field("IMMEDIATE_ACK", &self.immediate_ack) .field("MAX_DATA", &self.max_data) .field("MAX_STREAM_DATA", &self.max_stream_data) .field("MAX_STREAMS_BIDI", &self.max_streams_bidi) .field("MAX_STREAMS_UNI", &self.max_streams_uni) .field("NEW_CONNECTION_ID", &self.new_connection_id) .field("NEW_TOKEN", &self.new_token) .field("PATH_CHALLENGE", &self.path_challenge) .field("PATH_RESPONSE", &self.path_response) .field("PING", &self.ping) .field("RESET_STREAM", &self.reset_stream) .field("RETIRE_CONNECTION_ID", &self.retire_connection_id) .field("STREAM_DATA_BLOCKED", &self.stream_data_blocked) .field("STREAMS_BLOCKED_BIDI", &self.streams_blocked_bidi) .field("STREAMS_BLOCKED_UNI", &self.streams_blocked_uni) .field("STOP_SENDING", &self.stop_sending) .field("STREAM", &self.stream) .finish() } } /// Statistics related to a transmission path #[derive(Debug, Default, Copy, Clone)] #[non_exhaustive] pub struct PathStats { /// Current best estimate of this connection's latency (round-trip-time) pub rtt: Duration, /// Current congestion window of the connection pub cwnd: u64, /// Congestion events on the connection pub congestion_events: u64, /// The amount of packets lost on this path pub lost_packets: u64, /// The amount of bytes lost on this path pub lost_bytes: u64, /// The amount of packets sent on this path pub sent_packets: u64, /// The amount of PLPMTUD probe packets sent on this path (also counted by `sent_packets`) pub sent_plpmtud_probes: u64, /// The amount of PLPMTUD probe packets lost on this path (ignored by `lost_packets` and /// `lost_bytes`) pub lost_plpmtud_probes: u64, /// The number of times a black hole was detected in the path pub black_holes_detected: u64, /// Largest UDP payload size the path currently supports pub current_mtu: u16, } /// Connection statistics #[derive(Debug, Default, Copy, Clone)] #[non_exhaustive] pub struct ConnectionStats { /// Statistics about UDP datagrams transmitted on a connection pub udp_tx: UdpStats, /// Statistics about UDP datagrams received on a connection pub udp_rx: UdpStats, /// Statistics about frames transmitted on a connection pub frame_tx: FrameStats, /// Statistics about frames received on a connection pub frame_rx: FrameStats, /// Statistics related to the current transmission path pub path: PathStats, } quinn-proto-0.11.9/src/connection/streams/mod.rs000064400000000000000000000457651046102023000177610ustar 00000000000000use std::{ collections::{hash_map, BinaryHeap}, io, }; use bytes::Bytes; use thiserror::Error; use tracing::trace; use super::spaces::{Retransmits, ThinRetransmits}; use crate::{ connection::streams::state::{get_or_insert_recv, get_or_insert_send}, frame, Dir, StreamId, VarInt, }; mod recv; use recv::Recv; pub use recv::{Chunks, ReadError, ReadableError}; mod send; pub(crate) use send::{ByteSlice, BytesArray}; pub use send::{BytesSource, FinishError, WriteError, Written}; use send::{Send, SendState}; mod state; #[allow(unreachable_pub)] // fuzzing only pub use state::StreamsState; /// Access to streams pub struct Streams<'a> { pub(super) state: &'a mut StreamsState, pub(super) conn_state: &'a super::State, } #[allow(clippy::needless_lifetimes)] // Needed for cfg(fuzzing) impl<'a> Streams<'a> { #[cfg(fuzzing)] pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self { Self { state, conn_state } } /// Open a single stream if possible /// /// Returns `None` if the streams in the given direction are currently exhausted. pub fn open(&mut self, dir: Dir) -> Option { if self.conn_state.is_closed() { return None; } // TODO: Queue STREAM_ID_BLOCKED if this fails if self.state.next[dir as usize] >= self.state.max[dir as usize] { return None; } self.state.next[dir as usize] += 1; let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1); self.state.insert(false, id); self.state.send_streams += 1; Some(id) } /// Accept a remotely initiated stream of a certain directionality, if possible /// /// Returns `None` if there are no new incoming streams for this connection. /// Has no impact on the data flow-control or stream concurrency limits. pub fn accept(&mut self, dir: Dir) -> Option { if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] { return None; } let x = self.state.next_reported_remote[dir as usize]; self.state.next_reported_remote[dir as usize] = x + 1; if dir == Dir::Bi { self.state.send_streams += 1; } Some(StreamId::new(!self.state.side, dir, x)) } #[cfg(fuzzing)] pub fn state(&mut self) -> &mut StreamsState { self.state } /// The number of streams that may have unacknowledged data. pub fn send_streams(&self) -> usize { self.state.send_streams } /// The number of remotely initiated open streams of a certain directionality. /// /// Includes remotely initiated streams, which have not been accepted via [`accept`](Self::accept). /// These streams count against the respective concurrency limit reported by /// [`Connection::max_concurrent_streams`](super::Connection::max_concurrent_streams). pub fn remote_open_streams(&self, dir: Dir) -> u64 { // total opened - total closed = total opened - ( total permitted - total permitted unclosed ) self.state.next_remote[dir as usize] - (self.state.max_remote[dir as usize] - self.state.allocated_remote_count[dir as usize]) } } /// Access to streams pub struct RecvStream<'a> { pub(super) id: StreamId, pub(super) state: &'a mut StreamsState, pub(super) pending: &'a mut Retransmits, } impl RecvStream<'_> { /// Read from the given recv stream /// /// `max_length` limits the maximum size of the returned `Bytes` value; passing `usize::MAX` /// will yield the best performance. `ordered` will make sure the returned chunk's offset will /// have an offset exactly equal to the previously returned offset plus the previously returned /// bytes' length. /// /// Yields `Ok(None)` if the stream was finished. Otherwise, yields a segment of data and its /// offset in the stream. If `ordered` is `false`, segments may be received in any order, and /// the `Chunk`'s `offset` field can be used to determine ordering in the caller. /// /// While most applications will prefer to consume stream data in order, unordered reads can /// improve performance when packet loss occurs and data cannot be retransmitted before the flow /// control window is filled. On any given stream, you can switch from ordered to unordered /// reads, but ordered reads on streams that have seen previous unordered reads will return /// `ReadError::IllegalOrderedRead`. pub fn read(&mut self, ordered: bool) -> Result { Chunks::new(self.id, ordered, self.state, self.pending) } /// Stop accepting data on the given receive stream /// /// Discards unread data and notifies the peer to stop transmitting. Once stopped, further /// attempts to operate on a stream will yield `ClosedStream` errors. pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let mut entry = match self.state.recv.entry(self.id) { hash_map::Entry::Occupied(s) => s, hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }), }; let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut()); let (read_credits, stop_sending) = stream.stop()?; if stop_sending.should_transmit() { self.pending.stop_sending.push(frame::StopSending { id: self.id, error_code, }); } // We need to keep stopped streams around until they're finished or reset so we can update // connection-level flow control to account for discarded data. Otherwise, we can discard // state immediately. if !stream.final_offset_unknown() { let recv = entry.remove().expect("must have recv when stopping"); self.state.stream_recv_freed(self.id, recv); } if self.state.add_read_credits(read_credits).should_transmit() { self.pending.max_data = true; } Ok(()) } /// Check whether this stream has been reset by the peer, returning the reset error code if so /// /// After returning `Ok(Some(_))` once, stream state will be discarded and all future calls will /// return `Err(ClosedStream)`. pub fn received_reset(&mut self) -> Result, ClosedStream> { let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else { return Err(ClosedStream { _private: () }); }; let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else { return Ok(None); }; if s.stopped { return Err(ClosedStream { _private: () }); } let Some(code) = s.reset_code() else { return Ok(None); }; // Clean up state after application observes the reset, since there's no reason for the // application to attempt to read or stop the stream once it knows it's reset let (_, recv) = entry.remove_entry(); self.state .stream_recv_freed(self.id, recv.expect("must have recv on reset")); self.state.queue_max_stream_id(self.pending); Ok(Some(code)) } } /// Access to streams pub struct SendStream<'a> { pub(super) id: StreamId, pub(super) state: &'a mut StreamsState, pub(super) pending: &'a mut Retransmits, pub(super) conn_state: &'a super::State, } #[allow(clippy::needless_lifetimes)] // Needed for cfg(fuzzing) impl<'a> SendStream<'a> { #[cfg(fuzzing)] pub fn new( id: StreamId, state: &'a mut StreamsState, pending: &'a mut Retransmits, conn_state: &'a super::State, ) -> Self { Self { id, state, pending, conn_state, } } /// Send data on the given stream /// /// Returns the number of bytes successfully written. pub fn write(&mut self, data: &[u8]) -> Result { Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes) } /// Send data on the given stream /// /// Returns the number of bytes and chunks successfully written. /// Note that this method might also write a partial chunk. In this case /// [`Written::chunks`] will not count this chunk as fully written. However /// the chunk will be advanced and contain only non-written data after the call. pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result { self.write_source(&mut BytesArray::from_chunks(data)) } fn write_source(&mut self, source: &mut B) -> Result { if self.conn_state.is_closed() { trace!(%self.id, "write blocked; connection draining"); return Err(WriteError::Blocked); } let limit = self.state.write_limit(); let max_send_data = self.state.max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(WriteError::ClosedStream)?; if limit == 0 { trace!( stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent, "write blocked by connection-level flow control or send window" ); if !stream.connection_blocked { stream.connection_blocked = true; self.state.connection_blocked.push(self.id); } return Err(WriteError::Blocked); } let was_pending = stream.is_pending(); let written = stream.write(source, limit)?; self.state.data_sent += written.bytes as u64; self.state.unacked_data += written.bytes as u64; trace!(stream = %self.id, "wrote {} bytes", written.bytes); if !was_pending { self.state.pending.push_pending(self.id, stream.priority); } Ok(written) } /// Check if this stream was stopped, get the reason if it was pub fn stopped(&self) -> Result, ClosedStream> { match self.state.send.get(&self.id).as_ref() { Some(Some(s)) => Ok(s.stop_reason), Some(None) => Ok(None), None => Err(ClosedStream { _private: () }), } } /// Finish a send stream, signalling that no more data will be sent. /// /// If this fails, no [`StreamEvent::Finished`] will be generated. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished pub fn finish(&mut self) -> Result<(), FinishError> { let max_send_data = self.state.max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(FinishError::ClosedStream)?; let was_pending = stream.is_pending(); stream.finish()?; if !was_pending { self.state.pending.push_pending(self.id, stream.priority); } Ok(()) } /// Abandon transmitting data on a stream /// /// # Panics /// - when applied to a receive stream pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let max_send_data = self.state.max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(ClosedStream { _private: () })?; if matches!(stream.state, SendState::ResetSent) { // Redundant reset call return Err(ClosedStream { _private: () }); } // Restore the portion of the send window consumed by the data that we aren't about to // send. We leave flow control alone because the peer's responsible for issuing additional // credit based on the final offset communicated in the RESET_STREAM frame we send. self.state.unacked_data -= stream.pending.unacked(); stream.reset(); self.pending.reset_stream.push((self.id, error_code)); // Don't reopen an already-closed stream we haven't forgotten yet Ok(()) } /// Set the priority of a stream /// /// # Panics /// - when applied to a receive stream pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> { let max_send_data = self.state.max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(ClosedStream { _private: () })?; stream.priority = priority; Ok(()) } /// Get the priority of a stream /// /// # Panics /// - when applied to a receive stream pub fn priority(&self) -> Result { let stream = self .state .send .get(&self.id) .ok_or(ClosedStream { _private: () })?; Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default()) } } /// A queue of streams with pending outgoing data, sorted by priority struct PendingStreamsQueue { streams: BinaryHeap, /// The next stream to write out. This is `Some` when `TransportConfig::send_fairness(false)` and writing a stream is /// interrupted while the stream still has some pending data. See `reinsert_pending()`. next: Option, /// A monotonically decreasing counter, used to implement round-robin scheduling for streams of the same priority. /// Underflowing is not a practical concern, as it is initialized to u64::MAX and only decremented by 1 in `push_pending` recency: u64, } impl PendingStreamsQueue { fn new() -> Self { Self { streams: BinaryHeap::new(), next: None, recency: u64::MAX, } } /// Reinsert a stream that was pending and still contains unsent data. fn reinsert_pending(&mut self, id: StreamId, priority: i32) { assert!(self.next.is_none()); self.next = Some(PendingStream { priority, recency: self.recency, // the value here doesn't really matter id, }); } /// Push a pending stream ID with the given priority, queued after any already-queued streams for the priority fn push_pending(&mut self, id: StreamId, priority: i32) { // Note that in the case where fairness is disabled, if we have a reinserted stream we don't // bump it even if priority > next.priority. In order to minimize fragmentation we // always try to complete a stream once part of it has been written. // As the recency counter is monotonically decreasing, we know that using its value to sort this stream will queue it // after all other queued streams of the same priority. // This is enough to implement round-robin scheduling for streams that are still pending even after being handled, // as in that case they are removed from the `BinaryHeap`, handled, and then immediately reinserted. self.recency -= 1; self.streams.push(PendingStream { priority, recency: self.recency, id, }); } fn pop(&mut self) -> Option { self.next.take().or_else(|| self.streams.pop()) } fn clear(&mut self) { self.next = None; self.streams.clear(); } fn iter(&self) -> impl Iterator { self.next.iter().chain(self.streams.iter()) } #[cfg(test)] fn len(&self) -> usize { self.streams.len() + self.next.is_some() as usize } } /// The [`StreamId`] of a stream with pending data queued, ordered by its priority and recency #[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] struct PendingStream { /// The priority of the stream // Note that this field should be kept above the `recency` field, in order for the `Ord` derive to be correct // (See https://doc.rust-lang.org/stable/std/cmp/trait.Ord.html#derivable) priority: i32, /// A tie-breaker for streams of the same priority, used to improve fairness by implementing round-robin scheduling: /// Larger values are prioritized, so it is initialised to `u64::MAX`, and when a stream writes data, we know /// that it currently has the highest recency value, so it is deprioritized by setting its recency to 1 less than the /// previous lowest recency value, such that all other streams of this priority will get processed once before we get back /// round to this one recency: u64, /// The ID of the stream // The way this type is used ensures that every instance has a unique `recency` value, so this field should be kept below // the `priority` and `recency` fields, so that it does not interfere with the behaviour of the `Ord` derive id: StreamId, } /// Application events about streams #[derive(Debug, PartialEq, Eq)] pub enum StreamEvent { /// One or more new streams has been opened and might be readable Opened { /// Directionality for which streams have been opened dir: Dir, }, /// A currently open stream likely has data or errors waiting to be read Readable { /// Which stream is now readable id: StreamId, }, /// A formerly write-blocked stream might be ready for a write or have been stopped /// /// Only generated for streams that are currently open. Writable { /// Which stream is now writable id: StreamId, }, /// A finished stream has been fully acknowledged or stopped Finished { /// Which stream has been finished id: StreamId, }, /// The peer asked us to stop sending on an outgoing stream Stopped { /// Which stream has been stopped id: StreamId, /// Error code supplied by the peer error_code: VarInt, }, /// At least one new stream of a certain directionality may be opened Available { /// Directionality for which streams are newly available dir: Dir, }, } /// Indicates whether a frame needs to be transmitted /// /// This type wraps around bool and uses the `#[must_use]` attribute in order /// to prevent accidental loss of the frame transmission requirement. #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] #[must_use = "A frame might need to be enqueued"] pub struct ShouldTransmit(bool); impl ShouldTransmit { /// Returns whether a frame should be transmitted pub fn should_transmit(self) -> bool { self.0 } } /// Error indicating that a stream has not been opened or has already been finished or reset #[derive(Debug, Error, Clone, PartialEq, Eq)] #[error("closed stream")] pub struct ClosedStream { _private: (), } impl ClosedStream { #[doc(hidden)] // For use in quinn only pub fn new() -> Self { Self { _private: () } } } impl From for io::Error { fn from(x: ClosedStream) -> Self { Self::new(io::ErrorKind::NotConnected, x) } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum StreamHalf { Send, Recv, } quinn-proto-0.11.9/src/connection/streams/recv.rs000064400000000000000000000474561046102023000201400ustar 00000000000000use std::collections::hash_map::Entry; use std::mem; use thiserror::Error; use tracing::debug; use super::state::get_or_insert_recv; use super::{ClosedStream, Retransmits, ShouldTransmit, StreamId, StreamsState}; use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead}; use crate::connection::streams::state::StreamRecv; use crate::{frame, TransportError, VarInt}; #[derive(Debug, Default)] pub(super) struct Recv { // NB: when adding or removing fields, remember to update `reinit`. state: RecvState, pub(super) assembler: Assembler, sent_max_stream_data: u64, pub(super) end: u64, pub(super) stopped: bool, } impl Recv { pub(super) fn new(initial_max_data: u64) -> Box { Box::new(Self { state: RecvState::default(), assembler: Assembler::new(), sent_max_stream_data: initial_max_data, end: 0, stopped: false, }) } /// Reset to the initial state pub(super) fn reinit(&mut self, initial_max_data: u64) { self.state = RecvState::default(); self.assembler.reinit(); self.sent_max_stream_data = initial_max_data; self.end = 0; self.stopped = false; } /// Process a STREAM frame /// /// Return value is `(number_of_new_bytes_ingested, stream_is_closed)` pub(super) fn ingest( &mut self, frame: frame::Stream, payload_len: usize, received: u64, max_data: u64, ) -> Result<(u64, bool), TransportError> { let end = frame.offset + frame.data.len() as u64; if end >= 2u64.pow(62) { return Err(TransportError::FLOW_CONTROL_ERROR( "maximum stream offset too large", )); } if let Some(final_offset) = self.final_offset() { if end > final_offset || (frame.fin && end != final_offset) { debug!(end, final_offset, "final size error"); return Err(TransportError::FINAL_SIZE_ERROR("")); } } let new_bytes = self.credit_consumed_by(end, received, max_data)?; // Stopped streams don't need to wait for the actual data, they just need to know // how much there was. if frame.fin && !self.stopped { if let RecvState::Recv { ref mut size } = self.state { *size = Some(end); } } self.end = self.end.max(end); // Don't bother storing data or releasing stream-level flow control credit if the stream's // already stopped if !self.stopped { self.assembler.insert(frame.offset, frame.data, payload_len); } Ok((new_bytes, frame.fin && self.stopped)) } pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), ClosedStream> { if self.stopped { return Err(ClosedStream { _private: () }); } self.stopped = true; self.assembler.clear(); // Issue flow control credit for unread data let read_credits = self.end - self.assembler.bytes_read(); // This may send a spurious STOP_SENDING if we've already received all data, but it's a bit // fiddly to distinguish that from the case where we've received a FIN but are missing some // data that the peer might still be trying to retransmit, in which case a STOP_SENDING is // still useful. Ok((read_credits, ShouldTransmit(self.is_receiving()))) } /// Returns the window that should be advertised in a `MAX_STREAM_DATA` frame /// /// The method returns a tuple which consists of the window that should be /// announced, as well as a boolean parameter which indicates if a new /// transmission of the value is recommended. If the boolean value is /// `false` the new window should only be transmitted if a previous transmission /// had failed. pub(super) fn max_stream_data(&mut self, stream_receive_window: u64) -> (u64, ShouldTransmit) { let max_stream_data = self.assembler.bytes_read() + stream_receive_window; // Only announce a window update if it's significant enough // to make it worthwhile sending a MAX_STREAM_DATA frame. // We use here a fraction of the configured stream receive window to make // the decision, and accommodate for streams using bigger windows requiring // less updates. A fixed size would also work - but it would need to be // smaller than `stream_receive_window` in order to make sure the stream // does not get stuck. let diff = max_stream_data - self.sent_max_stream_data; let transmit = self.can_send_flow_control() && diff >= (stream_receive_window / 8); (max_stream_data, ShouldTransmit(transmit)) } /// Records that a `MAX_STREAM_DATA` announcing a certain window was sent /// /// This will suppress enqueuing further `MAX_STREAM_DATA` frames unless /// either the previous transmission was not acknowledged or the window /// further increased. pub(super) fn record_sent_max_stream_data(&mut self, sent_value: u64) { if sent_value > self.sent_max_stream_data { self.sent_max_stream_data = sent_value; } } /// Whether the total amount of data that the peer will send on this stream is unknown /// /// True until we've received either a reset or the final frame. /// /// Implies that the sender might benefit from stream-level flow control updates, and we might /// need to issue connection-level flow control updates due to flow control budget use by this /// stream in the future, even if it's been stopped. pub(super) fn final_offset_unknown(&self) -> bool { matches!(self.state, RecvState::Recv { size: None }) } /// Whether stream-level flow control updates should be sent for this stream pub(super) fn can_send_flow_control(&self) -> bool { // Stream-level flow control is redundant if the sender has already sent the whole stream, // and moot if we no longer want data on this stream. self.final_offset_unknown() && !self.stopped } /// Whether data is still being accepted from the peer pub(super) fn is_receiving(&self) -> bool { matches!(self.state, RecvState::Recv { .. }) } fn final_offset(&self) -> Option { match self.state { RecvState::Recv { size } => size, RecvState::ResetRecvd { size, .. } => Some(size), } } /// Returns `false` iff the reset was redundant pub(super) fn reset( &mut self, error_code: VarInt, final_offset: VarInt, received: u64, max_data: u64, ) -> Result { // Validate final_offset if let Some(offset) = self.final_offset() { if offset != final_offset.into_inner() { return Err(TransportError::FINAL_SIZE_ERROR("inconsistent value")); } } else if self.end > final_offset.into() { return Err(TransportError::FINAL_SIZE_ERROR( "lower than high water mark", )); } self.credit_consumed_by(final_offset.into(), received, max_data)?; if matches!(self.state, RecvState::ResetRecvd { .. }) { return Ok(false); } self.state = RecvState::ResetRecvd { size: final_offset.into(), error_code, }; // Nuke buffers so that future reads fail immediately, which ensures future reads don't // issue flow control credit redundant to that already issued. We could instead special-case // reset streams during read, but it's unclear if there's any benefit to retaining data for // reset streams. self.assembler.clear(); Ok(true) } pub(super) fn reset_code(&self) -> Option { match self.state { RecvState::ResetRecvd { error_code, .. } => Some(error_code), _ => None, } } /// Compute the amount of flow control credit consumed, or return an error if more was consumed /// than issued fn credit_consumed_by( &self, offset: u64, received: u64, max_data: u64, ) -> Result { let prev_end = self.end; let new_bytes = offset.saturating_sub(prev_end); if offset > self.sent_max_stream_data || received + new_bytes > max_data { debug!( received, new_bytes, max_data, offset, stream_max_data = self.sent_max_stream_data, "flow control error" ); return Err(TransportError::FLOW_CONTROL_ERROR("")); } Ok(new_bytes) } } /// Chunks returned from [`RecvStream::read()`][crate::RecvStream::read]. /// /// ### Note: Finalization Needed /// Bytes read from the stream are not released from the congestion window until /// either [`Self::finalize()`] is called, or this type is dropped. /// /// It is recommended that you call [`Self::finalize()`] because it returns a flag /// telling you whether reading from the stream has resulted in the need to transmit a packet. /// /// If this type is leaked, the stream will remain blocked on the remote peer until /// another read from the stream is done. pub struct Chunks<'a> { id: StreamId, ordered: bool, streams: &'a mut StreamsState, pending: &'a mut Retransmits, state: ChunksState, read: u64, } impl<'a> Chunks<'a> { pub(super) fn new( id: StreamId, ordered: bool, streams: &'a mut StreamsState, pending: &'a mut Retransmits, ) -> Result { let mut entry = match streams.recv.entry(id) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => return Err(ReadableError::ClosedStream), }; let mut recv = match get_or_insert_recv(streams.stream_receive_window)(entry.get_mut()).stopped { true => return Err(ReadableError::ClosedStream), false => entry.remove().unwrap().into_inner(), // this can't fail due to the previous get_or_insert_with }; recv.assembler.ensure_ordering(ordered)?; Ok(Self { id, ordered, streams, pending, state: ChunksState::Readable(recv), read: 0, }) } /// Next /// /// Should call finalize() when done calling this. pub fn next(&mut self, max_length: usize) -> Result, ReadError> { let rs = match self.state { ChunksState::Readable(ref mut rs) => rs, ChunksState::Reset(error_code) => { return Err(ReadError::Reset(error_code)); } ChunksState::Finished => { return Ok(None); } ChunksState::Finalized => panic!("must not call next() after finalize()"), }; if let Some(chunk) = rs.assembler.read(max_length, self.ordered) { self.read += chunk.bytes.len() as u64; return Ok(Some(chunk)); } match rs.state { RecvState::ResetRecvd { error_code, .. } => { debug_assert_eq!(self.read, 0, "reset streams have empty buffers"); let state = mem::replace(&mut self.state, ChunksState::Reset(error_code)); // At this point if we have `rs` self.state must be `ChunksState::Readable` let recv = match state { ChunksState::Readable(recv) => StreamRecv::Open(recv), _ => unreachable!("state must be ChunkState::Readable"), }; self.streams.stream_recv_freed(self.id, recv); Err(ReadError::Reset(error_code)) } RecvState::Recv { size } => { if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end { let state = mem::replace(&mut self.state, ChunksState::Finished); // At this point if we have `rs` self.state must be `ChunksState::Readable` let recv = match state { ChunksState::Readable(recv) => StreamRecv::Open(recv), _ => unreachable!("state must be ChunkState::Readable"), }; self.streams.stream_recv_freed(self.id, recv); Ok(None) } else { // We don't need a distinct `ChunksState` variant for a blocked stream because // retrying a read harmlessly re-traces our steps back to returning // `Err(Blocked)` again. The buffers can't refill and the stream's own state // can't change so long as this `Chunks` exists. Err(ReadError::Blocked) } } } } /// Mark the read data as consumed from the stream. /// /// The number of read bytes will be released from the congestion window, /// allowing the remote peer to send more data if it was previously blocked. /// /// If [`ShouldTransmit::should_transmit()`] returns `true`, /// a packet needs to be sent to the peer informing them that the stream is unblocked. /// This means that you should call [`Connection::poll_transmit()`][crate::Connection::poll_transmit] /// and send the returned packet as soon as is reasonable, to unblock the remote peer. pub fn finalize(mut self) -> ShouldTransmit { self.finalize_inner() } fn finalize_inner(&mut self) -> ShouldTransmit { let state = mem::replace(&mut self.state, ChunksState::Finalized); if let ChunksState::Finalized = state { // Noop on repeated calls return ShouldTransmit(false); } // We issue additional stream ID credit after the application is notified that a previously // open stream has finished or been reset and we've therefore disposed of its state, as // recorded by `stream_freed` calls in `next`. let mut should_transmit = self.streams.queue_max_stream_id(self.pending); // If the stream hasn't finished, we may need to issue stream-level flow control credit if let ChunksState::Readable(mut rs) = state { let (_, max_stream_data) = rs.max_stream_data(self.streams.stream_receive_window); should_transmit |= max_stream_data.0; if max_stream_data.0 { self.pending.max_stream_data.insert(self.id); } // Return the stream to storage for future use self.streams .recv .insert(self.id, Some(StreamRecv::Open(rs))); } // Issue connection-level flow control credit for any data we read regardless of state let max_data = self.streams.add_read_credits(self.read); self.pending.max_data |= max_data.0; should_transmit |= max_data.0; ShouldTransmit(should_transmit) } } impl Drop for Chunks<'_> { fn drop(&mut self) { let _ = self.finalize_inner(); } } enum ChunksState { Readable(Box), Reset(VarInt), Finished, Finalized, } /// Errors triggered when reading from a recv stream #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ReadError { /// No more data is currently available on this stream. /// /// If more data on this stream is received from the peer, an `Event::StreamReadable` will be /// generated for this stream, indicating that retrying the read might succeed. #[error("blocked")] Blocked, /// The peer abandoned transmitting data on this stream. /// /// Carries an application-defined error code. #[error("reset by peer: code {0}")] Reset(VarInt), } /// Errors triggered when opening a recv stream for reading #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ReadableError { /// The stream has not been opened or was already stopped, finished, or reset #[error("closed stream")] ClosedStream, /// Attempted an ordered read following an unordered read /// /// Performing an unordered read allows discontinuities to arise in the receive buffer of a /// stream which cannot be recovered, making further ordered reads impossible. #[error("ordered read after unordered read")] IllegalOrderedRead, } impl From for ReadableError { fn from(_: IllegalOrderedRead) -> Self { Self::IllegalOrderedRead } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum RecvState { Recv { size: Option }, ResetRecvd { size: u64, error_code: VarInt }, } impl Default for RecvState { fn default() -> Self { Self::Recv { size: None } } } #[cfg(test)] mod tests { use bytes::Bytes; use crate::{Dir, Side}; use super::*; #[test] fn reordered_frames_while_stopped() { const INITIAL_BYTES: u64 = 3; const INITIAL_OFFSET: u64 = 3; const RECV_WINDOW: u64 = 8; let mut s = Recv::new(RECV_WINDOW); let mut data_recvd = 0; // Receive bytes 3..6 let (new_bytes, is_closed) = s .ingest( frame::Stream { id: StreamId::new(Side::Client, Dir::Uni, 0), offset: INITIAL_OFFSET, fin: false, data: Bytes::from_static(&[0; INITIAL_BYTES as usize]), }, 123, data_recvd, data_recvd + 1024, ) .unwrap(); data_recvd += new_bytes; assert_eq!(new_bytes, INITIAL_OFFSET + INITIAL_BYTES); assert!(!is_closed); let (credits, transmit) = s.stop().unwrap(); assert!(transmit.should_transmit()); assert_eq!( credits, INITIAL_OFFSET + INITIAL_BYTES, "full connection flow control credit is issued by stop" ); let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW); assert!(!transmit.should_transmit()); assert_eq!( max_stream_data, RECV_WINDOW, "stream flow control credit isn't issued by stop" ); // Receive byte 7 let (new_bytes, is_closed) = s .ingest( frame::Stream { id: StreamId::new(Side::Client, Dir::Uni, 0), offset: RECV_WINDOW - 1, fin: false, data: Bytes::from_static(&[0; 1]), }, 123, data_recvd, data_recvd + 1024, ) .unwrap(); data_recvd += new_bytes; assert_eq!(new_bytes, RECV_WINDOW - (INITIAL_OFFSET + INITIAL_BYTES)); assert!(!is_closed); let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW); assert!(!transmit.should_transmit()); assert_eq!( max_stream_data, RECV_WINDOW, "stream flow control credit isn't issued after stop" ); // Receive bytes 0..3 let (new_bytes, is_closed) = s .ingest( frame::Stream { id: StreamId::new(Side::Client, Dir::Uni, 0), offset: 0, fin: false, data: Bytes::from_static(&[0; INITIAL_OFFSET as usize]), }, 123, data_recvd, data_recvd + 1024, ) .unwrap(); assert_eq!( new_bytes, 0, "reordered frames don't issue connection-level flow control for stopped streams" ); assert!(!is_closed); let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW); assert!(!transmit.should_transmit()); assert_eq!( max_stream_data, RECV_WINDOW, "stream flow control credit isn't issued after stop" ); } } quinn-proto-0.11.9/src/connection/streams/send.rs000064400000000000000000000316261046102023000201220ustar 00000000000000use bytes::Bytes; use thiserror::Error; use crate::{connection::send_buffer::SendBuffer, frame, VarInt}; #[derive(Debug)] pub(super) struct Send { pub(super) max_data: u64, pub(super) state: SendState, pub(super) pending: SendBuffer, pub(super) priority: i32, /// Whether a frame containing a FIN bit must be transmitted, even if we don't have any new data pub(super) fin_pending: bool, /// Whether this stream is in the `connection_blocked` list of `Streams` pub(super) connection_blocked: bool, /// The reason the peer wants us to stop, if `STOP_SENDING` was received pub(super) stop_reason: Option, } impl Send { pub(super) fn new(max_data: VarInt) -> Box { Box::new(Self { max_data: max_data.into(), state: SendState::Ready, pending: SendBuffer::new(), priority: 0, fin_pending: false, connection_blocked: false, stop_reason: None, }) } /// Whether the stream has been reset pub(super) fn is_reset(&self) -> bool { matches!(self.state, SendState::ResetSent { .. }) } pub(super) fn finish(&mut self) -> Result<(), FinishError> { if let Some(error_code) = self.stop_reason { Err(FinishError::Stopped(error_code)) } else if self.state == SendState::Ready { self.state = SendState::DataSent { finish_acked: false, }; self.fin_pending = true; Ok(()) } else { Err(FinishError::ClosedStream) } } pub(super) fn write( &mut self, source: &mut S, limit: u64, ) -> Result { if !self.is_writable() { return Err(WriteError::ClosedStream); } if let Some(error_code) = self.stop_reason { return Err(WriteError::Stopped(error_code)); } let budget = self.max_data - self.pending.offset(); if budget == 0 { return Err(WriteError::Blocked); } let mut limit = limit.min(budget) as usize; let mut result = Written::default(); loop { let (chunk, chunks_consumed) = source.pop_chunk(limit); result.chunks += chunks_consumed; result.bytes += chunk.len(); if chunk.is_empty() { break; } limit -= chunk.len(); self.pending.write(chunk); } Ok(result) } /// Update stream state due to a reset sent by the local application pub(super) fn reset(&mut self) { use SendState::*; if let DataSent { .. } | Ready = self.state { self.state = ResetSent; } } /// Handle STOP_SENDING /// /// Returns true if the stream was stopped due to this frame, and false /// if it had been stopped before pub(super) fn try_stop(&mut self, error_code: VarInt) -> bool { if self.stop_reason.is_none() { self.stop_reason = Some(error_code); true } else { false } } /// Returns whether the stream has been finished and all data has been acknowledged by the peer pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool { self.pending.ack(frame.offsets); match self.state { SendState::DataSent { ref mut finish_acked, } => { *finish_acked |= frame.fin; *finish_acked && self.pending.is_fully_acked() } _ => false, } } /// Handle increase to stream-level flow control limit /// /// Returns whether the stream was unblocked pub(super) fn increase_max_data(&mut self, offset: u64) -> bool { if offset <= self.max_data || self.state != SendState::Ready { return false; } let was_blocked = self.pending.offset() == self.max_data; self.max_data = offset; was_blocked } pub(super) fn offset(&self) -> u64 { self.pending.offset() } pub(super) fn is_pending(&self) -> bool { self.pending.has_unsent_data() || self.fin_pending } pub(super) fn is_writable(&self) -> bool { matches!(self.state, SendState::Ready) } } /// A [`BytesSource`] implementation for `&'a mut [Bytes]` /// /// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to /// a configured limit. pub(crate) struct BytesArray<'a> { /// The wrapped slice of `Bytes` chunks: &'a mut [Bytes], /// The amount of chunks consumed from this source consumed: usize, } impl<'a> BytesArray<'a> { pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self { Self { chunks, consumed: 0, } } } impl BytesSource for BytesArray<'_> { fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { // The loop exists to skip empty chunks while still marking them as // consumed let mut chunks_consumed = 0; while self.consumed < self.chunks.len() { let chunk = &mut self.chunks[self.consumed]; if chunk.len() <= limit { let chunk = std::mem::take(chunk); self.consumed += 1; chunks_consumed += 1; if chunk.is_empty() { continue; } return (chunk, chunks_consumed); } else if limit > 0 { let chunk = chunk.split_to(limit); return (chunk, chunks_consumed); } else { break; } } (Bytes::new(), chunks_consumed) } } /// A [`BytesSource`] implementation for `&[u8]` /// /// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily /// created from a reference. This allows to defer the allocation until it is /// known how much data needs to be copied. pub(crate) struct ByteSlice<'a> { /// The wrapped byte slice data: &'a [u8], } impl<'a> ByteSlice<'a> { pub(crate) fn from_slice(data: &'a [u8]) -> Self { Self { data } } } impl BytesSource for ByteSlice<'_> { fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { let limit = limit.min(self.data.len()); if limit == 0 { return (Bytes::new(), 0); } let chunk = Bytes::from(self.data[..limit].to_owned()); self.data = &self.data[chunk.len()..]; let chunks_consumed = usize::from(self.data.is_empty()); (chunk, chunks_consumed) } } /// A source of one or more buffers which can be converted into `Bytes` buffers on demand /// /// The purpose of this data type is to defer conversion as long as possible, /// so that no heap allocation is required in case no data is writable. pub trait BytesSource { /// Returns the next chunk from the source of owned chunks. /// /// This method will consume parts of the source. /// Calling it will yield `Bytes` elements up to the configured `limit`. /// /// The method returns a tuple: /// - The first item is the yielded `Bytes` element. The element will be /// empty if the limit is zero or no more data is available. /// - The second item returns how many complete chunks inside the source had /// had been consumed. This can be less than 1, if a chunk inside the /// source had been truncated in order to adhere to the limit. It can also /// be more than 1, if zero-length chunks had been skipped. fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize); } /// Indicates how many bytes and chunks had been transferred in a write operation #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Written { /// The amount of bytes which had been written pub bytes: usize, /// The amount of full chunks which had been written /// /// If a chunk was only partially written, it will not be counted by this field. pub chunks: usize, } /// Errors triggered while writing to a send stream #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum WriteError { /// The peer is not able to accept additional data, or the connection is congested. /// /// If the peer issues additional flow control credit, a [`StreamEvent::Writable`] event will /// be generated, indicating that retrying the write might succeed. /// /// [`StreamEvent::Writable`]: crate::StreamEvent::Writable #[error("unable to accept further writes")] Blocked, /// The peer is no longer accepting data on this stream, and it has been implicitly reset. The /// stream cannot be finished or further written to. /// /// Carries an application-defined error code. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished #[error("stopped by peer: code {0}")] Stopped(VarInt), /// The stream has not been opened or has already been finished or reset #[error("closed stream")] ClosedStream, } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(super) enum SendState { /// Sending new data Ready, /// Stream was finished; now sending retransmits only DataSent { finish_acked: bool }, /// Sent RESET ResetSent, } /// Reasons why attempting to finish a stream might fail #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum FinishError { /// The peer is no longer accepting data on this stream. No /// [`StreamEvent::Finished`] event will be emitted for this stream. /// /// Carries an application-defined error code. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished #[error("stopped by peer: code {0}")] Stopped(VarInt), /// The stream has not been opened or was already finished or reset #[error("closed stream")] ClosedStream, } #[cfg(test)] mod tests { use super::*; #[test] fn bytes_array() { let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); for limit in 0..full.len() { let mut chunks = [ Bytes::from_static(b""), Bytes::from_static(b"Hello "), Bytes::from_static(b"Wo"), Bytes::from_static(b""), Bytes::from_static(b"r"), Bytes::from_static(b"ld"), Bytes::from_static(b""), Bytes::from_static(b" 12345678"), Bytes::from_static(b"9 ABCDE"), Bytes::from_static(b"F"), Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"), ]; let num_chunks = chunks.len(); let last_chunk_len = chunks[chunks.len() - 1].len(); let mut array = BytesArray::from_chunks(&mut chunks); let mut buf = Vec::new(); let mut chunks_popped = 0; let mut chunks_consumed = 0; let mut remaining = limit; loop { let (chunk, consumed) = array.pop_chunk(remaining); chunks_consumed += consumed; if !chunk.is_empty() { buf.extend_from_slice(&chunk); remaining -= chunk.len(); chunks_popped += 1; } else { break; } } assert_eq!(&buf[..], &full[..limit]); if limit == full.len() { // Full consumption of the last chunk assert_eq!(chunks_consumed, num_chunks); // Since there are empty chunks, we consume more than there are popped assert_eq!(chunks_consumed, chunks_popped + 3); } else if limit > full.len() - last_chunk_len { // Partial consumption of the last chunk assert_eq!(chunks_consumed, num_chunks - 1); assert_eq!(chunks_consumed, chunks_popped + 2); } } } #[test] fn byte_slice() { let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); for limit in 0..full.len() { let mut array = ByteSlice::from_slice(&full[..]); let mut buf = Vec::new(); let mut chunks_popped = 0; let mut chunks_consumed = 0; let mut remaining = limit; loop { let (chunk, consumed) = array.pop_chunk(remaining); chunks_consumed += consumed; if !chunk.is_empty() { buf.extend_from_slice(&chunk); remaining -= chunk.len(); chunks_popped += 1; } else { break; } } assert_eq!(&buf[..], &full[..limit]); if limit != 0 { assert_eq!(chunks_popped, 1); } else { assert_eq!(chunks_popped, 0); } if limit == full.len() { assert_eq!(chunks_consumed, 1); } else { assert_eq!(chunks_consumed, 0); } } } } quinn-proto-0.11.9/src/connection/streams/state.rs000064400000000000000000002065071046102023000203130ustar 00000000000000use std::{ collections::{hash_map, VecDeque}, convert::TryFrom, mem, }; use bytes::BufMut; use rustc_hash::FxHashMap; use tracing::{debug, trace}; use super::{ PendingStreamsQueue, Recv, Retransmits, Send, SendState, ShouldTransmit, StreamEvent, StreamHalf, ThinRetransmits, }; use crate::{ coding::BufMutExt, connection::stats::FrameStats, frame::{self, FrameStruct, StreamMetaVec}, transport_parameters::TransportParameters, Dir, Side, StreamId, TransportError, VarInt, MAX_STREAM_COUNT, }; /// Wrapper around `Recv` that facilitates reusing `Recv` instances #[derive(Debug)] pub(super) enum StreamRecv { /// A `Recv` that is ready to be opened Free(Box), /// A `Recv` that has been opened Open(Box), } impl StreamRecv { /// Returns a reference to the inner `Recv` if the stream is open pub(super) fn as_open_recv(&self) -> Option<&Recv> { match self { Self::Open(r) => Some(r), _ => None, } } // Returns a mutable reference to the inner `Recv` if the stream is open pub(super) fn as_open_recv_mut(&mut self) -> Option<&mut Recv> { match self { Self::Open(r) => Some(r), _ => None, } } // Returns the inner `Recv` pub(super) fn into_inner(self) -> Box { match self { Self::Free(r) | Self::Open(r) => r, } } // Reinitialize the stream so the inner `Recv` can be reused pub(super) fn free(self, initial_max_data: u64) -> Self { match self { Self::Free(_) => unreachable!("Self::Free on reinit()"), Self::Open(mut recv) => { recv.reinit(initial_max_data); Self::Free(recv) } } } } #[allow(unreachable_pub)] // fuzzing only pub struct StreamsState { pub(super) side: Side, // Set of streams that are currently open, or could be immediately opened by the peer pub(super) send: FxHashMap>>, pub(super) recv: FxHashMap>, pub(super) free_recv: Vec, pub(super) next: [u64; 2], /// Maximum number of locally-initiated streams that may be opened over the lifetime of the /// connection so far, per direction pub(super) max: [u64; 2], /// Maximum number of remotely-initiated streams that may be opened over the lifetime of the /// connection so far, per direction pub(super) max_remote: [u64; 2], /// Value of `max_remote` most recently transmitted to the peer in a `MAX_STREAMS` frame sent_max_remote: [u64; 2], /// Number of streams that we've given the peer permission to open and which aren't fully closed pub(super) allocated_remote_count: [u64; 2], /// Size of the desired stream flow control window. May be smaller than `allocated_remote_count` /// due to `set_max_concurrent` calls. max_concurrent_remote_count: [u64; 2], /// Whether `max_concurrent_remote_count` has ever changed flow_control_adjusted: bool, /// Lowest remotely-initiated stream index that haven't actually been opened by the peer pub(super) next_remote: [u64; 2], /// Whether the remote endpoint has opened any streams the application doesn't know about yet, /// per directionality opened: [bool; 2], // Next to report to the application, once opened pub(super) next_reported_remote: [u64; 2], /// Number of outbound streams /// /// This differs from `self.send.len()` in that it does not include streams that the peer is /// permitted to open but which have not yet been opened. pub(super) send_streams: usize, /// Streams with outgoing data queued, sorted by priority pub(super) pending: PendingStreamsQueue, events: VecDeque, /// Streams blocked on connection-level flow control or stream window space /// /// Streams are only added to this list when a write fails. pub(super) connection_blocked: Vec, /// Connection-level flow control budget dictated by the peer pub(super) max_data: u64, /// The initial receive window receive_window: u64, /// Limit on incoming data, which is transmitted through `MAX_DATA` frames local_max_data: u64, /// The last value of `MAX_DATA` which had been queued for transmission in /// an outgoing `MAX_DATA` frame sent_max_data: VarInt, /// Sum of current offsets of all send streams. pub(super) data_sent: u64, /// Sum of end offsets of all receive streams. Includes gaps, so it's an upper bound. data_recvd: u64, /// Total quantity of unacknowledged outgoing data pub(super) unacked_data: u64, /// Configured upper bound for `unacked_data` pub(super) send_window: u64, /// Configured upper bound for how much unacked data the peer can send us per stream pub(super) stream_receive_window: u64, // Pertinent state from the TransportParameters supplied by the peer initial_max_stream_data_uni: VarInt, initial_max_stream_data_bidi_local: VarInt, initial_max_stream_data_bidi_remote: VarInt, /// The shrink to be applied to local_max_data when receive_window is shrunk receive_window_shrink_debt: u64, } impl StreamsState { #[allow(unreachable_pub)] // fuzzing only pub fn new( side: Side, max_remote_uni: VarInt, max_remote_bi: VarInt, send_window: u64, receive_window: VarInt, stream_receive_window: VarInt, ) -> Self { let mut this = Self { side, send: FxHashMap::default(), recv: FxHashMap::default(), free_recv: Vec::new(), next: [0, 0], max: [0, 0], max_remote: [max_remote_bi.into(), max_remote_uni.into()], sent_max_remote: [max_remote_bi.into(), max_remote_uni.into()], allocated_remote_count: [max_remote_bi.into(), max_remote_uni.into()], max_concurrent_remote_count: [max_remote_bi.into(), max_remote_uni.into()], flow_control_adjusted: false, next_remote: [0, 0], opened: [false, false], next_reported_remote: [0, 0], send_streams: 0, pending: PendingStreamsQueue::new(), events: VecDeque::new(), connection_blocked: Vec::new(), max_data: 0, receive_window: receive_window.into(), local_max_data: receive_window.into(), sent_max_data: receive_window, data_sent: 0, data_recvd: 0, unacked_data: 0, send_window, stream_receive_window: stream_receive_window.into(), initial_max_stream_data_uni: 0u32.into(), initial_max_stream_data_bidi_local: 0u32.into(), initial_max_stream_data_bidi_remote: 0u32.into(), receive_window_shrink_debt: 0, }; for dir in Dir::iter() { for i in 0..this.max_remote[dir as usize] { this.insert(true, StreamId::new(!side, dir, i)); } } this } pub(crate) fn set_params(&mut self, params: &TransportParameters) { self.initial_max_stream_data_uni = params.initial_max_stream_data_uni; self.initial_max_stream_data_bidi_local = params.initial_max_stream_data_bidi_local; self.initial_max_stream_data_bidi_remote = params.initial_max_stream_data_bidi_remote; self.max[Dir::Bi as usize] = params.initial_max_streams_bidi.into(); self.max[Dir::Uni as usize] = params.initial_max_streams_uni.into(); self.received_max_data(params.initial_max_data); for i in 0..self.max_remote[Dir::Bi as usize] { let id = StreamId::new(!self.side, Dir::Bi, i); if let Some(s) = self.send.get_mut(&id).and_then(|s| s.as_mut()) { s.max_data = params.initial_max_stream_data_bidi_local.into(); } } } /// Ensure we have space for at least a full flow control window of remotely-initiated streams /// to be open, and notify the peer if the window has moved fn ensure_remote_streams(&mut self, dir: Dir) { let new_count = self.max_concurrent_remote_count[dir as usize] .saturating_sub(self.allocated_remote_count[dir as usize]); for i in 0..new_count { let id = StreamId::new(!self.side, dir, self.max_remote[dir as usize] + i); self.insert(true, id); } self.allocated_remote_count[dir as usize] += new_count; self.max_remote[dir as usize] += new_count; } pub(crate) fn zero_rtt_rejected(&mut self) { // Revert to initial state for outgoing streams for dir in Dir::iter() { for i in 0..self.next[dir as usize] { // We don't bother calling `stream_freed` here because we explicitly reset affected // counters below. let id = StreamId::new(self.side, dir, i); self.send.remove(&id).unwrap(); if let Dir::Bi = dir { self.recv.remove(&id).unwrap(); } } self.next[dir as usize] = 0; // If 0-RTT was rejected, any flow control frames we sent were lost. if self.flow_control_adjusted { // Conservative approximation of whatever we sent in transport parameters self.sent_max_remote[dir as usize] = 0; } } self.pending.clear(); self.send_streams = 0; self.data_sent = 0; self.connection_blocked.clear(); } /// Process incoming stream frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted pub(crate) fn received( &mut self, frame: frame::Stream, payload_len: usize, ) -> Result { let id = frame.id; self.validate_receive_id(id).map_err(|e| { debug!("received illegal STREAM frame"); e })?; let rs = match self .recv .get_mut(&id) .map(get_or_insert_recv(self.stream_receive_window)) { Some(rs) => rs, None => { trace!("dropping frame for closed stream"); return Ok(ShouldTransmit(false)); } }; if !rs.is_receiving() { trace!("dropping frame for finished stream"); return Ok(ShouldTransmit(false)); } let (new_bytes, closed) = rs.ingest(frame, payload_len, self.data_recvd, self.local_max_data)?; self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { self.on_stream_frame(true, id); return Ok(ShouldTransmit(false)); } // Stopped streams become closed instantly on FIN, so check whether we need to clean up if closed { let rs = self.recv.remove(&id).flatten().unwrap(); self.stream_recv_freed(id, rs); } // We don't buffer data on stopped streams, so issue flow control credit immediately Ok(self.add_read_credits(new_bytes)) } /// Process incoming RESET_STREAM frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted #[allow(unreachable_pub)] // fuzzing only pub fn received_reset( &mut self, frame: frame::ResetStream, ) -> Result { let frame::ResetStream { id, error_code, final_offset, } = frame; self.validate_receive_id(id).map_err(|e| { debug!("received illegal RESET_STREAM frame"); e })?; let rs = match self .recv .get_mut(&id) .map(get_or_insert_recv(self.stream_receive_window)) { Some(stream) => stream, None => { trace!("received RESET_STREAM on closed stream"); return Ok(ShouldTransmit(false)); } }; // State transition if !rs.reset( error_code, final_offset, self.data_recvd, self.local_max_data, )? { // Redundant reset return Ok(ShouldTransmit(false)); } let bytes_read = rs.assembler.bytes_read(); let stopped = rs.stopped; let end = rs.end; if stopped { // Stopped streams should be disposed immediately on reset let rs = self.recv.remove(&id).flatten().unwrap(); self.stream_recv_freed(id, rs); } self.on_stream_frame(!stopped, id); // Update connection-level flow control Ok(if bytes_read != final_offset.into_inner() { // bytes_read is always <= end, so this won't underflow. self.data_recvd = self .data_recvd .saturating_add(u64::from(final_offset) - end); self.add_read_credits(u64::from(final_offset) - bytes_read) } else { ShouldTransmit(false) }) } /// Process incoming `STOP_SENDING` frame #[allow(unreachable_pub)] // fuzzing only pub fn received_stop_sending(&mut self, id: StreamId, error_code: VarInt) { let max_send_data = self.max_send_data(id); let stream = match self .send .get_mut(&id) .map(get_or_insert_send(max_send_data)) { Some(ss) => ss, None => return, }; if stream.try_stop(error_code) { self.events .push_back(StreamEvent::Stopped { id, error_code }); self.on_stream_frame(false, id); } } pub(crate) fn reset_acked(&mut self, id: StreamId) { match self.send.entry(id) { hash_map::Entry::Vacant(_) => {} hash_map::Entry::Occupied(e) => { if let Some(SendState::ResetSent) = e.get().as_ref().map(|s| s.state) { e.remove_entry(); self.stream_freed(id, StreamHalf::Send); } } } } /// Whether any stream data is queued, regardless of control frames pub(crate) fn can_send_stream_data(&self) -> bool { // Reset streams may linger in the pending stream list, but will never produce stream frames self.pending.iter().any(|stream| { self.send .get(&stream.id) .and_then(|s| s.as_ref()) .map_or(false, |s| !s.is_reset()) }) } /// Whether MAX_STREAM_DATA frames could be sent for stream `id` pub(crate) fn can_send_flow_control(&self, id: StreamId) -> bool { self.recv .get(&id) .and_then(|s| s.as_ref()) .and_then(|s| s.as_open_recv()) .map_or(false, |s| s.can_send_flow_control()) } pub(in crate::connection) fn write_control_frames( &mut self, buf: &mut Vec, pending: &mut Retransmits, retransmits: &mut ThinRetransmits, stats: &mut FrameStats, max_size: usize, ) { // RESET_STREAM while buf.len() + frame::ResetStream::SIZE_BOUND < max_size { let (id, error_code) = match pending.reset_stream.pop() { Some(x) => x, None => break, }; let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; trace!(stream = %id, "RESET_STREAM"); retransmits .get_or_create() .reset_stream .push((id, error_code)); frame::ResetStream { id, error_code, final_offset: VarInt::try_from(stream.offset()).expect("impossibly large offset"), } .encode(buf); stats.reset_stream += 1; } // STOP_SENDING while buf.len() + frame::StopSending::SIZE_BOUND < max_size { let frame = match pending.stop_sending.pop() { Some(x) => x, None => break, }; // We may need to transmit STOP_SENDING even for streams whose state we have discarded, // because we are able to discard local state for stopped streams immediately upon // receiving FIN, even if the peer still has arbitrarily large amounts of data to // (re)transmit due to loss or unconventional sending strategy. We could fine-tune this // a little by dropping the frame if we specifically know the stream's been reset by the // peer, but we discard that information as soon as the application consumes it, so it // can't be relied upon regardless. trace!(stream = %frame.id, "STOP_SENDING"); frame.encode(buf); retransmits.get_or_create().stop_sending.push(frame); stats.stop_sending += 1; } // MAX_DATA if pending.max_data && buf.len() + 9 < max_size { pending.max_data = false; // `local_max_data` can grow bigger than `VarInt`. // For transmission inside QUIC frames we need to clamp it to the // maximum allowed `VarInt` size. let max = VarInt::try_from(self.local_max_data).unwrap_or(VarInt::MAX); trace!(value = max.into_inner(), "MAX_DATA"); if max > self.sent_max_data { // Record that a `MAX_DATA` announcing a certain window was sent. This will // suppress enqueuing further `MAX_DATA` frames unless either the previous // transmission was not acknowledged or the window further increased. self.sent_max_data = max; } retransmits.get_or_create().max_data = true; buf.write(frame::FrameType::MAX_DATA); buf.write(max); stats.max_data += 1; } // MAX_STREAM_DATA while buf.len() + 17 < max_size { let id = match pending.max_stream_data.iter().next() { Some(x) => *x, None => break, }; pending.max_stream_data.remove(&id); let rs = match self .recv .get_mut(&id) .and_then(|s| s.as_mut()) .and_then(|s| s.as_open_recv_mut()) { Some(x) => x, None => continue, }; if !rs.can_send_flow_control() { continue; } retransmits.get_or_create().max_stream_data.insert(id); let (max, _) = rs.max_stream_data(self.stream_receive_window); rs.record_sent_max_stream_data(max); trace!(stream = %id, max = max, "MAX_STREAM_DATA"); buf.write(frame::FrameType::MAX_STREAM_DATA); buf.write(id); buf.write_var(max); stats.max_stream_data += 1; } // MAX_STREAMS for dir in Dir::iter() { if !pending.max_stream_id[dir as usize] || buf.len() + 9 >= max_size { continue; } pending.max_stream_id[dir as usize] = false; retransmits.get_or_create().max_stream_id[dir as usize] = true; self.sent_max_remote[dir as usize] = self.max_remote[dir as usize]; trace!( value = self.max_remote[dir as usize], "MAX_STREAMS ({:?})", dir ); buf.write(match dir { Dir::Uni => frame::FrameType::MAX_STREAMS_UNI, Dir::Bi => frame::FrameType::MAX_STREAMS_BIDI, }); buf.write_var(self.max_remote[dir as usize]); match dir { Dir::Uni => stats.max_streams_uni += 1, Dir::Bi => stats.max_streams_bidi += 1, } } } pub(crate) fn write_stream_frames( &mut self, buf: &mut Vec, max_buf_size: usize, fair: bool, ) -> StreamMetaVec { let mut stream_frames = StreamMetaVec::new(); while buf.len() + frame::Stream::SIZE_BOUND < max_buf_size { if max_buf_size .checked_sub(buf.len() + frame::Stream::SIZE_BOUND) .is_none() { break; } // Pop the stream of the highest priority that currently has pending data // If the stream still has some pending data left after writing, it will be reinserted, otherwise not let Some(stream) = self.pending.pop() else { break; }; let id = stream.id; let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(s) => s, // Stream was reset with pending data and the reset was acknowledged None => continue, }; // Reset streams aren't removed from the pending list and still exist while the peer // hasn't acknowledged the reset, but should not generate STREAM frames, so we need to // check for them explicitly. if stream.is_reset() { continue; } // Now that we know the `StreamId`, we can better account for how many bytes // are required to encode it. let max_buf_size = max_buf_size - buf.len() - 1 - VarInt::size(id.into()); let (offsets, encode_length) = stream.pending.poll_transmit(max_buf_size); let fin = offsets.end == stream.pending.offset() && matches!(stream.state, SendState::DataSent { .. }); if fin { stream.fin_pending = false; } if stream.is_pending() { // If the stream still has pending data, reinsert it, possibly with an updated priority value // Fairness with other streams is achieved by implementing round-robin scheduling, // so that the other streams will have a chance to write data // before we touch this stream again. if fair { self.pending.push_pending(id, stream.priority); } else { self.pending.reinsert_pending(id, stream.priority); } } let meta = frame::StreamMeta { id, offsets, fin }; trace!(id = %meta.id, off = meta.offsets.start, len = meta.offsets.end - meta.offsets.start, fin = meta.fin, "STREAM"); meta.encode(encode_length, buf); // The range might not be retrievable in a single `get` if it is // stored in noncontiguous fashion. Therefore this loop iterates // until the range is fully copied into the frame. let mut offsets = meta.offsets.clone(); while offsets.start != offsets.end { let data = stream.pending.get(offsets.clone()); offsets.start += data.len() as u64; buf.put_slice(data); } stream_frames.push(meta); } stream_frames } /// Notify the application that new streams were opened or a stream became readable. fn on_stream_frame(&mut self, notify_readable: bool, stream: StreamId) { if stream.initiator() == self.side { // Notifying about the opening of locally-initiated streams would be redundant. if notify_readable { self.events.push_back(StreamEvent::Readable { id: stream }); } return; } let next = &mut self.next_remote[stream.dir() as usize]; if stream.index() >= *next { *next = stream.index() + 1; self.opened[stream.dir() as usize] = true; } else if notify_readable { self.events.push_back(StreamEvent::Readable { id: stream }); } } pub(crate) fn received_ack_of(&mut self, frame: frame::StreamMeta) { let mut entry = match self.send.entry(frame.id) { hash_map::Entry::Vacant(_) => return, hash_map::Entry::Occupied(e) => e, }; let stream = match entry.get_mut().as_mut() { Some(s) => s, None => { // Because we only call this after sending data on this stream, // this closure should be unreachable. If we did somehow screw that up, // then we might hit an underflow below with unpredictable effects down // the line. Best to short-circuit. return; } }; if stream.is_reset() { // We account for outstanding data on reset streams at time of reset return; } let id = frame.id; self.unacked_data -= frame.offsets.end - frame.offsets.start; if !stream.ack(frame) { // The stream is unfinished or may still need retransmits return; } entry.remove_entry(); self.stream_freed(id, StreamHalf::Send); self.events.push_back(StreamEvent::Finished { id }); } pub(crate) fn retransmit(&mut self, frame: frame::StreamMeta) { let stream = match self.send.get_mut(&frame.id).and_then(|s| s.as_mut()) { // Loss of data on a closed stream is a noop None => return, Some(x) => x, }; if !stream.is_pending() { self.pending.push_pending(frame.id, stream.priority); } stream.fin_pending |= frame.fin; stream.pending.retransmit(frame.offsets); } pub(crate) fn retransmit_all_for_0rtt(&mut self) { for dir in Dir::iter() { for index in 0..self.next[dir as usize] { let id = StreamId::new(Side::Client, dir, index); let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(stream) => stream, None => continue, }; if stream.pending.is_fully_acked() && !stream.fin_pending { // Stream data can't be acked in 0-RTT, so we must not have sent anything on // this stream continue; } if !stream.is_pending() { self.pending.push_pending(id, stream.priority); } stream.pending.retransmit_all_for_0rtt(); } } } pub(crate) fn received_max_streams( &mut self, dir: Dir, count: u64, ) -> Result<(), TransportError> { if count > MAX_STREAM_COUNT { return Err(TransportError::FRAME_ENCODING_ERROR( "unrepresentable stream limit", )); } let current = &mut self.max[dir as usize]; if count > *current { *current = count; self.events.push_back(StreamEvent::Available { dir }); } Ok(()) } /// Handle increase to connection-level flow control limit pub(crate) fn received_max_data(&mut self, n: VarInt) { self.max_data = self.max_data.max(n.into()); } pub(crate) fn received_max_stream_data( &mut self, id: StreamId, offset: u64, ) -> Result<(), TransportError> { if id.initiator() != self.side && id.dir() == Dir::Uni { debug!("got MAX_STREAM_DATA on recv-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "MAX_STREAM_DATA on recv-only stream", )); } let write_limit = self.write_limit(); let max_send_data = self.max_send_data(id); if let Some(ss) = self .send .get_mut(&id) .map(get_or_insert_send(max_send_data)) { if ss.increase_max_data(offset) { if write_limit > 0 { self.events.push_back(StreamEvent::Writable { id }); } else if !ss.connection_blocked { // The stream is still blocked on the connection flow control // window. In order to get unblocked when the window relaxes // it needs to be in the connection blocked list. ss.connection_blocked = true; self.connection_blocked.push(id); } } } else if id.initiator() == self.side && self.is_local_unopened(id) { debug!("got MAX_STREAM_DATA on unopened {}", id); return Err(TransportError::STREAM_STATE_ERROR( "MAX_STREAM_DATA on unopened stream", )); } self.on_stream_frame(false, id); Ok(()) } /// Returns the maximum amount of data this is allowed to be written on the connection pub(crate) fn write_limit(&self) -> u64 { (self.max_data - self.data_sent).min(self.send_window - self.unacked_data) } /// Yield stream events pub(crate) fn poll(&mut self) -> Option { if let Some(dir) = Dir::iter().find(|&i| mem::replace(&mut self.opened[i as usize], false)) { return Some(StreamEvent::Opened { dir }); } if self.write_limit() > 0 { while let Some(id) = self.connection_blocked.pop() { let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { None => continue, Some(s) => s, }; debug_assert!(stream.connection_blocked); stream.connection_blocked = false; // If it's no longer sensible to write to a stream (even to detect an error) then don't // report it. if stream.is_writable() && stream.max_data > stream.offset() { return Some(StreamEvent::Writable { id }); } } } self.events.pop_front() } /// Queues MAX_STREAM_ID frames in `pending` if needed /// /// Returns whether any frames were queued. pub(crate) fn queue_max_stream_id(&mut self, pending: &mut Retransmits) -> bool { let mut queued = false; for dir in Dir::iter() { let diff = self.max_remote[dir as usize] - self.sent_max_remote[dir as usize]; // To reduce traffic, only announce updates if at least 1/8 of the flow control window // has been consumed. if diff > self.max_concurrent_remote_count[dir as usize] / 8 { pending.max_stream_id[dir as usize] = true; queued = true; } } queued } /// Check for errors entailed by the peer's use of `id` as a send stream fn validate_receive_id(&mut self, id: StreamId) -> Result<(), TransportError> { if self.side == id.initiator() { match id.dir() { Dir::Uni => { return Err(TransportError::STREAM_STATE_ERROR( "illegal operation on send-only stream", )); } Dir::Bi if id.index() >= self.next[Dir::Bi as usize] => { return Err(TransportError::STREAM_STATE_ERROR( "operation on unopened stream", )); } Dir::Bi => {} }; } else { let limit = self.max_remote[id.dir() as usize]; if id.index() >= limit { return Err(TransportError::STREAM_LIMIT_ERROR("")); } } Ok(()) } /// Whether a locally initiated stream has never been open pub(crate) fn is_local_unopened(&self, id: StreamId) -> bool { id.index() >= self.next[id.dir() as usize] } pub(crate) fn set_max_concurrent(&mut self, dir: Dir, count: VarInt) { self.flow_control_adjusted = true; self.max_concurrent_remote_count[dir as usize] = count.into(); self.ensure_remote_streams(dir); } pub(crate) fn max_concurrent(&self, dir: Dir) -> u64 { self.allocated_remote_count[dir as usize] } /// Set the receive_window and returns whether the receive_window has been /// expanded or shrunk: true if expanded, false if shrunk. pub(crate) fn set_receive_window(&mut self, receive_window: VarInt) -> bool { let receive_window = receive_window.into(); let mut expanded = false; if receive_window > self.receive_window { self.local_max_data = self .local_max_data .saturating_add(receive_window - self.receive_window); expanded = true; } else { let diff = self.receive_window - receive_window; self.receive_window_shrink_debt = self.receive_window_shrink_debt.saturating_add(diff); } self.receive_window = receive_window; expanded } pub(super) fn insert(&mut self, remote: bool, id: StreamId) { let bi = id.dir() == Dir::Bi; // bidirectional OR (unidirectional AND NOT remote) if bi || !remote { assert!(self.send.insert(id, None).is_none()); } // bidirectional OR (unidirectional AND remote) if bi || remote { let recv = self.free_recv.pop(); assert!(self.recv.insert(id, recv).is_none()); } } /// Adds credits to the connection flow control window /// /// Returns whether a `MAX_DATA` frame should be enqueued as soon as possible. /// This will only be the case if the window update would is significant /// enough. As soon as a window update with a `MAX_DATA` frame has been /// queued, the [`Recv::record_sent_max_stream_data`] function should be called to /// suppress sending further updates until the window increases significantly /// again. pub(super) fn add_read_credits(&mut self, credits: u64) -> ShouldTransmit { if credits > self.receive_window_shrink_debt { let net_credits = credits - self.receive_window_shrink_debt; self.local_max_data = self.local_max_data.saturating_add(net_credits); self.receive_window_shrink_debt = 0; } else { self.receive_window_shrink_debt -= credits; } if self.local_max_data > VarInt::MAX.into_inner() { return ShouldTransmit(false); } // Only announce a window update if it's significant enough // to make it worthwhile sending a MAX_DATA frame. // We use a fraction of the configured connection receive window to make // the decision, to accommodate for connection using bigger windows requiring // less updates. let diff = self.local_max_data - self.sent_max_data.into_inner(); ShouldTransmit(diff >= (self.receive_window / 8)) } /// Update counters for removal of a stream pub(super) fn stream_freed(&mut self, id: StreamId, half: StreamHalf) { if id.initiator() != self.side { let fully_free = id.dir() == Dir::Uni || match half { StreamHalf::Send => !self.recv.contains_key(&id), StreamHalf::Recv => !self.send.contains_key(&id), }; if fully_free { self.allocated_remote_count[id.dir() as usize] -= 1; self.ensure_remote_streams(id.dir()); } } if half == StreamHalf::Send { self.send_streams -= 1; } } pub(super) fn stream_recv_freed(&mut self, id: StreamId, recv: StreamRecv) { self.free_recv.push(recv.free(self.stream_receive_window)); self.stream_freed(id, StreamHalf::Recv); } pub(super) fn max_send_data(&self, id: StreamId) -> VarInt { let remote = self.side != id.initiator(); match id.dir() { Dir::Uni => self.initial_max_stream_data_uni, // Remote/local appear reversed here because the transport parameters are named from // the perspective of the peer. Dir::Bi if remote => self.initial_max_stream_data_bidi_local, Dir::Bi => self.initial_max_stream_data_bidi_remote, } } } #[inline] pub(super) fn get_or_insert_send( max_data: VarInt, ) -> impl Fn(&mut Option>) -> &mut Box { move |opt| opt.get_or_insert_with(|| Send::new(max_data)) } #[inline] pub(super) fn get_or_insert_recv( initial_max_data: u64, ) -> impl FnMut(&mut Option) -> &mut Recv { move |opt| { *opt = opt.take().map(|s| match s { StreamRecv::Free(recv) => StreamRecv::Open(recv), s => s, }); opt.get_or_insert_with(|| StreamRecv::Open(Recv::new(initial_max_data))) .as_open_recv_mut() .unwrap() } } #[cfg(test)] mod tests { use super::*; use crate::{ connection::State as ConnState, connection::Streams, ReadableError, RecvStream, SendStream, TransportErrorCode, WriteError, }; use bytes::Bytes; fn make(side: Side) -> StreamsState { StreamsState::new( side, 128u32.into(), 128u32.into(), 1024 * 1024, (1024 * 1024u32).into(), (1024 * 1024u32).into(), ) } #[test] fn trivial_flow_control() { let mut client = StreamsState::new( Side::Client, 1u32.into(), 1u32.into(), 1024 * 1024, (1024 * 1024u32).into(), (1024 * 1024u32).into(), ); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; const MESSAGE_SIZE: usize = 2048; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: true, data: Bytes::from_static(&[0; MESSAGE_SIZE]), }, 2048 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 2048); assert_eq!(client.local_max_data - initial_max, 0); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(MESSAGE_SIZE).unwrap().unwrap().bytes.len(), MESSAGE_SIZE ); assert!(chunks.next(0).unwrap().is_none()); let should_transmit = chunks.finalize(); assert!(should_transmit.0); assert!(pending.max_stream_id[Dir::Uni as usize]); assert_eq!(client.local_max_data - initial_max, MESSAGE_SIZE as u64); } #[test] fn reset_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 2048]), }, 2048 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 2048); assert_eq!(client.local_max_data - initial_max, 0); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); chunks.next(1024).unwrap(); let _ = chunks.finalize(); assert_eq!(client.local_max_data - initial_max, 1024); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); // Ensure reading after a reset doesn't issue redundant credit let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(1024).unwrap_err(), crate::ReadError::Reset(0u32.into()) ); let _ = chunks.finalize(); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); } #[test] fn reset_after_empty_frame_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 4096, fin: false, data: Bytes::from_static(&[0; 0]), }, 0 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 0); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); } #[test] fn duplicate_reset_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); } #[test] fn recv_stopped() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 32]), }, 32 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.local_max_data, initial_max); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; recv.stop(0u32.into()).unwrap(); assert_eq!(recv.pending.stop_sending.len(), 1); assert!(!recv.pending.max_data); assert!(recv.stop(0u32.into()).is_err()); assert_eq!(recv.read(true).err(), Some(ReadableError::ClosedStream)); assert_eq!(recv.read(false).err(), Some(ReadableError::ClosedStream)); assert_eq!(client.local_max_data - initial_max, 32); assert_eq!( client .received( frame::Stream { id, offset: 32, fin: true, data: Bytes::from_static(&[0; 16]), }, 16 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.local_max_data - initial_max, 48); assert!(!client.recv.contains_key(&id)); } #[test] fn stopped_reset() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); // Server opens stream assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 32]) }, 32 ) .unwrap(), ShouldTransmit(false) ); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; recv.stop(0u32.into()).unwrap(); assert_eq!(pending.stop_sending.len(), 1); assert!(!pending.max_data); // Server complies let prev_max = client.max_remote[Dir::Uni as usize]; assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 32u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert!(!client.recv.contains_key(&id), "stream state is freed"); assert_eq!(client.max_remote[Dir::Uni as usize], prev_max + 1); } #[test] fn send_stopped() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_uni: 1u32.into(), initial_max_data: 42u32.into(), initial_max_stream_data_uni: 42u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let id = Streams { state: &mut server, conn_state: &state, } .open(Dir::Uni) .unwrap(); let mut stream = SendStream { id, state: &mut server, pending: &mut pending, conn_state: &state, }; let error_code = 0u32.into(); stream.state.received_stop_sending(id, error_code); assert!(stream .state .events .contains(&StreamEvent::Stopped { id, error_code })); stream.state.events.clear(); assert_eq!(stream.write(&[]), Err(WriteError::Stopped(error_code))); stream.reset(0u32.into()).unwrap(); assert_eq!(stream.write(&[]), Err(WriteError::ClosedStream)); // A duplicate frame is a no-op stream.state.received_stop_sending(id, error_code); assert!(stream.state.events.is_empty()); } #[test] fn final_offset_flow_control() { let mut client = make(Side::Client); assert_eq!( client .received_reset(frame::ResetStream { id: StreamId::new(Side::Server, Dir::Uni, 0), error_code: 0u32.into(), final_offset: VarInt::MAX, }) .unwrap_err() .code, TransportErrorCode::FLOW_CONTROL_ERROR ); } #[test] fn stream_priority() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 10u32.into(), initial_max_stream_data_bidi_remote: 10u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id_high = streams.open(Dir::Bi).unwrap(); let id_mid = streams.open(Dir::Bi).unwrap(); let id_low = streams.open(Dir::Bi).unwrap(); let mut mid = SendStream { id: id_mid, state: &mut server, pending: &mut pending, conn_state: &state, }; mid.write(b"mid").unwrap(); let mut low = SendStream { id: id_low, state: &mut server, pending: &mut pending, conn_state: &state, }; low.set_priority(-1).unwrap(); low.write(b"low").unwrap(); let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(1).unwrap(); high.write(b"high").unwrap(); let mut buf = Vec::with_capacity(40); let meta = server.write_stream_frames(&mut buf, 40, true); assert_eq!(meta[0].id, id_high); assert_eq!(meta[1].id, id_mid); assert_eq!(meta[2].id, id_low); assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 0); } #[test] fn requeue_stream_priority() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 1000u32.into(), initial_max_stream_data_bidi_remote: 1000u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id_high = streams.open(Dir::Bi).unwrap(); let id_mid = streams.open(Dir::Bi).unwrap(); let mut mid = SendStream { id: id_mid, state: &mut server, pending: &mut pending, conn_state: &state, }; assert_eq!(mid.write(b"mid").unwrap(), 3); assert_eq!(server.pending.len(), 1); let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(1).unwrap(); assert_eq!(high.write(&[0; 200]).unwrap(), 200); assert_eq!(server.pending.len(), 2); // Requeue the high priority stream to lowest priority. The initial send // still uses high priority since it's queued that way. After that it will // switch to low priority let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(-1).unwrap(); let mut buf = Vec::with_capacity(1000); let meta = server.write_stream_frames(&mut buf, 40, true); assert_eq!(meta.len(), 1); assert_eq!(meta[0].id, id_high); // After requeuing we should end up with 2 priorities - not 3 assert_eq!(server.pending.len(), 2); // Send the remaining data. The initial mid priority one should go first now let meta = server.write_stream_frames(&mut buf, 1000, true); assert_eq!(meta.len(), 2); assert_eq!(meta[0].id, id_mid); assert_eq!(meta[1].id, id_high); assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 0); } #[test] fn same_stream_priority() { for fair in [true, false] { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 300u32.into(), initial_max_stream_data_bidi_remote: 300u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; // a, b and c all have the same priority let id_a = streams.open(Dir::Bi).unwrap(); let id_b = streams.open(Dir::Bi).unwrap(); let id_c = streams.open(Dir::Bi).unwrap(); let mut stream_a = SendStream { id: id_a, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_a.write(&[b'a'; 100]).unwrap(); let mut stream_b = SendStream { id: id_b, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_b.write(&[b'b'; 100]).unwrap(); let mut stream_c = SendStream { id: id_c, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_c.write(&[b'c'; 100]).unwrap(); let mut metas = vec![]; let mut buf = Vec::with_capacity(1024); // loop until all the streams are written loop { let buf_len = buf.len(); let meta = server.write_stream_frames(&mut buf, buf_len + 40, fair); if meta.is_empty() { break; } metas.extend(meta); } assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 0); let stream_ids = metas.iter().map(|m| m.id).collect::>(); if fair { // When fairness is enabled, if we run out of buffer space to write out a stream, // the stream is re-queued after all the streams with the same priority. assert_eq!( stream_ids, vec![id_a, id_b, id_c, id_a, id_b, id_c, id_a, id_b, id_c] ); } else { // When fairness is disabled the stream is re-queued before all the other streams // with the same priority. assert_eq!( stream_ids, vec![id_a, id_a, id_a, id_b, id_b, id_b, id_c, id_c, id_c] ); } } } #[test] fn unfair_priority_bump() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 300u32.into(), initial_max_stream_data_bidi_remote: 300u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; // a, and b have the same priority, c has higher priority let id_a = streams.open(Dir::Bi).unwrap(); let id_b = streams.open(Dir::Bi).unwrap(); let id_c = streams.open(Dir::Bi).unwrap(); let mut stream_a = SendStream { id: id_a, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_a.write(&[b'a'; 100]).unwrap(); let mut stream_b = SendStream { id: id_b, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_b.write(&[b'b'; 100]).unwrap(); let mut metas = vec![]; let mut buf = Vec::with_capacity(1024); // Write the first chunk of stream_a let buf_len = buf.len(); let meta = server.write_stream_frames(&mut buf, buf_len + 40, false); assert!(!meta.is_empty()); metas.extend(meta); // Queue stream_c which has higher priority let mut stream_c = SendStream { id: id_c, state: &mut server, pending: &mut pending, conn_state: &state, }; stream_c.set_priority(1).unwrap(); stream_c.write(&[b'b'; 100]).unwrap(); // loop until all the streams are written loop { let buf_len = buf.len(); let meta = server.write_stream_frames(&mut buf, buf_len + 40, false); if meta.is_empty() { break; } metas.extend(meta); } assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 0); let stream_ids = metas.iter().map(|m| m.id).collect::>(); assert_eq!( stream_ids, // stream_c bumps stream_b but doesn't bump stream_a which had already been partly // written out vec![id_a, id_a, id_a, id_c, id_c, id_c, id_b, id_b, id_b] ); } #[test] fn stop_finished() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); // Server finishes stream let _ = client .received( frame::Stream { id, offset: 0, fin: true, data: Bytes::from_static(&[0; 32]), }, 32, ) .unwrap(); let mut pending = Retransmits::default(); let mut stream = RecvStream { id, state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); assert!(client.recv.get_mut(&id).is_none(), "stream is freed"); } // Verify that a stream that's been reset doesn't cause the appearance of pending data #[test] fn reset_stream_cannot_send() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_uni: 1u32.into(), initial_max_data: 42u32.into(), initial_max_stream_data_uni: 42u32.into(), ..TransportParameters::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id = streams.open(Dir::Uni).unwrap(); let mut stream = SendStream { id, state: &mut server, pending: &mut pending, conn_state: &state, }; stream.write(b"hello").unwrap(); stream.reset(0u32.into()).unwrap(); assert_eq!(pending.reset_stream, &[(id, 0u32.into())]); assert!(!server.can_send_stream_data()); } #[test] fn stream_limit_fixed() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Try to open stream 128, exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Free stream 127 let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 127), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn stream_limit_grows() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Try to open stream 128, exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Relax limit by one client.set_max_concurrent(Dir::Uni, 129u32.into()); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn stream_limit_shrinks() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Tighten limit by one client.set_max_concurrent(Dir::Uni, 127u32.into()); // Free stream 127 let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 127), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); // Try to open stream 128, still exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Free stream 126 assert_eq!( client.received_reset(frame::ResetStream { id: StreamId::new(Side::Server, Dir::Uni, 126), error_code: 0u32.into(), final_offset: 0u32.into(), }), Ok(ShouldTransmit(false)) ); let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 126), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn remote_stream_capacity() { let mut client = make(Side::Client); for _ in 0..2 { client.set_max_concurrent(Dir::Uni, 200u32.into()); client.set_max_concurrent(Dir::Bi, 201u32.into()); assert_eq!(client.recv.len(), 200 + 201); assert_eq!(client.max_remote[Dir::Uni as usize], 200); assert_eq!(client.max_remote[Dir::Bi as usize], 201); } } #[test] fn expand_receive_window() { let mut server = make(Side::Server); let new_receive_window = 2 * server.receive_window as u32; let expanded = server.set_receive_window(new_receive_window.into()); assert!(expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, new_receive_window as u64); assert_eq!(server.receive_window_shrink_debt, 0); let prev_local_max_data = server.local_max_data; // credit, expecting all of them added to local_max_data let credits = 1024u64; let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, prev_local_max_data + credits); assert!(should_transmit.should_transmit()); } #[test] fn shrink_receive_window() { let mut server = make(Side::Server); let new_receive_window = server.receive_window as u32 / 2; let prev_local_max_data = server.local_max_data; // shrink the receive_winbow, local_max_data is not expected to be changed let shrink_diff = server.receive_window - new_receive_window as u64; let expanded = server.set_receive_window(new_receive_window.into()); assert!(!expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, prev_local_max_data); assert_eq!(server.receive_window_shrink_debt, shrink_diff); let prev_local_max_data = server.local_max_data; // credit twice, local_max_data does not change as it is absorbed by receive_window_shrink_debt let credits = 1024u64; for _ in 0..2 { let expected_receive_window_shrink_debt = server.receive_window_shrink_debt - credits; let should_transmit = server.add_read_credits(credits); assert_eq!( server.receive_window_shrink_debt, expected_receive_window_shrink_debt ); assert_eq!(server.local_max_data, prev_local_max_data); assert!(!should_transmit.should_transmit()); } // credit again which exceeds all remaining expected_receive_window_shrink_debt let credits = 1024 * 512; let prev_local_max_data = server.local_max_data; let expected_local_max_data = server.local_max_data + (credits - server.receive_window_shrink_debt); let _should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); assert!(server.local_max_data > prev_local_max_data); // credit again, all should be added to local_max_data let credits = 1024 * 512; let expected_local_max_data = server.local_max_data + credits; let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); assert!(should_transmit.should_transmit()); } } quinn-proto-0.11.9/src/connection/timer.rs000064400000000000000000000037321046102023000166300ustar 00000000000000use crate::Instant; #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) enum Timer { /// When to send an ack-eliciting probe packet or declare unacked packets lost LossDetection = 0, /// When to close the connection after no activity Idle = 1, /// When the close timer expires, the connection has been gracefully terminated. Close = 2, /// When keys are discarded because they should not be needed anymore KeyDiscard = 3, /// When to give up on validating a new path to the peer PathValidation = 4, /// When to send a `PING` frame to keep the connection alive KeepAlive = 5, /// When pacing will allow us to send a packet Pacing = 6, /// When to invalidate old CID and proactively push new one via NEW_CONNECTION_ID frame PushNewCid = 7, /// When to send an immediate ACK if there are unacked ack-eliciting packets of the peer MaxAckDelay = 8, } impl Timer { pub(crate) const VALUES: [Self; 9] = [ Self::LossDetection, Self::Idle, Self::Close, Self::KeyDiscard, Self::PathValidation, Self::KeepAlive, Self::Pacing, Self::PushNewCid, Self::MaxAckDelay, ]; } /// A table of data associated with each distinct kind of `Timer` #[derive(Debug, Copy, Clone, Default)] pub(crate) struct TimerTable { data: [Option; 10], } impl TimerTable { pub(super) fn set(&mut self, timer: Timer, time: Instant) { self.data[timer as usize] = Some(time); } pub(super) fn get(&self, timer: Timer) -> Option { self.data[timer as usize] } pub(super) fn stop(&mut self, timer: Timer) { self.data[timer as usize] = None; } pub(super) fn next_timeout(&self) -> Option { self.data.iter().filter_map(|&x| x).min() } pub(super) fn is_expired(&self, timer: Timer, after: Instant) -> bool { self.data[timer as usize].map_or(false, |x| x <= after) } } quinn-proto-0.11.9/src/constant_time.rs000064400000000000000000000012601046102023000162120ustar 00000000000000// This function is non-inline to prevent the optimizer from looking inside it. #[inline(never)] fn constant_time_ne(a: &[u8], b: &[u8]) -> u8 { assert!(a.len() == b.len()); // These useless slices make the optimizer elide the bounds checks. // See the comment in clone_from_slice() added on Rust commit 6a7bc47. let len = a.len(); let a = &a[..len]; let b = &b[..len]; let mut tmp = 0; for i in 0..len { tmp |= a[i] ^ b[i]; } tmp // The compare with 0 must happen outside this function. } /// Compares byte strings in constant time. pub(crate) fn eq(a: &[u8], b: &[u8]) -> bool { a.len() == b.len() && constant_time_ne(a, b) == 0 } quinn-proto-0.11.9/src/crypto/ring_like.rs000064400000000000000000000033651046102023000166360ustar 00000000000000#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::{aead, error, hkdf, hmac}; #[cfg(feature = "ring")] use ring::{aead, error, hkdf, hmac}; use crate::crypto::{self, CryptoError}; impl crypto::HmacKey for hmac::Key { fn sign(&self, data: &[u8], out: &mut [u8]) { out.copy_from_slice(hmac::sign(self, data).as_ref()); } fn signature_len(&self) -> usize { 32 } fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError> { Ok(hmac::verify(self, data, signature)?) } } impl crypto::HandshakeTokenKey for hkdf::Prk { fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box { let mut key_buffer = [0u8; 32]; let info = [random_bytes]; let okm = self.expand(&info, hkdf::HKDF_SHA256).unwrap(); okm.fill(&mut key_buffer).unwrap(); let key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_buffer).unwrap(); Box::new(aead::LessSafeKey::new(key)) } } impl crypto::AeadKey for aead::LessSafeKey { fn seal(&self, data: &mut Vec, additional_data: &[u8]) -> Result<(), CryptoError> { let aad = aead::Aad::from(additional_data); let zero_nonce = aead::Nonce::assume_unique_for_key([0u8; 12]); Ok(self.seal_in_place_append_tag(zero_nonce, aad, data)?) } fn open<'a>( &self, data: &'a mut [u8], additional_data: &[u8], ) -> Result<&'a mut [u8], CryptoError> { let aad = aead::Aad::from(additional_data); let zero_nonce = aead::Nonce::assume_unique_for_key([0u8; 12]); Ok(self.open_in_place(zero_nonce, aad, data)?) } } impl From for CryptoError { fn from(_: error::Unspecified) -> Self { Self } } quinn-proto-0.11.9/src/crypto/rustls.rs000064400000000000000000000533641046102023000162330ustar 00000000000000use std::{any::Any, io, str, sync::Arc}; #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::aead; use bytes::BytesMut; #[cfg(feature = "ring")] use ring::aead; pub use rustls::Error; use rustls::{ self, client::danger::ServerCertVerifier, pki_types::{CertificateDer, PrivateKeyDer, ServerName}, quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version}, CipherSuite, }; use crate::{ crypto::{ self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion, }, transport_parameters::TransportParameters, ConnectError, ConnectionId, Side, TransportError, TransportErrorCode, }; impl From for rustls::Side { fn from(s: Side) -> Self { match s { Side::Client => Self::Client, Side::Server => Self::Server, } } } /// A rustls TLS session pub struct TlsSession { version: Version, got_handshake_data: bool, next_secrets: Option, inner: Connection, suite: Suite, } impl TlsSession { fn side(&self) -> Side { match self.inner { Connection::Client(_) => Side::Client, Connection::Server(_) => Side::Server, } } } impl crypto::Session for TlsSession { fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys { initial_keys(self.version, dst_cid, side, &self.suite) } fn handshake_data(&self) -> Option> { if !self.got_handshake_data { return None; } Some(Box::new(HandshakeData { protocol: self.inner.alpn_protocol().map(|x| x.into()), server_name: match self.inner { Connection::Client(_) => None, Connection::Server(ref session) => session.server_name().map(|x| x.into()), }, })) } /// For the rustls `TlsSession`, the `Any` type is `Vec` fn peer_identity(&self) -> Option> { self.inner.peer_certificates().map(|v| -> Box { Box::new( v.iter() .map(|v| v.clone().into_owned()) .collect::>>(), ) }) } fn early_crypto(&self) -> Option<(Box, Box)> { let keys = self.inner.zero_rtt_keys()?; Some((Box::new(keys.header), Box::new(keys.packet))) } fn early_data_accepted(&self) -> Option { match self.inner { Connection::Client(ref session) => Some(session.is_early_data_accepted()), _ => None, } } fn is_handshaking(&self) -> bool { self.inner.is_handshaking() } fn read_handshake(&mut self, buf: &[u8]) -> Result { self.inner.read_hs(buf).map_err(|e| { if let Some(alert) = self.inner.alert() { TransportError { code: TransportErrorCode::crypto(alert.into()), frame: None, reason: e.to_string(), } } else { TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}")) } })?; if !self.got_handshake_data { // Hack around the lack of an explicit signal from rustls to reflect ClientHello being // ready on incoming connections, or ALPN negotiation completing on outgoing // connections. let have_server_name = match self.inner { Connection::Client(_) => false, Connection::Server(ref session) => session.server_name().is_some(), }; if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() { self.got_handshake_data = true; return Ok(true); } } Ok(false) } fn transport_parameters(&self) -> Result, TransportError> { match self.inner.quic_transport_parameters() { None => Ok(None), Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) { Ok(params) => Ok(Some(params)), Err(e) => Err(e.into()), }, } } fn write_handshake(&mut self, buf: &mut Vec) -> Option { let keys = match self.inner.write_hs(buf)? { KeyChange::Handshake { keys } => keys, KeyChange::OneRtt { keys, next } => { self.next_secrets = Some(next); keys } }; Some(Keys { header: KeyPair { local: Box::new(keys.local.header), remote: Box::new(keys.remote.header), }, packet: KeyPair { local: Box::new(keys.local.packet), remote: Box::new(keys.remote.packet), }, }) } fn next_1rtt_keys(&mut self) -> Option>> { let secrets = self.next_secrets.as_mut()?; let keys = secrets.next_packet_keys(); Some(KeyPair { local: Box::new(keys.local), remote: Box::new(keys.remote), }) } fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool { let tag_start = match payload.len().checked_sub(16) { Some(x) => x, None => return false, }; let mut pseudo_packet = Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1); pseudo_packet.push(orig_dst_cid.len() as u8); pseudo_packet.extend_from_slice(orig_dst_cid); pseudo_packet.extend_from_slice(header); let tag_start = tag_start + pseudo_packet.len(); pseudo_packet.extend_from_slice(payload); let (nonce, key) = match self.version { Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1), Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT), _ => unreachable!(), }; let nonce = aead::Nonce::assume_unique_for_key(nonce); let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap()); let (aad, tag) = pseudo_packet.split_at_mut(tag_start); key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok() } fn export_keying_material( &self, output: &mut [u8], label: &[u8], context: &[u8], ) -> Result<(), ExportKeyingMaterialError> { self.inner .export_keying_material(output, label, Some(context)) .map_err(|_| ExportKeyingMaterialError)?; Ok(()) } } const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [ 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1, ]; const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [ 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c, ]; const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [ 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e, ]; const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [ 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb, ]; impl crypto::HeaderKey for Box { fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) { let (header, sample) = packet.split_at_mut(pn_offset + 4); let (first, rest) = header.split_at_mut(1); let pn_end = Ord::min(pn_offset + 3, rest.len()); self.decrypt_in_place( &sample[..self.sample_size()], &mut first[0], &mut rest[pn_offset - 1..pn_end], ) .unwrap(); } fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) { let (header, sample) = packet.split_at_mut(pn_offset + 4); let (first, rest) = header.split_at_mut(1); let pn_end = Ord::min(pn_offset + 3, rest.len()); self.encrypt_in_place( &sample[..self.sample_size()], &mut first[0], &mut rest[pn_offset - 1..pn_end], ) .unwrap(); } fn sample_size(&self) -> usize { self.sample_len() } } /// Authentication data for (rustls) TLS session pub struct HandshakeData { /// The negotiated application protocol, if ALPN is in use /// /// Guaranteed to be set if a nonempty list of protocols was specified for this connection. pub protocol: Option>, /// The server name specified by the client, if any /// /// Always `None` for outgoing connections pub server_name: Option, } /// A QUIC-compatible TLS client configuration /// /// Quinn implicitly constructs a `QuicClientConfig` with reasonable defaults within /// [`ClientConfig::with_root_certificates()`][root_certs] and [`ClientConfig::with_platform_verifier()`][platform]. /// Alternatively, `QuicClientConfig`'s [`TryFrom`] implementation can be used to wrap around a /// custom [`rustls::ClientConfig`], in which case care should be taken around certain points: /// /// - If `enable_early_data` is not set to true, then sending 0-RTT data will not be possible on /// outgoing connections. /// - The [`rustls::ClientConfig`] must have TLS 1.3 support enabled for conversion to succeed. /// /// The object in the `resumption` field of the inner [`rustls::ClientConfig`] determines whether /// calling `into_0rtt` on outgoing connections returns `Ok` or `Err`. It typically allows /// `into_0rtt` to proceed if it recognizes the server name, and defaults to an in-memory cache of /// 256 server names. /// /// [root_certs]: crate::config::ClientConfig::with_root_certificates() /// [platform]: crate::config::ClientConfig::with_platform_verifier() pub struct QuicClientConfig { pub(crate) inner: Arc, initial: Suite, } impl QuicClientConfig { /// Initialize a sane QUIC-compatible TLS client configuration /// /// QUIC requires that TLS 1.3 be enabled. Advanced users can use any [`rustls::ClientConfig`] that /// satisfies this requirement. pub(crate) fn new(verifier: Arc) -> Self { let inner = Self::inner(verifier); Self { // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256 initial: initial_suite_from_provider(inner.crypto_provider()) .expect("no initial cipher suite found"), inner: Arc::new(inner), } } /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite /// /// This is useful if you want to avoid the initial cipher suite for traffic encryption. pub fn with_initial( inner: Arc, initial: Suite, ) -> Result { match initial.suite.common.suite { CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }), _ => Err(NoInitialCipherSuite { specific: true }), } } pub(crate) fn inner(verifier: Arc) -> rustls::ClientConfig { let mut config = rustls::ClientConfig::builder_with_provider(configured_provider()) .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() // The default providers support TLS 1.3 .dangerous() .with_custom_certificate_verifier(verifier) .with_no_client_auth(); config.enable_early_data = true; config } } impl crypto::ClientConfig for QuicClientConfig { fn start_session( self: Arc, version: u32, server_name: &str, params: &TransportParameters, ) -> Result, ConnectError> { let version = interpret_version(version)?; Ok(Box::new(TlsSession { version, got_handshake_data: false, next_secrets: None, inner: rustls::quic::Connection::Client( rustls::quic::ClientConnection::new( self.inner.clone(), version, ServerName::try_from(server_name) .map_err(|_| ConnectError::InvalidServerName(server_name.into()))? .to_owned(), to_vec(params), ) .unwrap(), ), suite: self.initial, })) } } impl TryFrom for QuicClientConfig { type Error = NoInitialCipherSuite; fn try_from(inner: rustls::ClientConfig) -> Result { Arc::new(inner).try_into() } } impl TryFrom> for QuicClientConfig { type Error = NoInitialCipherSuite; fn try_from(inner: Arc) -> Result { Ok(Self { initial: initial_suite_from_provider(inner.crypto_provider()) .ok_or(NoInitialCipherSuite { specific: false })?, inner, }) } } /// The initial cipher suite (AES-128-GCM-SHA256) is not available /// /// When the cipher suite is supplied `with_initial()`, it must be /// [`CipherSuite::TLS13_AES_128_GCM_SHA256`]. When the cipher suite is derived from a config's /// [`CryptoProvider`][provider], that provider must reference a cipher suite with the same ID. /// /// [provider]: rustls::crypto::CryptoProvider #[derive(Clone, Debug)] pub struct NoInitialCipherSuite { /// Whether the initial cipher suite was supplied by the caller specific: bool, } impl std::fmt::Display for NoInitialCipherSuite { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.write_str(match self.specific { true => "invalid cipher suite specified", false => "no initial cipher suite found", }) } } impl std::error::Error for NoInitialCipherSuite {} /// A QUIC-compatible TLS server configuration /// /// Quinn implicitly constructs a `QuicServerConfig` with reasonable defaults within /// [`ServerConfig::with_single_cert()`][single]. Alternatively, `QuicServerConfig`'s [`TryFrom`] /// implementation or `with_initial` method can be used to wrap around a custom /// [`rustls::ServerConfig`], in which case care should be taken around certain points: /// /// - If `max_early_data_size` is not set to `u32::MAX`, the server will not be able to accept /// incoming 0-RTT data. QUIC prohibits `max_early_data_size` values other than 0 or `u32::MAX`. /// - The `rustls::ServerConfig` must have TLS 1.3 support enabled for conversion to succeed. /// /// [single]: crate::config::ServerConfig::with_single_cert() pub struct QuicServerConfig { inner: Arc, initial: Suite, } impl QuicServerConfig { pub(crate) fn new( cert_chain: Vec>, key: PrivateKeyDer<'static>, ) -> Result { let inner = Self::inner(cert_chain, key)?; Ok(Self { // We're confident that the *ring* default provider contains TLS13_AES_128_GCM_SHA256 initial: initial_suite_from_provider(inner.crypto_provider()) .expect("no initial cipher suite found"), inner: Arc::new(inner), }) } /// Initialize a QUIC-compatible TLS client configuration with a separate initial cipher suite /// /// This is useful if you want to avoid the initial cipher suite for traffic encryption. pub fn with_initial( inner: Arc, initial: Suite, ) -> Result { match initial.suite.common.suite { CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }), _ => Err(NoInitialCipherSuite { specific: true }), } } /// Initialize a sane QUIC-compatible TLS server configuration /// /// QUIC requires that TLS 1.3 be enabled, and that the maximum early data size is either 0 or /// `u32::MAX`. Advanced users can use any [`rustls::ServerConfig`] that satisfies these /// requirements. pub(crate) fn inner( cert_chain: Vec>, key: PrivateKeyDer<'static>, ) -> Result { let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider()) .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() // The *ring* default provider supports TLS 1.3 .with_no_client_auth() .with_single_cert(cert_chain, key)?; inner.max_early_data_size = u32::MAX; Ok(inner) } } impl TryFrom for QuicServerConfig { type Error = NoInitialCipherSuite; fn try_from(inner: rustls::ServerConfig) -> Result { Arc::new(inner).try_into() } } impl TryFrom> for QuicServerConfig { type Error = NoInitialCipherSuite; fn try_from(inner: Arc) -> Result { Ok(Self { initial: initial_suite_from_provider(inner.crypto_provider()) .ok_or(NoInitialCipherSuite { specific: false })?, inner, }) } } impl crypto::ServerConfig for QuicServerConfig { fn start_session( self: Arc, version: u32, params: &TransportParameters, ) -> Box { // Safe: `start_session()` is never called if `initial_keys()` rejected `version` let version = interpret_version(version).unwrap(); Box::new(TlsSession { version, got_handshake_data: false, next_secrets: None, inner: rustls::quic::Connection::Server( rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params)) .unwrap(), ), suite: self.initial, }) } fn initial_keys( &self, version: u32, dst_cid: &ConnectionId, ) -> Result { let version = interpret_version(version)?; Ok(initial_keys(version, dst_cid, Side::Server, &self.initial)) } fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] { // Safe: `start_session()` is never called if `initial_keys()` rejected `version` let version = interpret_version(version).unwrap(); let (nonce, key) = match version { Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1), Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT), _ => unreachable!(), }; let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1); pseudo_packet.push(orig_dst_cid.len() as u8); pseudo_packet.extend_from_slice(orig_dst_cid); pseudo_packet.extend_from_slice(packet); let nonce = aead::Nonce::assume_unique_for_key(nonce); let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap()); let tag = key .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut []) .unwrap(); let mut result = [0; 16]; result.copy_from_slice(tag.as_ref()); result } } pub(crate) fn initial_suite_from_provider( provider: &Arc, ) -> Option { provider .cipher_suites .iter() .find_map(|cs| match (cs.suite(), cs.tls13()) { (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => { Some(suite.quic_suite()) } _ => None, }) .flatten() } pub(crate) fn configured_provider() -> Arc { #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))] let provider = rustls::crypto::aws_lc_rs::default_provider(); #[cfg(feature = "rustls-ring")] let provider = rustls::crypto::ring::default_provider(); Arc::new(provider) } fn to_vec(params: &TransportParameters) -> Vec { let mut bytes = Vec::new(); params.write(&mut bytes); bytes } pub(crate) fn initial_keys( version: Version, dst_cid: &ConnectionId, side: Side, suite: &Suite, ) -> Keys { let keys = suite.keys(dst_cid, side.into(), version); Keys { header: KeyPair { local: Box::new(keys.local.header), remote: Box::new(keys.remote.header), }, packet: KeyPair { local: Box::new(keys.local.packet), remote: Box::new(keys.remote.packet), }, } } impl crypto::PacketKey for Box { fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) { let (header, payload_tag) = buf.split_at_mut(header_len); let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len()); let tag = self.encrypt_in_place(packet, &*header, payload).unwrap(); tag_storage.copy_from_slice(tag.as_ref()); } fn decrypt( &self, packet: u64, header: &[u8], payload: &mut BytesMut, ) -> Result<(), CryptoError> { let plain = self .decrypt_in_place(packet, header, payload.as_mut()) .map_err(|_| CryptoError)?; let plain_len = plain.len(); payload.truncate(plain_len); Ok(()) } fn tag_len(&self) -> usize { (**self).tag_len() } fn confidentiality_limit(&self) -> u64 { (**self).confidentiality_limit() } fn integrity_limit(&self) -> u64 { (**self).integrity_limit() } } fn interpret_version(version: u32) -> Result { match version { 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft), 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1), _ => Err(UnsupportedVersion), } } quinn-proto-0.11.9/src/crypto.rs000064400000000000000000000200171046102023000146640ustar 00000000000000//! Traits and implementations for the QUIC cryptography protocol //! //! The protocol logic in Quinn is contained in types that abstract over the actual //! cryptographic protocol used. This module contains the traits used for this //! abstraction layer as well as a single implementation of these traits that uses //! *ring* and rustls to implement the TLS protocol support. //! //! Note that usage of any protocol (version) other than TLS 1.3 does not conform to any //! published versions of the specification, and will not be supported in QUIC v1. use std::{any::Any, str, sync::Arc}; use bytes::BytesMut; use crate::{ shared::ConnectionId, transport_parameters::TransportParameters, ConnectError, Side, TransportError, }; /// Cryptography interface based on *ring* #[cfg(any(feature = "aws-lc-rs", feature = "ring"))] pub(crate) mod ring_like; /// TLS interface based on rustls #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] pub mod rustls; /// A cryptographic session (commonly TLS) pub trait Session: Send + Sync + 'static { /// Create the initial set of keys given the client's initial destination ConnectionId fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys; /// Get data negotiated during the handshake, if available /// /// Returns `None` until the connection emits `HandshakeDataReady`. fn handshake_data(&self) -> Option>; /// Get the peer's identity, if available fn peer_identity(&self) -> Option>; /// Get the 0-RTT keys if available (clients only) /// /// On the client side, this method can be used to see if 0-RTT key material is available /// to start sending data before the protocol handshake has completed. /// /// Returns `None` if the key material is not available. This might happen if you have /// not connected to this server before. fn early_crypto(&self) -> Option<(Box, Box)>; /// If the 0-RTT-encrypted data has been accepted by the peer fn early_data_accepted(&self) -> Option; /// Returns `true` until the connection is fully established. fn is_handshaking(&self) -> bool; /// Read bytes of handshake data /// /// This should be called with the contents of `CRYPTO` frames. If it returns `Ok`, the /// caller should call `write_handshake()` to check if the crypto protocol has anything /// to send to the peer. This method will only return `true` the first time that /// handshake data is available. Future calls will always return false. /// /// On success, returns `true` iff `self.handshake_data()` has been populated. fn read_handshake(&mut self, buf: &[u8]) -> Result; /// The peer's QUIC transport parameters /// /// These are only available after the first flight from the peer has been received. fn transport_parameters(&self) -> Result, TransportError>; /// Writes handshake bytes into the given buffer and optionally returns the negotiated keys /// /// When the handshake proceeds to the next phase, this method will return a new set of /// keys to encrypt data with. fn write_handshake(&mut self, buf: &mut Vec) -> Option; /// Compute keys for the next key update fn next_1rtt_keys(&mut self) -> Option>>; /// Verify the integrity of a retry packet fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool; /// Fill `output` with `output.len()` bytes of keying material derived /// from the [Session]'s secrets, using `label` and `context` for domain /// separation. /// /// This function will fail, returning [ExportKeyingMaterialError], /// if the requested output length is too large. fn export_keying_material( &self, output: &mut [u8], label: &[u8], context: &[u8], ) -> Result<(), ExportKeyingMaterialError>; } /// A pair of keys for bidirectional communication pub struct KeyPair { /// Key for encrypting data pub local: T, /// Key for decrypting data pub remote: T, } /// A complete set of keys for a certain packet space pub struct Keys { /// Header protection keys pub header: KeyPair>, /// Packet protection keys pub packet: KeyPair>, } /// Client-side configuration for the crypto protocol pub trait ClientConfig: Send + Sync { /// Start a client session with this configuration fn start_session( self: Arc, version: u32, server_name: &str, params: &TransportParameters, ) -> Result, ConnectError>; } /// Server-side configuration for the crypto protocol pub trait ServerConfig: Send + Sync { /// Create the initial set of keys given the client's initial destination ConnectionId fn initial_keys( &self, version: u32, dst_cid: &ConnectionId, ) -> Result; /// Generate the integrity tag for a retry packet /// /// Never called if `initial_keys` rejected `version`. fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16]; /// Start a server session with this configuration /// /// Never called if `initial_keys` rejected `version`. fn start_session( self: Arc, version: u32, params: &TransportParameters, ) -> Box; } /// Keys used to protect packet payloads pub trait PacketKey: Send + Sync { /// Encrypt the packet payload with the given packet number fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize); /// Decrypt the packet payload with the given packet number fn decrypt( &self, packet: u64, header: &[u8], payload: &mut BytesMut, ) -> Result<(), CryptoError>; /// The length of the AEAD tag appended to packets on encryption fn tag_len(&self) -> usize; /// Maximum number of packets that may be sent using a single key fn confidentiality_limit(&self) -> u64; /// Maximum number of incoming packets that may fail decryption before the connection must be /// abandoned fn integrity_limit(&self) -> u64; } /// Keys used to protect packet headers pub trait HeaderKey: Send + Sync { /// Decrypt the given packet's header fn decrypt(&self, pn_offset: usize, packet: &mut [u8]); /// Encrypt the given packet's header fn encrypt(&self, pn_offset: usize, packet: &mut [u8]); /// The sample size used for this key's algorithm fn sample_size(&self) -> usize; } /// A key for signing with HMAC-based algorithms pub trait HmacKey: Send + Sync { /// Method for signing a message fn sign(&self, data: &[u8], signature_out: &mut [u8]); /// Length of `sign`'s output fn signature_len(&self) -> usize; /// Method for verifying a message fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError>; } /// Error returned by [Session::export_keying_material]. /// /// This error occurs if the requested output length is too large. #[derive(Debug, PartialEq, Eq)] pub struct ExportKeyingMaterialError; /// A pseudo random key for HKDF pub trait HandshakeTokenKey: Send + Sync { /// Derive AEAD using hkdf fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box; } /// A key for sealing data with AEAD-based algorithms pub trait AeadKey { /// Method for sealing message `data` fn seal(&self, data: &mut Vec, additional_data: &[u8]) -> Result<(), CryptoError>; /// Method for opening a sealed message `data` fn open<'a>( &self, data: &'a mut [u8], additional_data: &[u8], ) -> Result<&'a mut [u8], CryptoError>; } /// Generic crypto errors #[derive(Debug)] pub struct CryptoError; /// Error indicating that the specified QUIC version is not supported #[derive(Debug)] pub struct UnsupportedVersion; impl From for ConnectError { fn from(_: UnsupportedVersion) -> Self { Self::UnsupportedVersion } } quinn-proto-0.11.9/src/endpoint.rs000064400000000000000000001365161046102023000152000ustar 00000000000000use std::{ collections::{hash_map, HashMap}, convert::TryFrom, fmt, mem, net::{IpAddr, SocketAddr}, ops::{Index, IndexMut}, sync::Arc, }; use bytes::{BufMut, Bytes, BytesMut}; use rand::{rngs::StdRng, Rng, RngCore, SeedableRng}; use rustc_hash::FxHashMap; use slab::Slab; use thiserror::Error; use tracing::{debug, error, trace, warn}; use crate::{ cid_generator::ConnectionIdGenerator, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, connection::{Connection, ConnectionError}, crypto::{self, Keys, UnsupportedVersion}, frame, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, Packet, PacketDecodeError, PacketNumber, PartialDecode, ProtectedInitialHeader, }, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint, EndpointEvent, EndpointEventInner, IssuedCid, }, token::TokenDecodeError, transport_parameters::{PreferredAddress, TransportParameters}, Instant, ResetToken, RetryToken, Side, SystemTime, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, }; /// The main entry point to the library /// /// This object performs no I/O whatsoever. Instead, it consumes incoming packets and /// connection-generated events via `handle` and `handle_event`. pub struct Endpoint { rng: StdRng, index: ConnectionIndex, connections: Slab, local_cid_generator: Box, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets allow_mtud: bool, /// Time at which a stateless reset was most recently sent last_stateless_reset: Option, /// Buffered Initial and 0-RTT messages for pending incoming connections incoming_buffers: Slab, all_incoming_buffers_total_bytes: u64, } impl Endpoint { /// Create a new endpoint /// /// `allow_mtud` enables path MTU detection when requested by `Connection` configuration for /// better performance. This requires that outgoing packets are never fragmented, which can be /// achieved via e.g. the `IPV6_DONTFRAG` socket option. /// /// If `rng_seed` is provided, it will be used to initialize the endpoint's rng (having priority /// over the rng seed configured in [`EndpointConfig`]). Note that the `rng_seed` parameter will /// be removed in a future release, so prefer setting it to `None` and configuring rng seeds /// using [`EndpointConfig::rng_seed`]. pub fn new( config: Arc, server_config: Option>, allow_mtud: bool, rng_seed: Option<[u8; 32]>, ) -> Self { let rng_seed = rng_seed.or(config.rng_seed); Self { rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed), index: ConnectionIndex::default(), connections: Slab::new(), local_cid_generator: (config.connection_id_generator_factory.as_ref())(), config, server_config, allow_mtud, last_stateless_reset: None, incoming_buffers: Slab::new(), all_incoming_buffers_total_bytes: 0, } } /// Replace the server configuration, affecting new incoming connections only pub fn set_server_config(&mut self, server_config: Option>) { self.server_config = server_config; } /// Process `EndpointEvent`s emitted from related `Connection`s /// /// In turn, processing this event may return a `ConnectionEvent` for the same `Connection`. pub fn handle_event( &mut self, ch: ConnectionHandle, event: EndpointEvent, ) -> Option { use EndpointEventInner::*; match event.0 { NeedIdentifiers(now, n) => { return Some(self.send_new_identifiers(now, ch, n)); } ResetToken(remote, token) => { if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) { self.index.connection_reset_tokens.remove(old.0, old.1); } if self.index.connection_reset_tokens.insert(remote, token, ch) { warn!("duplicate reset token"); } } RetireConnectionId(now, seq, allow_more_cids) => { if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) { trace!("peer retired CID {}: {}", seq, cid); self.index.retire(&cid); if allow_more_cids { return Some(self.send_new_identifiers(now, ch, 1)); } } } Drained => { if let Some(conn) = self.connections.try_remove(ch.0) { self.index.remove(&conn); } else { // This indicates a bug in downstream code, which could cause spurious // connection loss instead of this error if the CID was (re)allocated prior to // the illegal call. error!(id = ch.0, "unknown connection drained"); } } } None } /// Process an incoming UDP datagram pub fn handle( &mut self, now: Instant, remote: SocketAddr, local_ip: Option, ecn: Option, data: BytesMut, buf: &mut Vec, ) -> Option { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()), &self.config.supported_versions, self.config.grease_quic_bit, ) { Ok(x) => x, Err(PacketDecodeError::UnsupportedVersion { src_cid, dst_cid, version, }) => { if self.server_config.is_none() { debug!("dropping packet with unsupported version"); return None; } trace!("sending version negotiation"); // Negotiate versions Header::VersionNegotiate { random: self.rng.gen::() | 0x40, src_cid: dst_cid, dst_cid: src_cid, } .encode(buf); // Grease with a reserved version if version != 0x0a1a_2a3a { buf.write::(0x0a1a_2a3a); } else { buf.write::(0x0a1a_2a4a); } for &version in &self.config.supported_versions { buf.write(version); } return Some(DatagramEvent::Response(Transmit { destination: remote, ecn: None, size: buf.len(), segment_size: None, src_ip: local_ip, })); } Err(e) => { trace!("malformed header: {}", e); return None; } }; // // Handle packet on existing connection, if any // let addresses = FourTuple { remote, local_ip }; if let Some(route_to) = self.index.get(&addresses, &first_decode) { let event = DatagramConnectionEvent { now, remote: addresses.remote, ecn, first_decode, remaining, }; match route_to { RouteDatagramTo::Incoming(incoming_idx) => { let incoming_buffer = &mut self.incoming_buffers[incoming_idx]; let config = &self.server_config.as_ref().unwrap(); if incoming_buffer .total_bytes .checked_add(datagram_len as u64) .map_or(false, |n| n <= config.incoming_buffer_size) && self .all_incoming_buffers_total_bytes .checked_add(datagram_len as u64) .map_or(false, |n| n <= config.incoming_buffer_size_total) { incoming_buffer.datagrams.push(event); incoming_buffer.total_bytes += datagram_len as u64; self.all_incoming_buffers_total_bytes += datagram_len as u64; } return None; } RouteDatagramTo::Connection(ch) => { return Some(DatagramEvent::ConnectionEvent( ch, ConnectionEvent(ConnectionEventInner::Datagram(event)), )) } } } // // Potentially create a new connection // let dst_cid = first_decode.dst_cid(); let server_config = match &self.server_config { Some(config) => config, None => { debug!("packet for unrecognized connection {}", dst_cid); return self .stateless_reset(now, datagram_len, addresses, dst_cid, buf) .map(DatagramEvent::Response); } }; if let Some(header) = first_decode.initial_header() { if datagram_len < MIN_INITIAL_SIZE as usize { debug!("ignoring short initial for connection {}", dst_cid); return None; } let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) { Ok(keys) => keys, Err(UnsupportedVersion) => { // This probably indicates that the user set supported_versions incorrectly in // `EndpointConfig`. debug!( "ignoring initial packet version {:#x} unsupported by cryptographic layer", header.version ); return None; } }; if let Err(reason) = self.early_validate_first_packet(header) { return Some(DatagramEvent::Response(self.initial_close( header.version, addresses, &crypto, &header.src_cid, reason, buf, ))); } return match first_decode.finish(Some(&*crypto.header.remote)) { Ok(packet) => { self.handle_first_packet(addresses, ecn, packet, remaining, crypto, buf) } Err(e) => { trace!("unable to decode initial packet: {}", e); None } }; } else if first_decode.has_long_header() { debug!( "ignoring non-initial packet for unknown connection {}", dst_cid ); return None; } // // If we got this far, we're a server receiving a seemingly valid packet for an unknown // connection. Send a stateless reset if possible. // if !first_decode.is_initial() && self .local_cid_generator .validate(first_decode.dst_cid()) .is_err() { debug!("dropping packet with invalid CID"); return None; } if !dst_cid.is_empty() { return self .stateless_reset(now, datagram_len, addresses, dst_cid, buf) .map(DatagramEvent::Response); } trace!("dropping unrecognized short packet without ID"); None } fn stateless_reset( &mut self, now: Instant, inciting_dgram_len: usize, addresses: FourTuple, dst_cid: &ConnectionId, buf: &mut Vec, ) -> Option { if self .last_stateless_reset .map_or(false, |last| last + self.config.min_reset_interval > now) { debug!("ignoring unexpected packet within minimum stateless reset interval"); return None; } /// Minimum amount of padding for the stateless reset to look like a short-header packet const MIN_PADDING_LEN: usize = 5; // Prevent amplification attacks and reset loops by ensuring we pad to at most 1 byte // smaller than the inciting packet. let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) { Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1, _ => { debug!("ignoring unexpected {} byte packet: not larger than minimum stateless reset size", inciting_dgram_len); return None; } }; debug!( "sending stateless reset for {} to {}", dst_cid, addresses.remote ); self.last_stateless_reset = Some(now); // Resets with at least this much padding can't possibly be distinguished from real packets const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE; let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN { max_padding_len } else { self.rng.gen_range(IDEAL_MIN_PADDING_LEN..max_padding_len) }; buf.reserve(padding_len + RESET_TOKEN_SIZE); buf.resize(padding_len, 0); self.rng.fill_bytes(&mut buf[0..padding_len]); buf[0] = 0b0100_0000 | buf[0] >> 2; buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid)); debug_assert!(buf.len() < inciting_dgram_len); Some(Transmit { destination: addresses.remote, ecn: None, size: buf.len(), segment_size: None, src_ip: addresses.local_ip, }) } /// Initiate a connection pub fn connect( &mut self, now: Instant, config: ClientConfig, remote: SocketAddr, server_name: &str, ) -> Result<(ConnectionHandle, Connection), ConnectError> { if self.cids_exhausted() { return Err(ConnectError::CidsExhausted); } if remote.port() == 0 || remote.ip().is_unspecified() { return Err(ConnectError::InvalidRemoteAddress(remote)); } if !self.config.supported_versions.contains(&config.version) { return Err(ConnectError::UnsupportedVersion); } let remote_id = (config.initial_dst_cid_provider)(); trace!(initial_dcid = %remote_id); let ch = ConnectionHandle(self.connections.vacant_key()); let loc_cid = self.new_cid(ch); let params = TransportParameters::new( &config.transport, &self.config, self.local_cid_generator.as_ref(), loc_cid, None, ); let tls = config .crypto .start_session(config.version, server_name, ¶ms)?; let conn = self.add_connection( ch, config.version, remote_id, loc_cid, remote_id, None, FourTuple { remote, local_ip: None, }, now, tls, None, config.transport, true, ); Ok((ch, conn)) } fn send_new_identifiers( &mut self, now: Instant, ch: ConnectionHandle, num: u64, ) -> ConnectionEvent { let mut ids = vec![]; for _ in 0..num { let id = self.new_cid(ch); let meta = &mut self.connections[ch]; let sequence = meta.cids_issued; meta.cids_issued += 1; meta.loc_cids.insert(sequence, id); ids.push(IssuedCid { sequence, id, reset_token: ResetToken::new(&*self.config.reset_key, &id), }); } ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now)) } /// Generate a connection ID for `ch` fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId { loop { let cid = self.local_cid_generator.generate_cid(); if cid.len() == 0 { // Zero-length CID; nothing to track debug_assert_eq!(self.local_cid_generator.cid_len(), 0); return cid; } if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) { e.insert(ch); break cid; } } } fn handle_first_packet( &mut self, addresses: FourTuple, ecn: Option, packet: Packet, rest: Option, crypto: Keys, buf: &mut Vec, ) -> Option { if !packet.reserved_bits_valid() { debug!("dropping connection attempt with invalid reserved bits"); return None; } let Header::Initial(header) = packet.header else { panic!("non-initial packet in handle_first_packet()"); }; let server_config = self.server_config.as_ref().unwrap().clone(); let (retry_src_cid, orig_dst_cid) = if header.token.is_empty() { (None, header.dst_cid) } else { match RetryToken::from_bytes( &*server_config.token_key, &addresses.remote, &header.dst_cid, &header.token, ) { Ok(token) if token.issued + server_config.retry_token_lifetime > SystemTime::now() => { (Some(header.dst_cid), token.orig_dst_cid) } Err(TokenDecodeError::UnknownToken) => { // Token may have been generated by an incompatible endpoint, e.g. a // different version or a neighbor behind the same load balancer. We // can't interpret it, so we proceed as if there was no token. (None, header.dst_cid) } _ => { debug!("rejecting invalid stateless retry token"); return Some(DatagramEvent::Response(self.initial_close( header.version, addresses, &crypto, &header.src_cid, TransportError::INVALID_TOKEN(""), buf, ))); } } }; let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default()); self.index .insert_initial_incoming(header.dst_cid, incoming_idx); Some(DatagramEvent::NewConnection(Incoming { addresses, ecn, packet: InitialPacket { header, header_data: packet.header_data, payload: packet.payload, }, rest, crypto, retry_src_cid, orig_dst_cid, incoming_idx, improper_drop_warner: IncomingImproperDropWarner, })) } /// Attempt to accept this incoming connection (an error may still occur) pub fn accept( &mut self, mut incoming: Incoming, now: Instant, buf: &mut Vec, server_config: Option>, ) -> Result<(ConnectionHandle, Connection), AcceptError> { let remote_address_validated = incoming.remote_address_validated(); incoming.improper_drop_warner.dismiss(); let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx); self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes; let packet_number = incoming.packet.header.number.expand(0); let InitialHeader { src_cid, dst_cid, version, .. } = incoming.packet.header; if self.cids_exhausted() { debug!("refusing connection"); self.index.remove_initial(dst_cid); return Err(AcceptError { cause: ConnectionError::CidsExhausted, response: Some(self.initial_close( version, incoming.addresses, &incoming.crypto, &src_cid, TransportError::CONNECTION_REFUSED(""), buf, )), }); } let server_config = server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone()); if incoming .crypto .packet .remote .decrypt( packet_number, &incoming.packet.header_data, &mut incoming.packet.payload, ) .is_err() { debug!(packet_number, "failed to authenticate initial packet"); self.index.remove_initial(dst_cid); return Err(AcceptError { cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(), response: None, }); }; let ch = ConnectionHandle(self.connections.vacant_key()); let loc_cid = self.new_cid(ch); let mut params = TransportParameters::new( &server_config.transport, &self.config, self.local_cid_generator.as_ref(), loc_cid, Some(&server_config), ); params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid)); params.original_dst_cid = Some(incoming.orig_dst_cid); params.retry_src_cid = incoming.retry_src_cid; let mut pref_addr_cid = None; if server_config.preferred_address_v4.is_some() || server_config.preferred_address_v6.is_some() { let cid = self.new_cid(ch); pref_addr_cid = Some(cid); params.preferred_address = Some(PreferredAddress { address_v4: server_config.preferred_address_v4, address_v6: server_config.preferred_address_v6, connection_id: cid, stateless_reset_token: ResetToken::new(&*self.config.reset_key, &cid), }); } let tls = server_config.crypto.clone().start_session(version, ¶ms); let transport_config = server_config.transport.clone(); let mut conn = self.add_connection( ch, version, dst_cid, loc_cid, src_cid, pref_addr_cid, incoming.addresses, now, tls, Some(server_config), transport_config, remote_address_validated, ); self.index.insert_initial(dst_cid, ch); match conn.handle_first_packet( now, incoming.addresses.remote, incoming.ecn, packet_number, incoming.packet, incoming.rest, ) { Ok(()) => { trace!(id = ch.0, icid = %dst_cid, "new connection"); for event in incoming_buffer.datagrams { conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event))) } Ok((ch, conn)) } Err(e) => { debug!("handshake failed: {}", e); self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained)); let response = match e { ConnectionError::TransportError(ref e) => Some(self.initial_close( version, incoming.addresses, &incoming.crypto, &src_cid, e.clone(), buf, )), _ => None, }; Err(AcceptError { cause: e, response }) } } } /// Check if we should refuse a connection attempt regardless of the packet's contents fn early_validate_first_packet( &mut self, header: &ProtectedInitialHeader, ) -> Result<(), TransportError> { let config = &self.server_config.as_ref().unwrap(); if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming { return Err(TransportError::CONNECTION_REFUSED("")); } // RFC9000 ยง7.2 dictates that initial (client-chosen) destination CIDs must be at least 8 // bytes. If this is a Retry packet, then the length must instead match our usual CID // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll // also need to validate CID length for those after decoding the token. if header.dst_cid.len() < 8 && (header.token_pos.is_empty() || header.dst_cid.len() != self.local_cid_generator.cid_len()) { debug!( "rejecting connection due to invalid DCID length {}", header.dst_cid.len() ); return Err(TransportError::PROTOCOL_VIOLATION( "invalid destination CID length", )); } Ok(()) } /// Reject this incoming connection attempt pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec) -> Transmit { self.clean_up_incoming(&incoming); incoming.improper_drop_warner.dismiss(); self.initial_close( incoming.packet.header.version, incoming.addresses, &incoming.crypto, &incoming.packet.header.src_cid, TransportError::CONNECTION_REFUSED(""), buf, ) } /// Respond with a retry packet, requiring the client to retry with address validation /// /// Errors if `incoming.remote_address_validated()` is true. pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec) -> Result { if incoming.remote_address_validated() { return Err(RetryError(incoming)); } self.clean_up_incoming(&incoming); incoming.improper_drop_warner.dismiss(); let server_config = self.server_config.as_ref().unwrap(); // First Initial // The peer will use this as the DCID of its following Initials. Initial DCIDs are // looked up separately from Handshake/Data DCIDs, so there is no risk of collision // with established connections. In the unlikely event that a collision occurs // between two connections in the initial phase, both will fail fast and may be // retried by the application layer. let loc_cid = self.local_cid_generator.generate_cid(); let token = RetryToken { orig_dst_cid: incoming.packet.header.dst_cid, issued: SystemTime::now(), } .encode( &*server_config.token_key, &incoming.addresses.remote, &loc_cid, ); let header = Header::Retry { src_cid: loc_cid, dst_cid: incoming.packet.header.src_cid, version: incoming.packet.header.version, }; let encode = header.encode(buf); buf.put_slice(&token); buf.extend_from_slice(&server_config.crypto.retry_tag( incoming.packet.header.version, &incoming.packet.header.dst_cid, buf, )); encode.finish(buf, &*incoming.crypto.header.local, None); Ok(Transmit { destination: incoming.addresses.remote, ecn: None, size: buf.len(), segment_size: None, src_ip: incoming.addresses.local_ip, }) } /// Ignore this incoming connection attempt, not sending any packet in response /// /// Doing this actively, rather than merely dropping the [`Incoming`], is necessary to prevent /// memory leaks due to state within [`Endpoint`] tracking the incoming connection. pub fn ignore(&mut self, incoming: Incoming) { self.clean_up_incoming(&incoming); incoming.improper_drop_warner.dismiss(); } /// Clean up endpoint data structures associated with an `Incoming`. fn clean_up_incoming(&mut self, incoming: &Incoming) { self.index.remove_initial(incoming.packet.header.dst_cid); let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx); self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes; } fn add_connection( &mut self, ch: ConnectionHandle, version: u32, init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, pref_addr_cid: Option, addresses: FourTuple, now: Instant, tls: Box, server_config: Option>, transport_config: Arc, path_validated: bool, ) -> Connection { let mut rng_seed = [0; 32]; self.rng.fill_bytes(&mut rng_seed); let side = match server_config.is_some() { true => Side::Server, false => Side::Client, }; let conn = Connection::new( self.config.clone(), server_config, transport_config, init_cid, loc_cid, rem_cid, pref_addr_cid, addresses.remote, addresses.local_ip, tls, self.local_cid_generator.as_ref(), now, version, self.allow_mtud, rng_seed, path_validated, ); let mut cids_issued = 0; let mut loc_cids = FxHashMap::default(); loc_cids.insert(cids_issued, loc_cid); cids_issued += 1; if let Some(cid) = pref_addr_cid { debug_assert_eq!(cids_issued, 1, "preferred address cid seq must be 1"); loc_cids.insert(cids_issued, cid); cids_issued += 1; } let id = self.connections.insert(ConnectionMeta { init_cid, cids_issued, loc_cids, addresses, side, reset_token: None, }); debug_assert_eq!(id, ch.0, "connection handle allocation out of sync"); self.index.insert_conn(addresses, loc_cid, ch, side); conn } fn initial_close( &mut self, version: u32, addresses: FourTuple, crypto: &Keys, remote_id: &ConnectionId, reason: TransportError, buf: &mut Vec, ) -> Transmit { // We don't need to worry about CID collisions in initial closes because the peer // shouldn't respond, and if it does, and the CID collides, we'll just drop the // unexpected response. let local_id = self.local_cid_generator.generate_cid(); let number = PacketNumber::U8(0); let header = Header::Initial(InitialHeader { dst_cid: *remote_id, src_cid: local_id, number, token: Bytes::new(), version, }); let partial_encode = header.encode(buf); let max_len = INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len(); frame::Close::from(reason).encode(buf, max_len); buf.resize(buf.len() + crypto.packet.local.tag_len(), 0); partial_encode.finish(buf, &*crypto.header.local, Some((0, &*crypto.packet.local))); Transmit { destination: addresses.remote, ecn: None, size: buf.len(), segment_size: None, src_ip: addresses.local_ip, } } /// Access the configuration used by this endpoint pub fn config(&self) -> &EndpointConfig { &self.config } /// Number of connections that are currently open pub fn open_connections(&self) -> usize { self.connections.len() } /// Counter for the number of bytes currently used /// in the buffers for Initial and 0-RTT messages for pending incoming connections pub fn incoming_buffer_bytes(&self) -> u64 { self.all_incoming_buffers_total_bytes } #[cfg(test)] pub(crate) fn known_connections(&self) -> usize { let x = self.connections.len(); debug_assert_eq!(x, self.index.connection_ids_initial.len()); // Not all connections have known reset tokens debug_assert!(x >= self.index.connection_reset_tokens.0.len()); // Not all connections have unique remotes, and 0-length CIDs might not be in use. debug_assert!(x >= self.index.incoming_connection_remotes.len()); debug_assert!(x >= self.index.outgoing_connection_remotes.len()); x } #[cfg(test)] pub(crate) fn known_cids(&self) -> usize { self.index.connection_ids.len() } /// Whether we've used up 3/4 of the available CID space /// /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot. fn cids_exhausted(&self) -> bool { self.local_cid_generator.cid_len() <= 4 && self.local_cid_generator.cid_len() != 0 && (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8) - self.index.connection_ids.len()) < 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2) } } impl fmt::Debug for Endpoint { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Endpoint") .field("rng", &self.rng) .field("index", &self.index) .field("connections", &self.connections) .field("config", &self.config) .field("server_config", &self.server_config) // incoming_buffers too large .field("incoming_buffers.len", &self.incoming_buffers.len()) .field( "all_incoming_buffers_total_bytes", &self.all_incoming_buffers_total_bytes, ) .finish() } } /// Buffered Initial and 0-RTT messages for a pending incoming connection #[derive(Default)] struct IncomingBuffer { datagrams: Vec, total_bytes: u64, } /// Part of protocol state incoming datagrams can be routed to #[derive(Copy, Clone, Debug)] enum RouteDatagramTo { Incoming(usize), Connection(ConnectionHandle), } /// Maps packets to existing connections #[derive(Default, Debug)] struct ConnectionIndex { /// Identifies connections based on the initial DCID the peer utilized /// /// Uses a standard `HashMap` to protect against hash collision attacks. /// /// Used by the server, not the client. connection_ids_initial: HashMap, /// Identifies connections based on locally created CIDs /// /// Uses a cheaper hash function since keys are locally created connection_ids: FxHashMap, /// Identifies incoming connections with zero-length CIDs /// /// Uses a standard `HashMap` to protect against hash collision attacks. incoming_connection_remotes: HashMap, /// Identifies outgoing connections with zero-length CIDs /// /// We don't yet support explicit source addresses for client connections, and zero-length CIDs /// require a unique four-tuple, so at most one client connection with zero-length local CIDs /// may be established per remote. We must omit the local address from the key because we don't /// necessarily know what address we're sending from, and hence receiving at. /// /// Uses a standard `HashMap` to protect against hash collision attacks. outgoing_connection_remotes: HashMap, /// Reset tokens provided by the peer for the CID each connection is currently sending to /// /// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct /// recipient, if any. connection_reset_tokens: ResetTokenTable, } impl ConnectionIndex { /// Associate an incoming connection with its initial destination CID fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) { if dst_cid.len() == 0 { return; } self.connection_ids_initial .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key)); } /// Remove an association with an initial destination CID fn remove_initial(&mut self, dst_cid: ConnectionId) { if dst_cid.len() == 0 { return; } let removed = self.connection_ids_initial.remove(&dst_cid); debug_assert!(removed.is_some()); } /// Associate a connection with its initial destination CID fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) { if dst_cid.len() == 0 { return; } self.connection_ids_initial .insert(dst_cid, RouteDatagramTo::Connection(connection)); } /// Associate a connection with its first locally-chosen destination CID if used, or otherwise /// its current 4-tuple fn insert_conn( &mut self, addresses: FourTuple, dst_cid: ConnectionId, connection: ConnectionHandle, side: Side, ) { match dst_cid.len() { 0 => match side { Side::Server => { self.incoming_connection_remotes .insert(addresses, connection); } Side::Client => { self.outgoing_connection_remotes .insert(addresses.remote, connection); } }, _ => { self.connection_ids.insert(dst_cid, connection); } } } /// Discard a connection ID fn retire(&mut self, dst_cid: &ConnectionId) { self.connection_ids.remove(dst_cid); } /// Remove all references to a connection fn remove(&mut self, conn: &ConnectionMeta) { if conn.side.is_server() { self.remove_initial(conn.init_cid); } for cid in conn.loc_cids.values() { self.connection_ids.remove(cid); } self.incoming_connection_remotes.remove(&conn.addresses); self.outgoing_connection_remotes .remove(&conn.addresses.remote); if let Some((remote, token)) = conn.reset_token { self.connection_reset_tokens.remove(remote, token); } } /// Find the existing connection that `datagram` should be routed to, if any fn get(&self, addresses: &FourTuple, datagram: &PartialDecode) -> Option { if datagram.dst_cid().len() != 0 { if let Some(&ch) = self.connection_ids.get(datagram.dst_cid()) { return Some(RouteDatagramTo::Connection(ch)); } } if datagram.is_initial() || datagram.is_0rtt() { if let Some(&ch) = self.connection_ids_initial.get(datagram.dst_cid()) { return Some(ch); } } if datagram.dst_cid().len() == 0 { if let Some(&ch) = self.incoming_connection_remotes.get(addresses) { return Some(RouteDatagramTo::Connection(ch)); } if let Some(&ch) = self.outgoing_connection_remotes.get(&addresses.remote) { return Some(RouteDatagramTo::Connection(ch)); } } let data = datagram.data(); if data.len() < RESET_TOKEN_SIZE { return None; } self.connection_reset_tokens .get(addresses.remote, &data[data.len() - RESET_TOKEN_SIZE..]) .cloned() .map(RouteDatagramTo::Connection) } } #[derive(Debug)] pub(crate) struct ConnectionMeta { init_cid: ConnectionId, /// Number of local connection IDs that have been issued in NEW_CONNECTION_ID frames. cids_issued: u64, loc_cids: FxHashMap, /// Remote/local addresses the connection began with /// /// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't /// bother keeping it up to date. addresses: FourTuple, side: Side, /// Reset token provided by the peer for the CID we're currently sending to, and the address /// being sent to reset_token: Option<(SocketAddr, ResetToken)>, } /// Internal identifier for a `Connection` currently associated with an endpoint #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct ConnectionHandle(pub usize); impl From for usize { fn from(x: ConnectionHandle) -> Self { x.0 } } impl Index for Slab { type Output = ConnectionMeta; fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta { &self[ch.0] } } impl IndexMut for Slab { fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta { &mut self[ch.0] } } /// Event resulting from processing a single datagram #[allow(clippy::large_enum_variant)] // Not passed around extensively pub enum DatagramEvent { /// The datagram is redirected to its `Connection` ConnectionEvent(ConnectionHandle, ConnectionEvent), /// The datagram may result in starting a new `Connection` NewConnection(Incoming), /// Response generated directly by the endpoint Response(Transmit), } /// An incoming connection for which the server has not yet begun its part of the handshake. pub struct Incoming { addresses: FourTuple, ecn: Option, packet: InitialPacket, rest: Option, crypto: Keys, retry_src_cid: Option, orig_dst_cid: ConnectionId, incoming_idx: usize, improper_drop_warner: IncomingImproperDropWarner, } impl Incoming { /// The local IP address which was used when the peer established /// the connection /// /// This has the same behavior as [`Connection::local_ip`] pub fn local_ip(&self) -> Option { self.addresses.local_ip } /// The peer's UDP address. pub fn remote_address(&self) -> SocketAddr { self.addresses.remote } /// Whether the socket address that is initiating this connection has been validated. /// /// This means that the sender of the initial packet has proved that they can receive traffic /// sent to `self.remote_address()`. pub fn remote_address_validated(&self) -> bool { self.retry_src_cid.is_some() } /// The original destination connection ID sent by the client pub fn orig_dst_cid(&self) -> &ConnectionId { &self.orig_dst_cid } } impl fmt::Debug for Incoming { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Incoming") .field("addresses", &self.addresses) .field("ecn", &self.ecn) // packet doesn't implement debug // rest is too big and not meaningful enough .field("retry_src_cid", &self.retry_src_cid) .field("orig_dst_cid", &self.orig_dst_cid) .field("incoming_idx", &self.incoming_idx) // improper drop warner contains no information .finish_non_exhaustive() } } struct IncomingImproperDropWarner; impl IncomingImproperDropWarner { fn dismiss(self) { mem::forget(self); } } impl Drop for IncomingImproperDropWarner { fn drop(&mut self) { warn!("quinn_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \ (may cause memory leak and eventual inability to accept new connections)"); } } /// Errors in the parameters being used to create a new connection /// /// These arise before any I/O has been performed. #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum ConnectError { /// The endpoint can no longer create new connections /// /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled. #[error("endpoint stopping")] EndpointStopping, /// The connection could not be created because not enough of the CID space is available /// /// Try using longer connection IDs #[error("CIDs exhausted")] CidsExhausted, /// The given server name was malformed #[error("invalid server name: {0}")] InvalidServerName(String), /// The remote [`SocketAddr`] supplied was malformed /// /// Examples include attempting to connect to port 0, or using an inappropriate address family. #[error("invalid remote address: {0}")] InvalidRemoteAddress(SocketAddr), /// No default client configuration was set up /// /// Use `Endpoint::connect_with` to specify a client configuration. #[error("no default client config")] NoDefaultClientConfig, /// The local endpoint does not support the QUIC version specified in the client configuration #[error("unsupported QUIC version")] UnsupportedVersion, } /// Error type for attempting to accept an [`Incoming`] #[derive(Debug)] pub struct AcceptError { /// Underlying error describing reason for failure pub cause: ConnectionError, /// Optional response to transmit back pub response: Option, } /// Error for attempting to retry an [`Incoming`] which already bears an address /// validation token from a previous retry #[derive(Debug, Error)] #[error("retry() with validated Incoming")] pub struct RetryError(Incoming); impl RetryError { /// Get the [`Incoming`] pub fn into_incoming(self) -> Incoming { self.0 } } /// Reset Tokens which are associated with peer socket addresses /// /// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are /// peer generated and might be usable for hash collision attacks. #[derive(Default, Debug)] struct ResetTokenTable(HashMap>); impl ResetTokenTable { fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool { self.0 .entry(remote) .or_default() .insert(token, ch) .is_some() } fn remove(&mut self, remote: SocketAddr, token: ResetToken) { use std::collections::hash_map::Entry; match self.0.entry(remote) { Entry::Vacant(_) => {} Entry::Occupied(mut e) => { e.get_mut().remove(&token); if e.get().is_empty() { e.remove_entry(); } } } } fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> { let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?); self.0.get(&remote)?.get(&token) } } /// Identifies a connection by the combination of remote and local addresses /// /// Including the local ensures good behavior when the host has multiple IP addresses on the same /// subnet and zero-length connection IDs are in use. #[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)] struct FourTuple { remote: SocketAddr, // A single socket can only listen on a single port, so no need to store it explicitly local_ip: Option, } quinn-proto-0.11.9/src/frame.rs000064400000000000000000000731601046102023000144450ustar 00000000000000use std::{ fmt::{self, Write}, io, mem, ops::{Range, RangeInclusive}, }; use bytes::{Buf, BufMut, Bytes}; use tinyvec::TinyVec; use crate::{ coding::{self, BufExt, BufMutExt, UnexpectedEnd}, range_set::ArrayRangeSet, shared::{ConnectionId, EcnCodepoint}, Dir, ResetToken, StreamId, TransportError, TransportErrorCode, VarInt, MAX_CID_SIZE, RESET_TOKEN_SIZE, }; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; /// A QUIC frame type #[derive(Copy, Clone, Eq, PartialEq)] pub struct FrameType(u64); impl FrameType { fn stream(self) -> Option { if STREAM_TYS.contains(&self.0) { Some(StreamInfo(self.0 as u8)) } else { None } } fn datagram(self) -> Option { if DATAGRAM_TYS.contains(&self.0) { Some(DatagramInfo(self.0 as u8)) } else { None } } } impl coding::Codec for FrameType { fn decode(buf: &mut B) -> coding::Result { Ok(Self(buf.get_var()?)) } fn encode(&self, buf: &mut B) { buf.write_var(self.0); } } pub(crate) trait FrameStruct { /// Smallest number of bytes this type of frame is guaranteed to fit within. const SIZE_BOUND: usize; } macro_rules! frame_types { {$($name:ident = $val:expr,)*} => { impl FrameType { $(pub(crate) const $name: FrameType = FrameType($val);)* } impl fmt::Debug for FrameType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* _ => write!(f, "Type({:02x})", self.0) } } } impl fmt::Display for FrameType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* x if STREAM_TYS.contains(&x) => f.write_str("STREAM"), x if DATAGRAM_TYS.contains(&x) => f.write_str("DATAGRAM"), _ => write!(f, "", self.0), } } } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] struct StreamInfo(u8); impl StreamInfo { fn fin(self) -> bool { self.0 & 0x01 != 0 } fn len(self) -> bool { self.0 & 0x02 != 0 } fn off(self) -> bool { self.0 & 0x04 != 0 } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] struct DatagramInfo(u8); impl DatagramInfo { fn len(self) -> bool { self.0 & 0x01 != 0 } } frame_types! { PADDING = 0x00, PING = 0x01, ACK = 0x02, ACK_ECN = 0x03, RESET_STREAM = 0x04, STOP_SENDING = 0x05, CRYPTO = 0x06, NEW_TOKEN = 0x07, // STREAM MAX_DATA = 0x10, MAX_STREAM_DATA = 0x11, MAX_STREAMS_BIDI = 0x12, MAX_STREAMS_UNI = 0x13, DATA_BLOCKED = 0x14, STREAM_DATA_BLOCKED = 0x15, STREAMS_BLOCKED_BIDI = 0x16, STREAMS_BLOCKED_UNI = 0x17, NEW_CONNECTION_ID = 0x18, RETIRE_CONNECTION_ID = 0x19, PATH_CHALLENGE = 0x1a, PATH_RESPONSE = 0x1b, CONNECTION_CLOSE = 0x1c, APPLICATION_CLOSE = 0x1d, HANDSHAKE_DONE = 0x1e, // ACK Frequency ACK_FREQUENCY = 0xaf, IMMEDIATE_ACK = 0x1f, // DATAGRAM } const STREAM_TYS: RangeInclusive = RangeInclusive::new(0x08, 0x0f); const DATAGRAM_TYS: RangeInclusive = RangeInclusive::new(0x30, 0x31); #[derive(Debug)] pub(crate) enum Frame { Padding, Ping, Ack(Ack), ResetStream(ResetStream), StopSending(StopSending), Crypto(Crypto), NewToken { token: Bytes }, Stream(Stream), MaxData(VarInt), MaxStreamData { id: StreamId, offset: u64 }, MaxStreams { dir: Dir, count: u64 }, DataBlocked { offset: u64 }, StreamDataBlocked { id: StreamId, offset: u64 }, StreamsBlocked { dir: Dir, limit: u64 }, NewConnectionId(NewConnectionId), RetireConnectionId { sequence: u64 }, PathChallenge(u64), PathResponse(u64), Close(Close), Datagram(Datagram), AckFrequency(AckFrequency), ImmediateAck, HandshakeDone, } impl Frame { pub(crate) fn ty(&self) -> FrameType { use self::Frame::*; match *self { Padding => FrameType::PADDING, ResetStream(_) => FrameType::RESET_STREAM, Close(self::Close::Connection(_)) => FrameType::CONNECTION_CLOSE, Close(self::Close::Application(_)) => FrameType::APPLICATION_CLOSE, MaxData(_) => FrameType::MAX_DATA, MaxStreamData { .. } => FrameType::MAX_STREAM_DATA, MaxStreams { dir: Dir::Bi, .. } => FrameType::MAX_STREAMS_BIDI, MaxStreams { dir: Dir::Uni, .. } => FrameType::MAX_STREAMS_UNI, Ping => FrameType::PING, DataBlocked { .. } => FrameType::DATA_BLOCKED, StreamDataBlocked { .. } => FrameType::STREAM_DATA_BLOCKED, StreamsBlocked { dir: Dir::Bi, .. } => FrameType::STREAMS_BLOCKED_BIDI, StreamsBlocked { dir: Dir::Uni, .. } => FrameType::STREAMS_BLOCKED_UNI, StopSending { .. } => FrameType::STOP_SENDING, RetireConnectionId { .. } => FrameType::RETIRE_CONNECTION_ID, Ack(_) => FrameType::ACK, Stream(ref x) => { let mut ty = *STREAM_TYS.start(); if x.fin { ty |= 0x01; } if x.offset != 0 { ty |= 0x04; } FrameType(ty) } PathChallenge(_) => FrameType::PATH_CHALLENGE, PathResponse(_) => FrameType::PATH_RESPONSE, NewConnectionId { .. } => FrameType::NEW_CONNECTION_ID, Crypto(_) => FrameType::CRYPTO, NewToken { .. } => FrameType::NEW_TOKEN, Datagram(_) => FrameType(*DATAGRAM_TYS.start()), AckFrequency(_) => FrameType::ACK_FREQUENCY, ImmediateAck => FrameType::IMMEDIATE_ACK, HandshakeDone => FrameType::HANDSHAKE_DONE, } } pub(crate) fn is_ack_eliciting(&self) -> bool { !matches!(*self, Self::Ack(_) | Self::Padding | Self::Close(_)) } } #[derive(Clone, Debug)] pub enum Close { Connection(ConnectionClose), Application(ApplicationClose), } impl Close { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { match *self { Self::Connection(ref x) => x.encode(out, max_len), Self::Application(ref x) => x.encode(out, max_len), } } pub(crate) fn is_transport_layer(&self) -> bool { matches!(*self, Self::Connection(_)) } } impl From for Close { fn from(x: TransportError) -> Self { Self::Connection(x.into()) } } impl From for Close { fn from(x: ConnectionClose) -> Self { Self::Connection(x) } } impl From for Close { fn from(x: ApplicationClose) -> Self { Self::Application(x) } } /// Reason given by the transport for closing the connection #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConnectionClose { /// Class of error as encoded in the specification pub error_code: TransportErrorCode, /// Type of frame that caused the close pub frame_type: Option, /// Human-readable reason for the close pub reason: Bytes, } impl fmt::Display for ConnectionClose { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.error_code.fmt(f)?; if !self.reason.as_ref().is_empty() { f.write_str(": ")?; f.write_str(&String::from_utf8_lossy(&self.reason))?; } Ok(()) } } impl From for ConnectionClose { fn from(x: TransportError) -> Self { Self { error_code: x.code, frame_type: x.frame, reason: x.reason.into(), } } } impl FrameStruct for ConnectionClose { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } impl ConnectionClose { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { out.write(FrameType::CONNECTION_CLOSE); // 1 byte out.write(self.error_code); // <= 8 bytes let ty = self.frame_type.map_or(0, |x| x.0); out.write_var(ty); // <= 8 bytes let max_len = max_len - 3 - VarInt::from_u64(ty).unwrap().size() - VarInt::from_u64(self.reason.len() as u64).unwrap().size(); let actual_len = self.reason.len().min(max_len); out.write_var(actual_len as u64); // <= 8 bytes out.put_slice(&self.reason[0..actual_len]); // whatever's left } } /// Reason given by an application for closing the connection #[derive(Debug, Clone, PartialEq, Eq)] pub struct ApplicationClose { /// Application-specific reason code pub error_code: VarInt, /// Human-readable reason for the close pub reason: Bytes, } impl fmt::Display for ApplicationClose { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if !self.reason.as_ref().is_empty() { f.write_str(&String::from_utf8_lossy(&self.reason))?; f.write_str(" (code ")?; self.error_code.fmt(f)?; f.write_str(")")?; } else { self.error_code.fmt(f)?; } Ok(()) } } impl FrameStruct for ApplicationClose { const SIZE_BOUND: usize = 1 + 8 + 8; } impl ApplicationClose { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { out.write(FrameType::APPLICATION_CLOSE); // 1 byte out.write(self.error_code); // <= 8 bytes let max_len = max_len - 3 - VarInt::from_u64(self.reason.len() as u64).unwrap().size(); let actual_len = self.reason.len().min(max_len); out.write_var(actual_len as u64); // <= 8 bytes out.put_slice(&self.reason[0..actual_len]); // whatever's left } } #[derive(Clone, Eq, PartialEq)] pub struct Ack { pub largest: u64, pub delay: u64, pub additional: Bytes, pub ecn: Option, } impl fmt::Debug for Ack { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut ranges = "[".to_string(); let mut first = true; for range in self.iter() { if !first { ranges.push(','); } write!(ranges, "{range:?}").unwrap(); first = false; } ranges.push(']'); f.debug_struct("Ack") .field("largest", &self.largest) .field("delay", &self.delay) .field("ecn", &self.ecn) .field("ranges", &ranges) .finish() } } impl<'a> IntoIterator for &'a Ack { type Item = RangeInclusive; type IntoIter = AckIter<'a>; fn into_iter(self) -> AckIter<'a> { AckIter::new(self.largest, &self.additional[..]) } } impl Ack { pub fn encode( delay: u64, ranges: &ArrayRangeSet, ecn: Option<&EcnCounts>, buf: &mut W, ) { let mut rest = ranges.iter().rev(); let first = rest.next().unwrap(); let largest = first.end - 1; let first_size = first.end - first.start; buf.write(if ecn.is_some() { FrameType::ACK_ECN } else { FrameType::ACK }); buf.write_var(largest); buf.write_var(delay); buf.write_var(ranges.len() as u64 - 1); buf.write_var(first_size - 1); let mut prev = first.start; for block in rest { let size = block.end - block.start; buf.write_var(prev - block.end - 1); buf.write_var(size - 1); prev = block.start; } if let Some(x) = ecn { x.encode(buf) } } pub fn iter(&self) -> AckIter<'_> { self.into_iter() } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct EcnCounts { pub ect0: u64, pub ect1: u64, pub ce: u64, } impl std::ops::AddAssign for EcnCounts { fn add_assign(&mut self, rhs: EcnCodepoint) { match rhs { EcnCodepoint::Ect0 => { self.ect0 += 1; } EcnCodepoint::Ect1 => { self.ect1 += 1; } EcnCodepoint::Ce => { self.ce += 1; } } } } impl EcnCounts { pub const ZERO: Self = Self { ect0: 0, ect1: 0, ce: 0, }; pub fn encode(&self, out: &mut W) { out.write_var(self.ect0); out.write_var(self.ect1); out.write_var(self.ce); } } #[derive(Debug, Clone)] pub(crate) struct Stream { pub(crate) id: StreamId, pub(crate) offset: u64, pub(crate) fin: bool, pub(crate) data: Bytes, } impl FrameStruct for Stream { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } /// Metadata from a stream frame #[derive(Debug, Clone)] pub(crate) struct StreamMeta { pub(crate) id: StreamId, pub(crate) offsets: Range, pub(crate) fin: bool, } // This manual implementation exists because `Default` is not implemented for `StreamId` impl Default for StreamMeta { fn default() -> Self { Self { id: StreamId(0), offsets: 0..0, fin: false, } } } impl StreamMeta { pub(crate) fn encode(&self, length: bool, out: &mut W) { let mut ty = *STREAM_TYS.start(); if self.offsets.start != 0 { ty |= 0x04; } if length { ty |= 0x02; } if self.fin { ty |= 0x01; } out.write_var(ty); // 1 byte out.write(self.id); // <=8 bytes if self.offsets.start != 0 { out.write_var(self.offsets.start); // <=8 bytes } if length { out.write_var(self.offsets.end - self.offsets.start); // <=8 bytes } } } /// A vector of [`StreamMeta`] with optimization for the single element case pub(crate) type StreamMetaVec = TinyVec<[StreamMeta; 1]>; #[derive(Debug, Clone)] pub(crate) struct Crypto { pub(crate) offset: u64, pub(crate) data: Bytes, } impl Crypto { pub(crate) const SIZE_BOUND: usize = 17; pub(crate) fn encode(&self, out: &mut W) { out.write(FrameType::CRYPTO); out.write_var(self.offset); out.write_var(self.data.len() as u64); out.put_slice(&self.data); } } pub(crate) struct Iter { // TODO: ditch io::Cursor after bytes 0.5 bytes: io::Cursor, last_ty: Option, } impl Iter { pub(crate) fn new(payload: Bytes) -> Result { if payload.is_empty() { // "An endpoint MUST treat receipt of a packet containing no frames as a // connection error of type PROTOCOL_VIOLATION." // https://www.rfc-editor.org/rfc/rfc9000.html#name-frames-and-frame-types return Err(TransportError::PROTOCOL_VIOLATION( "packet payload is empty", )); } Ok(Self { bytes: io::Cursor::new(payload), last_ty: None, }) } fn take_len(&mut self) -> Result { let len = self.bytes.get_var()?; if len > self.bytes.remaining() as u64 { return Err(UnexpectedEnd); } let start = self.bytes.position() as usize; self.bytes.advance(len as usize); Ok(self.bytes.get_ref().slice(start..(start + len as usize))) } fn try_next(&mut self) -> Result { let ty = self.bytes.get::()?; self.last_ty = Some(ty); Ok(match ty { FrameType::PADDING => Frame::Padding, FrameType::RESET_STREAM => Frame::ResetStream(ResetStream { id: self.bytes.get()?, error_code: self.bytes.get()?, final_offset: self.bytes.get()?, }), FrameType::CONNECTION_CLOSE => Frame::Close(Close::Connection(ConnectionClose { error_code: self.bytes.get()?, frame_type: { let x = self.bytes.get_var()?; if x == 0 { None } else { Some(FrameType(x)) } }, reason: self.take_len()?, })), FrameType::APPLICATION_CLOSE => Frame::Close(Close::Application(ApplicationClose { error_code: self.bytes.get()?, reason: self.take_len()?, })), FrameType::MAX_DATA => Frame::MaxData(self.bytes.get()?), FrameType::MAX_STREAM_DATA => Frame::MaxStreamData { id: self.bytes.get()?, offset: self.bytes.get_var()?, }, FrameType::MAX_STREAMS_BIDI => Frame::MaxStreams { dir: Dir::Bi, count: self.bytes.get_var()?, }, FrameType::MAX_STREAMS_UNI => Frame::MaxStreams { dir: Dir::Uni, count: self.bytes.get_var()?, }, FrameType::PING => Frame::Ping, FrameType::DATA_BLOCKED => Frame::DataBlocked { offset: self.bytes.get_var()?, }, FrameType::STREAM_DATA_BLOCKED => Frame::StreamDataBlocked { id: self.bytes.get()?, offset: self.bytes.get_var()?, }, FrameType::STREAMS_BLOCKED_BIDI => Frame::StreamsBlocked { dir: Dir::Bi, limit: self.bytes.get_var()?, }, FrameType::STREAMS_BLOCKED_UNI => Frame::StreamsBlocked { dir: Dir::Uni, limit: self.bytes.get_var()?, }, FrameType::STOP_SENDING => Frame::StopSending(StopSending { id: self.bytes.get()?, error_code: self.bytes.get()?, }), FrameType::RETIRE_CONNECTION_ID => Frame::RetireConnectionId { sequence: self.bytes.get_var()?, }, FrameType::ACK | FrameType::ACK_ECN => { let largest = self.bytes.get_var()?; let delay = self.bytes.get_var()?; let extra_blocks = self.bytes.get_var()? as usize; let start = self.bytes.position() as usize; scan_ack_blocks(&mut self.bytes, largest, extra_blocks)?; let end = self.bytes.position() as usize; Frame::Ack(Ack { delay, largest, additional: self.bytes.get_ref().slice(start..end), ecn: if ty != FrameType::ACK_ECN { None } else { Some(EcnCounts { ect0: self.bytes.get_var()?, ect1: self.bytes.get_var()?, ce: self.bytes.get_var()?, }) }, }) } FrameType::PATH_CHALLENGE => Frame::PathChallenge(self.bytes.get()?), FrameType::PATH_RESPONSE => Frame::PathResponse(self.bytes.get()?), FrameType::NEW_CONNECTION_ID => { let sequence = self.bytes.get_var()?; let retire_prior_to = self.bytes.get_var()?; if retire_prior_to > sequence { return Err(IterErr::Malformed); } let length = self.bytes.get::()? as usize; if length > MAX_CID_SIZE || length == 0 { return Err(IterErr::Malformed); } if length > self.bytes.remaining() { return Err(IterErr::UnexpectedEnd); } let mut stage = [0; MAX_CID_SIZE]; self.bytes.copy_to_slice(&mut stage[0..length]); let id = ConnectionId::new(&stage[..length]); if self.bytes.remaining() < 16 { return Err(IterErr::UnexpectedEnd); } let mut reset_token = [0; RESET_TOKEN_SIZE]; self.bytes.copy_to_slice(&mut reset_token); Frame::NewConnectionId(NewConnectionId { sequence, retire_prior_to, id, reset_token: reset_token.into(), }) } FrameType::CRYPTO => Frame::Crypto(Crypto { offset: self.bytes.get_var()?, data: self.take_len()?, }), FrameType::NEW_TOKEN => Frame::NewToken { token: self.take_len()?, }, FrameType::HANDSHAKE_DONE => Frame::HandshakeDone, FrameType::ACK_FREQUENCY => Frame::AckFrequency(AckFrequency { sequence: self.bytes.get()?, ack_eliciting_threshold: self.bytes.get()?, request_max_ack_delay: self.bytes.get()?, reordering_threshold: self.bytes.get()?, }), FrameType::IMMEDIATE_ACK => Frame::ImmediateAck, _ => { if let Some(s) = ty.stream() { Frame::Stream(Stream { id: self.bytes.get()?, offset: if s.off() { self.bytes.get_var()? } else { 0 }, fin: s.fin(), data: if s.len() { self.take_len()? } else { self.take_remaining() }, }) } else if let Some(d) = ty.datagram() { Frame::Datagram(Datagram { data: if d.len() { self.take_len()? } else { self.take_remaining() }, }) } else { return Err(IterErr::InvalidFrameId); } } }) } fn take_remaining(&mut self) -> Bytes { let mut x = mem::replace(self.bytes.get_mut(), Bytes::new()); x.advance(self.bytes.position() as usize); self.bytes.set_position(0); x } } impl Iterator for Iter { type Item = Result; fn next(&mut self) -> Option { if !self.bytes.has_remaining() { return None; } match self.try_next() { Ok(x) => Some(Ok(x)), Err(e) => { // Corrupt frame, skip it and everything that follows self.bytes = io::Cursor::new(Bytes::new()); Some(Err(InvalidFrame { ty: self.last_ty, reason: e.reason(), })) } } } } #[derive(Debug)] pub(crate) struct InvalidFrame { pub(crate) ty: Option, pub(crate) reason: &'static str, } impl From for TransportError { fn from(err: InvalidFrame) -> Self { let mut te = Self::FRAME_ENCODING_ERROR(err.reason); te.frame = err.ty; te } } fn scan_ack_blocks(buf: &mut io::Cursor, largest: u64, n: usize) -> Result<(), IterErr> { let first_block = buf.get_var()?; let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?; for _ in 0..n { let gap = buf.get_var()?; smallest = smallest.checked_sub(gap + 2).ok_or(IterErr::Malformed)?; let block = buf.get_var()?; smallest = smallest.checked_sub(block).ok_or(IterErr::Malformed)?; } Ok(()) } enum IterErr { UnexpectedEnd, InvalidFrameId, Malformed, } impl IterErr { fn reason(&self) -> &'static str { use self::IterErr::*; match *self { UnexpectedEnd => "unexpected end", InvalidFrameId => "invalid frame ID", Malformed => "malformed", } } } impl From for IterErr { fn from(_: UnexpectedEnd) -> Self { Self::UnexpectedEnd } } #[derive(Debug, Clone)] pub struct AckIter<'a> { largest: u64, data: io::Cursor<&'a [u8]>, } impl<'a> AckIter<'a> { fn new(largest: u64, payload: &'a [u8]) -> Self { let data = io::Cursor::new(payload); Self { largest, data } } } impl Iterator for AckIter<'_> { type Item = RangeInclusive; fn next(&mut self) -> Option> { if !self.data.has_remaining() { return None; } let block = self.data.get_var().unwrap(); let largest = self.largest; if let Ok(gap) = self.data.get_var() { self.largest -= block + gap + 2; } Some(largest - block..=largest) } } #[allow(unreachable_pub)] // fuzzing only #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone)] pub struct ResetStream { pub(crate) id: StreamId, pub(crate) error_code: VarInt, pub(crate) final_offset: VarInt, } impl FrameStruct for ResetStream { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } impl ResetStream { pub(crate) fn encode(&self, out: &mut W) { out.write(FrameType::RESET_STREAM); // 1 byte out.write(self.id); // <= 8 bytes out.write(self.error_code); // <= 8 bytes out.write(self.final_offset); // <= 8 bytes } } #[derive(Debug, Copy, Clone)] pub(crate) struct StopSending { pub(crate) id: StreamId, pub(crate) error_code: VarInt, } impl FrameStruct for StopSending { const SIZE_BOUND: usize = 1 + 8 + 8; } impl StopSending { pub(crate) fn encode(&self, out: &mut W) { out.write(FrameType::STOP_SENDING); // 1 byte out.write(self.id); // <= 8 bytes out.write(self.error_code) // <= 8 bytes } } #[derive(Debug, Copy, Clone)] pub(crate) struct NewConnectionId { pub(crate) sequence: u64, pub(crate) retire_prior_to: u64, pub(crate) id: ConnectionId, pub(crate) reset_token: ResetToken, } impl NewConnectionId { pub(crate) fn encode(&self, out: &mut W) { out.write(FrameType::NEW_CONNECTION_ID); out.write_var(self.sequence); out.write_var(self.retire_prior_to); out.write(self.id.len() as u8); out.put_slice(&self.id); out.put_slice(&self.reset_token); } } /// Smallest number of bytes this type of frame is guaranteed to fit within. pub(crate) const RETIRE_CONNECTION_ID_SIZE_BOUND: usize = 9; /// An unreliable datagram #[derive(Debug, Clone)] pub struct Datagram { /// Payload pub data: Bytes, } impl FrameStruct for Datagram { const SIZE_BOUND: usize = 1 + 8; } impl Datagram { pub(crate) fn encode(&self, length: bool, out: &mut Vec) { out.write(FrameType(*DATAGRAM_TYS.start() | u64::from(length))); // 1 byte if length { // Safe to unwrap because we check length sanity before queueing datagrams out.write(VarInt::from_u64(self.data.len() as u64).unwrap()); // <= 8 bytes } out.extend_from_slice(&self.data); } pub(crate) fn size(&self, length: bool) -> usize { 1 + if length { VarInt::from_u64(self.data.len() as u64).unwrap().size() } else { 0 } + self.data.len() } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) struct AckFrequency { pub(crate) sequence: VarInt, pub(crate) ack_eliciting_threshold: VarInt, pub(crate) request_max_ack_delay: VarInt, pub(crate) reordering_threshold: VarInt, } impl AckFrequency { pub(crate) fn encode(&self, buf: &mut W) { buf.write(FrameType::ACK_FREQUENCY); buf.write(self.sequence); buf.write(self.ack_eliciting_threshold); buf.write(self.request_max_ack_delay); buf.write(self.reordering_threshold); } } #[cfg(test)] mod test { use super::*; use crate::coding::Codec; use assert_matches::assert_matches; fn frames(buf: Vec) -> Vec { Iter::new(Bytes::from(buf)) .unwrap() .collect::, _>>() .unwrap() } #[test] #[allow(clippy::range_plus_one)] fn ack_coding() { const PACKETS: &[u64] = &[1, 2, 3, 5, 10, 11, 14]; let mut ranges = ArrayRangeSet::new(); for &packet in PACKETS { ranges.insert(packet..packet + 1); } let mut buf = Vec::new(); const ECN: EcnCounts = EcnCounts { ect0: 42, ect1: 24, ce: 12, }; Ack::encode(42, &ranges, Some(&ECN), &mut buf); let frames = frames(buf); assert_eq!(frames.len(), 1); match frames[0] { Frame::Ack(ref ack) => { let mut packets = ack.iter().flatten().collect::>(); packets.sort_unstable(); assert_eq!(&packets[..], PACKETS); assert_eq!(ack.ecn, Some(ECN)); } ref x => panic!("incorrect frame {x:?}"), } } #[test] fn ack_frequency_coding() { let mut buf = Vec::new(); let original = AckFrequency { sequence: VarInt(42), ack_eliciting_threshold: VarInt(20), request_max_ack_delay: VarInt(50_000), reordering_threshold: VarInt(1), }; original.encode(&mut buf); let frames = frames(buf); assert_eq!(frames.len(), 1); match &frames[0] { Frame::AckFrequency(decoded) => assert_eq!(decoded, &original), x => panic!("incorrect frame {x:?}"), } } #[test] fn immediate_ack_coding() { let mut buf = Vec::new(); FrameType::IMMEDIATE_ACK.encode(&mut buf); let frames = frames(buf); assert_eq!(frames.len(), 1); assert_matches!(&frames[0], Frame::ImmediateAck); } } quinn-proto-0.11.9/src/lib.rs000064400000000000000000000225661046102023000141250ustar 00000000000000//! Low-level protocol logic for the QUIC protoocol //! //! quinn-proto contains a fully deterministic implementation of QUIC protocol logic. It contains //! no networking code and does not get any relevant timestamps from the operating system. Most //! users may want to use the futures-based quinn API instead. //! //! The quinn-proto API might be of interest if you want to use it from a C or C++ project //! through C bindings or if you want to use a different event loop than the one tokio provides. //! //! The most important types are `Endpoint`, which conceptually represents the protocol state for //! a single socket and mostly manages configuration and dispatches incoming datagrams to the //! related `Connection`. `Connection` types contain the bulk of the protocol logic related to //! managing a single connection and all the related state (such as streams). #![cfg_attr(not(fuzzing), warn(missing_docs))] #![cfg_attr(test, allow(dead_code))] // Fixes welcome: #![warn(unreachable_pub)] #![allow(clippy::cognitive_complexity)] #![allow(clippy::too_many_arguments)] #![warn(clippy::use_self)] use std::{ fmt, net::{IpAddr, SocketAddr}, ops, }; mod cid_queue; #[doc(hidden)] pub mod coding; mod constant_time; mod range_set; #[cfg(all(test, any(feature = "rustls-aws-lc-rs", feature = "rustls-ring")))] mod tests; pub mod transport_parameters; mod varint; pub use varint::{VarInt, VarIntBoundsExceeded}; mod connection; pub use crate::connection::{ BytesSource, Chunk, Chunks, ClosedStream, Connection, ConnectionError, ConnectionStats, Datagrams, Event, FinishError, FrameStats, PathStats, ReadError, ReadableError, RecvStream, RttEstimator, SendDatagramError, SendStream, ShouldTransmit, StreamEvent, Streams, UdpStats, WriteError, Written, }; #[cfg(feature = "rustls")] pub use rustls; mod config; pub use config::{ AckFrequencyConfig, ClientConfig, ConfigError, EndpointConfig, IdleTimeout, MtuDiscoveryConfig, ServerConfig, TransportConfig, }; pub mod crypto; mod frame; use crate::frame::Frame; pub use crate::frame::{ApplicationClose, ConnectionClose, Datagram, FrameType}; mod endpoint; pub use crate::endpoint::{ AcceptError, ConnectError, ConnectionHandle, DatagramEvent, Endpoint, Incoming, RetryError, }; mod packet; pub use packet::{ ConnectionIdParser, FixedLengthConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader, ProtectedInitialHeader, }; mod shared; pub use crate::shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent}; mod transport_error; pub use crate::transport_error::{Code as TransportErrorCode, Error as TransportError}; pub mod congestion; mod cid_generator; pub use crate::cid_generator::{ ConnectionIdGenerator, HashedConnectionIdGenerator, InvalidCid, RandomConnectionIdGenerator, }; mod token; use token::{ResetToken, RetryToken}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; // Deal with time #[cfg(not(all(target_family = "wasm", target_os = "unknown")))] pub(crate) use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; #[cfg(all(target_family = "wasm", target_os = "unknown"))] pub(crate) use web_time::{Duration, Instant, SystemTime, UNIX_EPOCH}; #[doc(hidden)] #[cfg(fuzzing)] pub mod fuzzing { pub use crate::connection::{Retransmits, State as ConnectionState, StreamsState}; pub use crate::frame::ResetStream; pub use crate::packet::PartialDecode; pub use crate::transport_parameters::TransportParameters; pub use bytes::{BufMut, BytesMut}; #[cfg(feature = "arbitrary")] use arbitrary::{Arbitrary, Result, Unstructured}; #[cfg(feature = "arbitrary")] impl<'arbitrary> Arbitrary<'arbitrary> for TransportParameters { fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result { Ok(Self { initial_max_streams_bidi: u.arbitrary()?, initial_max_streams_uni: u.arbitrary()?, ack_delay_exponent: u.arbitrary()?, max_udp_payload_size: u.arbitrary()?, ..Self::default() }) } } #[derive(Debug)] pub struct PacketParams { pub local_cid_len: usize, pub buf: BytesMut, pub grease_quic_bit: bool, } #[cfg(feature = "arbitrary")] impl<'arbitrary> Arbitrary<'arbitrary> for PacketParams { fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result { let local_cid_len: usize = u.int_in_range(0..=crate::MAX_CID_SIZE)?; let bytes: Vec = Vec::arbitrary(u)?; let mut buf = BytesMut::new(); buf.put_slice(&bytes[..]); Ok(Self { local_cid_len, buf, grease_quic_bit: bool::arbitrary(u)?, }) } } } /// The QUIC protocol version implemented. pub const DEFAULT_SUPPORTED_VERSIONS: &[u32] = &[ 0x00000001, 0xff00_001d, 0xff00_001e, 0xff00_001f, 0xff00_0020, 0xff00_0021, 0xff00_0022, ]; /// Whether an endpoint was the initiator of a connection #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Side { /// The initiator of a connection Client = 0, /// The acceptor of a connection Server = 1, } impl Side { #[inline] /// Shorthand for `self == Side::Client` pub fn is_client(self) -> bool { self == Self::Client } #[inline] /// Shorthand for `self == Side::Server` pub fn is_server(self) -> bool { self == Self::Server } } impl ops::Not for Side { type Output = Self; fn not(self) -> Self { match self { Self::Client => Self::Server, Self::Server => Self::Client, } } } /// Whether a stream communicates data in both directions or only from the initiator #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Dir { /// Data flows in both directions Bi = 0, /// Data flows only from the stream's initiator Uni = 1, } impl Dir { fn iter() -> impl Iterator { [Self::Bi, Self::Uni].iter().cloned() } } impl fmt::Display for Dir { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::Dir::*; f.pad(match *self { Bi => "bidirectional", Uni => "unidirectional", }) } } /// Identifier for a stream within a particular connection #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct StreamId(#[doc(hidden)] pub u64); impl fmt::Display for StreamId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let initiator = match self.initiator() { Side::Client => "client", Side::Server => "server", }; let dir = match self.dir() { Dir::Uni => "uni", Dir::Bi => "bi", }; write!( f, "{} {}directional stream {}", initiator, dir, self.index() ) } } impl StreamId { /// Create a new StreamId pub fn new(initiator: Side, dir: Dir, index: u64) -> Self { Self(index << 2 | (dir as u64) << 1 | initiator as u64) } /// Which side of a connection initiated the stream pub fn initiator(self) -> Side { if self.0 & 0x1 == 0 { Side::Client } else { Side::Server } } /// Which directions data flows in pub fn dir(self) -> Dir { if self.0 & 0x2 == 0 { Dir::Bi } else { Dir::Uni } } /// Distinguishes streams of the same initiator and directionality pub fn index(self) -> u64 { self.0 >> 2 } } impl From for VarInt { fn from(x: StreamId) -> Self { unsafe { Self::from_u64_unchecked(x.0) } } } impl From for StreamId { fn from(v: VarInt) -> Self { Self(v.0) } } impl coding::Codec for StreamId { fn decode(buf: &mut B) -> coding::Result { VarInt::decode(buf).map(|x| Self(x.into_inner())) } fn encode(&self, buf: &mut B) { VarInt::from_u64(self.0).unwrap().encode(buf); } } /// An outgoing packet #[derive(Debug)] #[must_use] pub struct Transmit { /// The socket this datagram should be sent to pub destination: SocketAddr, /// Explicit congestion notification bits to set on the packet pub ecn: Option, /// Amount of data written to the caller-supplied buffer pub size: usize, /// The segment size if this transmission contains multiple datagrams. /// This is `None` if the transmit only contains a single datagram pub segment_size: Option, /// Optional source IP address for the datagram pub src_ip: Option, } // // Useful internal constants // /// The maximum number of CIDs we bother to issue per connection const LOC_CID_COUNT: u64 = 8; const RESET_TOKEN_SIZE: usize = 16; const MAX_CID_SIZE: usize = 20; const MIN_INITIAL_SIZE: u16 = 1200; /// const INITIAL_MTU: u16 = 1200; const MAX_UDP_PAYLOAD: u16 = 65527; const TIMER_GRANULARITY: Duration = Duration::from_millis(1); /// Maximum number of streams that can be uniquely identified by a stream ID const MAX_STREAM_COUNT: u64 = 1 << 60; quinn-proto-0.11.9/src/packet.rs000064400000000000000000000746071046102023000146310ustar 00000000000000use std::{cmp::Ordering, io, ops::Range, str}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use thiserror::Error; use crate::{ coding::{self, BufExt, BufMutExt}, crypto, ConnectionId, }; /// Decodes a QUIC packet's invariant header /// /// Due to packet number encryption, it is impossible to fully decode a header /// (which includes a variable-length packet number) without crypto context. /// The crypto context (represented by the `Crypto` type in Quinn) is usually /// part of the `Connection`, or can be derived from the destination CID for /// Initial packets. /// /// To cope with this, we decode the invariant header (which should be stable /// across QUIC versions), which gives us the destination CID and allows us /// to inspect the version and packet type (which depends on the version). /// This information allows us to fully decode and decrypt the packet. #[cfg_attr(test, derive(Clone))] #[derive(Debug)] pub struct PartialDecode { plain_header: ProtectedHeader, buf: io::Cursor, } #[allow(clippy::len_without_is_empty)] impl PartialDecode { /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet pub fn new( bytes: BytesMut, cid_parser: &(impl ConnectionIdParser + ?Sized), supported_versions: &[u32], grease_quic_bit: bool, ) -> Result<(Self, Option), PacketDecodeError> { let mut buf = io::Cursor::new(bytes); let plain_header = ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?; let dgram_len = buf.get_ref().len(); let packet_len = plain_header .payload_len() .map(|len| (buf.position() + len) as usize) .unwrap_or(dgram_len); match dgram_len.cmp(&packet_len) { Ordering::Equal => Ok((Self { plain_header, buf }, None)), Ordering::Less => Err(PacketDecodeError::InvalidHeader( "packet too short to contain payload length", )), Ordering::Greater => { let rest = Some(buf.get_mut().split_off(packet_len)); Ok((Self { plain_header, buf }, rest)) } } } /// The underlying partially-decoded packet data pub(crate) fn data(&self) -> &[u8] { self.buf.get_ref() } pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> { self.plain_header.as_initial() } pub(crate) fn has_long_header(&self) -> bool { !matches!(self.plain_header, ProtectedHeader::Short { .. }) } pub(crate) fn is_initial(&self) -> bool { self.space() == Some(SpaceId::Initial) } pub(crate) fn space(&self) -> Option { use self::ProtectedHeader::*; match self.plain_header { Initial { .. } => Some(SpaceId::Initial), Long { ty: LongType::Handshake, .. } => Some(SpaceId::Handshake), Long { ty: LongType::ZeroRtt, .. } => Some(SpaceId::Data), Short { .. } => Some(SpaceId::Data), _ => None, } } pub(crate) fn is_0rtt(&self) -> bool { match self.plain_header { ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt, _ => false, } } /// The destination connection ID of the packet pub fn dst_cid(&self) -> &ConnectionId { self.plain_header.dst_cid() } /// Length of QUIC packet being decoded #[allow(unreachable_pub)] // fuzzing only pub fn len(&self) -> usize { self.buf.get_ref().len() } pub(crate) fn finish( self, header_crypto: Option<&dyn crypto::HeaderKey>, ) -> Result { use self::ProtectedHeader::*; let Self { plain_header, mut buf, } = self; if let Initial(ProtectedInitialHeader { dst_cid, src_cid, token_pos, version, .. }) = plain_header { let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?; let header_len = buf.position() as usize; let mut bytes = buf.into_inner(); let header_data = bytes.split_to(header_len).freeze(); let token = header_data.slice(token_pos.start..token_pos.end); return Ok(Packet { header: Header::Initial(InitialHeader { dst_cid, src_cid, token, number, version, }), header_data, payload: bytes, }); } let header = match plain_header { Long { ty, dst_cid, src_cid, version, .. } => Header::Long { ty, dst_cid, src_cid, number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?, version, }, Retry { dst_cid, src_cid, version, } => Header::Retry { dst_cid, src_cid, version, }, Short { spin, dst_cid, .. } => { let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?; let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0; Header::Short { spin, key_phase, dst_cid, number, } } VersionNegotiate { random, dst_cid, src_cid, } => Header::VersionNegotiate { random, dst_cid, src_cid, }, Initial { .. } => unreachable!(), }; let header_len = buf.position() as usize; let mut bytes = buf.into_inner(); Ok(Packet { header, header_data: bytes.split_to(header_len).freeze(), payload: bytes, }) } fn decrypt_header( buf: &mut io::Cursor, header_crypto: &dyn crypto::HeaderKey, ) -> Result { let packet_length = buf.get_ref().len(); let pn_offset = buf.position() as usize; if packet_length < pn_offset + 4 + header_crypto.sample_size() { return Err(PacketDecodeError::InvalidHeader( "packet too short to extract header protection sample", )); } header_crypto.decrypt(pn_offset, buf.get_mut()); let len = PacketNumber::decode_len(buf.get_ref()[0]); PacketNumber::decode(len, buf) } } pub(crate) struct Packet { pub(crate) header: Header, pub(crate) header_data: Bytes, pub(crate) payload: BytesMut, } impl Packet { pub(crate) fn reserved_bits_valid(&self) -> bool { let mask = match self.header { Header::Short { .. } => SHORT_RESERVED_BITS, _ => LONG_RESERVED_BITS, }; self.header_data[0] & mask == 0 } } pub(crate) struct InitialPacket { pub(crate) header: InitialHeader, pub(crate) header_data: Bytes, pub(crate) payload: BytesMut, } impl From for Packet { fn from(x: InitialPacket) -> Self { Self { header: Header::Initial(x.header), header_data: x.header_data, payload: x.payload, } } } #[cfg_attr(test, derive(Clone))] #[derive(Debug)] pub(crate) enum Header { Initial(InitialHeader), Long { ty: LongType, dst_cid: ConnectionId, src_cid: ConnectionId, number: PacketNumber, version: u32, }, Retry { dst_cid: ConnectionId, src_cid: ConnectionId, version: u32, }, Short { spin: bool, key_phase: bool, dst_cid: ConnectionId, number: PacketNumber, }, VersionNegotiate { random: u8, src_cid: ConnectionId, dst_cid: ConnectionId, }, } impl Header { pub(crate) fn encode(&self, w: &mut Vec) -> PartialEncode { use self::Header::*; let start = w.len(); match *self { Initial(InitialHeader { ref dst_cid, ref src_cid, ref token, number, version, }) => { w.write(u8::from(LongHeaderType::Initial) | number.tag()); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); w.write_var(token.len() as u64); w.put_slice(token); w.write::(0); // Placeholder for payload length; see `set_payload_length` number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), true)), } } Long { ty, ref dst_cid, ref src_cid, number, version, } => { w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag()); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); w.write::(0); // Placeholder for payload length; see `set_payload_length` number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), true)), } } Retry { ref dst_cid, ref src_cid, version, } => { w.write(u8::from(LongHeaderType::Retry)); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); PartialEncode { start, header_len: w.len() - start, pn: None, } } Short { spin, key_phase, ref dst_cid, number, } => { w.write( FIXED_BIT | if key_phase { KEY_PHASE_BIT } else { 0 } | if spin { SPIN_BIT } else { 0 } | number.tag(), ); w.put_slice(dst_cid); number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), false)), } } VersionNegotiate { ref random, ref dst_cid, ref src_cid, } => { w.write(0x80u8 | random); w.write::(0); dst_cid.encode_long(w); src_cid.encode_long(w); PartialEncode { start, header_len: w.len() - start, pn: None, } } } } /// Whether the packet is encrypted on the wire pub(crate) fn is_protected(&self) -> bool { !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. }) } pub(crate) fn number(&self) -> Option { use self::Header::*; Some(match *self { Initial(InitialHeader { number, .. }) => number, Long { number, .. } => number, Short { number, .. } => number, _ => { return None; } }) } pub(crate) fn space(&self) -> SpaceId { use self::Header::*; match *self { Short { .. } => SpaceId::Data, Long { ty: LongType::ZeroRtt, .. } => SpaceId::Data, Long { ty: LongType::Handshake, .. } => SpaceId::Handshake, _ => SpaceId::Initial, } } pub(crate) fn key_phase(&self) -> bool { match *self { Self::Short { key_phase, .. } => key_phase, _ => false, } } pub(crate) fn is_short(&self) -> bool { matches!(*self, Self::Short { .. }) } pub(crate) fn is_1rtt(&self) -> bool { self.is_short() } pub(crate) fn is_0rtt(&self) -> bool { matches!( *self, Self::Long { ty: LongType::ZeroRtt, .. } ) } pub(crate) fn dst_cid(&self) -> &ConnectionId { use self::Header::*; match *self { Initial(InitialHeader { ref dst_cid, .. }) => dst_cid, Long { ref dst_cid, .. } => dst_cid, Retry { ref dst_cid, .. } => dst_cid, Short { ref dst_cid, .. } => dst_cid, VersionNegotiate { ref dst_cid, .. } => dst_cid, } } /// Whether the payload of this packet contains QUIC frames pub(crate) fn has_frames(&self) -> bool { use Header::*; match *self { Initial(_) => true, Long { .. } => true, Retry { .. } => false, Short { .. } => true, VersionNegotiate { .. } => false, } } } pub(crate) struct PartialEncode { pub(crate) start: usize, pub(crate) header_len: usize, // Packet number length, payload length needed pn: Option<(usize, bool)>, } impl PartialEncode { pub(crate) fn finish( self, buf: &mut [u8], header_crypto: &dyn crypto::HeaderKey, crypto: Option<(u64, &dyn crypto::PacketKey)>, ) { let Self { header_len, pn, .. } = self; let (pn_len, write_len) = match pn { Some((pn_len, write_len)) => (pn_len, write_len), None => return, }; let pn_pos = header_len - pn_len; if write_len { let len = buf.len() - header_len + pn_len; assert!(len < 2usize.pow(14)); // Fits in reserved space let mut slice = &mut buf[pn_pos - 2..pn_pos]; slice.put_u16(len as u16 | 0b01 << 14); } if let Some((number, crypto)) = crypto { crypto.encrypt(number, buf, header_len); } debug_assert!( pn_pos + 4 + header_crypto.sample_size() <= buf.len(), "packet must be padded to at least {} bytes for header protection sampling", pn_pos + 4 + header_crypto.sample_size() ); header_crypto.encrypt(pn_pos, buf); } } /// Plain packet header #[derive(Clone, Debug)] pub enum ProtectedHeader { /// An Initial packet header Initial(ProtectedInitialHeader), /// A Long packet header, as used during the handshake Long { /// Type of the Long header packet ty: LongType, /// Destination Connection ID dst_cid: ConnectionId, /// Source Connection ID src_cid: ConnectionId, /// Length of the packet payload len: u64, /// QUIC version version: u32, }, /// A Retry packet header Retry { /// Destination Connection ID dst_cid: ConnectionId, /// Source Connection ID src_cid: ConnectionId, /// QUIC version version: u32, }, /// A short packet header, as used during the data phase Short { /// Spin bit spin: bool, /// Destination Connection ID dst_cid: ConnectionId, }, /// A Version Negotiation packet header VersionNegotiate { /// Random value random: u8, /// Destination Connection ID dst_cid: ConnectionId, /// Source Connection ID src_cid: ConnectionId, }, } impl ProtectedHeader { fn as_initial(&self) -> Option<&ProtectedInitialHeader> { match self { Self::Initial(x) => Some(x), _ => None, } } /// The destination Connection ID of the packet pub fn dst_cid(&self) -> &ConnectionId { use self::ProtectedHeader::*; match self { Initial(header) => &header.dst_cid, Long { dst_cid, .. } => dst_cid, Retry { dst_cid, .. } => dst_cid, Short { dst_cid, .. } => dst_cid, VersionNegotiate { dst_cid, .. } => dst_cid, } } fn payload_len(&self) -> Option { use self::ProtectedHeader::*; match self { Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len), _ => None, } } /// Decode a plain header from given buffer, with given [`ConnectionIdParser`]. pub fn decode( buf: &mut io::Cursor, cid_parser: &(impl ConnectionIdParser + ?Sized), supported_versions: &[u32], grease_quic_bit: bool, ) -> Result { let first = buf.get::()?; if !grease_quic_bit && first & FIXED_BIT == 0 { return Err(PacketDecodeError::InvalidHeader("fixed bit unset")); } if first & LONG_HEADER_FORM == 0 { let spin = first & SPIN_BIT != 0; Ok(Self::Short { spin, dst_cid: cid_parser.parse(buf)?, }) } else { let version = buf.get::()?; let dst_cid = ConnectionId::decode_long(buf) .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?; let src_cid = ConnectionId::decode_long(buf) .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?; // TODO: Support long CIDs for compatibility with future QUIC versions if version == 0 { let random = first & !LONG_HEADER_FORM; return Ok(Self::VersionNegotiate { random, dst_cid, src_cid, }); } if !supported_versions.contains(&version) { return Err(PacketDecodeError::UnsupportedVersion { src_cid, dst_cid, version, }); } match LongHeaderType::from_byte(first)? { LongHeaderType::Initial => { let token_len = buf.get_var()? as usize; let token_start = buf.position() as usize; if token_len > buf.remaining() { return Err(PacketDecodeError::InvalidHeader("token out of bounds")); } buf.advance(token_len); let len = buf.get_var()?; Ok(Self::Initial(ProtectedInitialHeader { dst_cid, src_cid, token_pos: token_start..token_start + token_len, len, version, })) } LongHeaderType::Retry => Ok(Self::Retry { dst_cid, src_cid, version, }), LongHeaderType::Standard(ty) => Ok(Self::Long { ty, dst_cid, src_cid, len: buf.get_var()?, version, }), } } } } /// Header of an Initial packet, before decryption #[derive(Clone, Debug)] pub struct ProtectedInitialHeader { /// Destination Connection ID pub dst_cid: ConnectionId, /// Source Connection ID pub src_cid: ConnectionId, /// The position of a token in the packet buffer pub token_pos: Range, /// Length of the packet payload pub len: u64, /// QUIC version pub version: u32, } #[derive(Clone, Debug)] pub(crate) struct InitialHeader { pub(crate) dst_cid: ConnectionId, pub(crate) src_cid: ConnectionId, pub(crate) token: Bytes, pub(crate) number: PacketNumber, pub(crate) version: u32, } // An encoded packet number #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum PacketNumber { U8(u8), U16(u16), U24(u32), U32(u32), } impl PacketNumber { pub(crate) fn new(n: u64, largest_acked: u64) -> Self { let range = (n - largest_acked) * 2; if range < 1 << 8 { Self::U8(n as u8) } else if range < 1 << 16 { Self::U16(n as u16) } else if range < 1 << 24 { Self::U24(n as u32) } else if range < 1 << 32 { Self::U32(n as u32) } else { panic!("packet number too large to encode") } } pub(crate) fn len(self) -> usize { use self::PacketNumber::*; match self { U8(_) => 1, U16(_) => 2, U24(_) => 3, U32(_) => 4, } } pub(crate) fn encode(self, w: &mut W) { use self::PacketNumber::*; match self { U8(x) => w.write(x), U16(x) => w.write(x), U24(x) => w.put_uint(u64::from(x), 3), U32(x) => w.write(x), } } pub(crate) fn decode(len: usize, r: &mut R) -> Result { use self::PacketNumber::*; let pn = match len { 1 => U8(r.get()?), 2 => U16(r.get()?), 3 => U24(r.get_uint(3) as u32), 4 => U32(r.get()?), _ => unreachable!(), }; Ok(pn) } pub(crate) fn decode_len(tag: u8) -> usize { 1 + (tag & 0x03) as usize } fn tag(self) -> u8 { use self::PacketNumber::*; match self { U8(_) => 0b00, U16(_) => 0b01, U24(_) => 0b10, U32(_) => 0b11, } } pub(crate) fn expand(self, expected: u64) -> u64 { // From Appendix A use self::PacketNumber::*; let truncated = match self { U8(x) => u64::from(x), U16(x) => u64::from(x), U24(x) => u64::from(x), U32(x) => u64::from(x), }; let nbits = self.len() * 8; let win = 1 << nbits; let hwin = win / 2; let mask = win - 1; // The incoming packet number should be greater than expected - hwin and less than or equal // to expected + hwin // // This means we can't just strip the trailing bits from expected and add the truncated // because that might yield a value outside the window. // // The following code calculates a candidate value and makes sure it's within the packet // number window. let candidate = (expected & !mask) | truncated; if expected.checked_sub(hwin).map_or(false, |x| candidate <= x) { candidate + win } else if candidate > expected + hwin && candidate > win { candidate - win } else { candidate } } } /// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length pub struct FixedLengthConnectionIdParser { expected_len: usize, } impl FixedLengthConnectionIdParser { /// Create a new instance of `FixedLengthConnectionIdParser` pub fn new(expected_len: usize) -> Self { Self { expected_len } } } impl ConnectionIdParser for FixedLengthConnectionIdParser { fn parse(&self, buffer: &mut dyn Buf) -> Result { (buffer.remaining() >= self.expected_len) .then(|| ConnectionId::from_buf(buffer, self.expected_len)) .ok_or(PacketDecodeError::InvalidHeader("packet too small")) } } /// Parse connection id in short header packet pub trait ConnectionIdParser { /// Parse a connection id from given buffer fn parse(&self, buf: &mut dyn Buf) -> Result; } /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { Initial, Retry, Standard(LongType), } impl LongHeaderType { fn from_byte(b: u8) -> Result { use self::{LongHeaderType::*, LongType::*}; debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet"); Ok(match (b & 0x30) >> 4 { 0x0 => Initial, 0x1 => Standard(ZeroRtt), 0x2 => Standard(Handshake), 0x3 => Retry, _ => unreachable!(), }) } } impl From for u8 { fn from(ty: LongHeaderType) -> Self { use self::{LongHeaderType::*, LongType::*}; match ty { Initial => LONG_HEADER_FORM | FIXED_BIT, Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4), Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4), Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4), } } } /// Long packet types with uniform header structure #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum LongType { /// Handshake packet Handshake, /// 0-RTT packet ZeroRtt, } /// Packet decode error #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum PacketDecodeError { /// Packet uses a QUIC version that is not supported #[error("unsupported version {version:x}")] UnsupportedVersion { /// Source Connection ID src_cid: ConnectionId, /// Destination Connection ID dst_cid: ConnectionId, /// The version that was unsupported version: u32, }, /// The packet header is invalid #[error("invalid header: {0}")] InvalidHeader(&'static str), } impl From for PacketDecodeError { fn from(_: coding::UnexpectedEnd) -> Self { Self::InvalidHeader("unexpected end of packet") } } pub(crate) const LONG_HEADER_FORM: u8 = 0x80; pub(crate) const FIXED_BIT: u8 = 0x40; pub(crate) const SPIN_BIT: u8 = 0x20; const SHORT_RESERVED_BITS: u8 = 0x18; const LONG_RESERVED_BITS: u8 = 0x0c; const KEY_PHASE_BIT: u8 = 0x04; /// Packet number space identifiers #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum SpaceId { /// Unprotected packets, used to bootstrap the handshake Initial = 0, Handshake = 1, /// Application data space, used for 0-RTT and post-handshake/1-RTT packets Data = 2, } impl SpaceId { pub fn iter() -> impl Iterator { [Self::Initial, Self::Handshake, Self::Data].iter().cloned() } } #[cfg(test)] mod tests { use super::*; use hex_literal::hex; use std::io; fn check_pn(typed: PacketNumber, encoded: &[u8]) { let mut buf = Vec::new(); typed.encode(&mut buf); assert_eq!(&buf[..], encoded); let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap(); assert_eq!(typed, decoded); } #[test] fn roundtrip_packet_numbers() { check_pn(PacketNumber::U8(0x7f), &hex!("7f")); check_pn(PacketNumber::U16(0x80), &hex!("0080")); check_pn(PacketNumber::U16(0x3fff), &hex!("3fff")); check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000")); check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff")); } #[test] fn pn_encode() { check_pn(PacketNumber::new(0x10, 0), &hex!("10")); check_pn(PacketNumber::new(0x100, 0), &hex!("0100")); check_pn(PacketNumber::new(0x10000, 0), &hex!("010000")); } #[test] fn pn_expand_roundtrip() { for expected in 0..1024 { for actual in expected..1024 { assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected)); } } } #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] #[test] fn header_encoding() { use crate::crypto::rustls::{initial_keys, initial_suite_from_provider}; use crate::Side; #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))] use rustls::crypto::aws_lc_rs::default_provider; #[cfg(feature = "rustls-ring")] use rustls::crypto::ring::default_provider; use rustls::quic::Version; let dcid = ConnectionId::new(&hex!("06b858ec6f80452b")); let provider = default_provider(); let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap(); let client = initial_keys(Version::V1, &dcid, Side::Client, &suite); let mut buf = Vec::new(); let header = Header::Initial(InitialHeader { number: PacketNumber::U8(0), src_cid: ConnectionId::new(&[]), dst_cid: dcid, token: Bytes::new(), version: crate::DEFAULT_SUPPORTED_VERSIONS[0], }); let encode = header.encode(&mut buf); let header_len = buf.len(); buf.resize(header_len + 16 + client.packet.local.tag_len(), 0); encode.finish( &mut buf, &*client.header.local, Some((0, &*client.packet.local)), ); for byte in &buf { print!("{byte:02x}"); } println!(); assert_eq!( buf[..], hex!( "c8000000010806b858ec6f80452b00004021be 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1" )[..] ); let server = initial_keys(Version::V1, &dcid, Side::Server, &suite); let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new( buf.as_slice().into(), &FixedLengthConnectionIdParser::new(0), &supported_versions, false, ) .unwrap() .0; let mut packet = decode.finish(Some(&*server.header.remote)).unwrap(); assert_eq!( packet.header_data[..], hex!("c0000000010806b858ec6f80452b0000402100")[..] ); server .packet .remote .decrypt(0, &packet.header_data, &mut packet.payload) .unwrap(); assert_eq!(packet.payload[..], [0; 16]); match packet.header { Header::Initial(InitialHeader { number: PacketNumber::U8(0), .. }) => {} _ => { panic!("unexpected header {:?}", packet.header); } } } } quinn-proto-0.11.9/src/range_set/array_range_set.rs000064400000000000000000000147041046102023000204660ustar 00000000000000use std::ops::Range; use tinyvec::TinyVec; /// A set of u64 values optimized for long runs and random insert/delete/contains /// /// `ArrayRangeSet` uses an array representation, where each array entry represents /// a range. /// /// The array-based RangeSet provides 2 benefits: /// - There exists an inline representation, which avoids the need of heap /// allocating ACK ranges for SentFrames for small ranges. /// - Iterating over ranges should usually be faster since there is only /// a single cache-friendly contiguous range. /// /// `ArrayRangeSet` is especially useful for tracking ACK ranges where the amount /// of ranges is usually very low (since ACK numbers are in consecutive fashion /// unless reordering or packet loss occur). #[derive(Debug, Default)] pub struct ArrayRangeSet(TinyVec<[Range; ARRAY_RANGE_SET_INLINE_CAPACITY]>); /// The capacity of elements directly stored in [`ArrayRangeSet`] /// /// An inline capacity of 2 is chosen to keep `SentFrame` below 128 bytes. const ARRAY_RANGE_SET_INLINE_CAPACITY: usize = 2; impl Clone for ArrayRangeSet { fn clone(&self) -> Self { // tinyvec keeps the heap representation after clones. // We rather prefer the inline representation for clones if possible, // since clones (e.g. for storage in `SentFrames`) are rarely mutated if self.0.is_inline() || self.0.len() > ARRAY_RANGE_SET_INLINE_CAPACITY { return Self(self.0.clone()); } let mut vec = TinyVec::new(); vec.extend_from_slice(self.0.as_slice()); Self(vec) } } impl ArrayRangeSet { pub fn new() -> Self { Default::default() } pub fn iter(&self) -> impl DoubleEndedIterator> + '_ { self.0.iter().cloned() } pub fn elts(&self) -> impl Iterator + '_ { self.iter().flatten() } pub fn len(&self) -> usize { self.0.len() } pub fn contains(&self, x: u64) -> bool { for range in self.0.iter() { if range.start > x { // We only get here if there was no prior range that contained x return false; } else if range.contains(&x) { return true; } } false } pub fn subtract(&mut self, other: &Self) { // TODO: This can potentially be made more efficient, since the we know // individual ranges are not overlapping, and the next range must start // after the last one finished for range in &other.0 { self.remove(range.clone()); } } pub fn insert_one(&mut self, x: u64) -> bool { self.insert(x..x + 1) } pub fn insert(&mut self, x: Range) -> bool { let mut result = false; if x.is_empty() { // Don't try to deal with ranges where x.end <= x.start return false; } let mut idx = 0; while idx != self.0.len() { let range = &mut self.0[idx]; if range.start > x.end { // The range is fully before this range and therefore not extensible. // Add a new range to the left self.0.insert(idx, x); return true; } else if range.start > x.start { // The new range starts before this range but overlaps. // Extend the current range to the left // Note that we don't have to merge a potential left range, since // this case would have been captured by merging the right range // in the previous loop iteration result = true; range.start = x.start; } // At this point we have handled all parts of the new range which // are in front of the current range. Now we handle everything from // the start of the current range if x.end <= range.end { // Fully contained return result; } else if x.start <= range.end { // Extend the current range to the end of the new range. // Since it's not contained it must be bigger range.end = x.end; // Merge all follow-up ranges which overlap while idx != self.0.len() - 1 { let curr = self.0[idx].clone(); let next = self.0[idx + 1].clone(); if curr.end >= next.start { self.0[idx].end = next.end.max(curr.end); self.0.remove(idx + 1); } else { break; } } return true; } idx += 1; } // Insert a range at the end self.0.push(x); true } pub fn remove(&mut self, x: Range) -> bool { let mut result = false; if x.is_empty() { // Don't try to deal with ranges where x.end <= x.start return false; } let mut idx = 0; while idx != self.0.len() && x.start != x.end { let range = self.0[idx].clone(); if x.end <= range.start { // The range is fully before this range return result; } else if x.start >= range.end { // The range is fully after this range idx += 1; continue; } // The range overlaps with this range result = true; let left = range.start..x.start; let right = x.end..range.end; if left.is_empty() && right.is_empty() { self.0.remove(idx); } else if left.is_empty() { self.0[idx] = right; idx += 1; } else if right.is_empty() { self.0[idx] = left; idx += 1; } else { self.0[idx] = right; self.0.insert(idx, left); idx += 2; } } result } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn pop_min(&mut self) -> Option> { if !self.0.is_empty() { Some(self.0.remove(0)) } else { None } } pub fn min(&self) -> Option { self.iter().next().map(|x| x.start) } pub fn max(&self) -> Option { self.iter().next_back().map(|x| x.end - 1) } } quinn-proto-0.11.9/src/range_set/btree_range_set.rs000064400000000000000000000256121046102023000204510ustar 00000000000000use std::{ cmp, cmp::Ordering, collections::{btree_map, BTreeMap}, ops::{ Bound::{Excluded, Included}, Range, }, }; /// A set of u64 values optimized for long runs and random insert/delete/contains #[derive(Debug, Default, Clone)] pub struct RangeSet(BTreeMap); impl RangeSet { pub fn new() -> Self { Default::default() } pub fn contains(&self, x: u64) -> bool { self.pred(x).map_or(false, |(_, end)| end > x) } pub fn insert_one(&mut self, x: u64) -> bool { if let Some((start, end)) = self.pred(x) { match end.cmp(&x) { // Wholly contained Ordering::Greater => { return false; } Ordering::Equal => { // Extend existing self.0.remove(&start); let mut new_end = x + 1; if let Some((next_start, next_end)) = self.succ(x) { if next_start == new_end { self.0.remove(&next_start); new_end = next_end; } } self.0.insert(start, new_end); return true; } _ => {} } } let mut new_end = x + 1; if let Some((next_start, next_end)) = self.succ(x) { if next_start == new_end { self.0.remove(&next_start); new_end = next_end; } } self.0.insert(x, new_end); true } pub fn insert(&mut self, mut x: Range) -> bool { if x.is_empty() { return false; } if let Some((start, end)) = self.pred(x.start) { if end >= x.end { // Wholly contained return false; } else if end >= x.start { // Extend overlapping predecessor self.0.remove(&start); x.start = start; } } while let Some((next_start, next_end)) = self.succ(x.start) { if next_start > x.end { break; } // Overlaps with successor self.0.remove(&next_start); x.end = cmp::max(next_end, x.end); } self.0.insert(x.start, x.end); true } /// Find closest range to `x` that begins at or before it fn pred(&self, x: u64) -> Option<(u64, u64)> { self.0 .range((Included(0), Included(x))) .next_back() .map(|(&x, &y)| (x, y)) } /// Find the closest range to `x` that begins after it fn succ(&self, x: u64) -> Option<(u64, u64)> { self.0 .range((Excluded(x), Included(u64::MAX))) .next() .map(|(&x, &y)| (x, y)) } pub fn remove(&mut self, x: Range) -> bool { if x.is_empty() { return false; } let before = match self.pred(x.start) { Some((start, end)) if end > x.start => { self.0.remove(&start); if start < x.start { self.0.insert(start, x.start); } if end > x.end { self.0.insert(x.end, end); } // Short-circuit if we cannot possibly overlap with another range if end >= x.end { return true; } true } Some(_) | None => false, }; let mut after = false; while let Some((start, end)) = self.succ(x.start) { if start >= x.end { break; } after = true; self.0.remove(&start); if end > x.end { self.0.insert(x.end, end); break; } } before || after } /// Add a range to the set, returning the intersection of current ranges with the new one pub fn replace(&mut self, mut range: Range) -> Replace<'_> { let pred = if let Some((prev_start, prev_end)) = self .pred(range.start) .filter(|&(_, end)| end >= range.start) { self.0.remove(&prev_start); let replaced_start = range.start; range.start = range.start.min(prev_start); let replaced_end = range.end.min(prev_end); range.end = range.end.max(prev_end); if replaced_start != replaced_end { Some(replaced_start..replaced_end) } else { None } } else { None }; Replace { set: self, range, pred, } } pub fn add(&mut self, other: &Self) { for (&start, &end) in &other.0 { self.insert(start..end); } } pub fn subtract(&mut self, other: &Self) { for (&start, &end) in &other.0 { self.remove(start..end); } } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn min(&self) -> Option { self.0.first_key_value().map(|(&start, _)| start) } pub fn max(&self) -> Option { self.0.last_key_value().map(|(_, &end)| end - 1) } pub fn len(&self) -> usize { self.0.len() } pub fn iter(&self) -> Iter<'_> { Iter(self.0.iter()) } pub fn elts(&self) -> EltIter<'_> { EltIter { inner: self.0.iter(), next: 0, end: 0, } } pub fn peek_min(&self) -> Option> { let (&start, &end) = self.0.iter().next()?; Some(start..end) } pub fn pop_min(&mut self) -> Option> { let result = self.peek_min()?; self.0.remove(&result.start); Some(result) } } pub struct Iter<'a>(btree_map::Iter<'a, u64, u64>); impl Iterator for Iter<'_> { type Item = Range; fn next(&mut self) -> Option> { let (&start, &end) = self.0.next()?; Some(start..end) } } impl DoubleEndedIterator for Iter<'_> { fn next_back(&mut self) -> Option> { let (&start, &end) = self.0.next_back()?; Some(start..end) } } impl<'a> IntoIterator for &'a RangeSet { type Item = Range; type IntoIter = Iter<'a>; fn into_iter(self) -> Iter<'a> { self.iter() } } pub struct EltIter<'a> { inner: btree_map::Iter<'a, u64, u64>, next: u64, end: u64, } impl Iterator for EltIter<'_> { type Item = u64; fn next(&mut self) -> Option { if self.next == self.end { let (&start, &end) = self.inner.next()?; self.next = start; self.end = end; } let x = self.next; self.next += 1; Some(x) } } impl DoubleEndedIterator for EltIter<'_> { fn next_back(&mut self) -> Option { if self.next == self.end { let (&start, &end) = self.inner.next_back()?; self.next = start; self.end = end; } self.end -= 1; Some(self.end) } } /// Iterator returned by `RangeSet::replace` pub struct Replace<'a> { set: &'a mut RangeSet, /// Portion of the intersection arising from a range beginning at or before the newly inserted /// range pred: Option>, /// Union of the input range and all ranges that have been visited by the iterator so far range: Range, } impl Iterator for Replace<'_> { type Item = Range; fn next(&mut self) -> Option> { if let Some(pred) = self.pred.take() { // If a range starting before the inserted range overlapped with it, return the // corresponding overlap first return Some(pred); } let (next_start, next_end) = self.set.succ(self.range.start)?; if next_start > self.range.end { // If the next successor range starts after the current range ends, there can be no more // overlaps. This is sound even when `self.range.end` is increased because `RangeSet` is // guaranteed not to contain pairs of ranges that could be simplified. return None; } // Remove the redundant range... self.set.0.remove(&next_start); // ...and handle the case where the redundant range ends later than the new range. let replaced_end = self.range.end.min(next_end); self.range.end = self.range.end.max(next_end); if next_start == replaced_end { // If the redundant range started exactly where the new range ended, there was no // overlap with it or any later range. None } else { Some(next_start..replaced_end) } } } impl Drop for Replace<'_> { fn drop(&mut self) { // Ensure we drain all remaining overlapping ranges for _ in &mut *self {} // Insert the final aggregate range self.set.0.insert(self.range.start, self.range.end); } } /// This module contains tests which only apply for this `RangeSet` implementation /// /// Tests which apply for all implementations can be found in the `tests.rs` module #[cfg(test)] mod tests { #![allow(clippy::single_range_in_vec_init)] // https://github.com/rust-lang/rust-clippy/issues/11086 use super::*; #[test] fn replace_contained() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(1..5).collect::>(), &[2..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..5); } #[test] fn replace_contains() { let mut set = RangeSet::new(); set.insert(1..5); assert_eq!(set.replace(2..4).collect::>(), &[2..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..5); } #[test] fn replace_pred() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(3..5).collect::>(), &[3..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 2..5); } #[test] fn replace_succ() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(1..3).collect::>(), &[2..3]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..4); } #[test] fn replace_exact_pred() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(4..6).collect::>(), &[]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 2..6); } #[test] fn replace_exact_succ() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(0..2).collect::>(), &[]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 0..4); } } quinn-proto-0.11.9/src/range_set/mod.rs000064400000000000000000000002341046102023000160710ustar 00000000000000mod array_range_set; mod btree_range_set; #[cfg(test)] mod tests; pub(crate) use array_range_set::ArrayRangeSet; pub(crate) use btree_range_set::RangeSet; quinn-proto-0.11.9/src/range_set/tests.rs000064400000000000000000000206641046102023000164650ustar 00000000000000use std::ops::Range; use super::*; macro_rules! common_set_tests { ($set_name:ident, $set_type:ident) => { mod $set_name { use super::*; #[test] fn merge_and_split() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(2..4)); assert!(!set.insert(1..3)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3]); assert!(!set.contains(4)); assert!(set.remove(2..3)); assert_eq!(set.len(), 2); assert!(!set.contains(2)); assert_eq!(&set.elts().collect::>()[..], [0, 1, 3]); } #[test] fn double_merge_exact() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(2..4)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3, 4, 5]); } #[test] fn single_merge_low() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(2..3)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 4, 5]); } #[test] fn single_merge_high() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(3..4)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 1, 3, 4, 5]); } #[test] fn double_merge_wide() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(1..5)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3, 4, 5]); } #[test] fn double_remove() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert!(set.remove(1..5)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 5]); } #[test] fn insert_multiple() { let mut set = $set_type::new(); assert!(set.insert(0..1)); assert!(set.insert(2..3)); assert!(set.insert(4..5)); assert!(set.insert(0..5)); assert_eq!(set.len(), 1); } #[test] fn remove_multiple() { let mut set = $set_type::new(); assert!(set.insert(0..1)); assert!(set.insert(2..3)); assert!(set.insert(4..5)); assert!(set.remove(0..5)); assert!(set.is_empty()); } #[test] fn double_insert() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(!set.insert(0..2)); assert!(set.insert(2..4)); assert!(!set.insert(2..4)); assert!(!set.insert(0..4)); assert!(!set.insert(1..2)); assert!(!set.insert(1..3)); assert!(!set.insert(1..4)); assert_eq!(set.len(), 1); } #[test] fn skip_empty_ranges() { let mut set = $set_type::new(); assert!(!set.insert(2..2)); assert_eq!(set.len(), 0); assert!(!set.insert(4..4)); assert_eq!(set.len(), 0); assert!(!set.insert(0..0)); assert_eq!(set.len(), 0); } #[test] fn compare_insert_to_reference() { const MAX_RANGE: u64 = 50; for start in 0..=MAX_RANGE { for end in 0..=MAX_RANGE { println!("insert({}..{})", start, end); let (mut set, mut reference) = create_initial_sets(MAX_RANGE); assert_eq!(set.insert(start..end), reference.insert(start..end)); assert_sets_equal(&set, &reference); } } } #[test] fn compare_remove_to_reference() { const MAX_RANGE: u64 = 50; for start in 0..=MAX_RANGE { for end in 0..=MAX_RANGE { println!("remove({}..{})", start, end); let (mut set, mut reference) = create_initial_sets(MAX_RANGE); assert_eq!(set.remove(start..end), reference.remove(start..end)); assert_sets_equal(&set, &reference); } } } #[test] fn min_max() { let mut set = $set_type::new(); set.insert(1..3); set.insert(4..5); set.insert(6..10); assert_eq!(set.min(), Some(1)); assert_eq!(set.max(), Some(9)); } fn create_initial_sets(max_range: u64) -> ($set_type, RefRangeSet) { let mut set = $set_type::new(); let mut reference = RefRangeSet::new(max_range as usize); assert_sets_equal(&set, &reference); assert_eq!(set.insert(2..6), reference.insert(2..6)); assert_eq!(set.insert(10..14), reference.insert(10..14)); assert_eq!(set.insert(14..14), reference.insert(14..14)); assert_eq!(set.insert(18..19), reference.insert(18..19)); assert_eq!(set.insert(20..21), reference.insert(20..21)); assert_eq!(set.insert(22..24), reference.insert(22..24)); assert_eq!(set.insert(26..30), reference.insert(26..30)); assert_eq!(set.insert(34..38), reference.insert(34..38)); assert_eq!(set.insert(42..44), reference.insert(42..44)); assert_sets_equal(&set, &reference); (set, reference) } fn assert_sets_equal(set: &$set_type, reference: &RefRangeSet) { assert_eq!(set.len(), reference.len()); assert_eq!(set.is_empty(), reference.is_empty()); assert_eq!(set.elts().collect::>()[..], reference.elts()[..]); } } }; } common_set_tests!(range_set, RangeSet); common_set_tests!(array_range_set, ArrayRangeSet); /// A very simple reference implementation of a RangeSet struct RefRangeSet { data: Vec, } impl RefRangeSet { fn new(capacity: usize) -> Self { Self { data: vec![false; capacity], } } fn len(&self) -> usize { let mut last = false; let mut count = 0; for v in self.data.iter() { if !last && *v { count += 1; } last = *v; } count } fn is_empty(&self) -> bool { self.len() == 0 } fn insert(&mut self, x: Range) -> bool { let mut result = false; assert!(x.end <= self.data.len() as u64); for i in x { let i = i as usize; if !self.data[i] { result = true; self.data[i] = true; } } result } fn remove(&mut self, x: Range) -> bool { let mut result = false; assert!(x.end <= self.data.len() as u64); for i in x { let i = i as usize; if self.data[i] { result = true; self.data[i] = false; } } result } fn elts(&self) -> Vec { self.data .iter() .enumerate() .filter_map(|(i, e)| if *e { Some(i as u64) } else { None }) .collect() } } quinn-proto-0.11.9/src/shared.rs000064400000000000000000000122331046102023000146130ustar 00000000000000use std::{fmt, net::SocketAddr}; use bytes::{Buf, BufMut, BytesMut}; use crate::{coding::BufExt, packet::PartialDecode, Instant, ResetToken, MAX_CID_SIZE}; /// Events sent from an Endpoint to a Connection #[derive(Debug)] pub struct ConnectionEvent(pub(crate) ConnectionEventInner); #[derive(Debug)] pub(crate) enum ConnectionEventInner { /// A datagram has been received for the Connection Datagram(DatagramConnectionEvent), /// New connection identifiers have been issued for the Connection NewIdentifiers(Vec, Instant), } /// Variant of [`ConnectionEventInner`]. #[derive(Debug)] pub(crate) struct DatagramConnectionEvent { pub(crate) now: Instant, pub(crate) remote: SocketAddr, pub(crate) ecn: Option, pub(crate) first_decode: PartialDecode, pub(crate) remaining: Option, } /// Events sent from a Connection to an Endpoint #[derive(Debug)] pub struct EndpointEvent(pub(crate) EndpointEventInner); impl EndpointEvent { /// Construct an event that indicating that a `Connection` will no longer emit events /// /// Useful for notifying an `Endpoint` that a `Connection` has been destroyed outside of the /// usual state machine flow, e.g. when being dropped by the user. pub fn drained() -> Self { Self(EndpointEventInner::Drained) } /// Determine whether this is the last event a `Connection` will emit /// /// Useful for determining when connection-related event loop state can be freed. pub fn is_drained(&self) -> bool { self.0 == EndpointEventInner::Drained } } #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum EndpointEventInner { /// The connection has been drained Drained, /// The reset token and/or address eligible for generating resets has been updated ResetToken(SocketAddr, ResetToken), /// The connection needs connection identifiers NeedIdentifiers(Instant, u64), /// Stop routing connection ID for this sequence number to the connection /// When `bool == true`, a new connection ID will be issued to peer RetireConnectionId(Instant, u64, bool), } /// Protocol-level identifier for a connection. /// /// Mainly useful for identifying this connection's packets on the wire with tools like Wireshark. #[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct ConnectionId { /// length of CID len: u8, /// CID in byte array bytes: [u8; MAX_CID_SIZE], } impl ConnectionId { /// Construct cid from byte array pub fn new(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); let mut res = Self { len: bytes.len() as u8, bytes: [0; MAX_CID_SIZE], }; res.bytes[..bytes.len()].copy_from_slice(bytes); res } /// Constructs cid by reading `len` bytes from a `Buf` /// /// Callers need to assure that `buf.remaining() >= len` pub fn from_buf(buf: &mut (impl Buf + ?Sized), len: usize) -> Self { debug_assert!(len <= MAX_CID_SIZE); let mut res = Self { len: len as u8, bytes: [0; MAX_CID_SIZE], }; buf.copy_to_slice(&mut res[..len]); res } /// Decode from long header format pub(crate) fn decode_long(buf: &mut impl Buf) -> Option { let len = buf.get::().ok()? as usize; match len > MAX_CID_SIZE || buf.remaining() < len { false => Some(Self::from_buf(buf, len)), true => None, } } /// Encode in long header format pub(crate) fn encode_long(&self, buf: &mut impl BufMut) { buf.put_u8(self.len() as u8); buf.put_slice(self); } } impl ::std::ops::Deref for ConnectionId { type Target = [u8]; fn deref(&self) -> &[u8] { &self.bytes[0..self.len as usize] } } impl ::std::ops::DerefMut for ConnectionId { fn deref_mut(&mut self) -> &mut [u8] { &mut self.bytes[0..self.len as usize] } } impl fmt::Debug for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.bytes[0..self.len as usize].fmt(f) } } impl fmt::Display for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for byte in self.iter() { write!(f, "{byte:02x}")?; } Ok(()) } } /// Explicit congestion notification codepoint #[repr(u8)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum EcnCodepoint { #[doc(hidden)] Ect0 = 0b10, #[doc(hidden)] Ect1 = 0b01, #[doc(hidden)] Ce = 0b11, } impl EcnCodepoint { /// Create new object from the given bits pub fn from_bits(x: u8) -> Option { use self::EcnCodepoint::*; Some(match x & 0b11 { 0b10 => Ect0, 0b01 => Ect1, 0b11 => Ce, _ => { return None; } }) } /// Returns whether the codepoint is a CE, signalling that congestion was experienced pub fn is_ce(self) -> bool { matches!(self, Self::Ce) } } #[derive(Debug, Copy, Clone)] pub(crate) struct IssuedCid { pub(crate) sequence: u64, pub(crate) id: ConnectionId, pub(crate) reset_token: ResetToken, } quinn-proto-0.11.9/src/tests/mod.rs000064400000000000000000003221641046102023000152750ustar 00000000000000use std::{ convert::TryInto, mem, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, }; use assert_matches::assert_matches; #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::hmac; use bytes::{Bytes, BytesMut}; use hex_literal::hex; use rand::RngCore; #[cfg(feature = "ring")] use ring::hmac; #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))] use rustls::crypto::aws_lc_rs::default_provider; #[cfg(feature = "rustls-ring")] use rustls::crypto::ring::default_provider; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, server::WebPkiClientVerifier, AlertDescription, RootCertStore, }; use tracing::info; use super::*; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, crypto::rustls::QuicServerConfig, frame::FrameStruct, transport_parameters::TransportParameters, Duration, Instant, }; mod util; use util::*; #[cfg(all(target_family = "wasm", target_os = "unknown"))] use wasm_bindgen_test::wasm_bindgen_test as test; // Enable this if you want to run these tests in the browser. // Unfortunately it's either-or: Enable this and you can run in the browser, disable to run in nodejs. // #[cfg(all(target_family = "wasm", target_os = "unknown"))] // wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); #[test] fn version_negotiate_server() { let _guard = subscribe(); let client_addr = "[::2]:7890".parse().unwrap(); let mut server = Endpoint::new( Default::default(), Some(Arc::new(server_config())), true, None, ); let now = Instant::now(); let mut buf = Vec::with_capacity(server.config().get_max_udp_payload_size() as usize); let event = server.handle( now, client_addr, None, None, // Long-header packet with reserved version number hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(), &mut buf, ); let Some(DatagramEvent::Response(Transmit { .. })) = event else { panic!("expected a response"); }; assert_ne!(buf[0] & 0x80, 0); assert_eq!(&buf[1..15], hex!("00000000 04 00000000 04 00000000")); assert!(buf[15..].chunks(4).any(|x| { DEFAULT_SUPPORTED_VERSIONS.contains(&u32::from_be_bytes(x.try_into().unwrap())) })); } #[test] fn version_negotiate_client() { let _guard = subscribe(); let server_addr = "[::2]:7890".parse().unwrap(); // Configure client to use empty CIDs so we can easily hardcode a server version negotiation // packet let cid_generator_factory: fn() -> Box = || Box::new(RandomConnectionIdGenerator::new(0)); let mut client = Endpoint::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..Default::default() }), None, true, None, ); let (_, mut client_ch) = client .connect(Instant::now(), client_config(), server_addr, "localhost") .unwrap(); let now = Instant::now(); let mut buf = Vec::with_capacity(client.config().get_max_udp_payload_size() as usize); let opt_event = client.handle( now, server_addr, None, None, // Version negotiation packet for reserved version, with empty DCID hex!( "80 00000000 00 04 00000000 0a1a2a3a" )[..] .into(), &mut buf, ); if let Some(DatagramEvent::ConnectionEvent(_, event)) = opt_event { client_ch.handle_event(event); } assert_matches!( client_ch.poll(), Some(Event::ConnectionLost { reason: ConnectionError::VersionMismatch, }) ); } #[test] fn lifecycle() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); const REASON: &[u8] = b"whee"; info!("closing"); pair.client.connections.get_mut(&client_ch).unwrap().close( pair.time, VarInt(42), REASON.into(), ); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ApplicationClosed( ApplicationClose { error_code: VarInt(42), ref reason } )}) if reason == REASON); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn draft_version_compat() { let _guard = subscribe(); let mut client_config = client_config(); client_config.version(0xff00_0020); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect_with(client_config); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); const REASON: &[u8] = b"whee"; info!("closing"); pair.client.connections.get_mut(&client_ch).unwrap().close( pair.time, VarInt(42), REASON.into(), ); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ApplicationClosed( ApplicationClose { error_code: VarInt(42), ref reason } )}) if reason == REASON); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn stateless_retry() { let _guard = subscribe(); let mut pair = Pair::default(); pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; let (client_ch, _server_ch) = pair.connect(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn server_stateless_reset() { let _guard = subscribe(); let mut key_material = vec![0; 64]; let mut rng = rand::thread_rng(); rng.fill_bytes(&mut key_material); let reset_key = hmac::Key::new(hmac::HMAC_SHA256, &key_material); rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); let (client_ch, _) = pair.connect(); pair.drive(); // Flush any post-handshake frames pair.server.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true, None); // Force the server to generate the smallest possible stateless reset pair.client.connections.get_mut(&client_ch).unwrap().ping(); info!("resetting"); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::Reset }) ); } #[test] fn client_stateless_reset() { let _guard = subscribe(); let mut key_material = vec![0; 64]; let mut rng = rand::thread_rng(); rng.fill_bytes(&mut key_material); let reset_key = hmac::Key::new(hmac::HMAC_SHA256, &key_material); rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); let (_, server_ch) = pair.connect(); pair.client.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true, None); // Send something big enough to allow room for a smaller stateless reset. pair.server.connections.get_mut(&server_ch).unwrap().close( pair.time, VarInt(42), (&[0xab; 128][..]).into(), ); info!("resetting"); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::Reset }) ); } /// Verify that stateless resets are rate-limited #[test] fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); endpoint_config.cid_generator(move || Box::new(RandomConnectionIdGenerator::new(8))); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), Some(Arc::new(server_config())), true, None, ); let time = Instant::now(); let mut buf = Vec::new(); let event = endpoint.handle(time, remote, None, None, [0u8; 1024][..].into(), &mut buf); assert!(matches!(event, Some(DatagramEvent::Response(_)))); let event = endpoint.handle(time, remote, None, None, [0u8; 1024][..].into(), &mut buf); assert!(event.is_none()); let event = endpoint.handle( time + endpoint_config.min_reset_interval - Duration::from_nanos(1), remote, None, None, [0u8; 1024][..].into(), &mut buf, ); assert!(event.is_none()); let event = endpoint.handle( time + endpoint_config.min_reset_interval, remote, None, None, [0u8; 1024][..].into(), &mut buf, ); assert!(matches!(event, Some(DatagramEvent::Response(_)))); } #[test] fn export_keying_material() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); const LABEL: &[u8] = b"test_label"; const CONTEXT: &[u8] = b"test_context"; // client keying material let mut client_buf = [0u8; 64]; pair.client_conn_mut(client_ch) .crypto_session() .export_keying_material(&mut client_buf, LABEL, CONTEXT) .unwrap(); // server keying material let mut server_buf = [0u8; 64]; pair.server_conn_mut(server_ch) .crypto_session() .export_keying_material(&mut server_buf, LABEL, CONTEXT) .unwrap(); assert_eq!(&client_buf[..], &server_buf[..]); } #[test] fn finish_stream_simple() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); assert_eq!(pair.client_streams(client_ch).send_streams(), 1); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client_streams(client_ch).send_streams(), 0); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); // Receive-only streams do not get `StreamFinished` events assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn reset_stream() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("resetting stream"); const ERROR: VarInt = VarInt(42); pair.client_send(client_ch, s).reset(ERROR).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Reset(ERROR))); let _ = chunks.finalize(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); } #[test] fn stop_stream() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("stopping stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!( pair.client_send(client_ch, s).write(b"foo"), Err(WriteError::Stopped(ERROR)) ); assert_matches!( pair.client_send(client_ch, s).finish(), Err(FinishError::Stopped(ERROR)) ); } #[test] fn reject_self_signed_server_cert() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); // Create a self-signed certificate with a different distinguished name than the default one, // such that path building cannot confuse the default root the server is using and the one // the client is trusting (in which case we'd get a different error). let mut cert = rcgen::CertificateParams::new(["localhost".into()]).unwrap(); let mut issuer = rcgen::DistinguishedName::new(); issuer.push( rcgen::DnType::OrganizationName, "Crazy Quinn's House of Certificates", ); cert.distinguished_name = issuer; let cert = cert .self_signed(&rcgen::KeyPair::generate().unwrap()) .unwrap(); let client_ch = pair.begin_connect(client_config_with_certs(vec![cert.into()])); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TransportError(ref error)}) if error.code == TransportErrorCode::crypto(AlertDescription::UnknownCA.into())); } #[test] fn reject_missing_client_cert() { let _guard = subscribe(); let mut store = RootCertStore::empty(); // `WebPkiClientVerifier` requires a non-empty store, so we stick our own certificate into it // because it's convenient. store.add(CERTIFIED_KEY.cert.der().clone()).unwrap(); let key = PrivatePkcs8KeyDer::from(CERTIFIED_KEY.key_pair.serialize_der()); let cert = CERTIFIED_KEY.cert.der().clone(); let provider = Arc::new(default_provider()); let config = rustls::ServerConfig::builder_with_provider(provider.clone()) .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() .with_client_cert_verifier( WebPkiClientVerifier::builder_with_provider(Arc::new(store), provider) .build() .unwrap(), ) .with_single_cert(vec![cert], PrivateKeyDer::from(key)) .unwrap(); let config = QuicServerConfig::try_from(config).unwrap(); let mut pair = Pair::new( Default::default(), ServerConfig::with_crypto(Arc::new(config)), ); info!("connecting"); let client_ch = pair.begin_connect(client_config()); pair.drive(); // The client completes the connection, but finds it immediately closed assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ref close)}) if close.error_code == TransportErrorCode::crypto(AlertDescription::CertificateRequired.into())); // The server never completes the connection let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TransportError(ref error)}) if error.code == TransportErrorCode::crypto(AlertDescription::CertificateRequired.into())); } #[test] fn congestion() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); const TARGET: u64 = 2048; assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); // Send data without receiving ACKs until the congestion state falls below target while pair.client_conn_mut(client_ch).congestion_window() > TARGET { let n = pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); assert_eq!(n, 1024); pair.drive_client(); } // Ensure that the congestion state recovers after receiving the ACKs pair.drive(); assert!(pair.client_conn_mut(client_ch).congestion_window() >= TARGET); pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); } #[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 #[test] fn high_latency_handshake() { let _guard = subscribe(); let mut pair = Pair::default(); pair.latency = Duration::from_micros(200 * 1000); let (client_ch, server_ch) = pair.connect(); assert_eq!(pair.client_conn_mut(client_ch).bytes_in_flight(), 0); assert_eq!(pair.server_conn_mut(server_ch).bytes_in_flight(), 0); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); } #[test] fn zero_rtt_happypath() { let _guard = subscribe(); let mut pair = Pair::default(); pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; let config = client_config(); // Establish normal connection let client_ch = pair.begin_connect(config.clone()); pair.drive(); pair.server.assert_accept(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), [][..].into()); pair.drive(); pair.client.addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); info!("resuming session"); let client_ch = pair.begin_connect(config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"Hello, 0-RTT!"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected) ); assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); // We don't currently preserve stream event order wrt. connection events assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } #[test] fn zero_rtt_rejection() { let _guard = subscribe(); let server_config = ServerConfig::with_crypto(Arc::new(server_crypto_with_alpn(vec![ "foo".into(), "bar".into(), ]))); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_crypto = Arc::new(client_crypto_with_alpn(vec!["foo".into()])); let client_config = ClientConfig::new(client_crypto.clone()); // Establish normal connection let client_ch = pair.begin_connect(client_config); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), [][..].into()); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { .. }) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); pair.client.connections.clear(); pair.server.connections.clear(); // We want to have a TLS client config with the existing session cache (so resumption could // happen), but with different ALPN protocols (so that the server must reject it). Reuse // the existing `ClientConfig` and change the ALPN protocols to make that happen. let this = Arc::get_mut(&mut client_crypto).expect("QuicClientConfig is shared"); let inner = Arc::get_mut(&mut this.inner).expect("QuicClientConfig.inner is shared"); inner.alpn_protocols = vec!["bar".into()]; // Changing protocols invalidates 0-RTT let client_config = ClientConfig::new(client_crypto); info!("resuming session"); let client_ch = pair.begin_connect(client_config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"Hello, 0-RTT!"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert!(!pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let s2 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); assert_eq!(s, s2); let mut recv = pair.server_recv(server_ch, s2); let mut chunks = recv.read(false).unwrap(); assert_eq!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } fn test_zero_rtt_incoming_limit(configure_server: F) { // caller sets the server limit to 4000 bytes // the client writes 8000 bytes const CLIENT_WRITES: usize = 8000; // this gets split across 8 packets // the first packet is stored in the Incoming // the next three are incoming-buffered, bringing the incoming buffer size to 3600 bytes // the last four are dropped due to the buffering limit and must be retransmitted const EXPECTED_DROPPED: u64 = 4; let _guard = subscribe(); let mut server_config = server_config(); configure_server(&mut server_config); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let config = client_config(); // Establish normal connection let client_ch = pair.begin_connect(config.clone()); pair.drive(); pair.server.assert_accept(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), [][..].into()); pair.drive(); pair.client.addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); info!("resuming session"); pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Wait; let client_ch = pair.begin_connect(config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s) .write(&vec![0; CLIENT_WRITES]) .unwrap(); pair.drive(); let incoming = pair.server.waiting_incoming.pop().unwrap(); assert!(pair.server.waiting_incoming.is_empty()); let _ = pair.server.try_accept(incoming, pair.time); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected) ); assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); // We don't currently preserve stream event order wrt. connection events assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); let mut offset = 0; loop { match chunks.next(usize::MAX) { Ok(Some(chunk)) => { assert_eq!(chunk.offset as usize, offset); offset += chunk.bytes.len(); } Err(ReadError::Blocked) => break, Ok(None) => panic!("unexpected stream end"), Err(e) => panic!("{}", e), } } assert_eq!(offset, CLIENT_WRITES); let _ = chunks.finalize(); assert_eq!( pair.client_conn_mut(client_ch).lost_packets(), EXPECTED_DROPPED ); } #[test] fn zero_rtt_incoming_buffer_size() { test_zero_rtt_incoming_limit(|config| { config.incoming_buffer_size(4000); }); } #[test] fn zero_rtt_incoming_buffer_size_total() { test_zero_rtt_incoming_limit(|config| { config.incoming_buffer_size_total(4000); }); } #[test] fn alpn_success() { let _guard = subscribe(); let server_config = ServerConfig::with_crypto(Arc::new(server_crypto_with_alpn(vec![ "foo".into(), "bar".into(), "baz".into(), ]))); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_config = ClientConfig::new(Arc::new(client_crypto_with_alpn(vec![ "bar".into(), "quux".into(), "corge".into(), ]))); // Establish normal connection let client_ch = pair.begin_connect(client_config); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); let hd = pair .client_conn_mut(client_ch) .crypto_session() .handshake_data() .unwrap() .downcast::() .unwrap(); assert_eq!(hd.protocol.unwrap(), &b"bar"[..]); } #[test] fn server_alpn_unset() { let _guard = subscribe(); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config()); let client_config = ClientConfig::new(Arc::new(client_crypto_with_alpn(vec!["foo".into()]))); let client_ch = pair.begin_connect(client_config); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn client_alpn_unset() { let _guard = subscribe(); let server_config = ServerConfig::with_crypto(Arc::new(server_crypto_with_alpn(vec![ "foo".into(), "bar".into(), "baz".into(), ]))); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_ch = pair.begin_connect(client_config()); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn alpn_mismatch() { let _guard = subscribe(); let server_config = ServerConfig::with_crypto(Arc::new(server_crypto_with_alpn(vec![ "foo".into(), "bar".into(), "baz".into(), ]))); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_ch = pair.begin_connect(ClientConfig::new(Arc::new(client_crypto_with_alpn(vec![ "quux".into(), "corge".into(), ])))); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn stream_id_limit() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { max_concurrent_uni_streams: 1u32.into(), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Uni) .expect("couldn't open first stream"); assert_eq!( pair.client_streams(client_ch).open(Dir::Uni), None, "only one stream is permitted at a time" ); // Generate some activity to allow the server to see the stream const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_eq!( pair.client_streams(client_ch).open(Dir::Uni), None, "server does not immediately grant additional credit" ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_eq!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); // Server will only send MAX_STREAM_ID now that the application's been notified pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Available { dir: Dir::Uni })) ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Try opening the second stream again, now that we've made room let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Uni) .expect("didn't get stream id budget"); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); // Make sure the server actually processes data on the newly-available stream assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn key_update_simple() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Bi) .expect("couldn't open first stream"); const MSG1: &[u8] = b"hello1"; pair.client_send(client_ch, s).write(MSG1).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG1 ); let _ = chunks.finalize(); info!("initiating key update"); pair.client_conn_mut(client_ch).initiate_key_update(); const MSG2: &[u8] = b"hello2"; pair.client_send(client_ch, s).write(MSG2).unwrap(); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Readable { id })) if id == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 6 && chunk.bytes == MSG2 ); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); } #[test] fn key_update_reordered() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Bi) .expect("couldn't open first stream"); const MSG1: &[u8] = b"1"; pair.client_send(client_ch, s).write(MSG1).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert!(!pair.client.outbound.is_empty()); pair.client.delay_outbound(); pair.client_conn_mut(client_ch).initiate_key_update(); info!("updated keys"); const MSG2: &[u8] = b"two"; pair.client_send(client_ch, s).write(MSG2).unwrap(); pair.client.drive(pair.time, pair.server.addr); pair.client.finish_delay(); pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); let buf1 = chunks.next(usize::MAX).unwrap().unwrap(); assert_matches!(&*buf1.bytes, MSG1); let buf2 = chunks.next(usize::MAX).unwrap().unwrap(); assert_eq!(buf2.bytes, MSG2); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); } #[test] fn initial_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.client.drive(pair.time, pair.server.addr); pair.client.outbound.clear(); // Drop initial pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } #[test] fn instant_close_1() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); let client_ch = pair.begin_connect(client_config()); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), Bytes::new()); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, .. }), }) ); } #[test] fn instant_close_2() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); let client_ch = pair.begin_connect(client_config()); // Unlike `instant_close`, the server sees a valid Initial packet first. pair.drive_client(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, .. }), }) ); } #[test] fn instant_server_close() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); pair.begin_connect(client_config()); pair.drive_client(); pair.server.drive_incoming(pair.time, pair.client.addr); let server_ch = pair.server.assert_accept(); info!("closing"); pair.server .connections .get_mut(&server_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); assert_matches!( pair.client_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, .. }), }) ); } #[test] fn idle_timeout() { let _guard = subscribe(); const IDLE_TIMEOUT: u64 = 100; let server = ServerConfig { transport: Arc::new(TransportConfig { max_idle_timeout: Some(VarInt(IDLE_TIMEOUT)), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); pair.client_conn_mut(client_ch).ping(); let start = pair.time; while !pair.client_conn_mut(client_ch).is_closed() || !pair.server_conn_mut(server_ch).is_closed() { if !pair.step() { if let Some(t) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = t; } } pair.client.inbound.clear(); // Simulate total S->C packet loss } assert!(pair.time - start < Duration::from_millis(2 * IDLE_TIMEOUT)); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TimedOut, }) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TimedOut, }) ); } #[test] fn connection_close_sends_acks() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _server_ch) = pair.connect(); let client_acks = pair.client_conn_mut(client_ch).stats().frame_rx.acks; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); let time = pair.time; pair.server_conn_mut(client_ch) .close(time, VarInt(42), Bytes::new()); pair.drive(); let client_acks_2 = pair.client_conn_mut(client_ch).stats().frame_rx.acks; assert!( client_acks_2 > client_acks, "Connection close should send pending ACKs" ); } #[test] fn server_hs_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.step(); assert!(!pair.client.inbound.is_empty()); // Initial + Handshakes pair.client.inbound.clear(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } #[test] fn migration() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); pair.drive(); let client_stats_after_connect = pair.client_conn_mut(client_ch).stats(); pair.client.addr = SocketAddr::new( Ipv4Addr::new(127, 0, 0, 1).into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); pair.client_conn_mut(client_ch).ping(); // Assert that just receiving the ping message is accounted into the servers // anti-amplification budget pair.drive_client(); pair.drive_server(); assert_ne!(pair.server_conn_mut(server_ch).total_recvd(), 0); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!( pair.server_conn_mut(server_ch).remote_address(), pair.client.addr ); // Assert that the client's response to the PATH_CHALLENGE was an IMMEDIATE_ACK, instead of a // second ping let client_stats_after_migrate = pair.client_conn_mut(client_ch).stats(); assert_eq!( client_stats_after_migrate.frame_tx.ping - client_stats_after_connect.frame_tx.ping, 1 ); assert_eq!( client_stats_after_migrate.frame_tx.immediate_ack - client_stats_after_connect.frame_tx.immediate_ack, 1 ); } fn test_flow_control(config: TransportConfig, window_size: usize) { let _guard = subscribe(); let mut pair = Pair::new( Default::default(), ServerConfig { transport: Arc::new(config), ..server_config() }, ); let (client_ch, server_ch) = pair.connect(); let msg = vec![0xAB; window_size + 10]; // Stream reset before read let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); info!("writing"); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); info!("resetting"); pair.client_send(client_ch, s).reset(VarInt(42)).unwrap(); pair.drive(); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(usize::MAX).err(), Some(ReadError::Reset(VarInt(42))) ); let _ = chunks.finalize(); // Happy path info!("writing"); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); let mut cursor = 0; let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); loop { match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } Ok(None) => { panic!("end of stream"); } Err(ReadError::Blocked) => { break; } Err(e) => { panic!("{}", e); } } } let _ = chunks.finalize(); info!("finished reading"); assert_eq!(cursor, window_size); pair.drive(); info!("writing"); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); let mut cursor = 0; let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); loop { match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } Ok(None) => { panic!("end of stream"); } Err(ReadError::Blocked) => { break; } Err(e) => { panic!("{}", e); } } } assert_eq!(cursor, window_size); let _ = chunks.finalize(); info!("finished reading"); } #[test] fn stream_flow_control() { test_flow_control( TransportConfig { stream_receive_window: 2000u32.into(), ..TransportConfig::default() }, 2000, ); } #[test] fn conn_flow_control() { test_flow_control( TransportConfig { receive_window: 2000u32.into(), ..TransportConfig::default() }, 2000, ); } #[test] fn stop_opens_bidi() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_eq!(pair.client_streams(client_ch).send_streams(), 0); let s = pair.client_streams(client_ch).open(Dir::Bi).unwrap(); assert_eq!(pair.client_streams(client_ch).send_streams(), 1); const ERROR: VarInt = VarInt(42); pair.client .connections .get_mut(&server_ch) .unwrap() .recv_stream(s) .stop(ERROR) .unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 1); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); assert_matches!( pair.server_send(server_ch, s).write(b"foo"), Err(WriteError::Stopped(ERROR)) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Stopped { id: _, error_code: ERROR })) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); } #[test] fn implicit_open() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s1 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); let s2 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s2).write(b"hello").unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), Some(s1)); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), Some(s2)); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), None); } #[test] fn zero_length_cid() { let _guard = subscribe(); let cid_generator_factory: fn() -> Box = || Box::new(RandomConnectionIdGenerator::new(0)); let mut pair = Pair::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..EndpointConfig::default() }), server_config(), ); let (client_ch, server_ch) = pair.connect(); // Ensure we can reconnect after a previous connection is cleaned up info!("closing"); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); pair.server .connections .get_mut(&server_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.connect(); } #[test] fn keep_alive() { let _guard = subscribe(); const IDLE_TIMEOUT: u64 = 10; let server = ServerConfig { transport: Arc::new(TransportConfig { keep_alive_interval: Some(Duration::from_millis(IDLE_TIMEOUT / 2)), max_idle_timeout: Some(VarInt(IDLE_TIMEOUT)), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); // Run a good while longer than the idle timeout let end = pair.time + Duration::from_millis(20 * IDLE_TIMEOUT); while pair.time < end { if !pair.step() { if let Some(time) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = time; } } assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); } } #[test] fn cid_rotation() { let _guard = subscribe(); const CID_TIMEOUT: Duration = Duration::from_secs(2); let cid_generator_factory: fn() -> Box = || Box::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..EndpointConfig::default() }), Some(Arc::new(server_config())), true, None, ); let client = Endpoint::new(Arc::new(EndpointConfig::default()), None, true, None); let mut pair = Pair::new_from_endpoint(client, server); let (_, server_ch) = pair.connect(); let mut round: u64 = 1; let mut stop = pair.time; let end = pair.time + 5 * CID_TIMEOUT; use crate::cid_queue::CidQueue; use crate::LOC_CID_COUNT; let mut active_cid_num = CidQueue::LEN as u64 + 1; active_cid_num = active_cid_num.min(LOC_CID_COUNT); let mut left_bound = 0; let mut right_bound = active_cid_num - 1; while pair.time < end { stop += CID_TIMEOUT; // Run a while until PushNewCID timer fires while pair.time < stop { if !pair.step() { if let Some(time) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = time; } } } info!( "Checking active cid sequence range before {:?} seconds", round * CID_TIMEOUT.as_secs() ); let _bound = (left_bound, right_bound); assert_matches!( pair.server_conn_mut(server_ch).active_local_cid_seq(), _bound ); round += 1; left_bound += active_cid_num; right_bound += active_cid_num; pair.drive_server(); } } #[test] fn cid_retirement() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); // Server retires current active remote CIDs pair.server_conn_mut(server_ch) .rotate_local_cid(1, Instant::now()); pair.drive(); // Any unexpected behavior may trigger TransportError::CONNECTION_ID_LIMIT_ERROR assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); assert_matches!(pair.client_conn_mut(client_ch).active_rem_cid_seq(), 1); use crate::cid_queue::CidQueue; use crate::LOC_CID_COUNT; let mut active_cid_num = CidQueue::LEN as u64; active_cid_num = active_cid_num.min(LOC_CID_COUNT); let next_retire_prior_to = active_cid_num + 1; pair.client_conn_mut(client_ch).ping(); // Server retires all valid remote CIDs pair.server_conn_mut(server_ch) .rotate_local_cid(next_retire_prior_to, Instant::now()); pair.drive(); assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); assert_matches!( pair.client_conn_mut(client_ch).active_rem_cid_seq(), _next_retire_prior_to ); } #[test] fn finish_stream_flow_control_reordered() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive_client(); // Send stream data pair.server.drive(pair.time, pair.client.addr); // Receive // Issue flow control credit let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); pair.server.drive(pair.time, pair.client.addr); pair.server.delay_outbound(); // Delay it pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); // Send FIN pair.server.drive(pair.time, pair.client.addr); // Acknowledge pair.server.finish_delay(); // Add flow control packets after pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn handshake_1rtt_handling() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.drive_client(); pair.drive_server(); let server_ch = pair.server.assert_accept(); // Server now has 1-RTT keys, but remains in Handshake state until the TLS CFIN has // authenticated the client. Delay the final client handshake flight so that doesn't happen yet. pair.client.drive(pair.time, pair.server.addr); pair.client.delay_outbound(); // Send some 1-RTT data which will be received first. let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.client.drive(pair.time, pair.server.addr); // Add the handshake flight back on. pair.client.finish_delay(); pair.drive(); assert!(pair.client_conn_mut(client_ch).lost_packets() != 0); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); } #[test] fn stop_before_finish() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("stopping stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive(); assert_matches!( pair.client_send(client_ch, s).finish(), Err(FinishError::Stopped(ERROR)) ); } #[test] fn stop_during_finish() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); info!("stopping and finishing stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive_server(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Stopped { id, error_code: ERROR })) if id == s ); } // Ensure we can recover from loss of tail packets when the congestion window is full #[test] fn congested_tail_loss() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); const TARGET: u64 = 2048; assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); // Send data without receiving ACKs until the congestion state falls below target while pair.client_conn_mut(client_ch).congestion_window() > TARGET { let n = pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); assert_eq!(n, 1024); pair.drive_client(); } assert!(!pair.server.inbound.is_empty()); pair.server.inbound.clear(); // Ensure that the congestion state recovers after retransmits occur and are ACKed info!("recovering"); pair.drive(); assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); } #[test] fn datagram_send_recv() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.client_datagrams(client_ch).max_size(), Some(x) if x > 0); const DATA: &[u8] = b"whee"; pair.client_datagrams(client_ch) .send(DATA.into(), true) .unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::DatagramReceived) ); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA); assert_matches!(pair.server_datagrams(server_ch).recv(), None); } #[test] fn datagram_recv_buffer_overflow() { let _guard = subscribe(); const WINDOW: usize = 100; let server = ServerConfig { transport: Arc::new(TransportConfig { datagram_receive_buffer_size: Some(WINDOW), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_eq!( pair.client_conn_mut(client_ch).datagrams().max_size(), Some(WINDOW - Datagram::SIZE_BOUND) ); const DATA1: &[u8] = &[0xAB; (WINDOW / 3) + 1]; const DATA2: &[u8] = &[0xBC; (WINDOW / 3) + 1]; const DATA3: &[u8] = &[0xCD; (WINDOW / 3) + 1]; pair.client_datagrams(client_ch) .send(DATA1.into(), true) .unwrap(); pair.client_datagrams(client_ch) .send(DATA2.into(), true) .unwrap(); pair.client_datagrams(client_ch) .send(DATA3.into(), true) .unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::DatagramReceived) ); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA2); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA3); assert_matches!(pair.server_datagrams(server_ch).recv(), None); pair.client_datagrams(client_ch) .send(DATA1.into(), true) .unwrap(); pair.drive(); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA1); assert_matches!(pair.server_datagrams(server_ch).recv(), None); } #[test] fn datagram_unsupported() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { datagram_receive_buffer_size: None, ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.client_datagrams(client_ch).max_size(), None); match pair.client_datagrams(client_ch).send(Bytes::new(), true) { Err(SendDatagramError::UnsupportedByPeer) => {} Err(e) => panic!("unexpected error: {e}"), Ok(_) => panic!("unexpected success"), } } #[test] fn large_initial() { let _guard = subscribe(); let server_config = ServerConfig::with_crypto(Arc::new(server_crypto_with_alpn(vec![vec![0, 0, 0, 42]]))); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_crypto = client_crypto_with_alpn((0..1000u32).map(|x| x.to_be_bytes().to_vec()).collect()); let cfg = ClientConfig::new(Arc::new(client_crypto)); let client_ch = pair.begin_connect(cfg); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected { .. }) ); } #[test] /// Ensure that we don't yield a finish event before the actual FIN is acked so the peer isn't left /// hanging fn finish_acked() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); info!("client sends data to server"); pair.drive_client(); // send data to server info!("server acknowledges data"); pair.drive_server(); // process data and send data ack // Receive data assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); // Finish before receiving data ack pair.client_send(client_ch, s).finish().unwrap(); // Send FIN, receive data ack info!("client receives ACK, sends FIN"); pair.drive_client(); // Check for premature finish from data ack assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Process FIN ack info!("server ACKs FIN"); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] /// Ensure that we don't yield a finish event while there's still unacknowledged data fn finish_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive_client(); // send data to server pair.server.inbound.clear(); // Lose it // Send FIN pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); // Process FIN pair.drive_server(); // Receive FIN ack, but no data ack pair.drive_client(); // Check for premature finish from FIN ack assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Recover pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } /// Ensures that exchanging data on a client-initiated bidirectional stream works past the initial /// stream window. #[test] fn repeated_request_response() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { max_concurrent_bidi_streams: 1u32.into(), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); const REQUEST: &[u8] = b"hello"; const RESPONSE: &[u8] = b"world"; for _ in 0..3 { let s = pair.client_streams(client_ch).open(Dir::Bi).unwrap(); pair.client_send(client_ch, s).write(REQUEST).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_eq!(pair.server_streams(server_ch).accept(Dir::Bi), Some(s)); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == REQUEST ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); pair.server_send(server_ch, s).write(RESPONSE).unwrap(); pair.server_send(server_ch, s).finish().unwrap(); pair.drive(); let mut recv = pair.client_recv(client_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == RESPONSE ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } } /// Ensures that the client sends an anti-deadlock probe after an incomplete server's first flight #[test] fn handshake_anti_deadlock_probe() { let _guard = subscribe(); let (cert, key) = big_cert_and_key(); let server = server_config_with_cert(cert.clone(), key); let client = client_config_with_certs(vec![cert]); let mut pair = Pair::new(Default::default(), server); let client_ch = pair.begin_connect(client); // Client sends initial pair.drive_client(); // Server sends first flight, gets blocked on anti-amplification pair.drive_server(); // Client acks... pair.drive_client(); // ...but it's lost, so the server doesn't get anti-amplification credit from it pair.server.inbound.clear(); // Client sends an anti-deadlock probe, and the handshake completes as usual. pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } /// Ensures that the server can respond with 3 initial packets during the handshake /// before the anti-amplification limit kicks in when MTUs are similar. #[test] fn server_can_send_3_inital_packets() { let _guard = subscribe(); let (cert, key) = big_cert_and_key(); let server = server_config_with_cert(cert.clone(), key); let client = client_config_with_certs(vec![cert]); let mut pair = Pair::new(Default::default(), server); let client_ch = pair.begin_connect(client); // Client sends initial pair.drive_client(); // Server sends first flight, gets blocked on anti-amplification pair.drive_server(); // Server should have queued 3 packets at this time assert_eq!(pair.client.inbound.len(), 3); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } /// Generate a big fat certificate that can't fit inside the initial anti-amplification limit fn big_cert_and_key() -> (CertificateDer<'static>, PrivateKeyDer<'static>) { let cert = rcgen::generate_simple_self_signed( Some("localhost".into()) .into_iter() .chain((0..1000).map(|x| format!("foo_{x}"))) .collect::>(), ) .unwrap(); ( cert.cert.into(), PrivateKeyDer::Pkcs8(cert.key_pair.serialize_der().into()), ) } #[test] fn malformed_token_len() { let _guard = subscribe(); let client_addr = "[::2]:7890".parse().unwrap(); let mut server = Endpoint::new( Default::default(), Some(Arc::new(server_config())), true, None, ); let mut buf = Vec::with_capacity(server.config().get_max_udp_payload_size() as usize); server.handle( Instant::now(), client_addr, None, None, hex!("8900 0000 0101 0000 1b1b 841b 0000 0000 3f00")[..].into(), &mut buf, ); } #[test] fn loss_probe_requests_immediate_ack() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); pair.drive(); let stats_after_connect = pair.client_conn_mut(client_ch).stats(); // Lose a ping let default_mtu = mem::replace(&mut pair.mtu, 0); pair.client_conn_mut(client_ch).ping(); pair.drive_client(); pair.mtu = default_mtu; // Drive the connection further so a loss probe is sent pair.drive(); // Assert that two IMMEDIATE_ACKs were sent (two loss probes) let stats_after_recovery = pair.client_conn_mut(client_ch).stats(); assert_eq!( stats_after_recovery.frame_tx.immediate_ack - stats_after_connect.frame_tx.immediate_ack, 2 ); } #[test] /// This is mostly a sanity check to ensure our testing code is correctly dropping packets above the /// pmtu fn connect_too_low_mtu() { let _guard = subscribe(); let mut pair = Pair::default(); // The maximum payload size is lower than 1200, so no packages will get through! pair.mtu = 1000; pair.begin_connect(client_config()); pair.drive(); pair.server.assert_no_accept(); } #[test] fn connect_lost_mtu_probes_do_not_trigger_congestion_control() { let _guard = subscribe(); let mut pair = Pair::default(); pair.mtu = 1200; let (client_ch, server_ch) = pair.connect(); pair.drive(); let client_stats = pair.client_conn_mut(client_ch).stats(); let server_stats = pair.server_conn_mut(server_ch).stats(); // Sanity check (all MTU probes should have been lost) assert_eq!(client_stats.path.sent_plpmtud_probes, 9); assert_eq!(client_stats.path.lost_plpmtud_probes, 9); assert_eq!(server_stats.path.sent_plpmtud_probes, 9); assert_eq!(server_stats.path.lost_plpmtud_probes, 9); // No congestion events assert_eq!(client_stats.path.congestion_events, 0); assert_eq!(server_stats.path.congestion_events, 0); } #[test] fn connect_detects_mtu() { let _guard = subscribe(); let max_udp_payload_and_expected_mtu = &[(1200, 1200), (1400, 1389), (1500, 1452)]; for &(pair_max_udp, expected_mtu) in max_udp_payload_and_expected_mtu { let mut pair = Pair::default(); pair.mtu = pair_max_udp; let (client_ch, server_ch) = pair.connect(); pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), expected_mtu); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), expected_mtu); } } #[test] fn migrate_detects_new_mtu_and_respects_original_peer_max_udp_payload_size() { let _guard = subscribe(); let client_max_udp_payload_size: u16 = 1400; // Set up a client with a max payload size of 1400 (and use the defaults for the server) let server_endpoint_config = EndpointConfig::default(); let server = Endpoint::new( Arc::new(server_endpoint_config), Some(Arc::new(server_config())), true, None, ); let client_endpoint_config = EndpointConfig { max_udp_payload_size: VarInt::from(client_max_udp_payload_size), ..EndpointConfig::default() }; let client = Endpoint::new(Arc::new(client_endpoint_config), None, true, None); let mut pair = Pair::new_from_endpoint(client, server); pair.mtu = 1300; // Connect let (client_ch, server_ch) = pair.connect(); pair.drive(); // Sanity check: MTUD ran to completion (the numbers differ because binary search stops when // changes are smaller than 20, otherwise both endpoints would converge at the same MTU of 1300) assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1293); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1300); // Migrate client to a different port (and simulate a higher path MTU) pair.mtu = 1500; pair.client.addr = SocketAddr::new( Ipv4Addr::new(127, 0, 0, 1).into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); pair.client_conn_mut(client_ch).ping(); pair.drive(); // Sanity check: the server saw that the client address was updated assert_eq!( pair.server_conn_mut(server_ch).remote_address(), pair.client.addr ); // MTU detection has successfully run after migrating assert_eq!( pair.server_conn_mut(server_ch).path_mtu(), client_max_udp_payload_size ); // Sanity check: the client keeps the old MTU, because migration is triggered by incoming // packets from a different address assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1293); } #[test] fn connect_runs_mtud_again_after_600_seconds() { let _guard = subscribe(); let mut server_config = server_config(); let mut client_config = client_config(); // Note: we use an infinite idle timeout to ensure we can wait 600 seconds without the // connection closing Arc::get_mut(&mut server_config.transport) .unwrap() .max_idle_timeout(None); Arc::get_mut(&mut client_config.transport) .unwrap() .max_idle_timeout(None); let mut pair = Pair::new(Default::default(), server_config); pair.mtu = 1400; let (client_ch, server_ch) = pair.connect_with(client_config); pair.drive(); // Sanity check: the mtu has been discovered let client_conn = pair.client_conn_mut(client_ch); assert_eq!(client_conn.path_mtu(), 1389); assert_eq!(client_conn.stats().path.sent_plpmtud_probes, 5); assert_eq!(client_conn.stats().path.lost_plpmtud_probes, 3); let server_conn = pair.server_conn_mut(server_ch); assert_eq!(server_conn.path_mtu(), 1389); assert_eq!(server_conn.stats().path.sent_plpmtud_probes, 5); assert_eq!(server_conn.stats().path.lost_plpmtud_probes, 3); // Sanity check: the mtu does not change after the fact, even though the link now supports a // higher udp payload size pair.mtu = 1500; pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1389); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1389); // The MTU changes after 600 seconds, because now MTUD runs for the second time pair.time += Duration::from_secs(600); pair.drive(); assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(client_ch).is_closed()); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1452); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1452); } #[test] fn blackhole_after_mtu_change_repairs_itself() { let _guard = subscribe(); let mut pair = Pair::default(); pair.mtu = 1500; let (client_ch, server_ch) = pair.connect(); pair.drive(); // Sanity check assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1452); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1452); // Back to the base MTU pair.mtu = 1200; // The payload will be sent in a single packet, because the detected MTU was 1444, but it will // be dropped because the link no longer supports that packet size! let payload = vec![42; 1300]; let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); let out_of_bounds = pair.drive_bounded(); if out_of_bounds { panic!("Connections never reached an idle state"); } let recv = pair.server_recv(server_ch, s); let buf = stream_chunks(recv); // The whole packet arrived in the end assert_eq!(buf.len(), 1300); // Sanity checks (black hole detected after 3 lost packets) let client_stats = pair.client_conn_mut(client_ch).stats(); assert!(client_stats.path.lost_packets >= 3); assert!(client_stats.path.congestion_events >= 3); assert_eq!(client_stats.path.black_holes_detected, 1); } #[test] fn mtud_probes_include_immediate_ack() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); pair.drive(); let stats = pair.client_conn_mut(client_ch).stats(); assert_eq!(stats.path.sent_plpmtud_probes, 4); // Each probe contains a ping and an immediate ack assert_eq!(stats.frame_tx.ping, 4); assert_eq!(stats.frame_tx.immediate_ack, 4); } #[test] fn packet_splitting_with_default_mtu() { let _guard = subscribe(); // The payload needs to be split in 2 in order to be sent, because it is higher than the max MTU let payload = vec![42; 1300]; let mut pair = Pair::default(); pair.mtu = 1200; let (client_ch, _) = pair.connect(); pair.drive(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert_eq!(pair.client.outbound.len(), 2); pair.drive_client(); assert_eq!(pair.server.inbound.len(), 2); } #[test] fn packet_splitting_not_necessary_after_higher_mtu_discovered() { let _guard = subscribe(); let payload = vec![42; 1300]; let mut pair = Pair::default(); pair.mtu = 1500; let (client_ch, _) = pair.connect(); pair.drive(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert_eq!(pair.client.outbound.len(), 1); pair.drive_client(); assert_eq!(pair.server.inbound.len(), 1); } #[test] fn single_ack_eliciting_packet_triggers_ack_after_delay() { let _guard = subscribe(); let mut pair = Pair::default_with_deterministic_pns(); let (client_ch, _) = pair.connect_with(client_config_with_deterministic_pns()); pair.drive(); let stats_after_connect = pair.client_conn_mut(client_ch).stats(); let start = pair.time; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Send ping pair.drive_server(); // Process ping pair.drive_client(); // Give the client a chance to process an ack, so our assertion can fail // Sanity check: the time hasn't advanced in the meantime) assert_eq!(pair.time, start); let stats_after_ping = pair.client_conn_mut(client_ch).stats(); assert_eq!( stats_after_ping.frame_tx.ping - stats_after_connect.frame_tx.ping, 1 ); assert_eq!( stats_after_ping.frame_rx.acks - stats_after_connect.frame_rx.acks, 0 ); pair.client.capture_inbound_packets = true; pair.drive(); let stats_after_drive = pair.client_conn_mut(client_ch).stats(); assert_eq!( stats_after_drive.frame_rx.acks - stats_after_ping.frame_rx.acks, 1 ); // The time is start + max_ack_delay let default_max_ack_delay_ms = TransportParameters::default().max_ack_delay.into_inner(); assert_eq!( pair.time, start + Duration::from_millis(default_max_ack_delay_ms) ); // The ACK delay is properly calculated assert_eq!(pair.client.captured_packets.len(), 1); let mut frames = frame::Iter::new(pair.client.captured_packets.remove(0).into()) .unwrap() .collect::, _>>() .unwrap(); assert_eq!(frames.len(), 1); if let Frame::Ack(ack) = frames.remove(0) { let ack_delay_exp = TransportParameters::default().ack_delay_exponent; let delay = ack.delay << ack_delay_exp.into_inner(); assert_eq!(delay, default_max_ack_delay_ms * 1_000); } else { panic!("Expected ACK frame"); } // Sanity check: no loss probe was sent, because the delayed ACK was received on time assert_eq!( stats_after_drive.frame_tx.ping - stats_after_connect.frame_tx.ping, 1 ); } #[test] fn immediate_ack_triggers_ack() { let _guard = subscribe(); let mut pair = Pair::default_with_deterministic_pns(); let (client_ch, _) = pair.connect_with(client_config_with_deterministic_pns()); pair.drive(); let acks_after_connect = pair.client_conn_mut(client_ch).stats().frame_rx.acks; pair.client_conn_mut(client_ch).immediate_ack(); pair.drive_client(); // Send immediate ack pair.drive_server(); // Process immediate ack pair.drive_client(); // Give the client a chance to process the ack let acks_after_ping = pair.client_conn_mut(client_ch).stats().frame_rx.acks; assert_eq!(acks_after_ping - acks_after_connect, 1); } #[test] fn out_of_order_ack_eliciting_packet_triggers_ack() { let _guard = subscribe(); let mut pair = Pair::default_with_deterministic_pns(); let (client_ch, server_ch) = pair.connect_with(client_config_with_deterministic_pns()); pair.drive(); let default_mtu = pair.mtu; let client_stats_after_connect = pair.client_conn_mut(client_ch).stats(); let server_stats_after_connect = pair.server_conn_mut(server_ch).stats(); // Send a packet that won't arrive right away (it will be dropped and be re-sent later) pair.mtu = 0; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Sanity check (ping sent, no ACK received) let client_stats_after_first_ping = pair.client_conn_mut(client_ch).stats(); assert_eq!( client_stats_after_first_ping.frame_tx.ping - client_stats_after_connect.frame_tx.ping, 1 ); assert_eq!( client_stats_after_first_ping.frame_rx.acks - client_stats_after_connect.frame_rx.acks, 0 ); // Restore the default MTU and send another ping, which will arrive earlier than the dropped one pair.mtu = default_mtu; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); pair.drive_server(); pair.drive_client(); // Client sanity check (ping sent, one ACK received) let client_stats_after_second_ping = pair.client_conn_mut(client_ch).stats(); assert_eq!( client_stats_after_second_ping.frame_tx.ping - client_stats_after_connect.frame_tx.ping, 2 ); assert_eq!( client_stats_after_second_ping.frame_rx.acks - client_stats_after_connect.frame_rx.acks, 1 ); // Server checks (single ping received, ACK sent) let server_stats_after_second_ping = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after_second_ping.frame_rx.ping - server_stats_after_connect.frame_rx.ping, 1 ); assert_eq!( server_stats_after_second_ping.frame_tx.acks - server_stats_after_connect.frame_tx.acks, 1 ); } #[test] fn single_ack_eliciting_packet_with_ce_bit_triggers_immediate_ack() { let _guard = subscribe(); let mut pair = Pair::default_with_deterministic_pns(); let (client_ch, _) = pair.connect_with(client_config_with_deterministic_pns()); pair.drive(); let stats_after_connect = pair.client_conn_mut(client_ch).stats(); let start = pair.time; pair.client_conn_mut(client_ch).ping(); pair.congestion_experienced = true; pair.drive_client(); // Send ping pair.congestion_experienced = false; pair.drive_server(); // Process ping, send ACK in response to congestion pair.drive_client(); // Process ACK // Sanity check: the time hasn't advanced in the meantime) assert_eq!(pair.time, start); let stats_after_ping = pair.client_conn_mut(client_ch).stats(); assert_eq!( stats_after_ping.frame_tx.ping - stats_after_connect.frame_tx.ping, 1 ); assert_eq!( stats_after_ping.frame_rx.acks - stats_after_connect.frame_rx.acks, 1 ); assert_eq!( stats_after_ping.path.congestion_events - stats_after_connect.path.congestion_events, 1 ); } fn setup_ack_frequency_test(max_ack_delay: Duration) -> (Pair, ConnectionHandle, ConnectionHandle) { let mut client_config = client_config_with_deterministic_pns(); let mut ack_freq_config = AckFrequencyConfig::default(); ack_freq_config .ack_eliciting_threshold(10u32.into()) .max_ack_delay(Some(max_ack_delay)); Arc::get_mut(&mut client_config.transport) .unwrap() .ack_frequency_config(Some(ack_freq_config)) .mtu_discovery_config(None); // To keep traffic cleaner let mut pair = Pair::default_with_deterministic_pns(); pair.latency = Duration::from_millis(10); // Need latency to avoid an RTT = 0 let (client_ch, server_ch) = pair.connect_with(client_config); pair.drive(); assert_eq!( pair.client_conn_mut(client_ch) .stats() .frame_tx .ack_frequency, 1 ); assert_eq!(pair.client_conn_mut(client_ch).stats().frame_tx.ping, 0); (pair, client_ch, server_ch) } /// Verify that max ACK delay is counted from the first ACK-eliciting packet #[test] fn ack_frequency_ack_delayed_from_first_of_flight() { let _guard = subscribe(); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(Duration::from_millis(30)); // The client sends the following frames: // // * 0 ms: ping // * 5 ms: ping x2 pair.client_conn_mut(client_ch).ping(); pair.drive_client(); pair.time += Duration::from_millis(5); for _ in 0..2 { pair.client_conn_mut(client_ch).ping(); pair.drive_client(); } pair.time += Duration::from_millis(5); // Server: receive the first ping and send no ACK let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: receive the second and third pings and send no ACK pair.time += Duration::from_millis(10); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 2 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: Send an ACK after ACK delay expires pair.time += Duration::from_millis(20); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 1 ); } #[test] fn ack_frequency_ack_sent_after_max_ack_delay() { let _guard = subscribe(); let max_ack_delay = Duration::from_millis(30); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(max_ack_delay); // Client sends a ping pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Server: receive the ping, send no ACK pair.time += pair.latency; let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: send an ack after max_ack_delay has elapsed pair.time += max_ack_delay; let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 0 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 1 ); } #[test] fn ack_frequency_ack_sent_after_packets_above_threshold() { let _guard = subscribe(); let max_ack_delay = Duration::from_millis(30); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(max_ack_delay); // The client sends the following frames: // // * 0 ms: ping // * 5 ms: ping (11x) pair.client_conn_mut(client_ch).ping(); pair.drive_client(); pair.time += Duration::from_millis(5); for _ in 0..11 { pair.client_conn_mut(client_ch).ping(); pair.drive_client(); } // Server: receive the first ping, send no ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: receive the remaining pings, send ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 11 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 1 ); } #[test] fn ack_frequency_ack_sent_after_reordered_packets_below_threshold() { let _guard = subscribe(); let max_ack_delay = Duration::from_millis(30); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(max_ack_delay); // The client sends the following frames: // // * 0 ms: ping // * 5 ms: ping (lost) // * 5 ms: ping pair.client_conn_mut(client_ch).ping(); pair.drive_client(); pair.time += Duration::from_millis(5); // Send and lose an ack-eliciting packet pair.mtu = 0; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Restore the default MTU and send another ping, which will arrive earlier than the dropped one pair.mtu = DEFAULT_MTU; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Server: receive first ping, send no ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: receive second ping, send no ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); } #[test] fn ack_frequency_ack_sent_after_reordered_packets_above_threshold() { let _guard = subscribe(); let max_ack_delay = Duration::from_millis(30); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(max_ack_delay); // Send a ping pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Send and lose two ack-eliciting packets pair.time += Duration::from_millis(5); pair.mtu = 0; for _ in 0..2 { pair.client_conn_mut(client_ch).ping(); pair.drive_client(); } // Restore the default MTU and send another ping, which will arrive earlier than the dropped ones pair.mtu = DEFAULT_MTU; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); // Server: receive first ping, send no ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 0 ); // Server: receive remaining ping, send ACK pair.time += Duration::from_millis(5); let server_stats_before = pair.server_conn_mut(server_ch).stats(); pair.drive_server(); let server_stats_after = pair.server_conn_mut(server_ch).stats(); assert_eq!( server_stats_after.frame_rx.ping - server_stats_before.frame_rx.ping, 1 ); assert_eq!( server_stats_after.frame_tx.acks - server_stats_before.frame_tx.acks, 1 ); } #[test] fn ack_frequency_update_max_delay() { let _guard = subscribe(); let (mut pair, client_ch, server_ch) = setup_ack_frequency_test(Duration::from_millis(200)); // Ack frequency was sent initially assert_eq!( pair.server_conn_mut(server_ch) .stats() .frame_rx .ack_frequency, 1 ); // Client sends a PING info!("first ping"); pair.client_conn_mut(client_ch).ping(); pair.drive(); // No change in ACK frequency assert_eq!( pair.server_conn_mut(server_ch) .stats() .frame_rx .ack_frequency, 1 ); // RTT jumps, client sends another ping info!("delayed ping"); pair.latency *= 10; pair.client_conn_mut(client_ch).ping(); pair.drive(); // ACK frequency updated assert!( pair.server_conn_mut(server_ch) .stats() .frame_rx .ack_frequency >= 2 ); } fn stream_chunks(mut recv: RecvStream) -> Vec { let mut buf = Vec::new(); let mut chunks = recv.read(true).unwrap(); while let Ok(Some(chunk)) = chunks.next(usize::MAX) { buf.extend(chunk.bytes); } let _ = chunks.finalize(); buf } /// Verify that an endpoint which receives but does not send ACK-eliciting data still receives ACKs /// occasionally. This is not required for conformance, but makes loss detection more responsive and /// reduces receiver memory use. #[test] fn pure_sender_voluntarily_acks() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let receiver_acks_initial = pair.server_conn_mut(server_ch).stats().frame_rx.acks; for _ in 0..100 { const MSG: &[u8] = b"hello"; pair.client_datagrams(client_ch) .send(Bytes::from_static(MSG), true) .unwrap(); pair.drive(); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), MSG); } let receiver_acks_final = pair.server_conn_mut(server_ch).stats().frame_rx.acks; assert!(receiver_acks_final > receiver_acks_initial); } #[test] fn reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); pair.server.incoming_connection_behavior = IncomingConnectionBehavior::RejectAll; // The server should now reject incoming connections. let client_ch = pair.begin_connect(client_config()); pair.drive(); pair.server.assert_no_accept(); let client = pair.client.connections.get_mut(&client_ch).unwrap(); assert!(client.is_closed()); assert!(matches!( client.poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(close) }) if close.error_code == TransportErrorCode::CONNECTION_REFUSED )); } #[test] fn validate_then_reject_manually() { let _guard = subscribe(); let mut pair = Pair::default(); pair.server.incoming_connection_behavior = IncomingConnectionBehavior::ValidateThenReject; // The server should now retry and reject incoming connections. let client_ch = pair.begin_connect(client_config()); pair.drive(); pair.server.assert_no_accept(); let client = pair.client.connections.get_mut(&client_ch).unwrap(); assert!(client.is_closed()); assert!(matches!( client.poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(close) }) if close.error_code == TransportErrorCode::CONNECTION_REFUSED )); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn endpoint_and_connection_impl_send_sync() { const fn is_send_sync() {} is_send_sync::(); is_send_sync::(); } #[test] fn stream_gso() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); let initial_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; // Send 20KiB of stream data, which comfortably fits inside two `tests::util::MAX_DATAGRAMS` // datagram batches info!("sending"); for _ in 0..20 { pair.client_send(client_ch, s).write(&[0; 1024]).unwrap(); } pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); let final_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; assert_eq!(final_ios - initial_ios, 2); } #[test] fn datagram_gso() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); let initial_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; let initial_bytes = pair.client_conn_mut(client_ch).stats().udp_tx.bytes; // Send 10 datagrams above half the MTU, which fits inside a `tests::util::MAX_DATAGRAMS` // datagram batch info!("sending"); const DATAGRAM_LEN: usize = 1024; const DATAGRAMS: usize = 10; for _ in 0..DATAGRAMS { pair.client_datagrams(client_ch) .send(Bytes::from_static(&[0; DATAGRAM_LEN]), false) .unwrap(); } pair.drive(); let final_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; let final_bytes = pair.client_conn_mut(client_ch).stats().udp_tx.bytes; assert_eq!(final_ios - initial_ios, 1); // Expected overhead: flags + CID + PN + tag + frame type + frame length = 1 + 8 + 1 + 16 + 1 + 2 = 29 assert_eq!( final_bytes - initial_bytes, ((29 + DATAGRAM_LEN) * DATAGRAMS) as u64 ); } #[test] fn gso_truncation() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let initial_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; // Send three application datagrams such that each is large to be combined with another in a // single MTU, and the second datagram would require an unreasonably large amount of padding to // produce a QUIC packet of the same length as the first. info!("sending"); const SIZES: [usize; 3] = [1024, 768, 768]; for len in SIZES { pair.client_datagrams(client_ch) .send(vec![0; len].into(), false) .unwrap(); } pair.drive(); let final_ios = pair.client_conn_mut(client_ch).stats().udp_tx.ios; assert_eq!(final_ios - initial_ios, 2); for len in SIZES { assert_eq!( pair.server_datagrams(server_ch) .recv() .expect("datagram lost") .len(), len ); } } /// Verify that a large application datagram is sent successfully when an ACK frame too large to fit /// alongside it is also queued, in exactly 2 UDP datagrams. #[test] fn large_datagram_with_acks() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); // Force the client to generate a large ACK frame by dropping several packets for _ in 0..10 { pair.server_conn_mut(server_ch).ping(); pair.drive_server(); pair.client.inbound.pop_back(); pair.server_conn_mut(server_ch).ping(); pair.drive_server(); } let max_size = pair.client_datagrams(client_ch).max_size().unwrap(); let msg = Bytes::from(vec![0; max_size]); pair.client_datagrams(client_ch) .send(msg.clone(), true) .unwrap(); let initial_datagrams = pair.client_conn_mut(client_ch).stats().udp_tx.datagrams; pair.drive(); let final_datagrams = pair.client_conn_mut(client_ch).stats().udp_tx.datagrams; assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), msg); assert_eq!(final_datagrams - initial_datagrams, 2); } /// Verify that an ACK prompted by receipt of many non-ACK-eliciting packets is sent alongside /// outgoing application datagrams too large to coexist in the same packet with it. #[test] fn voluntary_ack_with_large_datagrams() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); // Prompt many large ACKs from the server let initial_datagrams = pair.client_conn_mut(client_ch).stats().udp_tx.datagrams; // Send enough packets that we're confident some packet numbers will be skipped, ensuring that // larger ACKs occur const COUNT: usize = 256; for _ in 0..COUNT { let max_size = pair.client_datagrams(client_ch).max_size().unwrap(); pair.client_datagrams(client_ch) .send(vec![0; max_size].into(), true) .unwrap(); pair.drive(); } let final_datagrams = pair.client_conn_mut(client_ch).stats().udp_tx.datagrams; // Failure may indicate `max_size` is too small and ACKs are reliably being packed into the same // datagram, which is reasonable behavior but makes this test ineffective. assert_ne!( final_datagrams - initial_datagrams, COUNT as u64, "client should have sent some ACK-only packets" ); } #[test] fn reject_short_idcid() { let _guard = subscribe(); let client_addr = "[::2]:7890".parse().unwrap(); let mut server = Endpoint::new( Default::default(), Some(Arc::new(server_config())), true, None, ); let now = Instant::now(); let mut buf = Vec::with_capacity(server.config().get_max_udp_payload_size() as usize); // Initial header that has an empty DCID but is otherwise well-formed let mut initial = BytesMut::from(hex!("c4 00000001 00 00 00 3f").as_ref()); initial.resize(MIN_INITIAL_SIZE.into(), 0); let event = server.handle(now, client_addr, None, None, initial, &mut buf); let Some(DatagramEvent::Response(Transmit { .. })) = event else { panic!("expected an initial close"); }; } quinn-proto-0.11.9/src/tests/util.rs000064400000000000000000000571311046102023000154720ustar 00000000000000use std::{ cmp, collections::{HashMap, VecDeque}, env, io::{self, Write}, mem, net::{Ipv6Addr, SocketAddr, UdpSocket}, ops::RangeFrom, str, sync::{Arc, Mutex}, }; use assert_matches::assert_matches; use bytes::BytesMut; use lazy_static::lazy_static; use rustls::{ client::WebPkiServerVerifier, pki_types::{CertificateDer, PrivateKeyDer}, KeyLogFile, }; use tracing::{info_span, trace}; use super::crypto::rustls::{configured_provider, QuicClientConfig, QuicServerConfig}; use super::*; use crate::{Duration, Instant}; pub(super) const DEFAULT_MTU: usize = 1452; pub(super) struct Pair { pub(super) server: TestEndpoint, pub(super) client: TestEndpoint, /// Start time epoch: Instant, /// Current time pub(super) time: Instant, /// Simulates the maximum size allowed for UDP payloads by the link (packets exceeding this size will be dropped) pub(super) mtu: usize, /// Simulates explicit congestion notification pub(super) congestion_experienced: bool, // One-way pub(super) latency: Duration, /// Number of spin bit flips pub(super) spins: u64, last_spin: bool, } impl Pair { pub(super) fn default_with_deterministic_pns() -> Self { let mut cfg = server_config(); let mut transport = TransportConfig::default(); transport.deterministic_packet_numbers(true); cfg.transport = Arc::new(transport); Self::new(Default::default(), cfg) } pub(super) fn new(endpoint_config: Arc, server_config: ServerConfig) -> Self { let server = Endpoint::new( endpoint_config.clone(), Some(Arc::new(server_config)), true, None, ); let client = Endpoint::new(endpoint_config, None, true, None); Self::new_from_endpoint(client, server) } pub(super) fn new_from_endpoint(client: Endpoint, server: Endpoint) -> Self { let server_addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), SERVER_PORTS.lock().unwrap().next().unwrap(), ); let client_addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); let now = Instant::now(); Self { server: TestEndpoint::new(server, server_addr), client: TestEndpoint::new(client, client_addr), epoch: now, time: now, mtu: DEFAULT_MTU, latency: Duration::new(0, 0), spins: 0, last_spin: false, congestion_experienced: false, } } /// Returns whether the connection is not idle pub(super) fn step(&mut self) -> bool { self.drive_client(); self.drive_server(); if self.client.is_idle() && self.server.is_idle() { return false; } let client_t = self.client.next_wakeup(); let server_t = self.server.next_wakeup(); match min_opt(client_t, server_t) { Some(t) if Some(t) == client_t => { if t != self.time { self.time = self.time.max(t); trace!("advancing to {:?} for client", self.time - self.epoch); } true } Some(t) if Some(t) == server_t => { if t != self.time { self.time = self.time.max(t); trace!("advancing to {:?} for server", self.time - self.epoch); } true } Some(_) => unreachable!(), None => false, } } /// Advance time until both connections are idle pub(super) fn drive(&mut self) { while self.step() {} } /// Advance time until both connections are idle, or after 100 steps have been executed /// /// Returns true if the amount of steps exceeds the bounds, because the connections never became /// idle pub(super) fn drive_bounded(&mut self) -> bool { for _ in 0..100 { if !self.step() { return false; } } true } pub(super) fn drive_client(&mut self) { let span = info_span!("client"); let _guard = span.enter(); self.client.drive(self.time, self.server.addr); for (packet, buffer) in self.client.outbound.drain(..) { let packet_size = packet_size(&packet, &buffer); if packet_size > self.mtu { info!(packet_size, "dropping packet (max size exceeded)"); continue; } if buffer[0] & packet::LONG_HEADER_FORM == 0 { let spin = buffer[0] & packet::SPIN_BIT != 0; self.spins += (spin == self.last_spin) as u64; self.last_spin = spin; } if let Some(ref socket) = self.client.socket { socket.send_to(&buffer, packet.destination).unwrap(); } if self.server.addr == packet.destination { let ecn = set_congestion_experienced(packet.ecn, self.congestion_experienced); self.server.inbound.push_back(( self.time + self.latency, ecn, buffer.as_ref().into(), )); } } } pub(super) fn drive_server(&mut self) { let span = info_span!("server"); let _guard = span.enter(); self.server.drive(self.time, self.client.addr); for (packet, buffer) in self.server.outbound.drain(..) { let packet_size = packet_size(&packet, &buffer); if packet_size > self.mtu { info!(packet_size, "dropping packet (max size exceeded)"); continue; } if let Some(ref socket) = self.server.socket { socket.send_to(&buffer, packet.destination).unwrap(); } if self.client.addr == packet.destination { let ecn = set_congestion_experienced(packet.ecn, self.congestion_experienced); self.client.inbound.push_back(( self.time + self.latency, ecn, buffer.as_ref().into(), )); } } } pub(super) fn connect(&mut self) -> (ConnectionHandle, ConnectionHandle) { self.connect_with(client_config()) } pub(super) fn connect_with( &mut self, config: ClientConfig, ) -> (ConnectionHandle, ConnectionHandle) { info!("connecting"); let client_ch = self.begin_connect(config); self.drive(); let server_ch = self.server.assert_accept(); self.finish_connect(client_ch, server_ch); (client_ch, server_ch) } /// Just start connecting the client pub(super) fn begin_connect(&mut self, config: ClientConfig) -> ConnectionHandle { let span = info_span!("client"); let _guard = span.enter(); let (client_ch, client_conn) = self .client .connect(self.time, config, self.server.addr, "localhost") .unwrap(); self.client.connections.insert(client_ch, client_conn); client_ch } fn finish_connect(&mut self, client_ch: ConnectionHandle, server_ch: ConnectionHandle) { assert_matches!( self.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( self.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); assert_matches!( self.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( self.server_conn_mut(server_ch).poll(), Some(Event::Connected { .. }) ); } pub(super) fn client_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection { self.client.connections.get_mut(&ch).unwrap() } pub(super) fn client_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> { self.client_conn_mut(ch).streams() } pub(super) fn client_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> { self.client_conn_mut(ch).send_stream(s) } pub(super) fn client_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> { self.client_conn_mut(ch).recv_stream(s) } pub(super) fn client_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> { self.client_conn_mut(ch).datagrams() } pub(super) fn server_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection { self.server.connections.get_mut(&ch).unwrap() } pub(super) fn server_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> { self.server_conn_mut(ch).streams() } pub(super) fn server_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> { self.server_conn_mut(ch).send_stream(s) } pub(super) fn server_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> { self.server_conn_mut(ch).recv_stream(s) } pub(super) fn server_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> { self.server_conn_mut(ch).datagrams() } } impl Default for Pair { fn default() -> Self { Self::new(Default::default(), server_config()) } } pub(super) struct TestEndpoint { pub(super) endpoint: Endpoint, pub(super) addr: SocketAddr, socket: Option, timeout: Option, pub(super) outbound: VecDeque<(Transmit, Bytes)>, delayed: VecDeque<(Transmit, Bytes)>, pub(super) inbound: VecDeque<(Instant, Option, BytesMut)>, accepted: Option>, pub(super) connections: HashMap, conn_events: HashMap>, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, pub(super) incoming_connection_behavior: IncomingConnectionBehavior, pub(super) waiting_incoming: Vec, } #[derive(Debug, Copy, Clone)] pub(super) enum IncomingConnectionBehavior { AcceptAll, RejectAll, Validate, ValidateThenReject, Wait, } impl TestEndpoint { fn new(endpoint: Endpoint, addr: SocketAddr) -> Self { let socket = if env::var_os("SSLKEYLOGFILE").is_some() { let socket = UdpSocket::bind(addr).expect("failed to bind UDP socket"); socket .set_read_timeout(Some(Duration::new(0, 10_000_000))) .unwrap(); Some(socket) } else { None }; Self { endpoint, addr, socket, timeout: None, outbound: VecDeque::new(), delayed: VecDeque::new(), inbound: VecDeque::new(), accepted: None, connections: HashMap::default(), conn_events: HashMap::default(), captured_packets: Vec::new(), capture_inbound_packets: false, incoming_connection_behavior: IncomingConnectionBehavior::AcceptAll, waiting_incoming: Vec::new(), } } pub(super) fn drive(&mut self, now: Instant, remote: SocketAddr) { self.drive_incoming(now, remote); self.drive_outgoing(now); } pub(super) fn drive_incoming(&mut self, now: Instant, remote: SocketAddr) { if let Some(ref socket) = self.socket { loop { let mut buf = [0; 8192]; if socket.recv_from(&mut buf).is_err() { break; } } } let buffer_size = self.endpoint.config().get_max_udp_payload_size() as usize; let mut buf = Vec::with_capacity(buffer_size); while self.inbound.front().map_or(false, |x| x.0 <= now) { let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap(); if let Some(event) = self .endpoint .handle(recv_time, remote, None, ecn, packet, &mut buf) { match event { DatagramEvent::NewConnection(incoming) => { match self.incoming_connection_behavior { IncomingConnectionBehavior::AcceptAll => { let _ = self.try_accept(incoming, now); } IncomingConnectionBehavior::RejectAll => { self.reject(incoming); } IncomingConnectionBehavior::Validate => { if incoming.remote_address_validated() { let _ = self.try_accept(incoming, now); } else { self.retry(incoming); } } IncomingConnectionBehavior::ValidateThenReject => { if incoming.remote_address_validated() { self.reject(incoming); } else { self.retry(incoming); } } IncomingConnectionBehavior::Wait => { self.waiting_incoming.push(incoming); } } } DatagramEvent::ConnectionEvent(ch, event) => { if self.capture_inbound_packets { let packet = self.connections[&ch].decode_packet(&event); self.captured_packets.extend(packet); } self.conn_events.entry(ch).or_default().push_back(event); } DatagramEvent::Response(transmit) => { let size = transmit.size; self.outbound.extend(split_transmit(transmit, &buf[..size])); buf.clear(); } } } } } pub(super) fn drive_outgoing(&mut self, now: Instant) { let buffer_size = self.endpoint.config().get_max_udp_payload_size() as usize; let mut buf = Vec::with_capacity(buffer_size); loop { let mut endpoint_events: Vec<(ConnectionHandle, EndpointEvent)> = vec![]; for (ch, conn) in self.connections.iter_mut() { if self.timeout.map_or(false, |x| x <= now) { self.timeout = None; conn.handle_timeout(now); } for (_, mut events) in self.conn_events.drain() { for event in events.drain(..) { conn.handle_event(event); } } while let Some(event) = conn.poll_endpoint_events() { endpoint_events.push((*ch, event)); } while let Some(transmit) = conn.poll_transmit(now, MAX_DATAGRAMS, &mut buf) { let size = transmit.size; self.outbound.extend(split_transmit(transmit, &buf[..size])); buf.clear(); } self.timeout = conn.poll_timeout(); } if endpoint_events.is_empty() { break; } for (ch, event) in endpoint_events { if let Some(event) = self.handle_event(ch, event) { if let Some(conn) = self.connections.get_mut(&ch) { conn.handle_event(event); } } } } } pub(super) fn next_wakeup(&self) -> Option { let next_inbound = self.inbound.front().map(|x| x.0); min_opt(self.timeout, next_inbound) } fn is_idle(&self) -> bool { self.connections.values().all(|x| x.is_idle()) } pub(super) fn delay_outbound(&mut self) { assert!(self.delayed.is_empty()); mem::swap(&mut self.delayed, &mut self.outbound); } pub(super) fn finish_delay(&mut self) { self.outbound.extend(self.delayed.drain(..)); } pub(super) fn try_accept( &mut self, incoming: Incoming, now: Instant, ) -> Result { let mut buf = Vec::new(); match self.endpoint.accept(incoming, now, &mut buf, None) { Ok((ch, conn)) => { self.connections.insert(ch, conn); self.accepted = Some(Ok(ch)); Ok(ch) } Err(error) => { if let Some(transmit) = error.response { let size = transmit.size; self.outbound.extend(split_transmit(transmit, &buf[..size])); } self.accepted = Some(Err(error.cause.clone())); Err(error.cause) } } } pub(super) fn retry(&mut self, incoming: Incoming) { let mut buf = Vec::new(); let transmit = self.endpoint.retry(incoming, &mut buf).unwrap(); let size = transmit.size; self.outbound.extend(split_transmit(transmit, &buf[..size])); } pub(super) fn reject(&mut self, incoming: Incoming) { let mut buf = Vec::new(); let transmit = self.endpoint.refuse(incoming, &mut buf); let size = transmit.size; self.outbound.extend(split_transmit(transmit, &buf[..size])); } pub(super) fn assert_accept(&mut self) -> ConnectionHandle { self.accepted .take() .expect("server didn't try connecting") .expect("server experienced error connecting") } pub(super) fn assert_accept_error(&mut self) -> ConnectionError { self.accepted .take() .expect("server didn't try connecting") .expect_err("server did unexpectedly connect without error") } pub(super) fn assert_no_accept(&self) { assert!(self.accepted.is_none(), "server did unexpectedly connect") } } impl ::std::ops::Deref for TestEndpoint { type Target = Endpoint; fn deref(&self) -> &Endpoint { &self.endpoint } } impl ::std::ops::DerefMut for TestEndpoint { fn deref_mut(&mut self) -> &mut Endpoint { &mut self.endpoint } } pub(super) fn subscribe() -> tracing::subscriber::DefaultGuard { let builder = tracing_subscriber::FmtSubscriber::builder() .with_max_level(tracing::Level::TRACE) .with_writer(|| TestWriter); // tracing uses std::time to trace time, which panics in wasm. #[cfg(all(target_family = "wasm", target_os = "unknown"))] let builder = builder.without_time(); tracing::subscriber::set_default(builder.finish()) } struct TestWriter; impl Write for TestWriter { fn write(&mut self, buf: &[u8]) -> io::Result { print!( "{}", str::from_utf8(buf).expect("tried to log invalid UTF-8") ); Ok(buf.len()) } fn flush(&mut self) -> io::Result<()> { io::stdout().flush() } } pub(super) fn server_config() -> ServerConfig { ServerConfig::with_crypto(Arc::new(server_crypto())) } pub(super) fn server_config_with_cert( cert: CertificateDer<'static>, key: PrivateKeyDer<'static>, ) -> ServerConfig { ServerConfig::with_crypto(Arc::new(server_crypto_with_cert(cert, key))) } pub(super) fn server_crypto() -> QuicServerConfig { server_crypto_inner(None, None) } pub(super) fn server_crypto_with_alpn(alpn: Vec>) -> QuicServerConfig { server_crypto_inner(None, Some(alpn)) } pub(super) fn server_crypto_with_cert( cert: CertificateDer<'static>, key: PrivateKeyDer<'static>, ) -> QuicServerConfig { server_crypto_inner(Some((cert, key)), None) } fn server_crypto_inner( identity: Option<(CertificateDer<'static>, PrivateKeyDer<'static>)>, alpn: Option>>, ) -> QuicServerConfig { let (cert, key) = identity.unwrap_or_else(|| { ( CERTIFIED_KEY.cert.der().clone(), PrivateKeyDer::Pkcs8(CERTIFIED_KEY.key_pair.serialize_der().into()), ) }); let mut config = QuicServerConfig::inner(vec![cert], key).unwrap(); if let Some(alpn) = alpn { config.alpn_protocols = alpn; } config.try_into().unwrap() } pub(super) fn client_config() -> ClientConfig { ClientConfig::new(Arc::new(client_crypto())) } pub(super) fn client_config_with_deterministic_pns() -> ClientConfig { let mut cfg = ClientConfig::new(Arc::new(client_crypto())); let mut transport = TransportConfig::default(); transport.deterministic_packet_numbers(true); cfg.transport = Arc::new(transport); cfg } pub(super) fn client_config_with_certs(certs: Vec>) -> ClientConfig { ClientConfig::new(Arc::new(client_crypto_inner(Some(certs), None))) } pub(super) fn client_crypto() -> QuicClientConfig { client_crypto_inner(None, None) } pub(super) fn client_crypto_with_alpn(protocols: Vec>) -> QuicClientConfig { client_crypto_inner(None, Some(protocols)) } fn client_crypto_inner( certs: Option>>, alpn: Option>>, ) -> QuicClientConfig { let mut roots = rustls::RootCertStore::empty(); for cert in certs.unwrap_or_else(|| vec![CERTIFIED_KEY.cert.der().clone()]) { roots.add(cert).unwrap(); } let mut inner = QuicClientConfig::inner( WebPkiServerVerifier::builder_with_provider(Arc::new(roots), configured_provider()) .build() .unwrap(), ); inner.key_log = Arc::new(KeyLogFile::new()); if let Some(alpn) = alpn { inner.alpn_protocols = alpn; } inner.try_into().unwrap() } pub(super) fn min_opt(x: Option, y: Option) -> Option { match (x, y) { (Some(x), Some(y)) => Some(cmp::min(x, y)), (Some(x), _) => Some(x), (_, Some(y)) => Some(y), _ => None, } } /// The maximum of datagrams TestEndpoint will produce via `poll_transmit` const MAX_DATAGRAMS: usize = 10; fn split_transmit(transmit: Transmit, buffer: &[u8]) -> Vec<(Transmit, Bytes)> { let mut buffer = Bytes::copy_from_slice(buffer); let segment_size = match transmit.segment_size { Some(segment_size) => segment_size, _ => return vec![(transmit, buffer)], }; let mut transmits = Vec::new(); while !buffer.is_empty() { let end = segment_size.min(buffer.len()); let contents = buffer.split_to(end); transmits.push(( Transmit { destination: transmit.destination, size: contents.len(), ecn: transmit.ecn, segment_size: None, src_ip: transmit.src_ip, }, contents, )); } transmits } fn packet_size(transmit: &Transmit, buffer: &Bytes) -> usize { if transmit.segment_size.is_some() { panic!("This transmit is meant to be split into multiple packets!"); } buffer.len() } fn set_congestion_experienced( x: Option, congestion_experienced: bool, ) -> Option { x.map(|codepoint| match congestion_experienced { true => EcnCodepoint::Ce, false => codepoint, }) } lazy_static! { pub static ref SERVER_PORTS: Mutex> = Mutex::new(4433..); pub static ref CLIENT_PORTS: Mutex> = Mutex::new(44433..); pub(crate) static ref CERTIFIED_KEY: rcgen::CertifiedKey = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); } quinn-proto-0.11.9/src/token.rs000064400000000000000000000153711046102023000144730ustar 00000000000000use std::{ fmt, io, net::{IpAddr, SocketAddr}, }; use bytes::{Buf, BufMut}; use crate::{ coding::{BufExt, BufMutExt}, crypto::{CryptoError, HandshakeTokenKey, HmacKey}, shared::ConnectionId, Duration, SystemTime, RESET_TOKEN_SIZE, UNIX_EPOCH, }; pub(crate) struct RetryToken { /// The destination connection ID set in the very first packet from the client pub(crate) orig_dst_cid: ConnectionId, /// The time at which this token was issued pub(crate) issued: SystemTime, } impl RetryToken { pub(crate) fn encode( &self, key: &dyn HandshakeTokenKey, address: &SocketAddr, retry_src_cid: &ConnectionId, ) -> Vec { let aead_key = key.aead_from_hkdf(retry_src_cid); let mut buf = Vec::new(); encode_addr(&mut buf, address); self.orig_dst_cid.encode_long(&mut buf); buf.write::( self.issued .duration_since(UNIX_EPOCH) .map(|x| x.as_secs()) .unwrap_or(0), ); aead_key.seal(&mut buf, &[]).unwrap(); buf } pub(crate) fn from_bytes( key: &dyn HandshakeTokenKey, address: &SocketAddr, retry_src_cid: &ConnectionId, raw_token_bytes: &[u8], ) -> Result { let aead_key = key.aead_from_hkdf(retry_src_cid); let mut sealed_token = raw_token_bytes.to_vec(); let data = aead_key.open(&mut sealed_token, &[])?; let mut reader = io::Cursor::new(data); let token_addr = decode_addr(&mut reader).ok_or(TokenDecodeError::UnknownToken)?; if token_addr != *address { return Err(TokenDecodeError::WrongAddress); } let orig_dst_cid = ConnectionId::decode_long(&mut reader).ok_or(TokenDecodeError::UnknownToken)?; let issued = UNIX_EPOCH + Duration::new( reader .get::() .map_err(|_| TokenDecodeError::UnknownToken)?, 0, ); Ok(Self { orig_dst_cid, issued, }) } } fn encode_addr(buf: &mut Vec, address: &SocketAddr) { match address.ip() { IpAddr::V4(x) => { buf.put_u8(0); buf.put_slice(&x.octets()); } IpAddr::V6(x) => { buf.put_u8(1); buf.put_slice(&x.octets()); } } buf.put_u16(address.port()); } fn decode_addr(buf: &mut B) -> Option { let ip = match buf.get_u8() { 0 => IpAddr::V4(buf.get().ok()?), 1 => IpAddr::V6(buf.get().ok()?), _ => return None, }; let port = buf.get_u16(); Some(SocketAddr::new(ip, port)) } /// Reasons why a retry token might fail to validate a client's address #[derive(Debug, Copy, Clone)] pub(crate) enum TokenDecodeError { /// Token was not recognized. It should be silently ignored. UnknownToken, /// Token was well-formed but associated with an incorrect address. The connection cannot be /// established. WrongAddress, } impl From for TokenDecodeError { fn from(CryptoError: CryptoError) -> Self { Self::UnknownToken } } /// Stateless reset token /// /// Used for an endpoint to securely communicate that it has lost state for a connection. #[allow(clippy::derived_hash_with_manual_eq)] // Custom PartialEq impl matches derived semantics #[derive(Debug, Copy, Clone, Hash)] pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]); impl ResetToken { pub(crate) fn new(key: &dyn HmacKey, id: &ConnectionId) -> Self { let mut signature = vec![0; key.signature_len()]; key.sign(id, &mut signature); // TODO: Server ID?? let mut result = [0; RESET_TOKEN_SIZE]; result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]); result.into() } } impl PartialEq for ResetToken { fn eq(&self, other: &Self) -> bool { crate::constant_time::eq(&self.0, &other.0) } } impl Eq for ResetToken {} impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken { fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self { Self(x) } } impl std::ops::Deref for ResetToken { type Target = [u8]; fn deref(&self) -> &[u8] { &self.0 } } impl fmt::Display for ResetToken { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for byte in self.iter() { write!(f, "{byte:02x}")?; } Ok(()) } } #[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))] mod test { #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))] use aws_lc_rs::hkdf; #[cfg(feature = "ring")] use ring::hkdf; #[test] fn token_sanity() { use super::*; use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; use crate::MAX_CID_SIZE; use crate::{Duration, UNIX_EPOCH}; use rand::RngCore; use std::net::Ipv6Addr; let rng = &mut rand::thread_rng(); let mut master_key = [0; 64]; rng.fill_bytes(&mut master_key); let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); let token = RetryToken { orig_dst_cid: RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(), issued: UNIX_EPOCH + Duration::new(42, 0), // Fractional seconds would be lost }; let encoded = token.encode(&prk, &addr, &retry_src_cid); let decoded = RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &encoded) .expect("token didn't validate"); assert_eq!(token.orig_dst_cid, decoded.orig_dst_cid); assert_eq!(token.issued, decoded.issued); } #[test] fn invalid_token_returns_err() { use super::*; use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; use crate::MAX_CID_SIZE; use rand::RngCore; use std::net::Ipv6Addr; let rng = &mut rand::thread_rng(); let mut master_key = [0; 64]; rng.fill_bytes(&mut master_key); let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key); let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); let mut invalid_token = Vec::new(); let mut random_data = [0; 32]; rand::thread_rng().fill_bytes(&mut random_data); invalid_token.put_slice(&random_data); // Assert: garbage sealed data returns err assert!(RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &invalid_token).is_err()); } } quinn-proto-0.11.9/src/transport_error.rs000064400000000000000000000114031046102023000166100ustar 00000000000000use std::fmt; use bytes::{Buf, BufMut}; use crate::{ coding::{self, BufExt, BufMutExt}, frame, }; /// Transport-level errors occur when a peer violates the protocol specification #[derive(Debug, Clone, Eq, PartialEq)] pub struct Error { /// Type of error pub code: Code, /// Frame type that triggered the error pub frame: Option, /// Human-readable explanation of the reason pub reason: String, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.code.fmt(f)?; if let Some(frame) = self.frame { write!(f, " in {frame}")?; } if !self.reason.is_empty() { write!(f, ": {}", self.reason)?; } Ok(()) } } impl std::error::Error for Error {} impl From for Error { fn from(x: Code) -> Self { Self { code: x, frame: None, reason: "".to_string(), } } } /// Transport-level error code #[derive(Copy, Clone, Eq, PartialEq)] pub struct Code(u64); impl Code { /// Create QUIC error code from TLS alert code pub fn crypto(code: u8) -> Self { Self(0x100 | u64::from(code)) } } impl coding::Codec for Code { fn decode(buf: &mut B) -> coding::Result { Ok(Self(buf.get_var()?)) } fn encode(&self, buf: &mut B) { buf.write_var(self.0) } } impl From for u64 { fn from(x: Code) -> Self { x.0 } } macro_rules! errors { {$($name:ident($val:expr) $desc:expr;)*} => { #[allow(non_snake_case, unused)] impl Error { $( pub(crate) fn $name(reason: T) -> Self where T: Into { Self { code: Code::$name, frame: None, reason: reason.into(), } } )* } impl Code { $(#[doc = $desc] pub const $name: Self = Code($val);)* } impl fmt::Debug for Code { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* x if (0x100..0x200).contains(&x) => write!(f, "Code::crypto({:02x})", self.0 as u8), _ => write!(f, "Code({:x})", self.0), } } } impl fmt::Display for Code { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str($desc),)* // We're trying to be abstract over the crypto protocol, so human-readable descriptions here is tricky. _ if self.0 >= 0x100 && self.0 < 0x200 => write!(f, "the cryptographic handshake failed: error {}", self.0 & 0xFF), _ => f.write_str("unknown error"), } } } } } errors! { NO_ERROR(0x0) "the connection is being closed abruptly in the absence of any error"; INTERNAL_ERROR(0x1) "the endpoint encountered an internal error and cannot continue with the connection"; CONNECTION_REFUSED(0x2) "the server refused to accept a new connection"; FLOW_CONTROL_ERROR(0x3) "received more data than permitted in advertised data limits"; STREAM_LIMIT_ERROR(0x4) "received a frame for a stream identifier that exceeded advertised the stream limit for the corresponding stream type"; STREAM_STATE_ERROR(0x5) "received a frame for a stream that was not in a state that permitted that frame"; FINAL_SIZE_ERROR(0x6) "received a STREAM frame or a RESET_STREAM frame containing a different final size to the one already established"; FRAME_ENCODING_ERROR(0x7) "received a frame that was badly formatted"; TRANSPORT_PARAMETER_ERROR(0x8) "received transport parameters that were badly formatted, included an invalid value, was absent even though it is mandatory, was present though it is forbidden, or is otherwise in error"; CONNECTION_ID_LIMIT_ERROR(0x9) "the number of connection IDs provided by the peer exceeds the advertised active_connection_id_limit"; PROTOCOL_VIOLATION(0xA) "detected an error with protocol compliance that was not covered by more specific error codes"; INVALID_TOKEN(0xB) "received an invalid Retry Token in a client Initial"; APPLICATION_ERROR(0xC) "the application or application protocol caused the connection to be closed during the handshake"; CRYPTO_BUFFER_EXCEEDED(0xD) "received more data in CRYPTO frames than can be buffered"; KEY_UPDATE_ERROR(0xE) "key update error"; AEAD_LIMIT_REACHED(0xF) "the endpoint has reached the confidentiality or integrity limit for the AEAD algorithm"; NO_VIABLE_PATH(0x10) "no viable network path exists"; } quinn-proto-0.11.9/src/transport_parameters.rs000064400000000000000000000536151046102023000176350ustar 00000000000000//! QUIC connection transport parameters //! //! The `TransportParameters` type is used to represent the transport parameters //! negotiated by peers while establishing a QUIC connection. This process //! happens as part of the establishment of the TLS session. As such, the types //! contained in this modules should generally only be referred to by custom //! implementations of the `crypto::Session` trait. use std::{ convert::TryFrom, net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, }; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::{ cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, config::{EndpointConfig, ServerConfig, TransportConfig}, shared::ConnectionId, ResetToken, Side, TransportError, VarInt, LOC_CID_COUNT, MAX_CID_SIZE, MAX_STREAM_COUNT, RESET_TOKEN_SIZE, TIMER_GRANULARITY, }; // Apply a given macro to a list of all the transport parameters having integer types, along with // their codes and default values. Using this helps us avoid error-prone duplication of the // contained information across decoding, encoding, and the `Default` impl. Whenever we want to do // something with transport parameters, we'll handle the bulk of cases by writing a macro that // takes a list of arguments in this form, then passing it to this macro. macro_rules! apply_params { ($macro:ident) => { $macro! { // #[doc] name (id) = default, /// Milliseconds, disabled if zero max_idle_timeout(0x0001) = 0, /// Limits the size of UDP payloads that the endpoint is willing to receive max_udp_payload_size(0x0003) = 65527, /// Initial value for the maximum amount of data that can be sent on the connection initial_max_data(0x0004) = 0, /// Initial flow control limit for locally-initiated bidirectional streams initial_max_stream_data_bidi_local(0x0005) = 0, /// Initial flow control limit for peer-initiated bidirectional streams initial_max_stream_data_bidi_remote(0x0006) = 0, /// Initial flow control limit for unidirectional streams initial_max_stream_data_uni(0x0007) = 0, /// Initial maximum number of bidirectional streams the peer may initiate initial_max_streams_bidi(0x0008) = 0, /// Initial maximum number of unidirectional streams the peer may initiate initial_max_streams_uni(0x0009) = 0, /// Exponent used to decode the ACK Delay field in the ACK frame ack_delay_exponent(0x000a) = 3, /// Maximum amount of time in milliseconds by which the endpoint will delay sending /// acknowledgments max_ack_delay(0x000b) = 25, /// Maximum number of connection IDs from the peer that an endpoint is willing to store active_connection_id_limit(0x000e) = 2, } }; } macro_rules! make_struct { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { /// Transport parameters used to negotiate connection-level preferences between peers #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct TransportParameters { $($(#[$doc])* pub(crate) $name : VarInt,)* /// Does the endpoint support active connection migration pub(crate) disable_active_migration: bool, /// Maximum size for datagram frames pub(crate) max_datagram_frame_size: Option, /// The value that the endpoint included in the Source Connection ID field of the first /// Initial packet it sends for the connection pub(crate) initial_src_cid: Option, /// The endpoint is willing to receive QUIC packets containing any value for the fixed /// bit pub(crate) grease_quic_bit: bool, /// Minimum amount of time in microseconds by which the endpoint is able to delay /// sending acknowledgments /// /// If a value is provided, it implies that the endpoint supports QUIC Acknowledgement /// Frequency pub(crate) min_ack_delay: Option, // Server-only /// The value of the Destination Connection ID field from the first Initial packet sent /// by the client pub(crate) original_dst_cid: Option, /// The value that the server included in the Source Connection ID field of a Retry /// packet pub(crate) retry_src_cid: Option, /// Token used by the client to verify a stateless reset from the server pub(crate) stateless_reset_token: Option, /// The server's preferred address for communication after handshake completion pub(crate) preferred_address: Option, } // We deliberately don't implement the `Default` trait, since that would be public, and // downstream crates should never construct `TransportParameters` except by decoding those // supplied by a peer. impl TransportParameters { /// Standard defaults, used if the peer does not supply a given parameter. pub(crate) fn default() -> Self { Self { $($name: VarInt::from_u32($default),)* disable_active_migration: false, max_datagram_frame_size: None, initial_src_cid: None, grease_quic_bit: false, min_ack_delay: None, original_dst_cid: None, retry_src_cid: None, stateless_reset_token: None, preferred_address: None, } } } } } apply_params!(make_struct); impl TransportParameters { pub(crate) fn new( config: &TransportConfig, endpoint_config: &EndpointConfig, cid_gen: &dyn ConnectionIdGenerator, initial_src_cid: ConnectionId, server_config: Option<&ServerConfig>, ) -> Self { Self { initial_src_cid: Some(initial_src_cid), initial_max_streams_bidi: config.max_concurrent_bidi_streams, initial_max_streams_uni: config.max_concurrent_uni_streams, initial_max_data: config.receive_window, initial_max_stream_data_bidi_local: config.stream_receive_window, initial_max_stream_data_bidi_remote: config.stream_receive_window, initial_max_stream_data_uni: config.stream_receive_window, max_udp_payload_size: endpoint_config.max_udp_payload_size, max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)), disable_active_migration: server_config.map_or(false, |c| !c.migration), active_connection_id_limit: if cid_gen.cid_len() == 0 { 2 // i.e. default, i.e. unsent } else { CidQueue::LEN as u32 } .into(), max_datagram_frame_size: config .datagram_receive_buffer_size .map(|x| (x.min(u16::MAX.into()) as u16).into()), grease_quic_bit: endpoint_config.grease_quic_bit, min_ack_delay: Some( VarInt::from_u64(u64::try_from(TIMER_GRANULARITY.as_micros()).unwrap()).unwrap(), ), ..Self::default() } } /// Check that these parameters are legal when resuming from /// certain cached parameters pub(crate) fn validate_resumption_from(&self, cached: &Self) -> Result<(), TransportError> { if cached.active_connection_id_limit > self.active_connection_id_limit || cached.initial_max_data > self.initial_max_data || cached.initial_max_stream_data_bidi_local > self.initial_max_stream_data_bidi_local || cached.initial_max_stream_data_bidi_remote > self.initial_max_stream_data_bidi_remote || cached.initial_max_stream_data_uni > self.initial_max_stream_data_uni || cached.initial_max_streams_bidi > self.initial_max_streams_bidi || cached.initial_max_streams_uni > self.initial_max_streams_uni || cached.max_datagram_frame_size > self.max_datagram_frame_size || cached.grease_quic_bit && !self.grease_quic_bit { return Err(TransportError::PROTOCOL_VIOLATION( "0-RTT accepted with incompatible transport parameters", )); } Ok(()) } /// Maximum number of CIDs to issue to this peer /// /// Consider both a) the active_connection_id_limit from the other end; and /// b) LOC_CID_COUNT used locally pub(crate) fn issue_cids_limit(&self) -> u64 { self.active_connection_id_limit.0.min(LOC_CID_COUNT) } } /// A server's preferred address /// /// This is communicated as a transport parameter during TLS session establishment. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) struct PreferredAddress { pub(crate) address_v4: Option, pub(crate) address_v6: Option, pub(crate) connection_id: ConnectionId, pub(crate) stateless_reset_token: ResetToken, } impl PreferredAddress { fn wire_size(&self) -> u16 { 4 + 2 + 16 + 2 + 1 + self.connection_id.len() as u16 + 16 } fn write(&self, w: &mut W) { w.write(self.address_v4.map_or(Ipv4Addr::UNSPECIFIED, |x| *x.ip())); w.write::(self.address_v4.map_or(0, |x| x.port())); w.write(self.address_v6.map_or(Ipv6Addr::UNSPECIFIED, |x| *x.ip())); w.write::(self.address_v6.map_or(0, |x| x.port())); w.write::(self.connection_id.len() as u8); w.put_slice(&self.connection_id); w.put_slice(&self.stateless_reset_token); } fn read(r: &mut R) -> Result { let ip_v4 = r.get::()?; let port_v4 = r.get::()?; let ip_v6 = r.get::()?; let port_v6 = r.get::()?; let cid_len = r.get::()?; if r.remaining() < cid_len as usize || cid_len > MAX_CID_SIZE as u8 { return Err(Error::Malformed); } let mut stage = [0; MAX_CID_SIZE]; r.copy_to_slice(&mut stage[0..cid_len as usize]); let cid = ConnectionId::new(&stage[0..cid_len as usize]); if r.remaining() < 16 { return Err(Error::Malformed); } let mut token = [0; RESET_TOKEN_SIZE]; r.copy_to_slice(&mut token); let address_v4 = if ip_v4.is_unspecified() && port_v4 == 0 { None } else { Some(SocketAddrV4::new(ip_v4, port_v4)) }; let address_v6 = if ip_v6.is_unspecified() && port_v6 == 0 { None } else { Some(SocketAddrV6::new(ip_v6, port_v6, 0, 0)) }; if address_v4.is_none() && address_v6.is_none() { return Err(Error::IllegalValue); } Ok(Self { address_v4, address_v6, connection_id: cid, stateless_reset_token: token.into(), }) } } /// Errors encountered while decoding `TransportParameters` #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] pub enum Error { /// Parameters that are semantically invalid #[error("parameter had illegal value")] IllegalValue, /// Catch-all error for problems while decoding transport parameters #[error("parameters were malformed")] Malformed, } impl From for TransportError { fn from(e: Error) -> Self { match e { Error::IllegalValue => Self::TRANSPORT_PARAMETER_ERROR("illegal value"), Error::Malformed => Self::TRANSPORT_PARAMETER_ERROR("malformed"), } } } impl From for Error { fn from(_: UnexpectedEnd) -> Self { Self::Malformed } } impl TransportParameters { /// Encode `TransportParameters` into buffer pub fn write(&self, w: &mut W) { macro_rules! write_params { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { $( if self.$name.0 != $default { w.write_var($code); w.write(VarInt::try_from(self.$name.size()).unwrap()); w.write(self.$name); } )* } } apply_params!(write_params); // Add a reserved parameter to keep people on their toes w.write_var(31 * 5 + 27); w.write_var(0); if let Some(ref x) = self.stateless_reset_token { w.write_var(0x02); w.write_var(16); w.put_slice(x); } if self.disable_active_migration { w.write_var(0x0c); w.write_var(0); } if let Some(x) = self.max_datagram_frame_size { w.write_var(0x20); w.write_var(x.size() as u64); w.write(x); } if let Some(ref x) = self.preferred_address { w.write_var(0x000d); w.write_var(x.wire_size() as u64); x.write(w); } for &(tag, cid) in &[ (0x00, &self.original_dst_cid), (0x0f, &self.initial_src_cid), (0x10, &self.retry_src_cid), ] { if let Some(ref cid) = *cid { w.write_var(tag); w.write_var(cid.len() as u64); w.put_slice(cid); } } if self.grease_quic_bit { w.write_var(0x2ab2); w.write_var(0); } if let Some(x) = self.min_ack_delay { w.write_var(0xff04de1b); w.write_var(x.size() as u64); w.write(x); } } /// Decode `TransportParameters` from buffer pub fn read(side: Side, r: &mut R) -> Result { // Initialize to protocol-specified defaults let mut params = Self::default(); // State to check for duplicate transport parameters. macro_rules! param_state { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => {{ struct ParamState { $($name: bool,)* } ParamState { $($name: false,)* } }} } let mut got = apply_params!(param_state); while r.has_remaining() { let id = r.get_var()?; let len = r.get_var()?; if (r.remaining() as u64) < len { return Err(Error::Malformed); } let len = len as usize; match id { 0x00 => decode_cid(len, &mut params.original_dst_cid, r)?, 0x02 => { if len != 16 || params.stateless_reset_token.is_some() { return Err(Error::Malformed); } let mut tok = [0; RESET_TOKEN_SIZE]; r.copy_to_slice(&mut tok); params.stateless_reset_token = Some(tok.into()); } 0x0c => { if len != 0 || params.disable_active_migration { return Err(Error::Malformed); } params.disable_active_migration = true; } 0x0d => { if params.preferred_address.is_some() { return Err(Error::Malformed); } params.preferred_address = Some(PreferredAddress::read(&mut r.take(len))?); } 0x0f => decode_cid(len, &mut params.initial_src_cid, r)?, 0x10 => decode_cid(len, &mut params.retry_src_cid, r)?, 0x20 => { if len > 8 || params.max_datagram_frame_size.is_some() { return Err(Error::Malformed); } params.max_datagram_frame_size = Some(r.get().unwrap()); } 0x2ab2 => match len { 0 => params.grease_quic_bit = true, _ => return Err(Error::Malformed), }, 0xff04de1b => params.min_ack_delay = Some(r.get().unwrap()), _ => { macro_rules! parse { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { match id { $($code => { let value = r.get::()?; if len != value.size() || got.$name { return Err(Error::Malformed); } params.$name = value.into(); got.$name = true; })* _ => r.advance(len as usize), } } } apply_params!(parse); } } } // Semantic validation // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.26.1 if params.ack_delay_exponent.0 > 20 // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.28.1 || params.max_ack_delay.0 >= 1 << 14 // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-6.2.1 || params.active_connection_id_limit.0 < 2 // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.10.1 || params.max_udp_payload_size.0 < 1200 // https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 || params.initial_max_streams_bidi.0 > MAX_STREAM_COUNT || params.initial_max_streams_uni.0 > MAX_STREAM_COUNT // https://www.ietf.org/archive/id/draft-ietf-quic-ack-frequency-08.html#section-3-4 || params.min_ack_delay.map_or(false, |min_ack_delay| { // min_ack_delay uses microseconds, whereas max_ack_delay uses milliseconds min_ack_delay.0 > params.max_ack_delay.0 * 1_000 }) // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-8 || (side.is_server() && (params.original_dst_cid.is_some() || params.preferred_address.is_some() || params.retry_src_cid.is_some() || params.stateless_reset_token.is_some())) // https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2-4.38.1 || params .preferred_address .map_or(false, |x| x.connection_id.is_empty()) { return Err(Error::IllegalValue); } Ok(params) } } fn decode_cid(len: usize, value: &mut Option, r: &mut impl Buf) -> Result<(), Error> { if len > MAX_CID_SIZE || value.is_some() || r.remaining() < len { return Err(Error::Malformed); } *value = Some(ConnectionId::from_buf(r, len)); Ok(()) } #[cfg(test)] mod test { use super::*; #[test] fn coding() { let mut buf = Vec::new(); let params = TransportParameters { initial_src_cid: Some(ConnectionId::new(&[])), original_dst_cid: Some(ConnectionId::new(&[])), initial_max_streams_bidi: 16u32.into(), initial_max_streams_uni: 16u32.into(), ack_delay_exponent: 2u32.into(), max_udp_payload_size: 1200u32.into(), preferred_address: Some(PreferredAddress { address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)), address_v6: Some(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 24, 0, 0)), connection_id: ConnectionId::new(&[0x42]), stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(), }), grease_quic_bit: true, min_ack_delay: Some(2_000u32.into()), ..TransportParameters::default() }; params.write(&mut buf); assert_eq!( TransportParameters::read(Side::Client, &mut buf.as_slice()).unwrap(), params ); } #[test] fn read_semantic_validation() { #[allow(clippy::type_complexity)] let illegal_params_builders: Vec> = vec![ Box::new(|t| { // This min_ack_delay is bigger than max_ack_delay! let min_ack_delay = t.max_ack_delay.0 * 1_000 + 1; t.min_ack_delay = Some(VarInt::from_u64(min_ack_delay).unwrap()) }), Box::new(|t| { // Preferred address can only be sent by senders (and we are reading the transport // params as a client) t.preferred_address = Some(PreferredAddress { address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)), address_v6: None, connection_id: ConnectionId::new(&[]), stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(), }) }), ]; for mut builder in illegal_params_builders { let mut buf = Vec::new(); let mut params = TransportParameters::default(); builder(&mut params); params.write(&mut buf); assert_eq!( TransportParameters::read(Side::Server, &mut buf.as_slice()), Err(Error::IllegalValue) ); } } #[test] fn resumption_params_validation() { let high_limit = TransportParameters { initial_max_streams_uni: 32u32.into(), ..TransportParameters::default() }; let low_limit = TransportParameters { initial_max_streams_uni: 16u32.into(), ..TransportParameters::default() }; high_limit.validate_resumption_from(&low_limit).unwrap(); low_limit.validate_resumption_from(&high_limit).unwrap_err(); } } quinn-proto-0.11.9/src/varint.rs000064400000000000000000000120151046102023000146460ustar 00000000000000use std::{convert::TryInto, fmt}; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::coding::{self, Codec, UnexpectedEnd}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; /// An integer less than 2^62 /// /// Values of this type are suitable for encoding as QUIC variable-length integer. // It would be neat if we could express to Rust that the top two bits are available for use as enum // discriminants #[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(pub(crate) u64); impl VarInt { /// The largest representable value pub const MAX: Self = Self((1 << 62) - 1); /// The largest encoded value length pub const MAX_SIZE: usize = 8; /// Construct a `VarInt` infallibly pub const fn from_u32(x: u32) -> Self { Self(x as u64) } /// Succeeds iff `x` < 2^62 pub fn from_u64(x: u64) -> Result { if x < 2u64.pow(62) { Ok(Self(x)) } else { Err(VarIntBoundsExceeded) } } /// Create a VarInt without ensuring it's in range /// /// # Safety /// /// `x` must be less than 2^62. pub const unsafe fn from_u64_unchecked(x: u64) -> Self { Self(x) } /// Extract the integer value pub const fn into_inner(self) -> u64 { self.0 } /// Compute the number of bytes needed to encode this value pub(crate) const fn size(self) -> usize { let x = self.0; if x < 2u64.pow(6) { 1 } else if x < 2u64.pow(14) { 2 } else if x < 2u64.pow(30) { 4 } else if x < 2u64.pow(62) { 8 } else { panic!("malformed VarInt"); } } } impl From for u64 { fn from(x: VarInt) -> Self { x.0 } } impl From for VarInt { fn from(x: u8) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u16) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u32) -> Self { Self(x.into()) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: u64) -> Result { Self::from_u64(x) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: u128) -> Result { Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: usize) -> Result { Self::try_from(x as u64) } } impl fmt::Debug for VarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl fmt::Display for VarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } #[cfg(feature = "arbitrary")] impl<'arbitrary> Arbitrary<'arbitrary> for VarInt { fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result { Ok(Self(u.int_in_range(0..=Self::MAX.0)?)) } } /// Error returned when constructing a `VarInt` from a value >= 2^62 #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] #[error("value too large for varint encoding")] pub struct VarIntBoundsExceeded; impl Codec for VarInt { fn decode(r: &mut B) -> coding::Result { if !r.has_remaining() { return Err(UnexpectedEnd); } let mut buf = [0; 8]; buf[0] = r.get_u8(); let tag = buf[0] >> 6; buf[0] &= 0b0011_1111; let x = match tag { 0b00 => u64::from(buf[0]), 0b01 => { if r.remaining() < 1 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..2]); u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) } 0b10 => { if r.remaining() < 3 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..4]); u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) } 0b11 => { if r.remaining() < 7 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..8]); u64::from_be_bytes(buf) } _ => unreachable!(), }; Ok(Self(x)) } fn encode(&self, w: &mut B) { let x = self.0; if x < 2u64.pow(6) { w.put_u8(x as u8); } else if x < 2u64.pow(14) { w.put_u16(0b01 << 14 | x as u16); } else if x < 2u64.pow(30) { w.put_u32(0b10 << 30 | x as u32); } else if x < 2u64.pow(62) { w.put_u64(0b11 << 62 | x); } else { unreachable!("malformed VarInt") } } }