quinn-proto-0.10.6/.cargo_vcs_info.json0000644000000001510000000000100134100ustar { "git": { "sha1": "db2df614fcab3e9c17b3e5f325eb197920489779" }, "path_in_vcs": "quinn-proto" }quinn-proto-0.10.6/Cargo.toml0000644000000037250000000000100114200ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.63" name = "quinn-proto" version = "0.10.6" description = "State machine for the QUIC transport protocol" keywords = ["quic"] categories = [ "network-programming", "asynchronous", ] license = "MIT OR Apache-2.0" repository = "https://github.com/quinn-rs/quinn" [package.metadata.docs.rs] all-features = true [dependencies.arbitrary] version = "1.0.1" features = ["derive"] optional = true [dependencies.bytes] version = "1" [dependencies.rand] version = "0.8" [dependencies.ring] version = "0.16.7" optional = true [dependencies.rustc-hash] version = "1.1" [dependencies.rustls] version = "0.21.0" features = ["quic"] optional = true default-features = false [dependencies.rustls-native-certs] version = "0.6" optional = true [dependencies.slab] version = "0.4" [dependencies.thiserror] version = "1.0.21" [dependencies.tinyvec] version = "1.1" features = ["alloc"] [dependencies.tracing] version = "0.1.10" [dev-dependencies.assert_matches] version = "1.1" [dev-dependencies.hex-literal] version = "0.4.0" [dev-dependencies.lazy_static] version = "1" [dev-dependencies.rcgen] version = "0.10.0" [dev-dependencies.tracing-subscriber] version = "0.3.0" features = [ "env-filter", "fmt", "ansi", "time", "local-time", ] default-features = false [features] default = [ "tls-rustls", "log", ] log = ["tracing/log"] native-certs = ["rustls-native-certs"] tls-rustls = [ "rustls", "ring", ] [badges.maintenance] status = "experimental" quinn-proto-0.10.6/Cargo.toml.orig000064400000000000000000000025301046102023000150720ustar 00000000000000[package] name = "quinn-proto" version = "0.10.6" edition = "2021" rust-version = "1.63" license = "MIT OR Apache-2.0" repository = "https://github.com/quinn-rs/quinn" description = "State machine for the QUIC transport protocol" keywords = ["quic"] categories = [ "network-programming", "asynchronous" ] workspace = ".." [package.metadata.docs.rs] all-features = true [badges] maintenance = { status = "experimental" } [features] default = ["tls-rustls", "log"] tls-rustls = ["rustls", "ring"] # Provides `ClientConfig::with_native_roots()` convenience method native-certs = ["rustls-native-certs"] # Write logs via the `log` crate when no `tracing` subscriber exists log = ["tracing/log"] [dependencies] arbitrary = { version = "1.0.1", features = ["derive"], optional = true } bytes = "1" rustc-hash = "1.1" rand = "0.8" ring = { version = "0.16.7", optional = true } rustls = { version = "0.21.0", default-features = false, features = ["quic"], optional = true } rustls-native-certs = { version = "0.6", optional = true } slab = "0.4" thiserror = "1.0.21" tinyvec = { version = "1.1", features = ["alloc"] } tracing = "0.1.10" [dev-dependencies] assert_matches = "1.1" hex-literal = "0.4.0" rcgen = "0.10.0" tracing-subscriber = { version = "0.3.0", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] } lazy_static = "1" quinn-proto-0.10.6/src/cid_generator.rs000064400000000000000000000042551046102023000161530ustar 00000000000000use std::time::Duration; use rand::RngCore; use crate::shared::ConnectionId; use crate::MAX_CID_SIZE; /// Generates connection IDs for incoming connections pub trait ConnectionIdGenerator: Send { /// Generates a new CID /// /// Connection IDs MUST NOT contain any information that can be used by /// an external observer (that is, one that does not cooperate with the /// issuer) to correlate them with other connection IDs for the same /// connection. fn generate_cid(&mut self) -> ConnectionId; /// Returns the length of a CID for connections created by this generator fn cid_len(&self) -> usize; /// Returns the lifetime of generated Connection IDs /// /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant. fn cid_lifetime(&self) -> Option; } /// Generates purely random connection IDs of a certain length #[derive(Debug, Clone, Copy)] pub struct RandomConnectionIdGenerator { cid_len: usize, lifetime: Option, } impl Default for RandomConnectionIdGenerator { fn default() -> Self { Self { cid_len: 8, lifetime: None, } } } impl RandomConnectionIdGenerator { /// Initialize Random CID generator with a fixed CID length /// /// The given length must be less than or equal to MAX_CID_SIZE. pub fn new(cid_len: usize) -> Self { debug_assert!(cid_len <= MAX_CID_SIZE); Self { cid_len, ..Self::default() } } /// Set the lifetime of CIDs created by this generator pub fn set_lifetime(&mut self, d: Duration) -> &mut Self { self.lifetime = Some(d); self } } impl ConnectionIdGenerator for RandomConnectionIdGenerator { fn generate_cid(&mut self) -> ConnectionId { let mut bytes_arr = [0; MAX_CID_SIZE]; rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]); ConnectionId::new(&bytes_arr[..self.cid_len]) } /// Provide the length of dst_cid in short header packet fn cid_len(&self) -> usize { self.cid_len } fn cid_lifetime(&self) -> Option { self.lifetime } } quinn-proto-0.10.6/src/cid_queue.rs000064400000000000000000000232141046102023000153050ustar 00000000000000use std::ops::Range; use crate::{frame::NewConnectionId, ConnectionId, ResetToken}; /// DataType stored in CidQueue buffer type CidData = (ConnectionId, Option); /// Sliding window of active Connection IDs /// /// May contain gaps due to packet loss or reordering #[derive(Debug)] pub(crate) struct CidQueue { /// Ring buffer indexed by `self.cursor` buffer: [Option; Self::LEN], /// Index at which circular buffer addressing is based cursor: usize, /// Sequence number of `self.buffer[cursor]` /// /// The sequence number of the active CID; must be the smallest among CIDs in `buffer`. offset: u64, } impl CidQueue { pub(crate) fn new(cid: ConnectionId) -> Self { let mut buffer = [None; Self::LEN]; buffer[0] = Some((cid, None)); Self { buffer, cursor: 0, offset: 0, } } /// Handle a `NEW_CONNECTION_ID` frame /// /// Returns a non-empty range of retired sequence numbers and the reset token of the new active /// CID iff any CIDs were retired. pub(crate) fn insert( &mut self, cid: NewConnectionId, ) -> Result, ResetToken)>, InsertError> { // Position of new CID wrt. the current active CID let index = match cid.sequence.checked_sub(self.offset) { None => return Err(InsertError::Retired), Some(x) => x, }; let retired_count = cid.retire_prior_to.saturating_sub(self.offset); if index >= Self::LEN as u64 + retired_count { return Err(InsertError::ExceedsLimit); } // Discard retired CIDs, if any for i in 0..(retired_count.min(Self::LEN as u64) as usize) { self.buffer[(self.cursor + i) % Self::LEN] = None; } // Record the new CID let index = ((self.cursor as u64 + index) % Self::LEN as u64) as usize; self.buffer[index] = Some((cid.id, Some(cid.reset_token))); if retired_count == 0 { return Ok(None); } // The active CID was retired. Find the first known CID with sequence number of at least // retire_prior_to, and inform the caller that all prior CIDs have been retired, and of // the new CID's reset token. self.cursor = ((self.cursor as u64 + retired_count) % Self::LEN as u64) as usize; let (i, (_, token)) = self .iter() .next() .expect("it is impossible to retire a CID without supplying a new one"); self.cursor = (self.cursor + i) % Self::LEN; let orig_offset = self.offset; self.offset = cid.retire_prior_to + i as u64; // We don't immediately retire CIDs in the range (orig_offset + // Self::LEN)..self.offset. These are CIDs that we haven't yet received from a // NEW_CONNECTION_ID frame, since having previously received them would violate the // connection ID limit we specified based on Self::LEN. If we do receive a such a frame // in the future, e.g. due to reordering, we'll retire it then. This ensures we can't be // made to buffer an arbitrarily large number of RETIRE_CONNECTION_ID frames. Ok(Some(( orig_offset..self.offset.min(orig_offset + Self::LEN as u64), token.expect("non-initial CID missing reset token"), ))) } /// Switch to next active CID if possible, return /// 1) the corresponding ResetToken and 2) a non-empty range preceding it to retire pub(crate) fn next(&mut self) -> Option<(ResetToken, Range)> { let (i, cid_data) = self.iter().nth(1)?; self.buffer[self.cursor] = None; let orig_offset = self.offset; self.offset += i as u64; self.cursor = (self.cursor + i) % Self::LEN; Some((cid_data.1.unwrap(), orig_offset..self.offset)) } /// Iterate CIDs in CidQueue that are not `None`, including the active CID fn iter(&self) -> impl Iterator + '_ { (0..Self::LEN).filter_map(move |step| { let index = (self.cursor + step) % Self::LEN; self.buffer[index].map(|cid_data| (step, cid_data)) }) } /// Replace the initial CID pub(crate) fn update_initial_cid(&mut self, cid: ConnectionId) { debug_assert_eq!(self.offset, 0); self.buffer[self.cursor] = Some((cid, None)); } /// Return active remote CID itself pub(crate) fn active(&self) -> ConnectionId { self.buffer[self.cursor].unwrap().0 } /// Return the sequence number of active remote CID pub(crate) fn active_seq(&self) -> u64 { self.offset } pub(crate) const LEN: usize = 5; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum InsertError { /// CID was already retired Retired, /// Sequence number violates the leading edge of the window ExceedsLimit, } #[cfg(test)] mod tests { use super::*; fn cid(sequence: u64, retire_prior_to: u64) -> NewConnectionId { NewConnectionId { sequence, id: ConnectionId::new(&[0xAB; 8]), reset_token: ResetToken::from([0xCD; crate::RESET_TOKEN_SIZE]), retire_prior_to, } } fn initial_cid() -> ConnectionId { ConnectionId::new(&[0xFF; 8]) } #[test] fn next_dense() { let mut q = CidQueue::new(initial_cid()); assert!(q.next().is_none()); assert!(q.next().is_none()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } for i in 1..CidQueue::LEN as u64 { let (_, retire) = q.next().unwrap(); assert_eq!(q.active_seq(), i); assert_eq!(retire.end - retire.start, 1); } assert!(q.next().is_none()); } #[test] fn next_sparse() { let mut q = CidQueue::new(initial_cid()); let seqs = (1..CidQueue::LEN as u64).filter(|x| x % 2 == 0); for i in seqs.clone() { q.insert(cid(i, 0)).unwrap(); } for i in seqs { let (_, retire) = q.next().unwrap(); dbg!(&retire); assert_eq!(q.active_seq(), i); assert_eq!(retire, (q.active_seq().saturating_sub(2))..q.active_seq()); } assert!(q.next().is_none()); } #[test] fn wrap() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } for _ in 1..(CidQueue::LEN as u64 - 1) { q.next().unwrap(); } for i in CidQueue::LEN as u64..(CidQueue::LEN as u64 + 3) { q.insert(cid(i, 0)).unwrap(); } for i in (CidQueue::LEN as u64 - 1)..(CidQueue::LEN as u64 + 3) { q.next().unwrap(); assert_eq!(q.active_seq(), i); } assert!(q.next().is_none()); } #[test] fn retire_dense() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } assert_eq!(q.active_seq(), 0); assert_eq!(q.insert(cid(4, 2)).unwrap().unwrap().0, 0..2); assert_eq!(q.active_seq(), 2); assert_eq!(q.insert(cid(4, 2)), Ok(None)); for i in 2..(CidQueue::LEN as u64 - 1) { let _ = q.next().unwrap(); assert_eq!(q.active_seq(), i + 1); assert_eq!(q.insert(cid(i + 1, i + 1)), Ok(None)); } assert!(q.next().is_none()); } #[test] fn retire_sparse() { // Retiring CID 0 when CID 1 is not known should retire CID 1 as we move to CID 2 let mut q = CidQueue::new(initial_cid()); q.insert(cid(2, 0)).unwrap(); assert_eq!(q.insert(cid(3, 1)).unwrap().unwrap().0, 0..2,); assert_eq!(q.active_seq(), 2); } #[test] fn retire_many() { let mut q = CidQueue::new(initial_cid()); q.insert(cid(2, 0)).unwrap(); assert_eq!( q.insert(cid(1_000_000, 1_000_000)).unwrap().unwrap().0, 0..CidQueue::LEN as u64, ); assert_eq!(q.active_seq(), 1_000_000); } #[test] fn insert_limit() { let mut q = CidQueue::new(initial_cid()); assert_eq!(q.insert(cid(CidQueue::LEN as u64 - 1, 0)), Ok(None)); assert_eq!( q.insert(cid(CidQueue::LEN as u64, 0)), Err(InsertError::ExceedsLimit) ); } #[test] fn insert_duplicate() { let mut q = CidQueue::new(initial_cid()); q.insert(cid(0, 0)).unwrap(); q.insert(cid(0, 0)).unwrap(); } #[test] fn insert_retired() { let mut q = CidQueue::new(initial_cid()); assert_eq!( q.insert(cid(0, 0)), Ok(None), "reinserting active CID succeeds" ); assert!(q.next().is_none(), "active CID isn't requeued"); q.insert(cid(1, 0)).unwrap(); q.next().unwrap(); assert_eq!( q.insert(cid(0, 0)), Err(InsertError::Retired), "previous active CID is already retired" ); } #[test] fn retire_then_insert_next() { let mut q = CidQueue::new(initial_cid()); for i in 1..CidQueue::LEN as u64 { q.insert(cid(i, 0)).unwrap(); } q.next().unwrap(); q.insert(cid(CidQueue::LEN as u64, 0)).unwrap(); assert_eq!( q.insert(cid(CidQueue::LEN as u64 + 1, 0)), Err(InsertError::ExceedsLimit) ); } #[test] fn always_valid() { let mut q = CidQueue::new(initial_cid()); assert!(q.next().is_none()); assert_eq!(q.active(), initial_cid()); assert_eq!(q.active_seq(), 0); } } quinn-proto-0.10.6/src/coding.rs000064400000000000000000000054321046102023000146070ustar 00000000000000use std::net::{Ipv4Addr, Ipv6Addr}; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::VarInt; #[derive(Error, Debug, Copy, Clone, Eq, PartialEq)] #[error("unexpected end of buffer")] pub struct UnexpectedEnd; pub type Result = ::std::result::Result; pub trait Codec: Sized { fn decode(buf: &mut B) -> Result; fn encode(&self, buf: &mut B); } impl Codec for u8 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 1 { return Err(UnexpectedEnd); } Ok(buf.get_u8()) } fn encode(&self, buf: &mut B) { buf.put_u8(*self); } } impl Codec for u16 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 2 { return Err(UnexpectedEnd); } Ok(buf.get_u16()) } fn encode(&self, buf: &mut B) { buf.put_u16(*self); } } impl Codec for u32 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 4 { return Err(UnexpectedEnd); } Ok(buf.get_u32()) } fn encode(&self, buf: &mut B) { buf.put_u32(*self); } } impl Codec for u64 { fn decode(buf: &mut B) -> Result { if buf.remaining() < 8 { return Err(UnexpectedEnd); } Ok(buf.get_u64()) } fn encode(&self, buf: &mut B) { buf.put_u64(*self); } } impl Codec for Ipv4Addr { fn decode(buf: &mut B) -> Result { if buf.remaining() < 4 { return Err(UnexpectedEnd); } let mut octets = [0; 4]; buf.copy_to_slice(&mut octets); Ok(octets.into()) } fn encode(&self, buf: &mut B) { buf.put_slice(&self.octets()); } } impl Codec for Ipv6Addr { fn decode(buf: &mut B) -> Result { if buf.remaining() < 16 { return Err(UnexpectedEnd); } let mut octets = [0; 16]; buf.copy_to_slice(&mut octets); Ok(octets.into()) } fn encode(&self, buf: &mut B) { buf.put_slice(&self.octets()); } } pub trait BufExt { fn get(&mut self) -> Result; fn get_var(&mut self) -> Result; } impl BufExt for T { fn get(&mut self) -> Result { U::decode(self) } fn get_var(&mut self) -> Result { Ok(VarInt::decode(self)?.into_inner()) } } pub trait BufMutExt { fn write(&mut self, x: T); fn write_var(&mut self, x: u64); } impl BufMutExt for T { fn write(&mut self, x: U) { x.encode(self); } fn write_var(&mut self, x: u64) { VarInt::from_u64(x).unwrap().encode(self); } } quinn-proto-0.10.6/src/config.rs000064400000000000000000001076001046102023000146110ustar 00000000000000use std::{fmt, num::TryFromIntError, sync::Arc, time::Duration}; use thiserror::Error; #[cfg(feature = "ring")] use rand::RngCore; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, congestion, crypto::{self, HandshakeTokenKey, HmacKey}, VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, MAX_UDP_PAYLOAD, }; /// Parameters governing the core QUIC state machine /// /// Default values should be suitable for most internet applications. Applications protocols which /// forbid remotely-initiated streams should set `max_concurrent_bidi_streams` and /// `max_concurrent_uni_streams` to zero. /// /// In some cases, performance or resource requirements can be improved by tuning these values to /// suit a particular application and/or network connection. In particular, data window sizes can be /// tuned for a particular expected round trip time, link capacity, and memory availability. Tuning /// for higher bandwidths and latencies increases worst-case memory consumption, but does not impair /// performance at lower bandwidths and latencies. The default configuration is tuned for a 100Mbps /// link with a 100ms round trip time. pub struct TransportConfig { pub(crate) max_concurrent_bidi_streams: VarInt, pub(crate) max_concurrent_uni_streams: VarInt, pub(crate) max_idle_timeout: Option, pub(crate) stream_receive_window: VarInt, pub(crate) receive_window: VarInt, pub(crate) send_window: u64, pub(crate) max_tlps: u32, pub(crate) packet_threshold: u32, pub(crate) time_threshold: f32, pub(crate) initial_rtt: Duration, pub(crate) initial_mtu: u16, pub(crate) min_mtu: u16, pub(crate) mtu_discovery_config: Option, pub(crate) persistent_congestion_threshold: u32, pub(crate) keep_alive_interval: Option, pub(crate) crypto_buffer_size: usize, pub(crate) allow_spin: bool, pub(crate) datagram_receive_buffer_size: Option, pub(crate) datagram_send_buffer_size: usize, pub(crate) congestion_controller_factory: Box, pub(crate) enable_segmentation_offload: bool, } impl TransportConfig { /// Maximum number of incoming bidirectional streams that may be open concurrently /// /// Must be nonzero for the peer to open any bidirectional streams. /// /// Worst-case memory use is directly proportional to `max_concurrent_bidi_streams * /// stream_receive_window`, with an upper bound proportional to `receive_window`. pub fn max_concurrent_bidi_streams(&mut self, value: VarInt) -> &mut Self { self.max_concurrent_bidi_streams = value; self } /// Variant of `max_concurrent_bidi_streams` affecting unidirectional streams pub fn max_concurrent_uni_streams(&mut self, value: VarInt) -> &mut Self { self.max_concurrent_uni_streams = value; self } /// Maximum duration of inactivity to accept before timing out the connection. /// /// The true idle timeout is the minimum of this and the peer's own max idle timeout. `None` /// represents an infinite timeout. /// /// **WARNING**: If a peer or its network path malfunctions or acts maliciously, an infinite /// idle timeout can result in permanently hung futures! /// /// ``` /// # use std::{convert::TryInto, time::Duration}; /// # use quinn_proto::{TransportConfig, VarInt, VarIntBoundsExceeded}; /// # fn main() -> Result<(), VarIntBoundsExceeded> { /// let mut config = TransportConfig::default(); /// /// // Set the idle timeout as `VarInt`-encoded milliseconds /// config.max_idle_timeout(Some(VarInt::from_u32(10_000).into())); /// /// // Set the idle timeout as a `Duration` /// config.max_idle_timeout(Some(Duration::from_secs(10).try_into()?)); /// # Ok(()) /// # } /// ``` pub fn max_idle_timeout(&mut self, value: Option) -> &mut Self { self.max_idle_timeout = value.map(|t| t.0); self } /// Maximum number of bytes the peer may transmit without acknowledgement on any one stream /// before becoming blocked. /// /// This should be set to at least the expected connection latency multiplied by the maximum /// desired throughput. Setting this smaller than `receive_window` helps ensure that a single /// stream doesn't monopolize receive buffers, which may otherwise occur if the application /// chooses not to read from a large stream for a time while still requiring data on other /// streams. pub fn stream_receive_window(&mut self, value: VarInt) -> &mut Self { self.stream_receive_window = value; self } /// Maximum number of bytes the peer may transmit across all streams of a connection before /// becoming blocked. /// /// This should be set to at least the expected connection latency multiplied by the maximum /// desired throughput. Larger values can be useful to allow maximum throughput within a /// stream while another is blocked. pub fn receive_window(&mut self, value: VarInt) -> &mut Self { self.receive_window = value; self } /// Maximum number of bytes to transmit to a peer without acknowledgment /// /// Provides an upper bound on memory when communicating with peers that issue large amounts of /// flow control credit. Endpoints that wish to handle large numbers of connections robustly /// should take care to set this low enough to guarantee memory exhaustion does not occur if /// every connection uses the entire window. pub fn send_window(&mut self, value: u64) -> &mut Self { self.send_window = value; self } /// Maximum number of tail loss probes before an RTO fires. pub fn max_tlps(&mut self, value: u32) -> &mut Self { self.max_tlps = value; self } /// Maximum reordering in packet number space before FACK style loss detection considers a /// packet lost. Should not be less than 3, per RFC5681. pub fn packet_threshold(&mut self, value: u32) -> &mut Self { self.packet_threshold = value; self } /// Maximum reordering in time space before time based loss detection considers a packet lost, /// as a factor of RTT pub fn time_threshold(&mut self, value: f32) -> &mut Self { self.time_threshold = value; self } /// The RTT used before an RTT sample is taken pub fn initial_rtt(&mut self, value: Duration) -> &mut Self { self.initial_rtt = value; self } /// The initial value to be used as the maximum UDP payload size before running MTU discovery /// (see [`TransportConfig::mtu_discovery_config`]). /// /// Must be at least 1200, which is the default, and known to be safe for typical internet /// applications. Larger values are more efficient, but increase the risk of packet loss due to /// exceeding the network path's IP MTU. If the provided value is higher than what the network /// path actually supports, packet loss will eventually trigger black hole detection and bring /// it down to [`TransportConfig::min_mtu`]. pub fn initial_mtu(&mut self, value: u16) -> &mut Self { self.initial_mtu = value.max(INITIAL_MTU); self } pub(crate) fn get_initial_mtu(&self) -> u16 { self.initial_mtu.max(self.min_mtu) } /// The maximum UDP payload size guaranteed to be supported by the network. /// /// Must be at least 1200, which is the default, and lower than or equal to /// [`TransportConfig::initial_mtu`]. /// /// Real-world MTUs can vary according to ISP, VPN, and properties of intermediate network links /// outside of either endpoint's control. Extreme care should be used when raising this value /// outside of private networks where these factors are fully controlled. If the provided value /// is higher than what the network path actually supports, the result will be unpredictable and /// catastrophic packet loss, without a possibility of repair. Prefer /// [`TransportConfig::initial_mtu`] together with /// [`TransportConfig::mtu_discovery_config`] to set a maximum UDP payload size that robustly /// adapts to the network. pub fn min_mtu(&mut self, value: u16) -> &mut Self { self.min_mtu = value.max(INITIAL_MTU); self } /// Specifies the MTU discovery config (see [`MtuDiscoveryConfig`] for details). /// /// Defaults to `None`, which disables MTU discovery altogether. /// /// # Important /// /// MTU discovery depends on platform support for disabling UDP packet fragmentation, which is /// not always available. If the platform allows fragmenting UDP packets, MTU discovery may end /// up "discovering" an MTU that is not really supported by the network, causing packet loss /// down the line. /// /// The `quinn` crate provides the `Endpoint::server` and `Endpoint::client` constructors that /// automatically disable UDP packet fragmentation on Linux and Windows. When using these /// constructors, MTU discovery will reliably work, unless the code is compiled targeting an /// unsupported platform (e.g. iOS). In the latter case, it is advisable to keep MTU discovery /// disabled. /// /// Users of `quinn-proto` and authors of custom `AsyncUdpSocket` implementations should ensure /// to disable UDP packet fragmentation (this is strongly recommended by [RFC /// 9000](https://www.rfc-editor.org/rfc/rfc9000.html#section-14-7), regardless of MTU /// discovery). They can build on top of the `quinn-udp` crate, used by `quinn` itself, which /// provides Linux, Windows, macOS, and FreeBSD support for disabling packet fragmentation. pub fn mtu_discovery_config(&mut self, value: Option) -> &mut Self { self.mtu_discovery_config = value; self } /// Number of consecutive PTOs after which network is considered to be experiencing persistent congestion. pub fn persistent_congestion_threshold(&mut self, value: u32) -> &mut Self { self.persistent_congestion_threshold = value; self } /// Period of inactivity before sending a keep-alive packet /// /// Keep-alive packets prevent an inactive but otherwise healthy connection from timing out. /// /// `None` to disable, which is the default. Only one side of any given connection needs keep-alive /// enabled for the connection to be preserved. Must be set lower than the idle_timeout of both /// peers to be effective. pub fn keep_alive_interval(&mut self, value: Option) -> &mut Self { self.keep_alive_interval = value; self } /// Maximum quantity of out-of-order crypto layer data to buffer pub fn crypto_buffer_size(&mut self, value: usize) -> &mut Self { self.crypto_buffer_size = value; self } /// Whether the implementation is permitted to set the spin bit on this connection /// /// This allows passive observers to easily judge the round trip time of a connection, which can /// be useful for network administration but sacrifices a small amount of privacy. pub fn allow_spin(&mut self, value: bool) -> &mut Self { self.allow_spin = value; self } /// Maximum number of incoming application datagram bytes to buffer, or None to disable /// incoming datagrams /// /// The peer is forbidden to send single datagrams larger than this size. If the aggregate size /// of all datagrams that have been received from the peer but not consumed by the application /// exceeds this value, old datagrams are dropped until it is no longer exceeded. pub fn datagram_receive_buffer_size(&mut self, value: Option) -> &mut Self { self.datagram_receive_buffer_size = value; self } /// Maximum number of outgoing application datagram bytes to buffer /// /// While datagrams are sent ASAP, it is possible for an application to generate data faster /// than the link, or even the underlying hardware, can transmit them. This limits the amount of /// memory that may be consumed in that case. When the send buffer is full and a new datagram is /// sent, older datagrams are dropped until sufficient space is available. pub fn datagram_send_buffer_size(&mut self, value: usize) -> &mut Self { self.datagram_send_buffer_size = value; self } /// How to construct new `congestion::Controller`s /// /// Typically the refcounted configuration of a `congestion::Controller`, /// e.g. a `congestion::NewRenoConfig`. /// /// # Example /// ``` /// # use quinn_proto::*; use std::sync::Arc; /// let mut config = TransportConfig::default(); /// config.congestion_controller_factory(Arc::new(congestion::NewRenoConfig::default())); /// ``` pub fn congestion_controller_factory( &mut self, factory: impl congestion::ControllerFactory + Send + Sync + 'static, ) -> &mut Self { self.congestion_controller_factory = Box::new(factory); self } /// Whether to use "Generic Segmentation Offload" to accelerate transmits, when supported by the /// environment /// /// Defaults to `true`. /// /// GSO dramatically reduces CPU consumption when sending large numbers of packets with the same /// headers, such as when transmitting bulk data on a connection. However, it is not supported /// by all network interface drivers or packet inspection tools. `quinn-udp` will attempt to /// disable GSO automatically when unavailable, but this can lead to spurious packet loss at /// startup, temporarily degrading performance. pub fn enable_segmentation_offload(&mut self, enabled: bool) -> &mut Self { self.enable_segmentation_offload = enabled; self } } impl Default for TransportConfig { fn default() -> Self { const EXPECTED_RTT: u32 = 100; // ms const MAX_STREAM_BANDWIDTH: u32 = 12500 * 1000; // bytes/s // Window size needed to avoid pipeline // stalls const STREAM_RWND: u32 = MAX_STREAM_BANDWIDTH / 1000 * EXPECTED_RTT; Self { max_concurrent_bidi_streams: 100u32.into(), max_concurrent_uni_streams: 100u32.into(), max_idle_timeout: Some(VarInt(10_000)), stream_receive_window: STREAM_RWND.into(), receive_window: VarInt::MAX, send_window: (8 * STREAM_RWND).into(), max_tlps: 2, packet_threshold: 3, time_threshold: 9.0 / 8.0, initial_rtt: Duration::from_millis(333), // per spec, intentionally distinct from EXPECTED_RTT initial_mtu: INITIAL_MTU, min_mtu: INITIAL_MTU, mtu_discovery_config: Some(MtuDiscoveryConfig::default()), persistent_congestion_threshold: 3, keep_alive_interval: None, crypto_buffer_size: 16 * 1024, allow_spin: true, datagram_receive_buffer_size: Some(STREAM_RWND as usize), datagram_send_buffer_size: 1024 * 1024, congestion_controller_factory: Box::new(Arc::new(congestion::CubicConfig::default())), enable_segmentation_offload: true, } } } impl fmt::Debug for TransportConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("TranportConfig") .field( "max_concurrent_bidi_streams", &self.max_concurrent_bidi_streams, ) .field( "max_concurrent_uni_streams", &self.max_concurrent_uni_streams, ) .field("max_idle_timeout", &self.max_idle_timeout) .field("stream_receive_window", &self.stream_receive_window) .field("receive_window", &self.receive_window) .field("send_window", &self.send_window) .field("max_tlps", &self.max_tlps) .field("packet_threshold", &self.packet_threshold) .field("time_threshold", &self.time_threshold) .field("initial_rtt", &self.initial_rtt) .field( "persistent_congestion_threshold", &self.persistent_congestion_threshold, ) .field("keep_alive_interval", &self.keep_alive_interval) .field("crypto_buffer_size", &self.crypto_buffer_size) .field("allow_spin", &self.allow_spin) .field( "datagram_receive_buffer_size", &self.datagram_receive_buffer_size, ) .field("datagram_send_buffer_size", &self.datagram_send_buffer_size) .field("congestion_controller_factory", &"[ opaque ]") .finish() } } /// Parameters governing MTU discovery. /// /// # The why of MTU discovery /// /// By design, QUIC ensures during the handshake that the network path between the client and the /// server is able to transmit unfragmented UDP packets with a body of 1200 bytes. In other words, /// once the connection is established, we know that the network path's maximum transmission unit /// (MTU) is of at least 1200 bytes (plus IP and UDP headers). Because of this, a QUIC endpoint can /// split outgoing data in packets of 1200 bytes, with confidence that the network will be able to /// deliver them (if the endpoint were to send bigger packets, they could prove too big and end up /// being dropped). /// /// There is, however, a significant overhead associated to sending a packet. If the same /// information can be sent in fewer packets, that results in higher throughput. The amount of /// packets that need to be sent is inversely proportional to the MTU: the higher the MTU, the /// bigger the packets that can be sent, and the fewer packets that are needed to transmit a given /// amount of bytes. /// /// Most networks have an MTU higher than 1200. Through MTU discovery, endpoints can detect the /// path's MTU and, if it turns out to be higher, start sending bigger packets. /// /// # MTU discovery internals /// /// Quinn implements MTU discovery through DPLPMTUD (Datagram Packetization Layer Path MTU /// Discovery), described in [section 14.3 of RFC /// 9000](https://www.rfc-editor.org/rfc/rfc9000.html#section-14.3). This method consists of sending /// QUIC packets padded to a particular size (called PMTU probes), and waiting to see if the remote /// peer responds with an ACK. If an ACK is received, that means the probe arrived at the remote /// peer, which in turn means that the network path's MTU is of at least the packet's size. If the /// probe is lost, it is sent another 2 times before concluding that the MTU is lower than the /// packet's size. /// /// MTU discovery runs on a schedule (e.g. every 600 seconds) specified through /// [`MtuDiscoveryConfig::interval`]. The first run happens right after the handshake, and /// subsequent discoveries are scheduled to run when the interval has elapsed, starting from the /// last time when MTU discovery completed. /// /// Since the search space for MTUs is quite big (the smallest possible MTU is 1200, and the highest /// is 65527), Quinn performs a binary search to keep the number of probes as low as possible. The /// lower bound of the search is equal to [`TransportConfig::initial_mtu`] in the /// initial MTU discovery run, and is equal to the currently discovered MTU in subsequent runs. The /// upper bound is determined by the minimum of [`MtuDiscoveryConfig::upper_bound`] and the /// `max_udp_payload_size` transport parameter received from the peer during the handshake. /// /// # Black hole detection /// /// If, at some point, the network path no longer accepts packets of the detected size, packet loss /// will eventually trigger black hole detection and reset the detected MTU to 1200. In that case, /// MTU discovery will be triggered after [`MtuDiscoveryConfig::black_hole_cooldown`] (ignoring the /// timer that was set based on [`MtuDiscoveryConfig::interval`]). /// /// # Interaction between peers /// /// There is no guarantee that the MTU on the path between A and B is the same as the MTU of the /// path between B and A. Therefore, each peer in the connection needs to run MTU discovery /// independently in order to discover the path's MTU. #[derive(Clone, Debug)] pub struct MtuDiscoveryConfig { pub(crate) interval: Duration, pub(crate) upper_bound: u16, pub(crate) black_hole_cooldown: Duration, } impl MtuDiscoveryConfig { /// Specifies the time to wait after completing MTU discovery before starting a new MTU /// discovery run. /// /// Defaults to 600 seconds, as recommended by [RFC /// 8899](https://www.rfc-editor.org/rfc/rfc8899). pub fn interval(&mut self, value: Duration) -> &mut Self { self.interval = value; self } /// Specifies the upper bound to the max UDP payload size that MTU discovery will search for. /// /// Defaults to 1452, to stay within Ethernet's MTU when using IPv4 and IPv6. The highest /// allowed value is 65527, which corresponds to the maximum permitted UDP payload on IPv6. /// /// It is safe to use an arbitrarily high upper bound, regardless of the network path's MTU. The /// only drawback is that MTU discovery might take more time to finish. pub fn upper_bound(&mut self, value: u16) -> &mut Self { self.upper_bound = value.min(MAX_UDP_PAYLOAD); self } /// Specifies the amount of time that MTU discovery should wait after a black hole was detected /// before running again. Defaults to one minute. /// /// Black hole detection can be spuriously triggered in case of congestion, so it makes sense to /// try MTU discovery again after a short period of time. pub fn black_hole_cooldown(&mut self, value: Duration) -> &mut Self { self.black_hole_cooldown = value; self } } impl Default for MtuDiscoveryConfig { fn default() -> Self { Self { interval: Duration::from_secs(600), upper_bound: 1452, black_hole_cooldown: Duration::from_secs(60), } } } /// Global configuration for the endpoint, affecting all connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] pub struct EndpointConfig { pub(crate) reset_key: Arc, pub(crate) max_udp_payload_size: VarInt, /// CID generator factory /// /// Create a cid generator for local cid in Endpoint struct pub(crate) connection_id_generator_factory: Arc Box + Send + Sync>, pub(crate) supported_versions: Vec, pub(crate) grease_quic_bit: bool, } impl EndpointConfig { /// Create a default config with a particular `reset_key` pub fn new(reset_key: Arc) -> Self { let cid_factory: fn() -> Box = || Box::::default(); Self { reset_key, max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers connection_id_generator_factory: Arc::new(cid_factory), supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(), grease_quic_bit: true, } } /// Supply a custom connection ID generator factory /// /// Called once by each `Endpoint` constructed from this configuration to obtain the CID /// generator which will be used to generate the CIDs used for incoming packets on all /// connections involving that `Endpoint`. A custom CID generator allows applications to embed /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// `EndpointConfig::new()` applies a default random CID generator factory. This functions /// accepts any customized CID generator to reset CID generator factory that implements /// the `ConnectionIdGenerator` trait. pub fn cid_generator Box + Send + Sync + 'static>( &mut self, factory: F, ) -> &mut Self { self.connection_id_generator_factory = Arc::new(factory); self } /// Private key used to send authenticated connection resets to peers who were /// communicating with a previous instance of this endpoint. pub fn reset_key(&mut self, key: Arc) -> &mut Self { self.reset_key = key; self } /// Maximum UDP payload size accepted from peers (excluding UDP and IP overhead). /// /// Must be greater or equal than 1200. /// /// Defaults to 1472, which is the largest UDP payload that can be transmitted in the typical /// 1500 byte Ethernet MTU. Deployments on links with larger MTUs (e.g. loopback or Ethernet /// with jumbo frames) can raise this to improve performance at the cost of a linear increase in /// datagram receive buffer size. pub fn max_udp_payload_size(&mut self, value: u16) -> Result<&mut Self, ConfigError> { if !(1200..=65_527).contains(&value) { return Err(ConfigError::OutOfBounds); } self.max_udp_payload_size = value.into(); Ok(self) } /// Get the current value of `max_udp_payload_size` /// /// While most parameters don't need to be readable, this must be exposed to allow higher-level /// layers, e.g. the `quinn` crate, to determine how large a receive buffer to allocate to /// support an externally-defined `EndpointConfig`. /// /// While `get_` accessors are typically unidiomatic in Rust, we favor concision for setters, /// which will be used far more heavily. #[doc(hidden)] pub fn get_max_udp_payload_size(&self) -> u64 { self.max_udp_payload_size.into() } /// Override supported QUIC versions pub fn supported_versions(&mut self, supported_versions: Vec) -> &mut Self { self.supported_versions = supported_versions; self } /// Whether to accept QUIC packets containing any value for the fixed bit /// /// Enabled by default. Helps protect against protocol ossification and makes traffic less /// identifiable to observers. Disable if helping observers identify this traffic as QUIC is /// desired. pub fn grease_quic_bit(&mut self, value: bool) -> &mut Self { self.grease_quic_bit = value; self } } impl fmt::Debug for EndpointConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("EndpointConfig") .field("reset_key", &"[ elided ]") .field("max_udp_payload_size", &self.max_udp_payload_size) .field("cid_generator_factory", &"[ elided ]") .field("supported_versions", &self.supported_versions) .field("grease_quic_bit", &self.grease_quic_bit) .finish() } } #[cfg(feature = "ring")] impl Default for EndpointConfig { fn default() -> Self { let mut reset_key = [0; 64]; rand::thread_rng().fill_bytes(&mut reset_key); Self::new(Arc::new(ring::hmac::Key::new( ring::hmac::HMAC_SHA256, &reset_key, ))) } } /// Parameters governing incoming connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] pub struct ServerConfig { /// Transport configuration to use for incoming connections pub transport: Arc, /// TLS configuration used for incoming connections. /// /// Must be set to use TLS 1.3 only. pub crypto: Arc, /// Used to generate one-time AEAD keys to protect handshake tokens pub(crate) token_key: Arc, /// Whether to require clients to prove ownership of an address before committing resources. /// /// Introduces an additional round-trip to the handshake to make denial of service attacks more difficult. pub(crate) use_retry: bool, /// Microseconds after a stateless retry token was issued for which it's considered valid. pub(crate) retry_token_lifetime: Duration, /// Maximum number of concurrent connections pub(crate) concurrent_connections: u32, /// Whether to allow clients to migrate to new addresses /// /// Improves behavior for clients that move between different internet connections or suffer NAT /// rebinding. Enabled by default. pub(crate) migration: bool, } impl ServerConfig { /// Create a default config with a particular handshake token key pub fn new( crypto: Arc, token_key: Arc, ) -> Self { Self { transport: Arc::new(TransportConfig::default()), crypto, token_key, use_retry: false, retry_token_lifetime: Duration::from_secs(15), concurrent_connections: 100_000, migration: true, } } /// Set a custom [`TransportConfig`] pub fn transport_config(&mut self, transport: Arc) -> &mut Self { self.transport = transport; self } /// Private key used to authenticate data included in handshake tokens. pub fn token_key(&mut self, value: Arc) -> &mut Self { self.token_key = value; self } /// Whether to require clients to prove ownership of an address before committing resources. /// /// Introduces an additional round-trip to the handshake to make denial of service attacks more difficult. pub fn use_retry(&mut self, value: bool) -> &mut Self { self.use_retry = value; self } /// Duration after a stateless retry token was issued for which it's considered valid. pub fn retry_token_lifetime(&mut self, value: Duration) -> &mut Self { self.retry_token_lifetime = value; self } /// Maximum number of simultaneous connections to accept. /// /// New incoming connections are only accepted if the total number of incoming or outgoing /// connections is less than this. Outgoing connections are unaffected. pub fn concurrent_connections(&mut self, value: u32) -> &mut Self { self.concurrent_connections = value; self } /// Whether to allow clients to migrate to new addresses /// /// Improves behavior for clients that move between different internet connections or suffer NAT /// rebinding. Enabled by default. pub fn migration(&mut self, value: bool) -> &mut Self { self.migration = value; self } } #[cfg(feature = "rustls")] impl ServerConfig { /// Create a server config with the given certificate chain to be presented to clients /// /// Uses a randomized handshake token key. pub fn with_single_cert( cert_chain: Vec, key: rustls::PrivateKey, ) -> Result { let crypto = crypto::rustls::server_config(cert_chain, key)?; Ok(Self::with_crypto(Arc::new(crypto))) } } #[cfg(feature = "ring")] impl ServerConfig { /// Create a server config with the given [`crypto::ServerConfig`] /// /// Uses a randomized handshake token key. pub fn with_crypto(crypto: Arc) -> Self { let rng = &mut rand::thread_rng(); let mut master_key = [0u8; 64]; rng.fill_bytes(&mut master_key); let master_key = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, &[]).extract(&master_key); Self::new(crypto, Arc::new(master_key)) } } impl fmt::Debug for ServerConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ServerConfig") .field("transport", &self.transport) .field("crypto", &"ServerConfig { elided }") .field("token_key", &"[ elided ]") .field("use_retry", &self.use_retry) .field("retry_token_lifetime", &self.retry_token_lifetime) .field("concurrent_connections", &self.concurrent_connections) .field("migration", &self.migration) .finish() } } /// Configuration for outgoing connections /// /// Default values should be suitable for most internet applications. #[derive(Clone)] #[non_exhaustive] pub struct ClientConfig { /// Transport configuration to use pub(crate) transport: Arc, /// Cryptographic configuration to use pub(crate) crypto: Arc, /// QUIC protocol version to use pub(crate) version: u32, } impl ClientConfig { /// Create a default config with a particular cryptographic config pub fn new(crypto: Arc) -> Self { Self { transport: Default::default(), crypto, version: 1, } } /// Set a custom [`TransportConfig`] pub fn transport_config(&mut self, transport: Arc) -> &mut Self { self.transport = transport; self } /// Set the QUIC version to use pub fn version(&mut self, version: u32) -> &mut Self { self.version = version; self } } #[cfg(feature = "rustls")] impl ClientConfig { /// Create a client configuration that trusts the platform's native roots #[cfg(feature = "native-certs")] pub fn with_native_roots() -> Self { let mut roots = rustls::RootCertStore::empty(); match rustls_native_certs::load_native_certs() { Ok(certs) => { for cert in certs { if let Err(e) = roots.add(&rustls::Certificate(cert.0)) { tracing::warn!("failed to parse trust anchor: {}", e); } } } Err(e) => { tracing::warn!("couldn't load any default trust roots: {}", e); } }; Self::with_root_certificates(roots) } /// Create a client configuration that trusts specified trust anchors pub fn with_root_certificates(roots: rustls::RootCertStore) -> Self { Self::new(Arc::new(crypto::rustls::client_config(roots))) } } impl fmt::Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ClientConfig") .field("transport", &self.transport) .field("crypto", &"ClientConfig { elided }") .field("version", &self.version) .finish() } } /// Errors in the configuration of an endpoint #[derive(Debug, Error, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum ConfigError { /// Value exceeds supported bounds #[error("value exceeds supported bounds")] OutOfBounds, } impl From for ConfigError { fn from(_: TryFromIntError) -> Self { Self::OutOfBounds } } impl From for ConfigError { fn from(_: VarIntBoundsExceeded) -> Self { Self::OutOfBounds } } /// Maximum duration of inactivity to accept before timing out the connection. /// /// This wraps an underlying [`VarInt`], representing the duration in milliseconds. Values can be /// constructed by converting directly from `VarInt`, or using `TryFrom`. /// /// ``` /// # use std::{convert::TryFrom, time::Duration}; /// # use quinn_proto::{IdleTimeout, VarIntBoundsExceeded, VarInt}; /// # fn main() -> Result<(), VarIntBoundsExceeded> { /// // A `VarInt`-encoded value in milliseconds /// let timeout = IdleTimeout::from(VarInt::from_u32(10_000)); /// /// // Try to convert a `Duration` into a `VarInt`-encoded timeout /// let timeout = IdleTimeout::try_from(Duration::from_secs(10))?; /// # Ok(()) /// # } /// ``` #[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct IdleTimeout(VarInt); impl From for IdleTimeout { fn from(inner: VarInt) -> Self { Self(inner) } } impl std::convert::TryFrom for IdleTimeout { type Error = VarIntBoundsExceeded; fn try_from(timeout: Duration) -> Result { let inner = VarInt::try_from(timeout.as_millis())?; Ok(Self(inner)) } } quinn-proto-0.10.6/src/congestion/bbr/bw_estimation.rs000064400000000000000000000067661046102023000211400ustar 00000000000000use std::fmt::{Debug, Display, Formatter}; use std::time::{Duration, Instant}; use super::min_max::MinMax; #[derive(Clone, Debug)] pub(crate) struct BandwidthEstimation { total_acked: u64, prev_total_acked: u64, acked_time: Option, prev_acked_time: Option, total_sent: u64, prev_total_sent: u64, sent_time: Instant, prev_sent_time: Option, max_filter: MinMax, acked_at_last_window: u64, } impl BandwidthEstimation { pub(crate) fn on_sent(&mut self, now: Instant, bytes: u64) { self.prev_total_sent = self.total_sent; self.total_sent += bytes; self.prev_sent_time = Some(self.sent_time); self.sent_time = now; } pub(crate) fn on_ack( &mut self, now: Instant, _sent: Instant, bytes: u64, round: u64, app_limited: bool, ) { self.prev_total_acked = self.total_acked; self.total_acked += bytes; self.prev_acked_time = self.acked_time; self.acked_time = Some(now); let prev_sent_time = match self.prev_sent_time { Some(prev_sent_time) => prev_sent_time, None => return, }; let send_rate = if self.sent_time > prev_sent_time { Self::bw_from_delta( self.total_sent - self.prev_total_sent, self.sent_time - prev_sent_time, ) .unwrap_or(0) } else { u64::MAX // will take the min of send and ack, so this is just a skip }; let ack_rate = match self.prev_acked_time { Some(prev_acked_time) => Self::bw_from_delta( self.total_acked - self.prev_total_acked, now - prev_acked_time, ) .unwrap_or(0), None => 0, }; let bandwidth = send_rate.min(ack_rate); if !app_limited && self.max_filter.get() < bandwidth { self.max_filter.update_max(round, bandwidth); } } pub(crate) fn bytes_acked_this_window(&self) -> u64 { self.total_acked - self.acked_at_last_window } pub(crate) fn end_acks(&mut self, _current_round: u64, _app_limited: bool) { self.acked_at_last_window = self.total_acked; } pub(crate) fn get_estimate(&self) -> u64 { self.max_filter.get() } pub(crate) const fn bw_from_delta(bytes: u64, delta: Duration) -> Option { let window_duration_ns = delta.as_nanos(); if window_duration_ns == 0 { return None; } let b_ns = bytes * 1_000_000_000; let bytes_per_second = b_ns / (window_duration_ns as u64); Some(bytes_per_second) } } impl Default for BandwidthEstimation { fn default() -> Self { Self { total_acked: 0, prev_total_acked: 0, acked_time: None, prev_acked_time: None, total_sent: 0, prev_total_sent: 0, // The `sent_time` value set here is ignored; it is used in `on_ack()`, but will // have been reset by `on_sent()` before that method is called. sent_time: Instant::now(), prev_sent_time: None, max_filter: MinMax::default(), acked_at_last_window: 0, } } } impl Display for BandwidthEstimation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, "{:.3} MB/s", self.get_estimate() as f32 / (1024 * 1024) as f32 ) } } quinn-proto-0.10.6/src/congestion/bbr/min_max.rs000064400000000000000000000120641046102023000177100ustar 00000000000000/* * Based on Google code released under BSD license here: * https://groups.google.com/forum/#!topic/bbr-dev/3RTgkzi5ZD8 */ /* * Kathleen Nichols' algorithm for tracking the minimum (or maximum) * value of a data stream over some fixed time interval. (E.g., * the minimum RTT over the past five minutes.) It uses constant * space and constant time per update yet almost always delivers * the same minimum as an implementation that has to keep all the * data in the window. * * The algorithm keeps track of the best, 2nd best & 3rd best min * values, maintaining an invariant that the measurement time of * the n'th best >= n-1'th best. It also makes sure that the three * values are widely separated in the time window since that bounds * the worse case error when that data is monotonically increasing * over the window. * * Upon getting a new min, we can forget everything earlier because * it has no value - the new min is <= everything else in the window * by definition and it samples the most recent. So we restart fresh on * every new min and overwrites 2nd & 3rd choices. The same property * holds for 2nd & 3rd best. */ use std::fmt::Debug; #[derive(Copy, Clone, Debug)] pub(super) struct MinMax { /// round count, not a timestamp window: u64, samples: [MinMaxSample; 3], } impl MinMax { pub(super) fn get(&self) -> u64 { self.samples[0].value } fn fill(&mut self, sample: MinMaxSample) { self.samples.fill(sample); } pub(super) fn reset(&mut self) { self.fill(Default::default()) } /// update_min is also defined in the original source, but removed here since it is not used. pub(super) fn update_max(&mut self, current_round: u64, measurement: u64) { let sample = MinMaxSample { time: current_round, value: measurement, }; if self.samples[0].value == 0 /* uninitialised */ || /* found new max? */ sample.value >= self.samples[0].value || /* nothing left in window? */ sample.time - self.samples[2].time > self.window { self.fill(sample); /* forget earlier samples */ return; } if sample.value >= self.samples[1].value { self.samples[2] = sample; self.samples[1] = sample; } else if sample.value >= self.samples[2].value { self.samples[2] = sample; } self.subwin_update(sample); } /* As time advances, update the 1st, 2nd, and 3rd choices. */ fn subwin_update(&mut self, sample: MinMaxSample) { let dt = sample.time - self.samples[0].time; if dt > self.window { /* * Passed entire window without a new sample so make 2nd * choice the new sample & 3rd choice the new 2nd choice. * we may have to iterate this since our 2nd choice * may also be outside the window (we checked on entry * that the third choice was in the window). */ self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; if sample.time - self.samples[0].time > self.window { self.samples[0] = self.samples[1]; self.samples[1] = self.samples[2]; self.samples[2] = sample; } } else if self.samples[1].time == self.samples[0].time && dt > self.window / 4 { /* * We've passed a quarter of the window without a new sample * so take a 2nd choice from the 2nd quarter of the window. */ self.samples[2] = sample; self.samples[1] = sample; } else if self.samples[2].time == self.samples[1].time && dt > self.window / 2 { /* * We've passed half the window without finding a new sample * so take a 3rd choice from the last half of the window */ self.samples[2] = sample; } } } impl Default for MinMax { fn default() -> Self { Self { window: 10, samples: [Default::default(); 3], } } } #[derive(Debug, Copy, Clone, Default)] struct MinMaxSample { /// round number, not a timestamp time: u64, value: u64, } #[cfg(test)] mod test { use super::*; #[test] fn test() { let round = 25; let mut min_max = MinMax::default(); min_max.update_max(round + 1, 100); assert_eq!(100, min_max.get()); min_max.update_max(round + 3, 120); assert_eq!(120, min_max.get()); min_max.update_max(round + 5, 160); assert_eq!(160, min_max.get()); min_max.update_max(round + 7, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 10, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 14, 100); assert_eq!(160, min_max.get()); min_max.update_max(round + 16, 100); assert_eq!(100, min_max.get()); min_max.update_max(round + 18, 130); assert_eq!(130, min_max.get()); } } quinn-proto-0.10.6/src/congestion/bbr/mod.rs000064400000000000000000000550711046102023000170440ustar 00000000000000use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use std::time::{Duration, Instant}; use rand::{Rng, SeedableRng}; use crate::congestion::bbr::bw_estimation::BandwidthEstimation; use crate::congestion::bbr::min_max::MinMax; use crate::connection::RttEstimator; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; mod bw_estimation; mod min_max; /// Experimental! Use at your own risk. /// /// Aims for reduced buffer bloat and improved performance over high bandwidth-delay product networks. /// Based on google's quiche implementation /// of BBR . /// More discussion and links at . #[derive(Debug, Clone)] pub struct Bbr { config: Arc, current_mtu: u64, max_bandwidth: BandwidthEstimation, acked_bytes: u64, mode: Mode, loss_state: LossState, recovery_state: RecoveryState, recovery_window: u64, is_at_full_bandwidth: bool, pacing_gain: f32, high_gain: f32, drain_gain: f32, cwnd_gain: f32, high_cwnd_gain: f32, last_cycle_start: Option, current_cycle_offset: u8, init_cwnd: u64, min_cwnd: u64, prev_in_flight_count: u64, exit_probe_rtt_at: Option, probe_rtt_last_started_at: Option, min_rtt: Duration, exiting_quiescence: bool, pacing_rate: u64, max_acked_packet_number: u64, max_sent_packet_number: u64, end_recovery_at_packet_number: u64, cwnd: u64, current_round_trip_end_packet_number: u64, round_count: u64, bw_at_last_round: u64, round_wo_bw_gain: u64, ack_aggregation: AckAggregationState, random_number_generator: rand::rngs::StdRng, } impl Bbr { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, current_mtu: u16) -> Self { let initial_window = config.initial_window; Self { config, current_mtu: current_mtu as u64, max_bandwidth: BandwidthEstimation::default(), acked_bytes: 0, mode: Mode::Startup, loss_state: Default::default(), recovery_state: RecoveryState::NotInRecovery, recovery_window: 0, is_at_full_bandwidth: false, pacing_gain: K_DEFAULT_HIGH_GAIN, high_gain: K_DEFAULT_HIGH_GAIN, drain_gain: 1.0 / K_DEFAULT_HIGH_GAIN, cwnd_gain: K_DEFAULT_HIGH_GAIN, high_cwnd_gain: K_DEFAULT_HIGH_GAIN, last_cycle_start: None, current_cycle_offset: 0, init_cwnd: initial_window, min_cwnd: calculate_min_window(current_mtu as u64), prev_in_flight_count: 0, exit_probe_rtt_at: None, probe_rtt_last_started_at: None, min_rtt: Default::default(), exiting_quiescence: false, pacing_rate: 0, max_acked_packet_number: 0, max_sent_packet_number: 0, end_recovery_at_packet_number: 0, cwnd: initial_window, current_round_trip_end_packet_number: 0, round_count: 0, bw_at_last_round: 0, round_wo_bw_gain: 0, ack_aggregation: AckAggregationState::default(), random_number_generator: rand::rngs::StdRng::from_entropy(), } } fn enter_startup_mode(&mut self) { self.mode = Mode::Startup; self.pacing_gain = self.high_gain; self.cwnd_gain = self.high_cwnd_gain; } fn enter_probe_bandwidth_mode(&mut self, now: Instant) { self.mode = Mode::ProbeBw; self.cwnd_gain = K_DERIVED_HIGH_CWNDGAIN; self.last_cycle_start = Some(now); // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is // excluded because in that case increased gain and decreased gain would not // follow each other. let mut rand_index = self .random_number_generator .gen_range(0..K_PACING_GAIN.len() as u8 - 1); if rand_index >= 1 { rand_index += 1; } self.current_cycle_offset = rand_index; self.pacing_gain = K_PACING_GAIN[rand_index as usize]; } fn update_recovery_state(&mut self, is_round_start: bool) { // Exit recovery when there are no losses for a round. if self.loss_state.has_losses() { self.end_recovery_at_packet_number = self.max_sent_packet_number; } match self.recovery_state { // Enter conservation on the first loss. RecoveryState::NotInRecovery if self.loss_state.has_losses() => { self.recovery_state = RecoveryState::Conservation; // This will cause the |recovery_window| to be set to the // correct value in CalculateRecoveryWindow(). self.recovery_window = 0; // Since the conservation phase is meant to be lasting for a whole // round, extend the current round as if it were started right now. self.current_round_trip_end_packet_number = self.max_sent_packet_number; } RecoveryState::Growth | RecoveryState::Conservation => { if self.recovery_state == RecoveryState::Conservation && is_round_start { self.recovery_state = RecoveryState::Growth; } // Exit recovery if appropriate. if !self.loss_state.has_losses() && self.max_acked_packet_number > self.end_recovery_at_packet_number { self.recovery_state = RecoveryState::NotInRecovery; } } _ => {} } } fn update_gain_cycle_phase(&mut self, now: Instant, in_flight: u64) { // In most cases, the cycle is advanced after an RTT passes. let mut should_advance_gain_cycling = self .last_cycle_start .map(|last_cycle_start| now.duration_since(last_cycle_start) > self.min_rtt) .unwrap_or(false); // If the pacing gain is above 1.0, the connection is trying to probe the // bandwidth by increasing the number of bytes in flight to at least // pacing_gain * BDP. Make sure that it actually reaches the target, as // long as there are no losses suggesting that the buffers are not able to // hold that much. if self.pacing_gain > 1.0 && !self.loss_state.has_losses() && self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain) { should_advance_gain_cycling = false; } // If pacing gain is below 1.0, the connection is trying to drain the extra // queue which could have been incurred by probing prior to it. If the // number of bytes in flight falls down to the estimated BDP value earlier, // conclude that the queue has been successfully drained and exit this cycle // early. if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) { should_advance_gain_cycling = true; } if should_advance_gain_cycling { self.current_cycle_offset = (self.current_cycle_offset + 1) % K_PACING_GAIN.len() as u8; self.last_cycle_start = Some(now); // Stay in low gain mode until the target BDP is hit. Low gain mode // will be exited immediately when the target BDP is achieved. if DRAIN_TO_TARGET && self.pacing_gain < 1.0 && (K_PACING_GAIN[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON && in_flight > self.get_target_cwnd(1.0) { return; } self.pacing_gain = K_PACING_GAIN[self.current_cycle_offset as usize]; } } fn maybe_exit_startup_or_drain(&mut self, now: Instant, in_flight: u64) { if self.mode == Mode::Startup && self.is_at_full_bandwidth { self.mode = Mode::Drain; self.pacing_gain = self.drain_gain; self.cwnd_gain = self.high_cwnd_gain; } if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) { self.enter_probe_bandwidth_mode(now); } } fn is_min_rtt_expired(&self, now: Instant, app_limited: bool) -> bool { !app_limited && self .probe_rtt_last_started_at .map(|last| now.saturating_duration_since(last) > Duration::from_secs(10)) .unwrap_or(true) } fn maybe_enter_or_exit_probe_rtt( &mut self, now: Instant, is_round_start: bool, bytes_in_flight: u64, app_limited: bool, ) { let min_rtt_expired = self.is_min_rtt_expired(now, app_limited); if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt { self.mode = Mode::ProbeRtt; self.pacing_gain = 1.0; // Do not decide on the time to exit ProbeRtt until the // |bytes_in_flight| is at the target small value. self.exit_probe_rtt_at = None; self.probe_rtt_last_started_at = Some(now); } if self.mode == Mode::ProbeRtt { if self.exit_probe_rtt_at.is_none() { // If the window has reached the appropriate size, schedule exiting // ProbeRtt. The CWND during ProbeRtt is // kMinimumCongestionWindow, but we allow an extra packet since QUIC // checks CWND before sending a packet. if bytes_in_flight < self.get_probe_rtt_cwnd() + self.current_mtu { const K_PROBE_RTT_TIME: Duration = Duration::from_millis(200); self.exit_probe_rtt_at = Some(now + K_PROBE_RTT_TIME); } } else if is_round_start && now >= self.exit_probe_rtt_at.unwrap() { if !self.is_at_full_bandwidth { self.enter_startup_mode(); } else { self.enter_probe_bandwidth_mode(now); } } } self.exiting_quiescence = false; } fn get_target_cwnd(&self, gain: f32) -> u64 { let bw = self.max_bandwidth.get_estimate(); let bdp = self.min_rtt.as_micros() as u64 * bw; let bdpf = bdp as f64; let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64; // BDP estimate will be zero if no bandwidth samples are available yet. if cwnd == 0 { return self.init_cwnd; } cwnd.max(self.min_cwnd) } fn get_probe_rtt_cwnd(&self) -> u64 { const K_MODERATE_PROBE_RTT_MULTIPLIER: f32 = 0.75; if PROBE_RTT_BASED_ON_BDP { return self.get_target_cwnd(K_MODERATE_PROBE_RTT_MULTIPLIER); } self.min_cwnd } fn calculate_pacing_rate(&mut self) { let bw = self.max_bandwidth.get_estimate(); if bw == 0 { return; } let target_rate = (bw as f64 * self.pacing_gain as f64) as u64; if self.is_at_full_bandwidth { self.pacing_rate = target_rate; return; } // Pace at the rate of initial_window / RTT as soon as RTT measurements are // available. if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 { self.pacing_rate = BandwidthEstimation::bw_from_delta(self.init_cwnd, self.min_rtt).unwrap(); return; } // Do not decrease the pacing rate during startup. if self.pacing_rate < target_rate { self.pacing_rate = target_rate; } } fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) { if self.mode == Mode::ProbeRtt { return; } let mut target_window = self.get_target_cwnd(self.cwnd_gain); if self.is_at_full_bandwidth { // Add the max recently measured ack aggregation to CWND. target_window += self.ack_aggregation.max_ack_height.get(); } else { // Add the most recent excess acked. Because CWND never decreases in // STARTUP, this will automatically create a very localized max filter. target_window += excess_acked; } // Instead of immediately setting the target CWND as the new one, BBR grows // the CWND towards |target_window| by only increasing it |bytes_acked| at a // time. if self.is_at_full_bandwidth { self.cwnd = target_window.min(self.cwnd + bytes_acked); } else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.init_cwnd) { // If the connection is not yet out of startup phase, do not decrease // the window. self.cwnd += bytes_acked; } // Enforce the limits on the congestion window. if self.cwnd < self.min_cwnd { self.cwnd = self.min_cwnd; } } fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) { if !self.recovery_state.in_recovery() { return; } // Set up the initial recovery window. if self.recovery_window == 0 { self.recovery_window = self.min_cwnd.max(in_flight + bytes_acked); return; } // Remove losses from the recovery window, while accounting for a potential // integer underflow. if self.recovery_window >= bytes_lost { self.recovery_window -= bytes_lost; } else { // k_max_segment_size = current_mtu self.recovery_window = self.current_mtu; } // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, // release additional |bytes_acked| to achieve a slow-start-like behavior. if self.recovery_state == RecoveryState::Growth { self.recovery_window += bytes_acked; } // Sanity checks. Ensure that we always allow to send at least an MSS or // |bytes_acked| in response, whichever is larger. self.recovery_window = self .recovery_window .max(in_flight + bytes_acked) .max(self.min_cwnd); } /// fn check_if_full_bw_reached(&mut self, app_limited: bool) { if app_limited { return; } let target = (self.bw_at_last_round as f64 * K_STARTUP_GROWTH_TARGET as f64) as u64; let bw = self.max_bandwidth.get_estimate(); if bw >= target { self.bw_at_last_round = bw; self.round_wo_bw_gain = 0; self.ack_aggregation.max_ack_height.reset(); return; } self.round_wo_bw_gain += 1; if self.round_wo_bw_gain >= K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP as u64 || (self.recovery_state.in_recovery()) { self.is_at_full_bandwidth = true; } } } impl Controller for Bbr { fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) { self.max_sent_packet_number = last_packet_number; self.max_bandwidth.on_sent(now, bytes); } fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { self.max_bandwidth .on_ack(now, sent, bytes, self.round_count, app_limited); self.acked_bytes += bytes; if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt.min() { self.min_rtt = rtt.min(); } } fn on_end_acks( &mut self, now: Instant, in_flight: u64, app_limited: bool, largest_packet_num_acked: Option, ) { let bytes_acked = self.max_bandwidth.bytes_acked_this_window(); let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes( bytes_acked, now, self.round_count, self.max_bandwidth.get_estimate(), ); self.max_bandwidth.end_acks(self.round_count, app_limited); if let Some(largest_acked_packet) = largest_packet_num_acked { self.max_acked_packet_number = largest_acked_packet; } let mut is_round_start = false; if bytes_acked > 0 { is_round_start = self.max_acked_packet_number > self.current_round_trip_end_packet_number; if is_round_start { self.current_round_trip_end_packet_number = self.max_sent_packet_number; self.round_count += 1; } } self.update_recovery_state(is_round_start); if self.mode == Mode::ProbeBw { self.update_gain_cycle_phase(now, in_flight); } if is_round_start && !self.is_at_full_bandwidth { self.check_if_full_bw_reached(app_limited); } self.maybe_exit_startup_or_drain(now, in_flight); self.maybe_enter_or_exit_probe_rtt(now, is_round_start, in_flight, app_limited); // After the model is updated, recalculate the pacing rate and congestion window. self.calculate_pacing_rate(); self.calculate_cwnd(bytes_acked, excess_acked); self.calculate_recovery_window(bytes_acked, self.loss_state.lost_bytes, in_flight); self.prev_in_flight_count = in_flight; self.loss_state.reset(); } fn on_congestion_event( &mut self, _now: Instant, _sent: Instant, _is_persistent_congestion: bool, lost_bytes: u64, ) { self.loss_state.lost_bytes += lost_bytes; } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.min_cwnd = calculate_min_window(self.current_mtu); self.init_cwnd = self.config.initial_window.max(self.min_cwnd); self.cwnd = self.cwnd.max(self.min_cwnd); } fn window(&self) -> u64 { if self.mode == Mode::ProbeRtt { return self.get_probe_rtt_cwnd(); } else if self.recovery_state.in_recovery() && self.mode != Mode::Startup { return self.cwnd.min(self.recovery_window); } self.cwnd } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the [`Bbr`] congestion controller #[derive(Debug, Clone)] pub struct BbrConfig { initial_window: u64, } impl BbrConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } } impl Default for BbrConfig { fn default() -> Self { Self { initial_window: K_MAX_INITIAL_CONGESTION_WINDOW * BASE_DATAGRAM_SIZE, } } } impl ControllerFactory for Arc { fn build(&self, _now: Instant, current_mtu: u16) -> Box { Box::new(Bbr::new(self.clone(), current_mtu)) } } #[derive(Debug, Default, Copy, Clone)] struct AckAggregationState { max_ack_height: MinMax, aggregation_epoch_start_time: Option, aggregation_epoch_bytes: u64, } impl AckAggregationState { fn update_ack_aggregation_bytes( &mut self, newly_acked_bytes: u64, now: Instant, round: u64, max_bandwidth: u64, ) -> u64 { // Compute how many bytes are expected to be delivered, assuming max // bandwidth is correct. let expected_bytes_acked = max_bandwidth * now .saturating_duration_since(self.aggregation_epoch_start_time.unwrap_or(now)) .as_micros() as u64 / 1_000_000; // Reset the current aggregation epoch as soon as the ack arrival rate is // less than or equal to the max bandwidth. if self.aggregation_epoch_bytes <= expected_bytes_acked { // Reset to start measuring a new aggregation epoch. self.aggregation_epoch_bytes = newly_acked_bytes; self.aggregation_epoch_start_time = Some(now); return 0; } // Compute how many extra bytes were delivered vs max bandwidth. // Include the bytes most recently acknowledged to account for stretch acks. self.aggregation_epoch_bytes += newly_acked_bytes; let diff = self.aggregation_epoch_bytes - expected_bytes_acked; self.max_ack_height.update_max(round, diff); diff } } #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum Mode { // Startup phase of the connection. Startup, // After achieving the highest possible bandwidth during the startup, lower // the pacing rate in order to drain the queue. Drain, // Cruising mode. ProbeBw, // Temporarily slow down sending in order to empty the buffer and measure // the real minimum RTT. ProbeRtt, } // Indicates how the congestion control limits the amount of bytes in flight. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum RecoveryState { // Do not limit. NotInRecovery, // Allow an extra outstanding byte for each byte acknowledged. Conservation, // Allow two extra outstanding bytes for each byte acknowledged (slow // start). Growth, } impl RecoveryState { pub(super) fn in_recovery(&self) -> bool { !matches!(self, Self::NotInRecovery) } } #[derive(Debug, Clone, Default)] struct LossState { lost_bytes: u64, } impl LossState { pub(super) fn reset(&mut self) { self.lost_bytes = 0; } pub(super) fn has_losses(&self) -> bool { self.lost_bytes != 0 } } fn calculate_min_window(current_mtu: u64) -> u64 { 4 * current_mtu } // The gain used for the STARTUP, equal to 2/ln(2). const K_DEFAULT_HIGH_GAIN: f32 = 2.885; // The newly derived CWND gain for STARTUP, 2. const K_DERIVED_HIGH_CWNDGAIN: f32 = 2.0; // The cycle of gains used during the ProbeBw stage. const K_PACING_GAIN: [f32; 8] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; const K_STARTUP_GROWTH_TARGET: f32 = 1.25; const K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP: u8 = 3; // Do not allow initial congestion window to be greater than 200 packets. const K_MAX_INITIAL_CONGESTION_WINDOW: u64 = 200; const PROBE_RTT_BASED_ON_BDP: bool = true; const DRAIN_TO_TARGET: bool = true; quinn-proto-0.10.6/src/congestion/cubic.rs000064400000000000000000000177401046102023000166060ustar 00000000000000use std::any::Any; use std::sync::Arc; use std::time::{Duration, Instant}; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; use crate::connection::RttEstimator; use std::cmp; /// CUBIC Constants. /// /// These are recommended value in RFC8312. const BETA_CUBIC: f64 = 0.7; const C: f64 = 0.4; /// CUBIC State Variables. /// /// We need to keep those variables across the connection. /// k, w_max are described in the RFC. #[derive(Debug, Default, Clone)] pub(super) struct State { k: f64, w_max: f64, // Store cwnd increment during congestion avoidance. cwnd_inc: u64, } /// CUBIC Functions. /// /// Note that these calculations are based on a count of cwnd as bytes, /// not packets. /// Unit of t (duration) and RTT are based on seconds (f64). impl State { // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2) fn cubic_k(&self, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (w_max * (1.0 - BETA_CUBIC) / C).cbrt() } // W_cubic(t) = C * (t - K)^3 - w_max (Eq. 1) fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64 } // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) * // (t / RTT) (Eq. 4) fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 { let w_max = self.w_max / max_datagram_size as f64; (w_max * BETA_CUBIC + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64()) * max_datagram_size as f64 } } /// The RFC8312 congestion controller, as widely used for TCP #[derive(Debug, Clone)] pub struct Cubic { config: Arc, /// Maximum number of bytes in flight that may be sent. window: u64, /// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is /// slow start and the window grows by the number of bytes acknowledged. ssthresh: u64, /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent /// after this time is acknowledged, QUIC exits recovery. recovery_start_time: Option, cubic_state: State, current_mtu: u64, } impl Cubic { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, _now: Instant, current_mtu: u16) -> Self { Self { window: config.initial_window, ssthresh: u64::MAX, recovery_start_time: None, config, cubic_state: Default::default(), current_mtu: current_mtu as u64, } } fn minimum_window(&self) -> u64 { 2 * self.current_mtu } } impl Controller for Cubic { fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { if app_limited || self .recovery_start_time .map(|recovery_start_time| sent <= recovery_start_time) .unwrap_or(false) { return; } if self.window < self.ssthresh { // Slow start self.window += bytes; } else { // Congestion avoidance. let ca_start_time; match self.recovery_start_time { Some(t) => ca_start_time = t, None => { // When we come here without congestion_event() triggered, // initialize congestion_recovery_start_time, w_max and k. ca_start_time = now; self.recovery_start_time = Some(now); self.cubic_state.w_max = self.window as f64; self.cubic_state.k = 0.0; } } let t = now - ca_start_time; // w_cubic(t + rtt) let w_cubic = self.cubic_state.w_cubic(t + rtt.get(), self.current_mtu); // w_est(t) let w_est = self.cubic_state.w_est(t, rtt.get(), self.current_mtu); let mut cubic_cwnd = self.window; if w_cubic < w_est { // TCP friendly region. cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64); } else if cubic_cwnd < w_cubic as u64 { // Concave region or convex region use same increment. let cubic_inc = (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64; cubic_cwnd += cubic_inc as u64; } // Update the increment and increase cwnd by MSS. self.cubic_state.cwnd_inc += cubic_cwnd - self.window; // cwnd_inc can be more than 1 MSS in the late stage of max probing. // however RFC9002 §7.3.3 (Congestion Avoidance) limits // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged. if self.cubic_state.cwnd_inc >= self.current_mtu { self.window += self.current_mtu; self.cubic_state.cwnd_inc = 0; } } } fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, _lost_bytes: u64, ) { if self .recovery_start_time .map(|recovery_start_time| sent <= recovery_start_time) .unwrap_or(false) { return; } self.recovery_start_time = Some(now); // Fast convergence #[allow(clippy::branches_sharing_code)] // https://github.com/rust-lang/rust-clippy/issues/7198 if (self.window as f64) < self.cubic_state.w_max { self.cubic_state.w_max = self.window as f64 * (1.0 + BETA_CUBIC) / 2.0; } else { self.cubic_state.w_max = self.window as f64; } self.ssthresh = cmp::max( (self.cubic_state.w_max * BETA_CUBIC) as u64, self.minimum_window(), ); self.window = self.ssthresh; self.cubic_state.k = self.cubic_state.cubic_k(self.current_mtu); self.cubic_state.cwnd_inc = (self.cubic_state.cwnd_inc as f64 * BETA_CUBIC) as u64; if is_persistent_congestion { self.recovery_start_time = None; self.cubic_state.w_max = self.window as f64; // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC self.ssthresh = cmp::max( (self.window as f64 * BETA_CUBIC) as u64, self.minimum_window(), ); self.cubic_state.cwnd_inc = 0; self.window = self.minimum_window(); } } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.window = self.window.max(self.minimum_window()); } fn window(&self) -> u64 { self.window } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the `Cubic` congestion controller #[derive(Debug, Clone)] pub struct CubicConfig { initial_window: u64, } impl CubicConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } } impl Default for CubicConfig { fn default() -> Self { Self { initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE), } } } impl ControllerFactory for Arc { fn build(&self, now: Instant, current_mtu: u16) -> Box { Box::new(Cubic::new(self.clone(), now, current_mtu)) } } quinn-proto-0.10.6/src/congestion/new_reno.rs000064400000000000000000000116341046102023000173310ustar 00000000000000use std::any::Any; use std::sync::Arc; use std::time::Instant; use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE}; use crate::connection::RttEstimator; /// A simple, standard congestion controller #[derive(Debug, Clone)] pub struct NewReno { config: Arc, current_mtu: u64, /// Maximum number of bytes in flight that may be sent. window: u64, /// Slow start threshold in bytes. When the congestion window is below ssthresh, the mode is /// slow start and the window grows by the number of bytes acknowledged. ssthresh: u64, /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent /// after this time is acknowledged, QUIC exits recovery. recovery_start_time: Instant, /// Bytes which had been acked by the peer since leaving slow start bytes_acked: u64, } impl NewReno { /// Construct a state using the given `config` and current time `now` pub fn new(config: Arc, now: Instant, current_mtu: u16) -> Self { Self { window: config.initial_window, ssthresh: u64::max_value(), recovery_start_time: now, current_mtu: current_mtu as u64, config, bytes_acked: 0, } } fn minimum_window(&self) -> u64 { 2 * self.current_mtu } } impl Controller for NewReno { fn on_ack( &mut self, _now: Instant, sent: Instant, bytes: u64, app_limited: bool, _rtt: &RttEstimator, ) { if app_limited || sent <= self.recovery_start_time { return; } if self.window < self.ssthresh { // Slow start self.window += bytes; if self.window >= self.ssthresh { // Exiting slow start // Initialize `bytes_acked` for congestion avoidance. The idea // here is that any bytes over `sshthresh` will already be counted // towards the congestion avoidance phase - independent of when // how close to `sshthresh` the `window` was when switching states, // and independent of datagram sizes. self.bytes_acked = self.window - self.ssthresh; } } else { // Congestion avoidance // This implementation uses the method which does not require // floating point math, which also increases the window by 1 datagram // for every round trip. // This mechanism is called Appropriate Byte Counting in // https://tools.ietf.org/html/rfc3465 self.bytes_acked += bytes; if self.bytes_acked >= self.window { self.bytes_acked -= self.window; self.window += self.current_mtu; } } } fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, _lost_bytes: u64, ) { if sent <= self.recovery_start_time { return; } self.recovery_start_time = now; self.window = (self.window as f32 * self.config.loss_reduction_factor) as u64; self.window = self.window.max(self.minimum_window()); self.ssthresh = self.window; if is_persistent_congestion { self.window = self.minimum_window(); } } fn on_mtu_update(&mut self, new_mtu: u16) { self.current_mtu = new_mtu as u64; self.window = self.window.max(self.minimum_window()); } fn window(&self) -> u64 { self.window } fn clone_box(&self) -> Box { Box::new(self.clone()) } fn initial_window(&self) -> u64 { self.config.initial_window } fn into_any(self: Box) -> Box { self } } /// Configuration for the `NewReno` congestion controller #[derive(Debug, Clone)] pub struct NewRenoConfig { initial_window: u64, loss_reduction_factor: f32, } impl NewRenoConfig { /// Default limit on the amount of outstanding data in bytes. /// /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))` pub fn initial_window(&mut self, value: u64) -> &mut Self { self.initial_window = value; self } /// Reduction in congestion window when a new loss event is detected. pub fn loss_reduction_factor(&mut self, value: f32) -> &mut Self { self.loss_reduction_factor = value; self } } impl Default for NewRenoConfig { fn default() -> Self { Self { initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE), loss_reduction_factor: 0.5, } } } impl ControllerFactory for Arc { fn build(&self, now: Instant, current_mtu: u16) -> Box { Box::new(NewReno::new(self.clone(), now, current_mtu)) } } quinn-proto-0.10.6/src/congestion.rs000064400000000000000000000047071046102023000155200ustar 00000000000000//! Logic for controlling the rate at which data is sent use crate::connection::RttEstimator; use std::any::Any; use std::time::Instant; mod bbr; mod cubic; mod new_reno; pub use bbr::{Bbr, BbrConfig}; pub use cubic::{Cubic, CubicConfig}; pub use new_reno::{NewReno, NewRenoConfig}; /// Common interface for different congestion controllers pub trait Controller: Send { /// One or more packets were just sent #[allow(unused_variables)] fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {} /// Packet deliveries were confirmed /// /// `app_limited` indicates whether the connection was blocked on outgoing /// application data prior to receiving these acknowledgements. #[allow(unused_variables)] fn on_ack( &mut self, now: Instant, sent: Instant, bytes: u64, app_limited: bool, rtt: &RttEstimator, ) { } /// Packets are acked in batches, all with the same `now` argument. This indicates one of those batches has completed. #[allow(unused_variables)] fn on_end_acks( &mut self, now: Instant, in_flight: u64, app_limited: bool, largest_packet_num_acked: Option, ) { } /// Packets were deemed lost or marked congested /// /// `in_persistent_congestion` indicates whether all packets sent within the persistent /// congestion threshold period ending when the most recent packet in this batch was sent were /// lost. /// `lost_bytes` indicates how many bytes were lost. This value will be 0 for ECN triggers. fn on_congestion_event( &mut self, now: Instant, sent: Instant, is_persistent_congestion: bool, lost_bytes: u64, ); /// The known MTU for the current network path has been updated fn on_mtu_update(&mut self, new_mtu: u16); /// Number of ack-eliciting bytes that may be in flight fn window(&self) -> u64; /// Duplicate the controller's state fn clone_box(&self) -> Box; /// Initial congestion window fn initial_window(&self) -> u64; /// Returns Self for use in down-casting to extract implementation details fn into_any(self: Box) -> Box; } /// Constructs controllers on demand pub trait ControllerFactory { /// Construct a fresh `Controller` fn build(&self, now: Instant, current_mtu: u16) -> Box; } const BASE_DATAGRAM_SIZE: u64 = 1200; quinn-proto-0.10.6/src/connection/assembler.rs000064400000000000000000000556721046102023000174730ustar 00000000000000use std::{ cmp::Ordering, collections::{binary_heap::PeekMut, BinaryHeap}, mem, }; use bytes::{Buf, Bytes, BytesMut}; use crate::range_set::RangeSet; /// Helper to assemble unordered stream frames into an ordered stream #[derive(Debug, Default)] pub(super) struct Assembler { state: State, data: BinaryHeap, /// Total number of buffered bytes, including duplicates in ordered mode. buffered: usize, /// Estimated number of allocated bytes, will never be less than `buffered`. allocated: usize, /// Number of bytes read by the application. When only ordered reads have been used, this is the /// length of the contiguous prefix of the stream which has been consumed by the application, /// aka the stream offset. bytes_read: u64, end: u64, } impl Assembler { pub(super) fn new() -> Self { Self::default() } pub(super) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> { if ordered && !self.state.is_ordered() { return Err(IllegalOrderedRead); } else if !ordered && self.state.is_ordered() { // Enter unordered mode if !self.data.is_empty() { // Get rid of possible duplicates self.defragment(); } let mut recvd = RangeSet::new(); recvd.insert(0..self.bytes_read); for chunk in &self.data { recvd.insert(chunk.offset..chunk.offset + chunk.bytes.len() as u64); } self.state = State::Unordered { recvd }; } Ok(()) } /// Get the the next chunk pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> Option { loop { let mut chunk = self.data.peek_mut()?; if ordered { if chunk.offset > self.bytes_read { // Next chunk is after current read index return None; } else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read { // Next chunk is useless as the read index is beyond its end self.buffered -= chunk.bytes.len(); self.allocated -= chunk.allocation_size; PeekMut::pop(chunk); continue; } // Determine `start` and `len` of the slice of useful data in chunk let start = (self.bytes_read - chunk.offset) as usize; if start > 0 { chunk.bytes.advance(start); chunk.offset += start as u64; self.buffered -= start; } } return Some(if max_length < chunk.bytes.len() { self.bytes_read += max_length as u64; let offset = chunk.offset; chunk.offset += max_length as u64; self.buffered -= max_length; Chunk::new(offset, chunk.bytes.split_to(max_length)) } else { self.bytes_read += chunk.bytes.len() as u64; self.buffered -= chunk.bytes.len(); self.allocated -= chunk.allocation_size; let chunk = PeekMut::pop(chunk); Chunk::new(chunk.offset, chunk.bytes) }); } } /// Copy fragmented chunk data to new chunks backed by a single buffer /// /// This makes sure we're not unnecessarily holding on to many larger allocations. /// We merge contiguous chunks in the process of doing so. fn defragment(&mut self) { let new = BinaryHeap::with_capacity(self.data.len()); let old = mem::replace(&mut self.data, new); let mut buffers = old.into_sorted_vec(); self.buffered = 0; let mut fragmented_buffered = 0; let mut offset = 0; for chunk in buffers.iter_mut().rev() { chunk.try_mark_defragment(offset); let size = chunk.bytes.len(); offset = chunk.offset + size as u64; self.buffered += size; if !chunk.defragmented { fragmented_buffered += size; } } self.allocated = self.buffered; let mut buffer = BytesMut::with_capacity(fragmented_buffered); let mut offset = 0; for chunk in buffers.into_iter().rev() { if chunk.defragmented { // bytes might be empty after try_mark_defragment if !chunk.bytes.is_empty() { self.data.push(chunk); } continue; } // Overlap is resolved by try_mark_defragment if chunk.offset != offset + (buffer.len() as u64) { if !buffer.is_empty() { self.data .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } offset = chunk.offset; } buffer.extend_from_slice(&chunk.bytes); } if !buffer.is_empty() { self.data .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } } // Note: If a packet contains many frames from the same stream, the estimated over-allocation // will be much higher because we are counting the same allocation multiple times. pub(super) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) { debug_assert!( bytes.len() <= allocation_size, "allocation_size less than bytes.len(): {:?} < {:?}", allocation_size, bytes.len() ); self.end = self.end.max(offset + bytes.len() as u64); if let State::Unordered { ref mut recvd } = self.state { // Discard duplicate data for duplicate in recvd.replace(offset..offset + bytes.len() as u64) { if duplicate.start > offset { let buffer = Buffer::new( offset, bytes.split_to((duplicate.start - offset) as usize), allocation_size, ); self.buffered += buffer.bytes.len(); self.allocated += buffer.allocation_size; self.data.push(buffer); offset = duplicate.start; } bytes.advance((duplicate.end - offset) as usize); offset = duplicate.end; } } else if offset < self.bytes_read { if (offset + bytes.len() as u64) <= self.bytes_read { return; } else { let diff = self.bytes_read - offset; offset += diff; bytes.advance(diff as usize); } } if bytes.is_empty() { return; } let buffer = Buffer::new(offset, bytes, allocation_size); self.buffered += buffer.bytes.len(); self.allocated += buffer.allocation_size; self.data.push(buffer); // `self.buffered` also counts duplicate bytes, therefore we use // `self.end - self.bytes_read` as an upper bound of buffered unique // bytes. This will cause a defragmentation if the amount of duplicate // bytes exceedes a proportion of the receive window size. let buffered = self.buffered.min((self.end - self.bytes_read) as usize); let over_allocation = self.allocated - buffered; // Rationale: on the one hand, we want to defragment rarely, ideally never // in non-pathological scenarios. However, a pathological or malicious // peer could send us one-byte frames, and since we use reference-counted // buffers in order to prevent copying, this could result in keeping a lot // of memory allocated. This limits over-allocation in proportion to the // buffered data. The constants are chosen somewhat arbitrarily and try to // balance between defragmentation overhead and over-allocation. let threshold = 32768.max(buffered * 3 / 2); if over_allocation > threshold { self.defragment() } } pub(super) fn set_bytes_read(&mut self, new: u64) { self.bytes_read = new; } /// Number of bytes consumed by the application pub(super) fn bytes_read(&self) -> u64 { self.bytes_read } /// Discard all buffered data pub(super) fn clear(&mut self) { self.data.clear(); self.buffered = 0; self.allocated = 0; } } /// A chunk of data from the receive stream #[derive(Debug, PartialEq, Eq)] pub struct Chunk { /// The offset in the stream pub offset: u64, /// The contents of the chunk pub bytes: Bytes, } impl Chunk { fn new(offset: u64, bytes: Bytes) -> Self { Self { offset, bytes } } } #[derive(Debug, Eq)] struct Buffer { offset: u64, bytes: Bytes, /// Size of the allocation behind `bytes`, if `defragmented == false`. /// Otherwise this will be set to `bytes.len()` by `try_mark_defragment`. /// Will never be less than `bytes.len()`. allocation_size: usize, defragmented: bool, } impl Buffer { /// Constructs a new fragmented Buffer fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self { Self { offset, bytes, allocation_size, defragmented: false, } } /// Constructs a new defragmented Buffer fn new_defragmented(offset: u64, bytes: Bytes) -> Self { let allocation_size = bytes.len(); Self { offset, bytes, allocation_size, defragmented: true, } } /// Discards data before `offset` and flags `self` as defragmented if it has good utilization fn try_mark_defragment(&mut self, offset: u64) { let duplicate = offset.saturating_sub(self.offset) as usize; self.offset = self.offset.max(offset); if duplicate >= self.bytes.len() { // All bytes are duplicate self.bytes = Bytes::new(); self.defragmented = true; self.allocation_size = 0; return; } self.bytes.advance(duplicate); // Make sure that fragmented buffers with high utilization become defragmented and // defragmented buffers remain defragmented self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size; if self.defragmented { // Make sure that defragmented buffers do not contribute to over-allocation self.allocation_size = self.bytes.len(); } } } impl Ord for Buffer { // Invert ordering based on offset (max-heap, min offset first), // prioritize longer chunks at the same offset. fn cmp(&self, other: &Self) -> Ordering { self.offset .cmp(&other.offset) .reverse() .then(self.bytes.len().cmp(&other.bytes.len())) } } impl PartialOrd for Buffer { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for Buffer { fn eq(&self, other: &Self) -> bool { (self.offset, self.bytes.len()) == (other.offset, other.bytes.len()) } } #[derive(Debug)] enum State { Ordered, Unordered { /// The set of offsets that have been received from the peer, including portions not yet /// read by the application. recvd: RangeSet, }, } impl State { fn is_ordered(&self) -> bool { matches!(self, Self::Ordered) } } impl Default for State { fn default() -> Self { Self::Ordered } } /// Error indicating that an ordered read was performed on a stream after an unordered read #[derive(Debug)] pub struct IllegalOrderedRead; #[cfg(test)] mod test { use super::*; use assert_matches::assert_matches; #[test] fn assemble_ordered() { let mut x = Assembler::new(); assert_matches!(next(&mut x, 32), None); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1"); assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23"); x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); x.insert(6, Bytes::from_static(b"789"), 3); x.insert(9, Bytes::from_static(b"10"), 2); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_unordered() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), None); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_duplicate() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_duplicate_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(0, Bytes::from_static(b"123"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contained() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"12345"), 5); x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contained_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"12345"), 5); x.insert(1, Bytes::from_static(b"234"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contains() { let mut x = Assembler::new(); x.insert(1, Bytes::from_static(b"234"), 3); x.insert(0, Bytes::from_static(b"12345"), 5); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_contains_compact() { let mut x = Assembler::new(); x.insert(1, Bytes::from_static(b"234"), 3); x.insert(0, Bytes::from_static(b"12345"), 5); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_overlapping() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 3); x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_overlapping_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"123"), 4); x.insert(1, Bytes::from_static(b"234"), 4); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_complex() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1"), 1); x.insert(2, Bytes::from_static(b"3"), 1); x.insert(4, Bytes::from_static(b"5"), 1); x.insert(0, Bytes::from_static(b"123456"), 6); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_complex_compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1"), 1); x.insert(2, Bytes::from_static(b"3"), 1); x.insert(4, Bytes::from_static(b"5"), 1); x.insert(0, Bytes::from_static(b"123456"), 6); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } #[test] fn assemble_old() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), None); } #[test] fn compact() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 4); x.insert(3, Bytes::from_static(b"def"), 4); x.insert(9, Bytes::from_static(b"jkl"), 4); x.insert(12, Bytes::from_static(b"mno"), 4); x.defragment(); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abcdef")) ); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jklmno")) ); } #[test] fn defrag_with_missing_prefix() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 3); x.defragment(); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); } #[test] fn defrag_read_chunk() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 4); x.insert(0, Bytes::from_static(b"abc"), 4); x.insert(7, Bytes::from_static(b"hij"), 4); x.insert(11, Bytes::from_static(b"lmn"), 4); x.defragment(); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef"); x.insert(5, Bytes::from_static(b"fghijklmn"), 9); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn"); x.insert(13, Bytes::from_static(b"nopq"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq"); x.insert(15, Bytes::from_static(b"pqrs"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs"); assert_matches!(x.read(usize::MAX, true), None); } #[test] fn unordered_happy_path() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abc")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn unordered_dedup() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(0, Bytes::from_static(b"a"), 1); x.insert(0, Bytes::from_static(b"abcdefghi"), 9); x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"a")) ); assert_eq!( next_unordered(&mut x), Chunk::new(1, Bytes::from_static(b"bc")) ); assert_eq!( next_unordered(&mut x), Chunk::new(6, Bytes::from_static(b"ghi")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jkl")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(12, Bytes::from_static(b"mno")) ); assert_eq!(x.read(usize::MAX, false), None); x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn chunks_dedup() { let mut x = Assembler::new(); x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!(x.read(usize::MAX, true), None); x.insert(0, Bytes::from_static(b"a"), 1); x.insert(1, Bytes::from_static(b"bcdefghi"), 9); x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abcd"))) ); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(4, Bytes::from_static(b"efghi"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(9, Bytes::from_static(b"jkl"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(12, Bytes::from_static(b"mno"))) ); assert_eq!(x.read(usize::MAX, true), None); x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, true), None); } #[test] fn ordered_eager_discard() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!(x.data.len(), 1); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); x.insert(0, Bytes::from_static(b"ab"), 2); assert_eq!(x.data.len(), 0); x.insert(2, Bytes::from_static(b"cd"), 2); assert_eq!( x.data.peek(), Some(&Buffer::new(3, Bytes::from_static(b"d"), 2)) ); } #[test] fn ordered_insert_unordered_read() { let mut x = Assembler::new(); x.insert(0, Bytes::from_static(b"abc"), 3); x.insert(0, Bytes::from_static(b"abc"), 3); x.ensure_ordering(false).unwrap(); assert_eq!( x.read(3, false), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); assert_eq!(x.read(3, false), None); } fn next_unordered(x: &mut Assembler) -> Chunk { x.read(usize::MAX, false).unwrap() } fn next(x: &mut Assembler, size: usize) -> Option { x.read(size, true).map(|chunk| chunk.bytes) } } quinn-proto-0.10.6/src/connection/cid_state.rs000064400000000000000000000203731046102023000174430ustar 00000000000000//! Maintain the state of local connection IDs use std::{ collections::VecDeque, time::{Duration, Instant}, }; use rustc_hash::FxHashSet; use tracing::{debug, trace}; use crate::{shared::IssuedCid, TransportError}; /// Local connection ID management pub(super) struct CidState { /// Timestamp when issued cids should be retired retire_timestamp: VecDeque, /// Number of local connection IDs that have been issued in NEW_CONNECTION_ID frames. issued: u64, /// Sequence numbers of local connection IDs not yet retired by the peer active_seq: FxHashSet, /// Sequence number the peer has already retired all CIDs below at our request via `retire_prior_to` prev_retire_seq: u64, /// Sequence number to set in retire_prior_to field in NEW_CONNECTION_ID frame retire_seq: u64, /// cid length used to decode short packet cid_len: usize, //// cid lifetime cid_lifetime: Option, } impl CidState { pub(crate) fn new(cid_len: usize, cid_lifetime: Option, now: Instant) -> Self { let mut active_seq = FxHashSet::default(); // Add sequence number of CID used in handshaking into tracking set active_seq.insert(0); let mut this = Self { retire_timestamp: VecDeque::new(), issued: 1, // One CID is already supplied during handshaking active_seq, prev_retire_seq: 0, retire_seq: 0, cid_len, cid_lifetime, }; // Track lifetime of cid used in handshaking this.track_lifetime(0, now); this } /// Find the next timestamp when previously issued CID should be retired pub(crate) fn next_timeout(&mut self) -> Option { self.retire_timestamp.front().map(|nc| { trace!("CID {} will expire at {:?}", nc.sequence, nc.timestamp); nc.timestamp }) } /// Track the lifetime of issued cids in `retire_timestamp` fn track_lifetime(&mut self, new_cid_seq: u64, now: Instant) { let lifetime = match self.cid_lifetime { Some(lifetime) => lifetime, None => return, }; let expire_timestamp = now.checked_add(lifetime); let expire_at = match expire_timestamp { Some(expire_at) => expire_at, None => return, }; let last_record = self.retire_timestamp.back_mut(); if let Some(last) = last_record { // Compare the timestamp with the last inserted record // Combine into a single batch if timestamp of current cid is same as the last record if expire_at == last.timestamp { debug_assert!(new_cid_seq > last.sequence); last.sequence = new_cid_seq; return; } } self.retire_timestamp.push_back(CidTimestamp { sequence: new_cid_seq, timestamp: expire_at, }); } /// Update local CID state when previously issued CID is retired /// /// Return whether a new CID needs to be pushed that notifies remote peer to respond `RETIRE_CONNECTION_ID` pub(crate) fn on_cid_timeout(&mut self) -> bool { // Whether the peer hasn't retired all the CIDs we asked it to yet let unretired_ids_found = (self.prev_retire_seq..self.retire_seq).any(|seq| self.active_seq.contains(&seq)); let current_retire_prior_to = self.retire_seq; let next_retire_sequence = self .retire_timestamp .pop_front() .map(|seq| seq.sequence + 1); // According to RFC: // Endpoints SHOULD NOT issue updates of the Retire Prior To field // before receiving RETIRE_CONNECTION_ID frames that retire all // connection IDs indicated by the previous Retire Prior To value. // https://tools.ietf.org/html/draft-ietf-quic-transport-29#section-5.1.2 if !unretired_ids_found { // All Cids are retired, `prev_retire_cid_seq` can be assigned to `retire_cid_seq` self.prev_retire_seq = self.retire_seq; // Advance `retire_seq` if next cid that needs to be retired exists if let Some(next_retire_prior_to) = next_retire_sequence { self.retire_seq = next_retire_prior_to; } } // Check if retirement of all CIDs that reach their lifetime is still needed // According to RFC: // An endpoint MUST NOT // provide more connection IDs than the peer's limit. An endpoint MAY // send connection IDs that temporarily exceed a peer's limit if the // NEW_CONNECTION_ID frame also requires the retirement of any excess, // by including a sufficiently large value in the Retire Prior To field. // // If yes (return true), a new CID must be pushed with updated `retire_prior_to` field to remote peer. // If no (return false), it means CIDs that reach the end of lifetime have been retired already. Do not push a new CID in order to avoid violating above RFC. (current_retire_prior_to..self.retire_seq).any(|seq| self.active_seq.contains(&seq)) } /// Update cid state when `NewIdentifiers` event is received pub(crate) fn new_cids(&mut self, ids: &[IssuedCid], now: Instant) { // `ids` could be `None` once active_connection_id_limit is set to 1 by peer let last_cid = match ids.last() { Some(cid) => cid, None => return, }; self.issued += ids.len() as u64; // Record the timestamp of CID with the largest seq number let sequence = last_cid.sequence; ids.iter().for_each(|frame| { self.active_seq.insert(frame.sequence); }); self.track_lifetime(sequence, now); } /// Update CidState for receipt of a `RETIRE_CONNECTION_ID` frame /// /// Returns whether a new CID can be issued, or an error if the frame was illegal. pub(crate) fn on_cid_retirement( &mut self, sequence: u64, limit: u64, ) -> Result { if self.cid_len == 0 { return Err(TransportError::PROTOCOL_VIOLATION( "RETIRE_CONNECTION_ID when CIDs aren't in use", )); } if sequence > self.issued { debug!( sequence, "got RETIRE_CONNECTION_ID for unissued sequence number" ); return Err(TransportError::PROTOCOL_VIOLATION( "RETIRE_CONNECTION_ID for unissued sequence number", )); } self.active_seq.remove(&sequence); // Consider a scenario where peer A has active remote cid 0,1,2. // Peer B first send a NEW_CONNECTION_ID with cid 3 and retire_prior_to set to 1. // Peer A processes this NEW_CONNECTION_ID frame; update remote cid to 1,2,3 // and meanwhile send a RETIRE_CONNECTION_ID to retire cid 0 to peer B. // If peer B doesn't check the cid limit here and send a new cid again, peer A will then face CONNECTION_ID_LIMIT_ERROR Ok(limit > self.active_seq.len() as u64) } /// Length of local Connection IDs pub(crate) fn cid_len(&self) -> usize { self.cid_len } /// The value for `retire_prior_to` field in `NEW_CONNECTION_ID` frame pub(crate) fn retire_prior_to(&self) -> u64 { self.retire_seq } #[cfg(test)] pub(crate) fn active_seq(&self) -> (u64, u64) { let mut min = u64::MAX; let mut max = u64::MIN; for n in self.active_seq.iter() { if n < &min { min = *n; } if n > &max { max = *n; } } (min, max) } #[cfg(test)] pub(crate) fn assign_retire_seq(&mut self, v: u64) -> u64 { // Cannot retire more CIDs than what have been issued debug_assert!(v <= *self.active_seq.iter().max().unwrap() + 1); let n = v.checked_sub(self.retire_seq).unwrap(); self.retire_seq = v; n } } /// Data structure that records when issued cids should be retired #[derive(Copy, Clone, Eq, PartialEq)] struct CidTimestamp { /// Highest cid sequence number created in a batch sequence: u64, /// Timestamp when cid needs to be retired timestamp: Instant, } quinn-proto-0.10.6/src/connection/datagrams.rs000064400000000000000000000134701046102023000174470ustar 00000000000000use std::collections::VecDeque; use bytes::{Bytes, BytesMut}; use thiserror::Error; use tracing::{debug, trace}; use super::Connection; use crate::{ frame::{Datagram, FrameStruct}, packet::SpaceId, TransportError, }; /// API to control datagram traffic pub struct Datagrams<'a> { pub(super) conn: &'a mut Connection, } impl<'a> Datagrams<'a> { /// Queue an unreliable, unordered datagram for immediate transmission /// /// Returns `Err` iff a `len`-byte datagram cannot currently be sent pub fn send(&mut self, data: Bytes) -> Result<(), SendDatagramError> { if self.conn.config.datagram_receive_buffer_size.is_none() { return Err(SendDatagramError::Disabled); } let max = self .max_size() .ok_or(SendDatagramError::UnsupportedByPeer)?; while self.conn.datagrams.outgoing_total > self.conn.config.datagram_send_buffer_size { let prev = self .conn .datagrams .outgoing .pop_front() .expect("datagrams.outgoing_total desynchronized"); trace!(len = prev.data.len(), "dropping outgoing datagram"); self.conn.datagrams.outgoing_total -= prev.data.len(); } if data.len() > max { return Err(SendDatagramError::TooLarge); } self.conn.datagrams.outgoing_total += data.len(); self.conn.datagrams.outgoing.push_back(Datagram { data }); Ok(()) } /// Compute the maximum size of datagrams that may passed to `send_datagram` /// /// Returns `None` if datagrams are unsupported by the peer or disabled locally. /// /// This may change over the lifetime of a connection according to variation in the path MTU /// estimate. The peer can also enforce an arbitrarily small fixed limit, but if the peer's /// limit is large this is guaranteed to be a little over a kilobyte at minimum. /// /// Not necessarily the maximum size of received datagrams. pub fn max_size(&self) -> Option { let max_size = self.conn.path.current_mtu() as usize - 1 // flags byte - self.conn.rem_cids.active().len() - 4 // worst-case packet number size - self.conn.spaces[SpaceId::Data].crypto.as_ref().map_or_else(|| &self.conn.zero_rtt_crypto.as_ref().unwrap().packet, |x| &x.packet.local).tag_len() - Datagram::SIZE_BOUND; let limit = self .conn .peer_params .max_datagram_frame_size? .into_inner() .saturating_sub(Datagram::SIZE_BOUND as u64); Some(limit.min(max_size as u64) as usize) } /// Receive an unreliable, unordered datagram pub fn recv(&mut self) -> Option { self.conn.datagrams.recv() } /// Bytes available in the outgoing datagram buffer /// /// When greater than zero, [`send`](Self::send)ing a datagram of at most this size is /// guaranteed not to cause older datagrams to be dropped. pub fn send_buffer_space(&self) -> usize { self.conn .config .datagram_send_buffer_size .saturating_sub(self.conn.datagrams.outgoing_total) } } #[derive(Default)] pub(super) struct DatagramState { /// Number of bytes of datagrams that have been received by the local transport but not /// delivered to the application pub(super) recv_buffered: usize, pub(super) incoming: VecDeque, pub(super) outgoing: VecDeque, pub(super) outgoing_total: usize, } impl DatagramState { pub(super) fn received( &mut self, datagram: Datagram, window: &Option, ) -> Result { let window = match window { None => { return Err(TransportError::PROTOCOL_VIOLATION( "unexpected DATAGRAM frame", )); } Some(x) => *x, }; if datagram.data.len() > window { return Err(TransportError::PROTOCOL_VIOLATION("oversized datagram")); } let was_empty = self.recv_buffered == 0; while datagram.data.len() + self.recv_buffered > window { debug!("dropping stale datagram"); self.recv(); } self.recv_buffered += datagram.data.len(); self.incoming.push_back(datagram); Ok(was_empty) } pub(super) fn write(&mut self, buf: &mut BytesMut, max_size: usize) -> bool { let datagram = match self.outgoing.pop_front() { Some(x) => x, None => return false, }; if buf.len() + datagram.size(true) > max_size { // Future work: we could be more clever about cramming small datagrams into // mostly-full packets when a larger one is queued first self.outgoing.push_front(datagram); return false; } self.outgoing_total -= datagram.data.len(); datagram.encode(true, buf); true } pub(super) fn recv(&mut self) -> Option { let x = self.incoming.pop_front()?.data; self.recv_buffered -= x.len(); Some(x) } } /// Errors that can arise when sending a datagram #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum SendDatagramError { /// The peer does not support receiving datagram frames #[error("datagrams not supported by peer")] UnsupportedByPeer, /// Datagram support is disabled locally #[error("datagram support disabled")] Disabled, /// The datagram is larger than the connection can currently accommodate /// /// Indicates that the path MTU minus overhead or the limit advertised by the peer has been /// exceeded. #[error("datagram too large")] TooLarge, } quinn-proto-0.10.6/src/connection/mod.rs000064400000000000000000004121611046102023000162630ustar 00000000000000use std::{ cmp, collections::VecDeque, convert::TryFrom, fmt, io, mem, net::{IpAddr, SocketAddr}, sync::Arc, time::{Duration, Instant}, }; use bytes::{Bytes, BytesMut}; use frame::StreamMetaVec; use rand::{rngs::StdRng, Rng, SeedableRng}; use thiserror::Error; use tracing::{debug, error, trace, trace_span, warn}; use crate::{ cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, HeaderKey, KeyPair, Keys, PacketKey}, frame, frame::{Close, Datagram, FrameStruct}, packet::{Header, LongType, Packet, PartialDecode, SpaceId}, range_set::ArrayRangeSet, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent, EndpointEventInner, }, token::ResetToken, transport_parameters::TransportParameters, Dir, EndpointConfig, Frame, Side, StreamId, Transmit, TransportError, TransportErrorCode, VarInt, MAX_STREAM_COUNT, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, TIMER_GRANULARITY, }; mod assembler; pub use assembler::Chunk; mod cid_state; use cid_state::CidState; mod datagrams; use datagrams::DatagramState; pub use datagrams::{Datagrams, SendDatagramError}; mod mtud; mod pacing; mod packet_builder; use packet_builder::PacketBuilder; mod paths; use paths::PathData; pub use paths::RttEstimator; mod send_buffer; mod spaces; #[cfg(fuzzing)] pub use spaces::Retransmits; #[cfg(not(fuzzing))] use spaces::Retransmits; use spaces::{PacketSpace, SendableFrames, SentPacket, ThinRetransmits}; mod stats; pub use stats::{ConnectionStats, FrameStats, PathStats, UdpStats}; mod streams; #[cfg(fuzzing)] pub use streams::StreamsState; #[cfg(not(fuzzing))] use streams::StreamsState; //pub(crate) use streams::{ByteSlice, BytesArray}; pub use streams::{ BytesSource, Chunks, FinishError, ReadError, ReadableError, RecvStream, SendStream, StreamEvent, Streams, UnknownStream, WriteError, Written, }; mod timer; use crate::congestion::Controller; use timer::{Timer, TimerTable}; /// Protocol state and logic for a single QUIC connection /// /// Objects of this type receive [`ConnectionEvent`]s and emit [`EndpointEvent`]s and application /// [`Event`]s to make progress. To handle timeouts, a `Connection` returns timer updates and /// expects timeouts through various methods. A number of simple getter methods are exposed /// to allow callers to inspect some of the connection state. /// /// `Connection` has roughly 4 types of methods: /// /// - A. Simple getters, taking `&self` /// - B. Handlers for incoming events from the network or system, named `handle_*`. /// - C. State machine mutators, for incoming commands from the application. For convenience we /// refer to this as "performing I/O" below, however as per the design of this library none of the /// functions actually perform system-level I/O. For example, [`read`](RecvStream::read) and /// [`write`](SendStream::write), but also things like [`reset`](SendStream::reset). /// - D. Polling functions for outgoing events or actions for the caller to /// take, named `poll_*`. /// /// The simplest way to use this API correctly is to call (B) and (C) whenever /// appropriate, then after each of those calls, as soon as feasible call all /// polling methods (D) and deal with their outputs appropriately, e.g. by /// passing it to the application or by making a system-level I/O call. You /// should call the polling functions in this order: /// /// 1. [`poll_transmit`](Self::poll_transmit) /// 2. [`poll_timeout`](Self::poll_timeout) /// 3. [`poll_endpoint_events`](Self::poll_endpoint_events) /// 4. [`poll`](Self::poll) /// /// Currently the only actual dependency is from (2) to (1), however additional /// dependencies may be added in future, so the above order is recommended. /// /// (A) may be called whenever desired. /// /// Care should be made to ensure that the input events represent monotonically /// increasing time. Specifically, calling [`handle_timeout`](Self::handle_timeout) /// with events of the same [`Instant`] may be interleaved in any order with a /// call to [`handle_event`](Self::handle_event) at that same instant; however /// events or timeouts with different instants must not be interleaved. pub struct Connection { endpoint_config: Arc, server_config: Option>, config: Arc, rng: StdRng, crypto: Box, /// The CID we initially chose, for use during the handshake handshake_cid: ConnectionId, /// The CID the peer initially chose, for use during the handshake rem_handshake_cid: ConnectionId, /// The "real" local IP address which was was used to receive the initial packet. /// This is only populated for the server case, and if known local_ip: Option, path: PathData, prev_path: Option, state: State, side: Side, /// Whether or not 0-RTT was enabled during the handshake. Does not imply acceptance. zero_rtt_enabled: bool, /// Set if 0-RTT is supported, then cleared when no longer needed. zero_rtt_crypto: Option, key_phase: bool, /// Transport parameters set by the peer peer_params: TransportParameters, /// Source ConnectionId of the first packet received from the peer orig_rem_cid: ConnectionId, /// Destination ConnectionId sent by the client on the first Initial initial_dst_cid: ConnectionId, /// The value that the server included in the Source Connection ID field of a Retry packet, if /// one was received retry_src_cid: Option, /// Total number of outgoing packets that have been deemed lost lost_packets: u64, events: VecDeque, endpoint_events: VecDeque, /// Whether the spin bit is in use for this connection spin_enabled: bool, /// Outgoing spin bit state spin: bool, /// Packet number spaces: initial, handshake, 1-RTT spaces: [PacketSpace; 3], /// Highest usable packet number space highest_space: SpaceId, /// 1-RTT keys used prior to a key update prev_crypto: Option, /// 1-RTT keys to be used for the next key update /// /// These are generated in advance to prevent timing attacks and/or DoS by third-party attackers /// spoofing key updates. next_crypto: Option>>, accepted_0rtt: bool, /// Whether the idle timer should be reset the next time an ack-eliciting packet is transmitted. permit_idle_reset: bool, /// Negotiated idle timeout idle_timeout: Option, timers: TimerTable, /// Number of packets received which could not be authenticated authentication_failures: u64, /// Why the connection was lost, if it has been error: Option, /// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are /// discarded. retry_token: Bytes, // // Queued non-retransmittable 1-RTT data // path_response: Option, close: bool, // // Loss Detection // /// The number of times a PTO has been sent without receiving an ack. pto_count: u32, // // Congestion Control // /// Summary statistics of packets that have been sent, but not yet acked or deemed lost in_flight: InFlight, /// Whether the most recently received packet had an ECN codepoint set receiving_ecn: bool, /// Number of packets authenticated total_authed_packets: u64, /// Whether the last `poll_transmit` call yielded no data because there was /// no outgoing application data. app_limited: bool, streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, // Attributes of CIDs generated by local peer local_cid_state: CidState, /// State of the unreliable datagram extension datagrams: DatagramState, /// Connection level statistics stats: ConnectionStats, /// QUIC version used for the connection. version: u32, } impl Connection { pub(crate) fn new( endpoint_config: Arc, server_config: Option>, config: Arc, init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, remote: SocketAddr, local_ip: Option, crypto: Box, cid_gen: &dyn ConnectionIdGenerator, now: Instant, version: u32, allow_mtud: bool, ) -> Self { let side = if server_config.is_some() { Side::Server } else { Side::Client }; let initial_space = PacketSpace { crypto: Some(crypto.initial_keys(&init_cid, side)), ..PacketSpace::new(now) }; let state = State::Handshake(state::Handshake { rem_cid_set: side.is_server(), expected_token: Bytes::new(), client_hello: None, }); let mut rng = StdRng::from_entropy(); let path_validated = server_config.as_ref().map_or(true, |c| c.use_retry); let mut this = Self { endpoint_config, server_config, crypto, handshake_cid: loc_cid, rem_handshake_cid: rem_cid, local_cid_state: CidState::new(cid_gen.cid_len(), cid_gen.cid_lifetime(), now), path: PathData::new( remote, config.initial_rtt, config .congestion_controller_factory .build(now, config.get_initial_mtu()), config.get_initial_mtu(), config.min_mtu, None, match allow_mtud { true => config.mtu_discovery_config.clone(), false => None, }, now, path_validated, ), local_ip, prev_path: None, side, state, zero_rtt_enabled: false, zero_rtt_crypto: None, key_phase: false, peer_params: TransportParameters::default(), orig_rem_cid: rem_cid, initial_dst_cid: init_cid, retry_src_cid: None, lost_packets: 0, events: VecDeque::new(), endpoint_events: VecDeque::new(), spin_enabled: config.allow_spin && rng.gen_ratio(7, 8), spin: false, spaces: [initial_space, PacketSpace::new(now), PacketSpace::new(now)], highest_space: SpaceId::Initial, prev_crypto: None, next_crypto: None, accepted_0rtt: false, permit_idle_reset: true, idle_timeout: config.max_idle_timeout, timers: TimerTable::default(), authentication_failures: 0, error: None, retry_token: Bytes::new(), path_response: None, close: false, pto_count: 0, app_limited: false, in_flight: InFlight::new(), receiving_ecn: false, total_authed_packets: 0, streams: StreamsState::new( side, config.max_concurrent_uni_streams, config.max_concurrent_bidi_streams, config.send_window, config.receive_window, config.stream_receive_window, ), datagrams: DatagramState::default(), config, rem_cids: CidQueue::new(rem_cid), rng, stats: ConnectionStats::default(), version, }; if side.is_client() { // Kick off the connection this.write_crypto(); this.init_0rtt(); } this } /// Returns the next time at which `handle_timeout` should be called /// /// The value returned may change after: /// - the application performed some I/O on the connection /// - a call was made to `handle_event` /// - a call to `poll_transmit` returned `Some` /// - a call was made to `handle_timeout` #[must_use] pub fn poll_timeout(&mut self) -> Option { self.timers.next_timeout() } /// Returns application-facing events /// /// Connections should be polled for events after: /// - a call was made to `handle_event` /// - a call was made to `handle_timeout` #[must_use] pub fn poll(&mut self) -> Option { if let Some(x) = self.events.pop_front() { return Some(x); } if let Some(event) = self.streams.poll() { return Some(Event::Stream(event)); } if let Some(err) = self.error.take() { return Some(Event::ConnectionLost { reason: err }); } None } /// Return endpoint-facing events #[must_use] pub fn poll_endpoint_events(&mut self) -> Option { self.endpoint_events.pop_front().map(EndpointEvent) } /// Provide control over streams #[must_use] pub fn streams(&mut self) -> Streams<'_> { Streams { state: &mut self.streams, conn_state: &self.state, } } /// Provide control over streams #[must_use] pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> { assert!(id.dir() == Dir::Bi || id.initiator() != self.side); RecvStream { id, state: &mut self.streams, pending: &mut self.spaces[SpaceId::Data].pending, } } /// Provide control over streams #[must_use] pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> { assert!(id.dir() == Dir::Bi || id.initiator() == self.side); SendStream { id, state: &mut self.streams, pending: &mut self.spaces[SpaceId::Data].pending, conn_state: &self.state, } } /// Returns packets to transmit /// /// Connections should be polled for transmit after: /// - the application performed some I/O on the connection /// - a call was made to `handle_event` /// - a call was made to `handle_timeout` /// /// `max_datagrams` specifies how many datagrams can be returned inside a /// single Transmit using GSO. This must be at least 1. #[must_use] pub fn poll_transmit(&mut self, now: Instant, max_datagrams: usize) -> Option { assert!(max_datagrams != 0); let max_datagrams = match self.config.enable_segmentation_offload { false => 1, true => max_datagrams.min(MAX_TRANSMIT_SEGMENTS), }; let mut num_datagrams = 0; // Send PATH_CHALLENGE for a previous path if necessary if let Some(ref mut prev_path) = self.prev_path { if prev_path.challenge_pending { prev_path.challenge_pending = false; let token = prev_path .challenge .expect("previous path challenge pending without token"); let destination = prev_path.remote; debug_assert_eq!( self.highest_space, SpaceId::Data, "PATH_CHALLENGE queued without 1-RTT keys" ); let mut buf = BytesMut::with_capacity(self.path.current_mtu() as usize); let buf_capacity = self.path.current_mtu() as usize; let mut builder = PacketBuilder::new( now, SpaceId::Data, &mut buf, buf_capacity, 0, false, self, self.version, )?; trace!("validating previous path with PATH_CHALLENGE {:08x}", token); buf.write(frame::Type::PATH_CHALLENGE); buf.write(token); self.stats.frame_tx.path_challenge += 1; // An endpoint MUST expand datagrams that contain a PATH_CHALLENGE frame // to at least the smallest allowed maximum datagram size of 1200 bytes, // unless the anti-amplification limit for the path does not permit // sending a datagram of this size builder.pad_to(MIN_INITIAL_SIZE); builder.finish(self, &mut buf); self.stats.udp_tx.datagrams += 1; self.stats.udp_tx.transmits += 1; self.stats.udp_tx.bytes += buf.len() as u64; return Some(Transmit { destination, contents: buf.freeze(), ecn: None, segment_size: None, src_ip: self.local_ip, }); } } // If we need to send a probe, make sure we have something to send. for space in SpaceId::iter() { self.spaces[space].maybe_queue_probe(&self.streams); } // Check whether we need to send a close message let close = match self.state { State::Drained => { self.app_limited = true; return None; } State::Draining | State::Closed(_) => { // self.close is only reset once the associated packet had been // encoded successfully if !self.close { self.app_limited = true; return None; } true } _ => false, }; let mut buf = BytesMut::new(); // Reserving capacity can provide more capacity than we asked for. // However we are not allowed to write more than MTU size. Therefore // the maximum capacity is tracked separately. let mut buf_capacity = 0; let mut coalesce = true; let mut builder: Option = None; let mut sent_frames = None; let mut pad_datagram = false; let mut congestion_blocked = false; // Iterate over all spaces and find data to send let mut space_idx = 0; let spaces = [SpaceId::Initial, SpaceId::Handshake, SpaceId::Data]; // This loop will potentially spend multiple iterations in the same `SpaceId`, // so we cannot trivially rewrite it to take advantage of `SpaceId::iter()`. while space_idx < spaces.len() { let space_id = spaces[space_idx]; if close && space_id != self.highest_space { // We ignore data in this space, since the close message // has higher priority space_idx += 1; continue; } // Is there data or a close message to send in this space? let can_send = self.space_can_send(space_id); if can_send.is_empty() && !close { space_idx += 1; continue; } let mut ack_eliciting = !self.spaces[space_id].pending.is_empty(&self.streams) || self.spaces[space_id].ping_pending; if space_id == SpaceId::Data { ack_eliciting |= self.can_send_1rtt(); } // Can we append more data into the current buffer? // It is not safe to assume that `buf.len()` is the end of the data, // since the last packet might not have been finished. let buf_end = if let Some(builder) = &builder { buf.len().max(builder.min_size) + builder.tag_len } else { buf.len() }; if !coalesce || buf_capacity - buf_end < MIN_PACKET_SPACE { // We need to send 1 more datagram and extend the buffer for that. // Is 1 more datagram allowed? if buf_capacity >= self.path.current_mtu() as usize * max_datagrams { // No more datagrams allowed break; } // Anti-amplification is only based on `total_sent`, which gets // updated at the end of this method. Therefore we pass the amount // of bytes for datagrams that are already created, as well as 1 byte // for starting another datagram. If there is any anti-amplification // budget left, we always allow a full MTU to be sent // (see https://github.com/quinn-rs/quinn/issues/1082) if self.path.anti_amplification_blocked( self.path.current_mtu() as u64 * num_datagrams as u64 + 1, ) { trace!("blocked by anti-amplification"); break; } // Congestion control and pacing checks // Tail loss probes must not be blocked by congestion, or a deadlock could arise if ack_eliciting && self.spaces[space_id].loss_probes == 0 { // Assume the current packet will get padded to fill the full MTU let untracked_bytes = if let Some(builder) = &builder { buf_capacity - builder.partial_encode.start } else { 0 } as u64; debug_assert!(untracked_bytes <= self.path.current_mtu() as u64); let bytes_to_send = u64::from(self.path.current_mtu()) + untracked_bytes; if self.in_flight.bytes + bytes_to_send >= self.path.congestion.window() { space_idx += 1; congestion_blocked = true; // We continue instead of breaking here in order to avoid // blocking loss probes queued for higher spaces. continue; } // Check whether the next datagram is blocked by pacing let smoothed_rtt = self.path.rtt.get(); if let Some(delay) = self.path.pacing.delay( smoothed_rtt, bytes_to_send, self.path.current_mtu(), self.path.congestion.window(), now, ) { self.timers.set(Timer::Pacing, delay); congestion_blocked = true; // Loss probes should be subject to pacing, even though // they are not congestion controlled. break; } } // Finish current packet if let Some(mut builder) = builder.take() { // Pad the packet to make it suitable for sending with GSO // which will always send the maximum PDU. builder.pad_to(self.path.current_mtu()); builder.finish_and_track(now, self, sent_frames.take(), &mut buf); debug_assert_eq!(buf.len(), buf_capacity, "Packet must be padded"); } // Allocate space for another datagram buf_capacity += self.path.current_mtu() as usize; if buf.capacity() < buf_capacity { // We reserve the maximum space for sending `max_datagrams` upfront // to avoid any reallocations if more datagrams have to be appended later on. // Benchmarks have shown shown a 5-10% throughput improvement // compared to continuously resizing the datagram buffer. // While this will lead to over-allocation for small transmits // (e.g. purely containing ACKs), modern memory allocators // (e.g. mimalloc and jemalloc) will pool certain allocation sizes // and therefore this is still rather efficient. buf.reserve(max_datagrams * self.path.current_mtu() as usize - buf.capacity()); } num_datagrams += 1; coalesce = true; pad_datagram = false; } else { // We can append/coalesce the next packet into the current // datagram. // Finish current packet without adding extra padding if let Some(builder) = builder.take() { builder.finish_and_track(now, self, sent_frames.take(), &mut buf); } } debug_assert!(buf_capacity - buf.len() >= MIN_PACKET_SPACE); // // From here on, we've determined that a packet will definitely be sent. // if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake && self.side.is_client() { // A client stops both sending and processing Initial packets when it // sends its first Handshake packet. self.discard_space(now, SpaceId::Initial); } if let Some(ref mut prev) = self.prev_crypto { prev.update_unacked = false; } debug_assert!( builder.is_none() && sent_frames.is_none(), "Previous packet must have been finished" ); // This should really be `builder.insert()`, but `Option::insert` // is not stable yet. Since we `debug_assert!(builder.is_none())` it // doesn't make any functional difference. let builder = builder.get_or_insert(PacketBuilder::new( now, space_id, &mut buf, buf_capacity, (num_datagrams - 1) * (self.path.current_mtu() as usize), ack_eliciting, self, self.version, )?); coalesce = coalesce && !builder.short_header; // https://tools.ietf.org/html/draft-ietf-quic-transport-34#section-14.1 pad_datagram |= space_id == SpaceId::Initial && (self.side.is_client() || ack_eliciting); if close { trace!("sending CONNECTION_CLOSE"); // Encode ACKs before the ConnectionClose message, to give the receiver // a better approximate on what data has been processed. This is // especially important with ack delay, since the peer might not // have gotten any other ACK for the data earlier on. if !self.spaces[space_id].pending_acks.ranges().is_empty() { Self::populate_acks( self.receiving_ecn, &mut SentFrames::default(), &mut self.spaces[space_id], &mut buf, &mut self.stats, ); } // Since there only 64 ACK frames there will always be enough space // to encode the ConnectionClose frame too. However we still have the // check here to prevent crashes if something changes. debug_assert!( buf.len() + frame::ConnectionClose::SIZE_BOUND < builder.max_size, "ACKs should leave space for ConnectionClose" ); if buf.len() + frame::ConnectionClose::SIZE_BOUND < builder.max_size { match self.state { State::Closed(state::Closed { ref reason }) => { if space_id == SpaceId::Data { reason.encode(&mut buf, builder.max_size) } else { frame::ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, frame_type: None, reason: Bytes::new(), } .encode(&mut buf, builder.max_size) } } State::Draining => frame::ConnectionClose { error_code: TransportErrorCode::NO_ERROR, frame_type: None, reason: Bytes::new(), } .encode(&mut buf, builder.max_size), _ => unreachable!( "tried to make a close packet when the connection wasn't closed" ), } } // Don't send another close packet self.close = false; // `CONNECTION_CLOSE` is the final packet break; } let sent = self.populate_packet(space_id, &mut buf, buf_capacity - builder.tag_len); // ACK-only packets should only be sent when explicitly allowed. If we write them due // to any other reason, there is a bug which leads to one component announcing write // readiness while not writing any data. This degrades performance. The condition is // only checked if the full MTU is available, so that lack of space in the datagram isn't // the reason for just writing ACKs. debug_assert!( !(sent.is_ack_only(&self.streams) && !can_send.acks && can_send.other && (buf_capacity - builder.datagram_start) == self.path.current_mtu() as usize), "SendableFrames was {can_send:?}, but only ACKs have been written" ); pad_datagram |= sent.requires_padding; if sent.largest_acked.is_some() { self.spaces[space_id].pending_acks.acks_sent(); } // Keep information about the packet around until it gets finalized sent_frames = Some(sent); // Don't increment space_idx. // We stay in the current space and check if there is more data to send. } // Finish the last packet if let Some(mut builder) = builder { if pad_datagram { builder.pad_to(MIN_INITIAL_SIZE); } let last_packet_number = builder.exact_number; builder.finish_and_track(now, self, sent_frames, &mut buf); self.path .congestion .on_sent(now, buf.len() as u64, last_packet_number); } self.app_limited = buf.is_empty() && !congestion_blocked; // Send MTU probe if necessary if buf.is_empty() && self.state.is_established() { let space_id = SpaceId::Data; let probe_size = match self .path .mtud .poll_transmit(now, self.spaces[space_id].next_packet_number) { Some(next_probe_size) => next_probe_size, None => return None, }; let buf_capacity = probe_size as usize; buf.reserve(buf_capacity); let mut builder = PacketBuilder::new( now, space_id, &mut buf, buf_capacity, 0, true, self, self.version, )?; // We implement MTU probes as ping packets padded up to the probe size buf.write(frame::Type::PING); builder.pad_to(probe_size); let sent_frames = SentFrames { non_retransmits: true, ..Default::default() }; builder.finish_and_track(now, self, Some(sent_frames), &mut buf); self.stats.frame_tx.ping += 1; self.stats.path.sent_plpmtud_probes += 1; num_datagrams = 1; trace!(?probe_size, "writing MTUD probe"); } if buf.is_empty() { return None; } trace!("sending {} bytes in {} datagrams", buf.len(), num_datagrams); self.path.total_sent = self.path.total_sent.saturating_add(buf.len() as u64); self.stats.udp_tx.datagrams += num_datagrams as u64; self.stats.udp_tx.bytes += buf.len() as u64; self.stats.udp_tx.transmits += 1; Some(Transmit { destination: self.path.remote, contents: buf.freeze(), ecn: if self.path.sending_ecn { Some(EcnCodepoint::Ect0) } else { None }, segment_size: match num_datagrams { 1 => None, _ => Some(self.path.current_mtu() as usize), }, src_ip: self.local_ip, }) } /// Indicate what types of frames are ready to send for the given space fn space_can_send(&self, space_id: SpaceId) -> SendableFrames { if self.spaces[space_id].crypto.is_some() { let can_send = self.spaces[space_id].can_send(&self.streams); if !can_send.is_empty() { return can_send; } } if space_id != SpaceId::Data { return SendableFrames::empty(); } if self.spaces[space_id].crypto.is_some() && self.can_send_1rtt() { return SendableFrames { other: true, acks: false, }; } if self.zero_rtt_crypto.is_some() && self.side.is_client() { let mut can_send = self.spaces[space_id].can_send(&self.streams); can_send.other |= self.can_send_1rtt(); if !can_send.is_empty() { return can_send; } } SendableFrames::empty() } /// Process `ConnectionEvent`s generated by the associated `Endpoint` /// /// Will execute protocol logic upon receipt of a connection event, in turn preparing signals /// (including application `Event`s, `EndpointEvent`s and outgoing datagrams) that should be /// extracted through the relevant methods. pub fn handle_event(&mut self, event: ConnectionEvent) { use self::ConnectionEventInner::*; match event.0 { Datagram { now, remote, ecn, first_decode, remaining, } => { // If this packet could initiate a migration and we're a client or a server that // forbids migration, drop the datagram. This could be relaxed to heuristically // permit NAT-rebinding-like migration. if remote != self.path.remote && self.server_config.as_ref().map_or(true, |x| !x.migration) { trace!("discarding packet from unrecognized peer {}", remote); return; } let was_anti_amplification_blocked = self.path.anti_amplification_blocked(1); self.stats.udp_rx.datagrams += 1; self.stats.udp_rx.bytes += first_decode.len() as u64; let data_len = first_decode.len(); self.handle_decode(now, remote, ecn, first_decode); // The current `path` might have changed inside `handle_decode`, // since the packet could have triggered a migration. Make sure // the data received is accounted for the most recent path by accessing // `path` after `handle_decode`. self.path.total_recvd = self.path.total_recvd.saturating_add(data_len as u64); if let Some(data) = remaining { self.stats.udp_rx.bytes += data.len() as u64; self.handle_coalesced(now, remote, ecn, data); } if was_anti_amplification_blocked { // A prior attempt to set the loss detection timer may have failed due to // anti-amplification, so ensure it's set now. Prevents a handshake deadlock if // the server's first flight is lost. self.set_loss_detection_timer(now); } } NewIdentifiers(ids, now) => { self.local_cid_state.new_cids(&ids, now); ids.into_iter().rev().for_each(|frame| { self.spaces[SpaceId::Data].pending.new_cids.push(frame); }); // Update Timer::PushNewCid if self .timers .get(Timer::PushNewCid) .map_or(true, |x| x <= now) { self.reset_cid_retirement(); } } } } /// Process timer expirations /// /// Executes protocol logic, potentially preparing signals (including application `Event`s, /// `EndpointEvent`s and outgoing datagrams) that should be extracted through the relevant /// methods. /// /// It is most efficient to call this immediately after the system clock reaches the latest /// `Instant` that was output by `poll_timeout`; however spurious extra calls will simply /// no-op and therefore are safe. pub fn handle_timeout(&mut self, now: Instant) { for &timer in &Timer::VALUES { if !self.timers.is_expired(timer, now) { continue; } self.timers.stop(timer); trace!(timer = ?timer, "timeout"); match timer { Timer::Close => { self.state = State::Drained; self.endpoint_events.push_back(EndpointEventInner::Drained); } Timer::Idle => { self.kill(ConnectionError::TimedOut); } Timer::KeepAlive => { trace!("sending keep-alive"); self.ping(); } Timer::LossDetection => { self.on_loss_detection_timeout(now); } Timer::KeyDiscard => { self.zero_rtt_crypto = None; self.prev_crypto = None; } Timer::PathValidation => { debug!("path validation failed"); if let Some(prev) = self.prev_path.take() { self.path = prev; } self.path.challenge = None; self.path.challenge_pending = false; } Timer::Pacing => trace!("pacing timer expired"), Timer::PushNewCid => { // Update `retire_prior_to` field in NEW_CONNECTION_ID frame let num_new_cid = self.local_cid_state.on_cid_timeout().into(); if !self.state.is_closed() { trace!( "push a new cid to peer RETIRE_PRIOR_TO field {}", self.local_cid_state.retire_prior_to() ); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid)); } } } } } /// Close a connection immediately /// /// This does not ensure delivery of outstanding data. It is the application's responsibility to /// call this only when all important communications have been completed, e.g. by calling /// [`SendStream::finish`] on outstanding streams and waiting for the corresponding /// [`StreamEvent::Finished`] event. /// /// If [`Streams::send_streams`] returns 0, all outstanding stream data has been /// delivered. There may still be data from the peer that has not been received. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished pub fn close(&mut self, now: Instant, error_code: VarInt, reason: Bytes) { self.close_inner( now, Close::Application(frame::ApplicationClose { error_code, reason }), ) } fn close_inner(&mut self, now: Instant, reason: Close) { let was_closed = self.state.is_closed(); if !was_closed { self.close_common(); self.set_close_timer(now); self.close = true; self.state = State::Closed(state::Closed { reason }); } } /// Control datagrams pub fn datagrams(&mut self) -> Datagrams<'_> { Datagrams { conn: self } } /// Returns connection statistics pub fn stats(&self) -> ConnectionStats { let mut stats = self.stats; stats.path.rtt = self.path.rtt.get(); stats.path.cwnd = self.path.congestion.window(); stats } /// Ping the remote endpoint /// /// Causes an ACK-eliciting packet to be transmitted. pub fn ping(&mut self) { self.spaces[self.highest_space].ping_pending = true; } #[doc(hidden)] pub fn initiate_key_update(&mut self) { self.update_keys(None, false); } /// Get a session reference pub fn crypto_session(&self) -> &dyn crypto::Session { &*self.crypto } /// Whether the connection is in the process of being established /// /// If this returns `false`, the connection may be either established or closed, signaled by the /// emission of a `Connected` or `ConnectionLost` message respectively. pub fn is_handshaking(&self) -> bool { self.state.is_handshake() } /// Whether the connection is closed /// /// Closed connections cannot transport any further data. A connection becomes closed when /// either peer application intentionally closes it, or when either transport layer detects an /// error such as a time-out or certificate validation failure. /// /// A `ConnectionLost` event is emitted with details when the connection becomes closed. pub fn is_closed(&self) -> bool { self.state.is_closed() } /// Whether there is no longer any need to keep the connection around /// /// Closed connections become drained after a brief timeout to absorb any remaining in-flight /// packets from the peer. All drained connections have been closed. pub fn is_drained(&self) -> bool { self.state.is_drained() } /// For clients, if the peer accepted the 0-RTT data packets /// /// The value is meaningless until after the handshake completes. pub fn accepted_0rtt(&self) -> bool { self.accepted_0rtt } /// Whether 0-RTT is/was possible during the handshake pub fn has_0rtt(&self) -> bool { self.zero_rtt_enabled } /// Whether there are any pending retransmits pub fn has_pending_retransmits(&self) -> bool { !self.spaces[SpaceId::Data].pending.is_empty(&self.streams) } /// Look up whether we're the client or server of this Connection pub fn side(&self) -> Side { self.side } /// The latest socket address for this connection's peer pub fn remote_address(&self) -> SocketAddr { self.path.remote } /// The local IP address which was used when the peer established /// the connection /// /// This can be different from the address the endpoint is bound to, in case /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`. /// /// This will return `None` for clients. /// /// Retrieving the local IP address is currently supported on the following /// platforms: /// - Linux /// /// On all non-supported platforms the local IP address will not be available, /// and the method will return `None`. pub fn local_ip(&self) -> Option { self.local_ip } /// Current best estimate of this connection's latency (round-trip-time) pub fn rtt(&self) -> Duration { self.path.rtt.get() } /// Current state of this connection's congestion controller, for debugging purposes pub fn congestion_state(&self) -> &dyn Controller { self.path.congestion.as_ref() } /// Modify the number of remotely initiated streams that may be concurrently open /// /// No streams may be opened by the peer unless fewer than `count` are already open. Large /// `count`s increase both minimum and worst-case memory consumption. pub fn set_max_concurrent_streams(&mut self, dir: Dir, count: VarInt) { self.streams.set_max_concurrent(dir, count); } /// Current number of remotely initiated streams that may be concurrently open /// /// If the target for this limit is reduced using [`set_max_concurrent_streams`](Self::set_max_concurrent_streams), /// it will not change immediately, even if fewer streams are open. Instead, it will /// decrement by one for each time a remotely initiated stream of matching directionality is closed. pub fn max_concurrent_streams(&self, dir: Dir) -> u64 { self.streams.max_concurrent(dir) } /// See [`TransportConfig::receive_window()`] pub fn set_receive_window(&mut self, receive_window: VarInt) { if self.streams.set_receive_window(receive_window) { self.spaces[SpaceId::Data].pending.max_data = true; } } fn on_ack_received( &mut self, now: Instant, space: SpaceId, ack: frame::Ack, ) -> Result<(), TransportError> { if ack.largest >= self.spaces[space].next_packet_number { return Err(TransportError::PROTOCOL_VIOLATION("unsent packet acked")); } let new_largest = { let space = &mut self.spaces[space]; if space .largest_acked_packet .map_or(true, |pn| ack.largest > pn) { space.largest_acked_packet = Some(ack.largest); if let Some(info) = space.sent_packets.get(&ack.largest) { // This should always succeed, but a misbehaving peer might ACK a packet we // haven't sent. At worst, that will result in us spuriously reducing the // congestion window. space.largest_acked_packet_sent = info.time_sent; } true } else { false } }; // Avoid DoS from unreasonably huge ack ranges by filtering out just the new acks. let mut newly_acked = ArrayRangeSet::new(); for range in ack.iter() { for (&pn, _) in self.spaces[space].sent_packets.range(range) { newly_acked.insert_one(pn); } } if newly_acked.is_empty() { return Ok(()); } let mut ack_eliciting_acked = false; for packet in newly_acked.elts() { if let Some(info) = self.spaces[space].sent_packets.remove(&packet) { if let Some(acked) = info.largest_acked { // Assume ACKs for all packets below the largest acknowledged in `packet` have // been received. This can cause the peer to spuriously retransmit if some of // our earlier ACKs were lost, but allows for simpler state tracking. See // discussion at // https://www.rfc-editor.org/rfc/rfc9000.html#name-limiting-ranges-by-tracking self.spaces[space].pending_acks.subtract_below(acked); } ack_eliciting_acked |= info.ack_eliciting; // Notify MTU discovery that a packet was acked, because it might be an MTU probe let mtu_updated = self.path.mtud.on_acked(space, packet, info.size); if mtu_updated { self.path .congestion .on_mtu_update(self.path.mtud.current_mtu()); } self.on_packet_acked(now, space, info); } } self.path.congestion.on_end_acks( now, self.in_flight.bytes, self.app_limited, self.spaces[space].largest_acked_packet, ); if new_largest && ack_eliciting_acked { let ack_delay = if space != SpaceId::Data { Duration::from_micros(0) } else { cmp::min( self.max_ack_delay(), Duration::from_micros(ack.delay << self.peer_params.ack_delay_exponent.0), ) }; let rtt = instant_saturating_sub(now, self.spaces[space].largest_acked_packet_sent); self.path.rtt.update(ack_delay, rtt); if self.path.first_packet_after_rtt_sample.is_none() { self.path.first_packet_after_rtt_sample = Some((space, self.spaces[space].next_packet_number)); } } // Must be called before crypto/pto_count are clobbered self.detect_lost_packets(now, space, true); if self.peer_completed_address_validation() { self.pto_count = 0; } // Explicit congestion notification if self.path.sending_ecn { if let Some(ecn) = ack.ecn { // We only examine ECN counters from ACKs that we are certain we received in transmit // order, allowing us to compute an increase in ECN counts to compare against the number // of newly acked packets that remains well-defined in the presence of arbitrary packet // reordering. if new_largest { let sent = self.spaces[space].largest_acked_packet_sent; self.process_ecn(now, space, newly_acked.len() as u64, ecn, sent); } } else { // We always start out sending ECN, so any ack that doesn't acknowledge it disables it. debug!("ECN not acknowledged by peer"); self.path.sending_ecn = false; } } self.set_loss_detection_timer(now); Ok(()) } /// Process a new ECN block from an in-order ACK fn process_ecn( &mut self, now: Instant, space: SpaceId, newly_acked: u64, ecn: frame::EcnCounts, largest_sent_time: Instant, ) { match self.spaces[space].detect_ecn(newly_acked, ecn) { Err(e) => { debug!("halting ECN due to verification failure: {}", e); self.path.sending_ecn = false; // Wipe out the existing value because it might be garbage and could interfere with // future attempts to use ECN on new paths. self.spaces[space].ecn_feedback = frame::EcnCounts::ZERO; } Ok(false) => {} Ok(true) => { self.stats.path.congestion_events += 1; self.path .congestion .on_congestion_event(now, largest_sent_time, false, 0); } } } // Not timing-aware, so it's safe to call this for inferred acks, such as arise from // high-latency handshakes fn on_packet_acked(&mut self, now: Instant, space: SpaceId, info: SentPacket) { self.remove_in_flight(space, &info); if info.ack_eliciting && self.path.challenge.is_none() { // Only pass ACKs to the congestion controller if we are not validating the current // path, so as to ignore any ACKs from older paths still coming in. self.path.congestion.on_ack( now, info.time_sent, info.size.into(), self.app_limited, &self.path.rtt, ); } // Update state for confirmed delivery of frames if let Some(retransmits) = info.retransmits.get() { for (id, _) in retransmits.reset_stream.iter() { self.streams.reset_acked(*id); } } for frame in info.stream_frames { self.streams.received_ack_of(frame); } } fn set_key_discard_timer(&mut self, now: Instant, space: SpaceId) { let start = if self.zero_rtt_crypto.is_some() { now } else { self.prev_crypto .as_ref() .expect("no previous keys") .end_packet .as_ref() .expect("update not acknowledged yet") .1 }; self.timers .set(Timer::KeyDiscard, start + self.pto(space) * 3); } fn on_loss_detection_timeout(&mut self, now: Instant) { if let Some((_, pn_space)) = self.loss_time_and_space() { // Time threshold loss Detection self.detect_lost_packets(now, pn_space, false); self.set_loss_detection_timer(now); return; } let (_, space) = match self.pto_time_and_space(now) { Some(x) => x, None => { error!("PTO expired while unset"); return; } }; trace!( in_flight = self.in_flight.bytes, count = self.pto_count, ?space, "PTO fired" ); let count = match self.in_flight.ack_eliciting { // A PTO when we're not expecting any ACKs must be due to handshake anti-amplification // deadlock preventions 0 => { debug_assert!(!self.peer_completed_address_validation()); 1 } // Conventional loss probe _ => 2, }; self.spaces[space].loss_probes = self.spaces[space].loss_probes.saturating_add(count); self.pto_count = self.pto_count.saturating_add(1); self.set_loss_detection_timer(now); } fn detect_lost_packets(&mut self, now: Instant, pn_space: SpaceId, due_to_ack: bool) { let mut lost_packets = Vec::::new(); let mut lost_mtu_probe = None; let in_flight_mtu_probe = self.path.mtud.in_flight_mtu_probe(); let rtt = self.path.rtt.conservative(); let loss_delay = cmp::max(rtt.mul_f32(self.config.time_threshold), TIMER_GRANULARITY); // Packets sent before this time are deemed lost. let lost_send_time = now.checked_sub(loss_delay).unwrap(); let largest_acked_packet = self.spaces[pn_space].largest_acked_packet.unwrap(); let packet_threshold = self.config.packet_threshold as u64; let mut size_of_lost_packets = 0u64; // InPersistentCongestion: Determine if all packets in the time period before the newest // lost packet, including the edges, are marked lost. PTO computation must always // include max ACK delay, i.e. operate as if in Data space (see RFC9001 §7.6.1). let congestion_period = self.pto(SpaceId::Data) * self.config.persistent_congestion_threshold; let mut persistent_congestion_start: Option = None; let mut prev_packet = None; let mut in_persistent_congestion = false; let space = &mut self.spaces[pn_space]; space.loss_time = None; for (&packet, info) in space.sent_packets.range(0..largest_acked_packet) { if prev_packet != Some(packet.wrapping_sub(1)) { // An intervening packet was acknowledged persistent_congestion_start = None; } if info.time_sent <= lost_send_time || largest_acked_packet >= packet + packet_threshold { if Some(packet) == in_flight_mtu_probe { // Lost MTU probes are not included in `lost_packets`, because they should not // trigger a congestion control response lost_mtu_probe = in_flight_mtu_probe; } else { lost_packets.push(packet); size_of_lost_packets += info.size as u64; if info.ack_eliciting && due_to_ack { match persistent_congestion_start { // Two ACK-eliciting packets lost more than congestion_period apart, with no // ACKed packets in between Some(start) if info.time_sent - start > congestion_period => { in_persistent_congestion = true; } // Persistent congestion must start after the first RTT sample None if self .path .first_packet_after_rtt_sample .map_or(false, |x| x < (pn_space, packet)) => { persistent_congestion_start = Some(info.time_sent); } _ => {} } } } } else { let next_loss_time = info.time_sent + loss_delay; space.loss_time = Some( space .loss_time .map_or(next_loss_time, |x| cmp::min(x, next_loss_time)), ); persistent_congestion_start = None; } prev_packet = Some(packet); } // OnPacketsLost if let Some(largest_lost) = lost_packets.last().cloned() { let old_bytes_in_flight = self.in_flight.bytes; let largest_lost_sent = self.spaces[pn_space].sent_packets[&largest_lost].time_sent; self.lost_packets += lost_packets.len() as u64; self.stats.path.lost_packets += lost_packets.len() as u64; self.stats.path.lost_bytes += size_of_lost_packets; trace!( "packets lost: {:?}, bytes lost: {}", lost_packets, size_of_lost_packets ); for packet in &lost_packets { let info = self.spaces[pn_space].sent_packets.remove(packet).unwrap(); // safe: lost_packets is populated just above self.remove_in_flight(pn_space, &info); for frame in info.stream_frames { self.streams.retransmit(frame); } self.spaces[pn_space].pending |= info.retransmits; self.path.mtud.on_non_probe_lost(*packet, info.size); } if self.path.mtud.black_hole_detected(now) { self.stats.path.black_holes_detected += 1; } // Don't apply congestion penalty for lost ack-only packets let lost_ack_eliciting = old_bytes_in_flight != self.in_flight.bytes; if lost_ack_eliciting { self.stats.path.congestion_events += 1; self.path.congestion.on_congestion_event( now, largest_lost_sent, in_persistent_congestion, size_of_lost_packets, ); } } // Handle a lost MTU probe if let Some(packet) = lost_mtu_probe { let info = self.spaces[SpaceId::Data] .sent_packets .remove(&packet) .unwrap(); // safe: lost_mtu_probe is omitted from lost_packets, and therefore must not have been removed yet self.remove_in_flight(SpaceId::Data, &info); self.path.mtud.on_probe_lost(); self.stats.path.lost_plpmtud_probes += 1; } } fn loss_time_and_space(&self) -> Option<(Instant, SpaceId)> { SpaceId::iter() .filter_map(|id| Some((self.spaces[id].loss_time?, id))) .min_by_key(|&(time, _)| time) } fn pto_time_and_space(&self, now: Instant) -> Option<(Instant, SpaceId)> { let backoff = 2u32.pow(self.pto_count.min(MAX_BACKOFF_EXPONENT)); let mut duration = self.path.rtt.pto_base() * backoff; if self.in_flight.ack_eliciting == 0 { debug_assert!(!self.peer_completed_address_validation()); let space = match self.highest_space { SpaceId::Handshake => SpaceId::Handshake, _ => SpaceId::Initial, }; return Some((now + duration, space)); } let mut result = None; for space in SpaceId::iter() { if self.spaces[space].in_flight == 0 { continue; } if space == SpaceId::Data { // Skip ApplicationData until handshake completes. if self.is_handshaking() { return result; } // Include max_ack_delay and backoff for ApplicationData. duration += self.max_ack_delay() * backoff; } let last_ack_eliciting = match self.spaces[space].time_of_last_ack_eliciting_packet { Some(time) => time, None => continue, }; let pto = last_ack_eliciting + duration; if result.map_or(true, |(earliest_pto, _)| pto < earliest_pto) { result = Some((pto, space)); } } result } #[allow(clippy::suspicious_operation_groupings)] fn peer_completed_address_validation(&self) -> bool { if self.side.is_server() || self.state.is_closed() { return true; } // The server is guaranteed to have validated our address if any of our handshake or 1-RTT // packets are acknowledged or we've seen HANDSHAKE_DONE and discarded handshake keys. self.spaces[SpaceId::Handshake] .largest_acked_packet .is_some() || self.spaces[SpaceId::Data].largest_acked_packet.is_some() || (self.spaces[SpaceId::Data].crypto.is_some() && self.spaces[SpaceId::Handshake].crypto.is_none()) } fn set_loss_detection_timer(&mut self, now: Instant) { if let Some((loss_time, _)) = self.loss_time_and_space() { // Time threshold loss detection. self.timers.set(Timer::LossDetection, loss_time); return; } if self.path.anti_amplification_blocked(1) { // We wouldn't be able to send anything, so don't bother. self.timers.stop(Timer::LossDetection); return; } if self.in_flight.ack_eliciting == 0 && self.peer_completed_address_validation() { // There is nothing to detect lost, so no timer is set. However, the client needs to arm // the timer if the server might be blocked by the anti-amplification limit. self.timers.stop(Timer::LossDetection); return; } // Determine which PN space to arm PTO for. // Calculate PTO duration if let Some((timeout, _)) = self.pto_time_and_space(now) { self.timers.set(Timer::LossDetection, timeout); } else { self.timers.stop(Timer::LossDetection); } } /// Probe Timeout fn pto(&self, space: SpaceId) -> Duration { let max_ack_delay = match space { SpaceId::Initial | SpaceId::Handshake => Duration::new(0, 0), SpaceId::Data => self.max_ack_delay(), }; self.path.rtt.pto_base() + max_ack_delay } fn on_packet_authenticated( &mut self, now: Instant, space_id: SpaceId, ecn: Option, packet: Option, spin: bool, is_1rtt: bool, ) { self.total_authed_packets += 1; self.reset_keep_alive(now); self.reset_idle_timeout(now, space_id); self.permit_idle_reset = true; self.receiving_ecn |= ecn.is_some(); if let Some(x) = ecn { self.spaces[space_id].ecn_counters += x; } let packet = match packet { Some(x) => x, None => return, }; if self.side.is_server() { if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake { // A server stops sending and processing Initial packets when it receives its first Handshake packet. self.discard_space(now, SpaceId::Initial); } if self.zero_rtt_crypto.is_some() && is_1rtt { // Discard 0-RTT keys soon after receiving a 1-RTT packet self.set_key_discard_timer(now, space_id) } } let space = &mut self.spaces[space_id]; space.pending_acks.insert_one(packet, now); if packet >= space.rx_packet { space.rx_packet = packet; // Update outgoing spin bit, inverting iff we're the client self.spin = self.side.is_client() ^ spin; } } fn reset_idle_timeout(&mut self, now: Instant, space: SpaceId) { let timeout = match self.idle_timeout { None => return, Some(x) => Duration::from_millis(x.0), }; if self.state.is_closed() { self.timers.stop(Timer::Idle); return; } let dt = cmp::max(timeout, 3 * self.pto(space)); self.timers.set(Timer::Idle, now + dt); } fn reset_keep_alive(&mut self, now: Instant) { let interval = match self.config.keep_alive_interval { Some(x) if self.state.is_established() => x, _ => return, }; self.timers.set(Timer::KeepAlive, now + interval); } fn reset_cid_retirement(&mut self) { if let Some(t) = self.local_cid_state.next_timeout() { self.timers.set(Timer::PushNewCid, t); } } /// Handle the already-decrypted first packet from the client /// /// Decrypting the first packet in the `Endpoint` allows stateless packet handling to be more /// efficient. pub(crate) fn handle_first_packet( &mut self, now: Instant, remote: SocketAddr, ecn: Option, packet_number: u64, packet: Packet, remaining: Option, ) -> Result<(), ConnectionError> { let span = trace_span!("first recv"); let _guard = span.enter(); debug_assert!(self.side.is_server()); let len = packet.header_data.len() + packet.payload.len(); self.path.total_recvd = len as u64; match self.state { State::Handshake(ref mut state) => match packet.header { Header::Initial { ref token, .. } => { state.expected_token = token.clone(); } _ => unreachable!("first packet must be an Initial packet"), }, _ => unreachable!("first packet must be delivered in Handshake state"), } self.on_packet_authenticated( now, SpaceId::Initial, ecn, Some(packet_number), false, false, ); self.process_decrypted_packet(now, remote, Some(packet_number), packet)?; if let Some(data) = remaining { self.handle_coalesced(now, remote, ecn, data); } Ok(()) } fn init_0rtt(&mut self) { let (header, packet) = match self.crypto.early_crypto() { Some(x) => x, None => return, }; if self.side.is_client() { match self.crypto.transport_parameters() { Ok(params) => { let params = params .expect("crypto layer didn't supply transport parameters with ticket"); // Certain values must not be cached let params = TransportParameters { initial_src_cid: None, original_dst_cid: None, preferred_address: None, retry_src_cid: None, stateless_reset_token: None, ack_delay_exponent: TransportParameters::default().ack_delay_exponent, max_ack_delay: TransportParameters::default().max_ack_delay, ..params }; self.set_peer_params(params); } Err(e) => { error!("session ticket has malformed transport parameters: {}", e); return; } } } trace!("0-RTT enabled"); self.zero_rtt_enabled = true; self.zero_rtt_crypto = Some(ZeroRttCrypto { header, packet }); } fn read_crypto( &mut self, space: SpaceId, crypto: &frame::Crypto, payload_len: usize, ) -> Result<(), TransportError> { let expected = if !self.state.is_handshake() { SpaceId::Data } else if self.highest_space == SpaceId::Initial { SpaceId::Initial } else { // On the server, self.highest_space can be Data after receiving the client's first // flight, but we expect Handshake CRYPTO until the handshake is complete. SpaceId::Handshake }; // We can't decrypt Handshake packets when highest_space is Initial, CRYPTO frames in 0-RTT // packets are illegal, and we don't process 1-RTT packets until the handshake is // complete. Therefore, we will never see CRYPTO data from a later-than-expected space. debug_assert!(space <= expected, "received out-of-order CRYPTO data"); let end = crypto.offset + crypto.data.len() as u64; if space < expected && end > self.spaces[space].crypto_stream.bytes_read() { warn!( "received new {:?} CRYPTO data when expecting {:?}", space, expected ); return Err(TransportError::PROTOCOL_VIOLATION( "new data at unexpected encryption level", )); } let space = &mut self.spaces[space]; let max = end.saturating_sub(space.crypto_stream.bytes_read()); if max > self.config.crypto_buffer_size as u64 { return Err(TransportError::CRYPTO_BUFFER_EXCEEDED("")); } space .crypto_stream .insert(crypto.offset, crypto.data.clone(), payload_len); while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) { trace!("consumed {} CRYPTO bytes", chunk.bytes.len()); if self.crypto.read_handshake(&chunk.bytes)? { self.events.push_back(Event::HandshakeDataReady); } } Ok(()) } fn write_crypto(&mut self) { loop { let space = self.highest_space; let mut outgoing = Vec::new(); if let Some(crypto) = self.crypto.write_handshake(&mut outgoing) { match space { SpaceId::Initial => { self.upgrade_crypto(SpaceId::Handshake, crypto); } SpaceId::Handshake => { self.upgrade_crypto(SpaceId::Data, crypto); } _ => unreachable!("got updated secrets during 1-RTT"), } } if outgoing.is_empty() { if space == self.highest_space { break; } else { // Keys updated, check for more data to send continue; } } let offset = self.spaces[space].crypto_offset; let outgoing = Bytes::from(outgoing); if let State::Handshake(ref mut state) = self.state { if space == SpaceId::Initial && offset == 0 && self.side.is_client() { state.client_hello = Some(outgoing.clone()); } } self.spaces[space].crypto_offset += outgoing.len() as u64; trace!("wrote {} {:?} CRYPTO bytes", outgoing.len(), space); self.spaces[space].pending.crypto.push_back(frame::Crypto { offset, data: outgoing, }); } } /// Switch to stronger cryptography during handshake fn upgrade_crypto(&mut self, space: SpaceId, crypto: Keys) { debug_assert!( self.spaces[space].crypto.is_none(), "already reached packet space {space:?}" ); trace!("{:?} keys ready", space); if space == SpaceId::Data { // Precompute the first key update self.next_crypto = Some( self.crypto .next_1rtt_keys() .expect("handshake should be complete"), ); } self.spaces[space].crypto = Some(crypto); debug_assert!(space as usize > self.highest_space as usize); self.highest_space = space; if space == SpaceId::Data && self.side.is_client() { // Discard 0-RTT keys because 1-RTT keys are available. self.zero_rtt_crypto = None; } } fn discard_space(&mut self, now: Instant, space_id: SpaceId) { debug_assert!(space_id != SpaceId::Data); trace!("discarding {:?} keys", space_id); if space_id == SpaceId::Initial { // No longer needed self.retry_token = Bytes::new(); } let space = &mut self.spaces[space_id]; space.crypto = None; space.time_of_last_ack_eliciting_packet = None; space.loss_time = None; let sent_packets = mem::take(&mut space.sent_packets); for (_, packet) in sent_packets.into_iter() { self.remove_in_flight(space_id, &packet); } self.set_loss_detection_timer(now) } fn handle_coalesced( &mut self, now: Instant, remote: SocketAddr, ecn: Option, data: BytesMut, ) { self.path.total_recvd = self.path.total_recvd.saturating_add(data.len() as u64); let mut remaining = Some(data); while let Some(data) = remaining { match PartialDecode::new( data, self.local_cid_state.cid_len(), &[self.version], self.endpoint_config.grease_quic_bit, ) { Ok((partial_decode, rest)) => { remaining = rest; self.handle_decode(now, remote, ecn, partial_decode); } Err(e) => { trace!("malformed header: {}", e); return; } } } } fn handle_decode( &mut self, now: Instant, remote: SocketAddr, ecn: Option, partial_decode: PartialDecode, ) { let header_crypto = if partial_decode.is_0rtt() { if let Some(ref crypto) = self.zero_rtt_crypto { Some(&*crypto.header) } else { debug!("dropping unexpected 0-RTT packet"); return; } } else if let Some(space) = partial_decode.space() { if let Some(ref crypto) = self.spaces[space].crypto { Some(&*crypto.header.remote) } else { debug!( "discarding unexpected {:?} packet ({} bytes)", space, partial_decode.len(), ); return; } } else { // Unprotected packet None }; let packet = partial_decode.data(); let stateless_reset = packet.len() >= RESET_TOKEN_SIZE + 5 && self.peer_params.stateless_reset_token.as_deref() == Some(&packet[packet.len() - RESET_TOKEN_SIZE..]); match partial_decode.finish(header_crypto) { Ok(packet) => self.handle_packet(now, remote, ecn, Some(packet), stateless_reset), Err(_) if stateless_reset => self.handle_packet(now, remote, ecn, None, true), Err(e) => { trace!("unable to complete packet decoding: {}", e); } } } fn handle_packet( &mut self, now: Instant, remote: SocketAddr, ecn: Option, packet: Option, stateless_reset: bool, ) { if let Some(ref packet) = packet { trace!( "got {:?} packet ({} bytes) from {} using id {}", packet.header.space(), packet.payload.len() + packet.header_data.len(), remote, packet.header.dst_cid(), ); } if self.is_handshaking() && remote != self.path.remote { debug!("discarding packet with unexpected remote during handshake"); return; } let was_closed = self.state.is_closed(); let was_drained = self.state.is_drained(); let decrypted = match packet { None => Err(None), Some(mut packet) => self .decrypt_packet(now, &mut packet) .map(move |number| (packet, number)), }; let result = match decrypted { _ if stateless_reset => { debug!("got stateless reset"); Err(ConnectionError::Reset) } Err(Some(e)) => { warn!("illegal packet: {}", e); Err(e.into()) } Err(None) => { debug!("failed to authenticate packet"); self.authentication_failures += 1; let integrity_limit = self.spaces[self.highest_space] .crypto .as_ref() .unwrap() .packet .local .integrity_limit(); if self.authentication_failures > integrity_limit { Err(TransportError::AEAD_LIMIT_REACHED("integrity limit violated").into()) } else { return; } } Ok((packet, number)) => { let span = match number { Some(pn) => trace_span!("recv", space = ?packet.header.space(), pn), None => trace_span!("recv", space = ?packet.header.space()), }; let _guard = span.enter(); let is_duplicate = |n| self.spaces[packet.header.space()].dedup.insert(n); if number.map_or(false, is_duplicate) { debug!("discarding possible duplicate packet"); return; } else if self.state.is_handshake() && packet.header.is_short() { // TODO: SHOULD buffer these to improve reordering tolerance. trace!("dropping short packet during handshake"); return; } else { if let Header::Initial { ref token, .. } = packet.header { if let State::Handshake(ref hs) = self.state { if self.side.is_server() && token != &hs.expected_token { // Clients must send the same retry token in every Initial. Initial // packets can be spoofed, so we discard rather than killing the // connection. warn!("discarding Initial with invalid retry token"); return; } } } if !self.state.is_closed() { let spin = match packet.header { Header::Short { spin, .. } => spin, _ => false, }; self.on_packet_authenticated( now, packet.header.space(), ecn, number, spin, packet.header.is_1rtt(), ); } self.process_decrypted_packet(now, remote, number, packet) } } }; // State transitions for error cases if let Err(conn_err) = result { self.error = Some(conn_err.clone()); self.state = match conn_err { ConnectionError::ApplicationClosed(reason) => State::closed(reason), ConnectionError::ConnectionClosed(reason) => State::closed(reason), ConnectionError::Reset | ConnectionError::TransportError(TransportError { code: TransportErrorCode::AEAD_LIMIT_REACHED, .. }) => State::Drained, ConnectionError::TimedOut => { unreachable!("timeouts aren't generated by packet processing"); } ConnectionError::TransportError(err) => { debug!("closing connection due to transport error: {}", err); State::closed(err) } ConnectionError::VersionMismatch => State::Draining, ConnectionError::LocallyClosed => { unreachable!("LocallyClosed isn't generated by packet processing") } }; } if !was_closed && self.state.is_closed() { self.close_common(); if !self.state.is_drained() { self.set_close_timer(now); } } if !was_drained && self.state.is_drained() { self.endpoint_events.push_back(EndpointEventInner::Drained); // Close timer may have been started previously, e.g. if we sent a close and got a // stateless reset in response self.timers.stop(Timer::Close); } // Transmit CONNECTION_CLOSE if necessary if let State::Closed(_) = self.state { self.close = remote == self.path.remote; } } fn process_decrypted_packet( &mut self, now: Instant, remote: SocketAddr, number: Option, packet: Packet, ) -> Result<(), ConnectionError> { let state = match self.state { State::Established => { match packet.header.space() { SpaceId::Data => { self.process_payload(now, remote, number.unwrap(), packet.payload.freeze())? } _ => self.process_early_payload(now, packet)?, } return Ok(()); } State::Closed(_) => { for result in frame::Iter::new(packet.payload.freeze()) { let frame = match result { Ok(frame) => frame, Err(err) => { debug!("frame decoding error: {err:?}"); continue; } }; if let Frame::Padding = frame { continue; }; self.stats.frame_rx.record(&frame); if let Frame::Close(_) = frame { trace!("draining"); self.state = State::Draining; break; } } return Ok(()); } State::Draining | State::Drained => return Ok(()), State::Handshake(ref mut state) => state, }; match packet.header { Header::Retry { src_cid: rem_cid, .. } => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent Retry").into()); } if self.total_authed_packets > 1 || packet.payload.len() <= 16 // token + 16 byte tag || !self.crypto.is_valid_retry( &self.rem_cids.active(), &packet.header_data, &packet.payload, ) { trace!("discarding invalid Retry"); // - After the client has received and processed an Initial or Retry // packet from the server, it MUST discard any subsequent Retry // packets that it receives. // - A client MUST discard a Retry packet with a zero-length Retry Token // field. // - Clients MUST discard Retry packets that have a Retry Integrity Tag // that cannot be validated return Ok(()); } trace!("retrying with CID {}", rem_cid); let client_hello = state.client_hello.take().unwrap(); self.retry_src_cid = Some(rem_cid); self.rem_cids.update_initial_cid(rem_cid); self.rem_handshake_cid = rem_cid; let space = &mut self.spaces[SpaceId::Initial]; if let Some(info) = space.sent_packets.remove(&0) { self.on_packet_acked(now, SpaceId::Initial, info); }; self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials self.spaces[SpaceId::Initial] = PacketSpace { crypto: Some(self.crypto.initial_keys(&rem_cid, self.side)), next_packet_number: self.spaces[SpaceId::Initial].next_packet_number, crypto_offset: client_hello.len() as u64, ..PacketSpace::new(now) }; self.spaces[SpaceId::Initial] .pending .crypto .push_back(frame::Crypto { offset: 0, data: client_hello, }); // Retransmit all 0-RTT data let zero_rtt = mem::take(&mut self.spaces[SpaceId::Data].sent_packets); for (_, info) in zero_rtt { self.remove_in_flight(SpaceId::Data, &info); self.spaces[SpaceId::Data].pending |= info.retransmits; } self.streams.retransmit_all_for_0rtt(); let token_len = packet.payload.len() - 16; self.retry_token = packet.payload.freeze().split_to(token_len); self.state = State::Handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, client_hello: None, }); Ok(()) } Header::Long { ty: LongType::Handshake, src_cid: rem_cid, .. } => { if rem_cid != self.rem_handshake_cid { debug!( "discarding packet with mismatched remote CID: {} != {}", self.rem_handshake_cid, rem_cid ); return Ok(()); } self.path.validated = true; self.process_early_payload(now, packet)?; if self.state.is_closed() { return Ok(()); } if self.crypto.is_handshaking() { trace!("handshake ongoing"); return Ok(()); } if self.side.is_client() { // Client-only because server params were set from the client's Initial let params = self.crypto .transport_parameters()? .ok_or_else(|| TransportError { code: TransportErrorCode::crypto(0x6d), frame: None, reason: "transport parameters missing".into(), })?; if self.has_0rtt() { if !self.crypto.early_data_accepted().unwrap() { debug_assert!(self.side.is_client()); debug!("0-RTT rejected"); self.accepted_0rtt = false; self.streams.zero_rtt_rejected(); // Discard already-queued frames self.spaces[SpaceId::Data].pending = Retransmits::default(); // Discard 0-RTT packets let sent_packets = mem::take(&mut self.spaces[SpaceId::Data].sent_packets); for (_, packet) in sent_packets { self.remove_in_flight(SpaceId::Data, &packet); } } else { self.accepted_0rtt = true; params.validate_resumption_from(&self.peer_params)?; } } if let Some(token) = params.stateless_reset_token { self.endpoint_events .push_back(EndpointEventInner::ResetToken(self.path.remote, token)); } self.handle_peer_params(params)?; self.issue_cids(now); } else { // Server-only self.spaces[SpaceId::Data].pending.handshake_done = true; self.discard_space(now, SpaceId::Handshake); } self.events.push_back(Event::Connected); self.state = State::Established; trace!("established"); Ok(()) } Header::Initial { src_cid: rem_cid, .. } => { if !state.rem_cid_set { trace!("switching remote CID to {}", rem_cid); let mut state = state.clone(); self.rem_cids.update_initial_cid(rem_cid); self.rem_handshake_cid = rem_cid; self.orig_rem_cid = rem_cid; state.rem_cid_set = true; self.state = State::Handshake(state); } else if rem_cid != self.rem_handshake_cid { debug!( "discarding packet with mismatched remote CID: {} != {}", self.rem_handshake_cid, rem_cid ); return Ok(()); } let starting_space = self.highest_space; self.process_early_payload(now, packet)?; if self.side.is_server() && starting_space == SpaceId::Initial && self.highest_space != SpaceId::Initial { let params = self.crypto .transport_parameters()? .ok_or_else(|| TransportError { code: TransportErrorCode::crypto(0x6d), frame: None, reason: "transport parameters missing".into(), })?; self.handle_peer_params(params)?; self.issue_cids(now); self.init_0rtt(); } Ok(()) } Header::Long { ty: LongType::ZeroRtt, .. } => { self.process_payload(now, remote, number.unwrap(), packet.payload.freeze())?; Ok(()) } Header::VersionNegotiate { .. } => { if self.total_authed_packets > 1 { return Ok(()); } let supported = packet .payload .chunks(4) .any(|x| match <[u8; 4]>::try_from(x) { Ok(version) => self.version == u32::from_be_bytes(version), Err(_) => false, }); if supported { return Ok(()); } debug!("remote doesn't support our version"); Err(ConnectionError::VersionMismatch) } Header::Short { .. } => unreachable!( "short packets received during handshake are discarded in handle_packet" ), } } /// Process an Initial or Handshake packet payload fn process_early_payload( &mut self, now: Instant, packet: Packet, ) -> Result<(), TransportError> { debug_assert_ne!(packet.header.space(), SpaceId::Data); let payload_len = packet.payload.len(); let mut ack_eliciting = false; for result in frame::Iter::new(packet.payload.freeze()) { let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), }; self.stats.frame_rx.record(&frame); let _guard = span.as_ref().map(|x| x.enter()); ack_eliciting |= frame.is_ack_eliciting(); // Process frames match frame { Frame::Padding | Frame::Ping => {} Frame::Crypto(frame) => { self.read_crypto(packet.header.space(), &frame, payload_len)?; } Frame::Ack(ack) => { self.on_ack_received(now, packet.header.space(), ack)?; } Frame::Close(reason) => { self.error = Some(reason.into()); self.state = State::Draining; return Ok(()); } _ => { let mut err = TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake"); err.frame = Some(frame.ty()); return Err(err); } } } self.spaces[packet.header.space()] .pending_acks .packet_received(ack_eliciting); self.write_crypto(); Ok(()) } fn process_payload( &mut self, now: Instant, remote: SocketAddr, number: u64, payload: Bytes, ) -> Result<(), TransportError> { let is_0rtt = self.spaces[SpaceId::Data].crypto.is_none(); let mut is_probing_packet = true; let mut close = None; let payload_len = payload.len(); let mut ack_eliciting = false; for result in frame::Iter::new(payload) { let frame = result?; let span = match frame { Frame::Padding => continue, _ => Some(trace_span!("frame", ty = %frame.ty())), }; self.stats.frame_rx.record(&frame); // Crypto, Stream and Datagram frames are special cased in order no pollute // the log with payload data match &frame { Frame::Crypto(f) => { trace!(offset = f.offset, len = f.data.len(), "got crypto frame"); } Frame::Stream(f) => { trace!(id = %f.id, offset = f.offset, len = f.data.len(), fin = f.fin, "got stream frame"); } Frame::Datagram(f) => { trace!(len = f.data.len(), "got datagram frame"); } f => { trace!("got frame {:?}", f); } } let _guard = span.as_ref().map(|x| x.enter()); if is_0rtt { match frame { Frame::Crypto(_) | Frame::Close(Close::Application(_)) => { return Err(TransportError::PROTOCOL_VIOLATION( "illegal frame type in 0-RTT", )); } _ => {} } } ack_eliciting |= frame.is_ack_eliciting(); // Check whether this could be a probing packet match frame { Frame::Padding | Frame::PathChallenge(_) | Frame::PathResponse(_) | Frame::NewConnectionId(_) => {} _ => { is_probing_packet = false; } } match frame { Frame::Crypto(frame) => { self.read_crypto(SpaceId::Data, &frame, payload_len)?; } Frame::Stream(frame) => { if self.streams.received(frame, payload_len)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } Frame::Ack(ack) => { self.on_ack_received(now, SpaceId::Data, ack)?; } Frame::Padding | Frame::Ping => {} Frame::Close(reason) => { close = Some(reason); } Frame::PathChallenge(token) => { if self .path_response .as_ref() .map_or(true, |x| x.packet <= number) { self.path_response = Some(PathResponse { packet: number, token, }); } if remote == self.path.remote { // PATH_CHALLENGE on active path, possible off-path packet forwarding // attack. Send a non-probing packet to recover the active path. self.ping(); } } Frame::PathResponse(token) => { if self.path.challenge == Some(token) && remote == self.path.remote { trace!("new path validated"); self.timers.stop(Timer::PathValidation); self.path.challenge = None; self.path.validated = true; if let Some(ref mut prev_path) = self.prev_path { prev_path.challenge = None; prev_path.challenge_pending = false; } } else { debug!(token, "ignoring invalid PATH_RESPONSE"); } } Frame::MaxData(bytes) => { self.streams.received_max_data(bytes); } Frame::MaxStreamData { id, offset } => { self.streams.received_max_stream_data(id, offset)?; } Frame::MaxStreams { dir, count } => { self.streams.received_max_streams(dir, count)?; } Frame::ResetStream(frame) => { if self.streams.received_reset(frame)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } Frame::DataBlocked { offset } => { debug!(offset, "peer claims to be blocked at connection level"); } Frame::StreamDataBlocked { id, offset } => { if id.initiator() == self.side && id.dir() == Dir::Uni { debug!("got STREAM_DATA_BLOCKED on send-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "STREAM_DATA_BLOCKED on send-only stream", )); } debug!( stream = %id, offset, "peer claims to be blocked at stream level" ); } Frame::StreamsBlocked { dir, limit } => { if limit > MAX_STREAM_COUNT { return Err(TransportError::FRAME_ENCODING_ERROR( "unrepresentable stream limit", )); } debug!( "peer claims to be blocked opening more than {} {} streams", limit, dir ); } Frame::StopSending(frame::StopSending { id, error_code }) => { if id.initiator() != self.side { if id.dir() == Dir::Uni { debug!("got STOP_SENDING on recv-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "STOP_SENDING on recv-only stream", )); } } else if self.streams.is_local_unopened(id) { return Err(TransportError::STREAM_STATE_ERROR( "STOP_SENDING on unopened stream", )); } self.streams.received_stop_sending(id, error_code); } Frame::RetireConnectionId { sequence } => { let allow_more_cids = self .local_cid_state .on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?; self.endpoint_events .push_back(EndpointEventInner::RetireConnectionId( now, sequence, allow_more_cids, )); } Frame::NewConnectionId(frame) => { trace!( sequence = frame.sequence, id = %frame.id, retire_prior_to = frame.retire_prior_to, ); if self.rem_cids.active().is_empty() { return Err(TransportError::PROTOCOL_VIOLATION( "NEW_CONNECTION_ID when CIDs aren't in use", )); } if frame.retire_prior_to > frame.sequence { return Err(TransportError::PROTOCOL_VIOLATION( "NEW_CONNECTION_ID retiring unissued CIDs", )); } use crate::cid_queue::InsertError; match self.rem_cids.insert(frame) { Ok(None) => {} Ok(Some((retired, reset_token))) => { self.spaces[SpaceId::Data] .pending .retire_cids .extend(retired); self.set_reset_token(reset_token); } Err(InsertError::ExceedsLimit) => { return Err(TransportError::CONNECTION_ID_LIMIT_ERROR("")); } Err(InsertError::Retired) => { trace!("discarding already-retired"); // RETIRE_CONNECTION_ID might not have been previously sent if e.g. a // range of connection IDs larger than the active connection ID limit // was retired all at once via retire_prior_to. self.spaces[SpaceId::Data] .pending .retire_cids .push(frame.sequence); continue; } }; if self.side.is_server() && self.rem_cids.active_seq() == 0 { // We're a server still using the initial remote CID for the client, so // let's switch immediately to enable clientside stateless resets. self.update_rem_cid(); } } Frame::NewToken { token } => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION("client sent NEW_TOKEN")); } if token.is_empty() { return Err(TransportError::FRAME_ENCODING_ERROR("empty token")); } trace!("got new token"); // TODO: Cache, or perhaps forward to user? } Frame::Datagram(datagram) => { if self .datagrams .received(datagram, &self.config.datagram_receive_buffer_size)? { self.events.push_back(Event::DatagramReceived); } } Frame::HandshakeDone => { if self.side.is_server() { return Err(TransportError::PROTOCOL_VIOLATION( "client sent HANDSHAKE_DONE", )); } if self.spaces[SpaceId::Handshake].crypto.is_some() { self.discard_space(now, SpaceId::Handshake); } } } } self.spaces[SpaceId::Data] .pending_acks .packet_received(ack_eliciting); // Issue stream ID credit due to ACKs of outgoing finish/resets and incoming finish/resets // on stopped streams. Incoming finishes/resets on open streams are not handled here as they // are only freed, and hence only issue credit, once the application has been notified // during a read on the stream. let pending = &mut self.spaces[SpaceId::Data].pending; for dir in Dir::iter() { if self.streams.take_max_streams_dirty(dir) { pending.max_stream_id[dir as usize] = true; } } if let Some(reason) = close { self.error = Some(reason.into()); self.state = State::Draining; self.close = true; } if remote != self.path.remote && !is_probing_packet && number == self.spaces[SpaceId::Data].rx_packet { debug_assert!( self.server_config .as_ref() .expect("packets from unknown remote should be dropped by clients") .migration, "migration-initiating packets should have been dropped immediately" ); self.migrate(now, remote); // Break linkability, if possible self.update_rem_cid(); self.spin = false; } Ok(()) } fn migrate(&mut self, now: Instant, remote: SocketAddr) { trace!(%remote, "migration initiated"); // Reset rtt/congestion state for new path unless it looks like a NAT rebinding. // Note that the congestion window will not grow until validation terminates. Helps mitigate // amplification attacks performed by spoofing source addresses. let mut new_path = if remote.is_ipv4() && remote.ip() == self.path.remote.ip() { PathData::from_previous(remote, &self.path, now) } else { let peer_max_udp_payload_size = u16::try_from(self.peer_params.max_udp_payload_size.into_inner()) .unwrap_or(u16::MAX); PathData::new( remote, self.config.initial_rtt, self.config .congestion_controller_factory .build(now, self.config.get_initial_mtu()), self.config.get_initial_mtu(), self.config.min_mtu, Some(peer_max_udp_payload_size), self.config.mtu_discovery_config.clone(), now, false, ) }; new_path.challenge = Some(self.rng.gen()); new_path.challenge_pending = true; let prev_pto = self.pto(SpaceId::Data); let mut prev = mem::replace(&mut self.path, new_path); // Don't clobber the original path if the previous one hasn't been validated yet if prev.challenge.is_none() { prev.challenge = Some(self.rng.gen()); prev.challenge_pending = true; self.prev_path = Some(prev); } self.timers.set( Timer::PathValidation, now + 3 * cmp::max(self.pto(SpaceId::Data), prev_pto), ); } /// Switch to a previously unused remote connection ID, if possible fn update_rem_cid(&mut self) { let (reset_token, retired) = match self.rem_cids.next() { Some(x) => x, None => return, }; // Retire the current remote CID and any CIDs we had to skip. self.spaces[SpaceId::Data] .pending .retire_cids .extend(retired); self.set_reset_token(reset_token); } fn set_reset_token(&mut self, reset_token: ResetToken) { self.endpoint_events .push_back(EndpointEventInner::ResetToken( self.path.remote, reset_token, )); self.peer_params.stateless_reset_token = Some(reset_token); } /// Issue an initial set of connection IDs to the peer fn issue_cids(&mut self, now: Instant) { if self.local_cid_state.cid_len() == 0 { return; } // Subtract 1 to account for the CID we supplied while handshaking let n = self.peer_params.issue_cids_limit() - 1; self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } fn populate_packet( &mut self, space_id: SpaceId, buf: &mut BytesMut, max_size: usize, ) -> SentFrames { let mut sent = SentFrames::default(); let space = &mut self.spaces[space_id]; let is_0rtt = space_id == SpaceId::Data && space.crypto.is_none(); // HANDSHAKE_DONE if !is_0rtt && mem::replace(&mut space.pending.handshake_done, false) { buf.write(frame::Type::HANDSHAKE_DONE); sent.retransmits.get_or_create().handshake_done = true; // This is just a u8 counter and the frame is typically just sent once self.stats.frame_tx.handshake_done = self.stats.frame_tx.handshake_done.saturating_add(1); } // PING if mem::replace(&mut space.ping_pending, false) { trace!("PING"); buf.write(frame::Type::PING); sent.non_retransmits = true; self.stats.frame_tx.ping += 1; } // ACK if space.pending_acks.can_send() { debug_assert!(!space.pending_acks.ranges().is_empty()); Self::populate_acks(self.receiving_ecn, &mut sent, space, buf, &mut self.stats); } // PATH_CHALLENGE if buf.len() + 9 < max_size && space_id == SpaceId::Data { // Transmit challenges with every outgoing frame on an unvalidated path if let Some(token) = self.path.challenge { // But only send a packet solely for that purpose at most once self.path.challenge_pending = false; sent.non_retransmits = true; sent.requires_padding = true; trace!("PATH_CHALLENGE {:08x}", token); buf.write(frame::Type::PATH_CHALLENGE); buf.write(token); self.stats.frame_tx.path_challenge += 1; } } // PATH_RESPONSE if buf.len() + 9 < max_size && space_id == SpaceId::Data { if let Some(response) = self.path_response.take() { sent.non_retransmits = true; sent.requires_padding = true; trace!("PATH_RESPONSE {:08x}", response.token); buf.write(frame::Type::PATH_RESPONSE); buf.write(response.token); self.stats.frame_tx.path_response += 1; } } // CRYPTO while buf.len() + frame::Crypto::SIZE_BOUND < max_size && !is_0rtt { let mut frame = match space.pending.crypto.pop_front() { Some(x) => x, None => break, }; // Calculate the maximum amount of crypto data we can store in the buffer. // Since the offset is known, we can reserve the exact size required to encode it. // For length we reserve 2bytes which allows to encode up to 2^14, // which is more than what fits into normally sized QUIC frames. let max_crypto_data_size = max_size - buf.len() - 1 // Frame Type - VarInt::size(unsafe { VarInt::from_u64_unchecked(frame.offset) }) - 2; // Maximum encoded length for frame size, given we send less than 2^14 bytes let len = frame .data .len() .min(2usize.pow(14) - 1) .min(max_crypto_data_size); let data = frame.data.split_to(len); let truncated = frame::Crypto { offset: frame.offset, data, }; trace!( "CRYPTO: off {} len {}", truncated.offset, truncated.data.len() ); truncated.encode(buf); self.stats.frame_tx.crypto += 1; sent.retransmits.get_or_create().crypto.push_back(truncated); if !frame.data.is_empty() { frame.offset += len as u64; space.pending.crypto.push_front(frame); } } if space_id == SpaceId::Data { self.streams.write_control_frames( buf, &mut space.pending, &mut sent.retransmits, &mut self.stats.frame_tx, max_size, ); } // NEW_CONNECTION_ID while buf.len() + 44 < max_size { let issued = match space.pending.new_cids.pop() { Some(x) => x, None => break, }; trace!( sequence = issued.sequence, id = %issued.id, "NEW_CONNECTION_ID" ); frame::NewConnectionId { sequence: issued.sequence, retire_prior_to: self.local_cid_state.retire_prior_to(), id: issued.id, reset_token: issued.reset_token, } .encode(buf); sent.retransmits.get_or_create().new_cids.push(issued); self.stats.frame_tx.new_connection_id += 1; } // RETIRE_CONNECTION_ID while buf.len() + frame::RETIRE_CONNECTION_ID_SIZE_BOUND < max_size { let seq = match space.pending.retire_cids.pop() { Some(x) => x, None => break, }; trace!(sequence = seq, "RETIRE_CONNECTION_ID"); buf.write(frame::Type::RETIRE_CONNECTION_ID); buf.write_var(seq); sent.retransmits.get_or_create().retire_cids.push(seq); self.stats.frame_tx.retire_connection_id += 1; } // DATAGRAM while buf.len() + Datagram::SIZE_BOUND < max_size && space_id == SpaceId::Data { match self.datagrams.write(buf, max_size) { true => { sent.non_retransmits = true; self.stats.frame_tx.datagram += 1; } false => break, } } // STREAM if space_id == SpaceId::Data { sent.stream_frames = self.streams.write_stream_frames(buf, max_size); self.stats.frame_tx.stream += sent.stream_frames.len() as u64; } sent } /// Write pending ACKs into a buffer /// /// This method assumes ACKs are pending, and should only be called if /// `!PendingAcks::ranges().is_empty()` returns `true`. fn populate_acks( receiving_ecn: bool, sent: &mut SentFrames, space: &mut PacketSpace, buf: &mut BytesMut, stats: &mut ConnectionStats, ) { debug_assert!(!space.pending_acks.ranges().is_empty()); // 0-RTT packets must never carry acks (which would have to be of handshake packets) debug_assert!(space.crypto.is_some(), "tried to send ACK in 0-RTT"); let ecn = if receiving_ecn { Some(&space.ecn_counters) } else { None }; sent.largest_acked = space.pending_acks.ranges().max(); let delay_micros = space.pending_acks.ack_delay().as_micros() as u64; // TODO: This should come frome `TransportConfig` if that gets configurable let ack_delay_exp = TransportParameters::default().ack_delay_exponent; let delay = delay_micros >> ack_delay_exp.into_inner(); trace!("ACK {:?}, Delay = {}us", space.pending_acks.ranges(), delay); frame::Ack::encode(delay as _, space.pending_acks.ranges(), ecn, buf); stats.frame_tx.acks += 1; } fn close_common(&mut self) { trace!("connection closed"); for &timer in &Timer::VALUES { self.timers.stop(timer); } } fn set_close_timer(&mut self, now: Instant) { self.timers .set(Timer::Close, now + 3 * self.pto(self.highest_space)); } /// Handle transport parameters received from the peer fn handle_peer_params(&mut self, params: TransportParameters) -> Result<(), TransportError> { if Some(self.orig_rem_cid) != params.initial_src_cid || (self.side.is_client() && (Some(self.initial_dst_cid) != params.original_dst_cid || self.retry_src_cid != params.retry_src_cid)) { return Err(TransportError::TRANSPORT_PARAMETER_ERROR( "CID authentication failure", )); } self.set_peer_params(params); Ok(()) } fn set_peer_params(&mut self, params: TransportParameters) { self.streams.set_params(¶ms); self.idle_timeout = match (self.config.max_idle_timeout, params.max_idle_timeout) { (None, VarInt(0)) => None, (None, x) => Some(x), (Some(x), VarInt(0)) => Some(x), (Some(x), y) => Some(cmp::min(x, y)), }; if let Some(ref info) = params.preferred_address { self.rem_cids.insert(frame::NewConnectionId { sequence: 1, id: info.connection_id, reset_token: info.stateless_reset_token, retire_prior_to: 0, }).expect("preferred address CID is the first received, and hence is guaranteed to be legal"); } self.peer_params = params; self.path.mtud.on_peer_max_udp_payload_size_received( u16::try_from(self.peer_params.max_udp_payload_size.into_inner()).unwrap_or(u16::MAX), ); } fn decrypt_packet( &mut self, now: Instant, packet: &mut Packet, ) -> Result, Option> { if !packet.header.is_protected() { // Unprotected packets also don't have packet numbers return Ok(None); } let space = packet.header.space(); let rx_packet = self.spaces[space].rx_packet; let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1); let key_phase = packet.header.key_phase(); let mut crypto_update = false; let crypto = if packet.header.is_0rtt() { &self.zero_rtt_crypto.as_ref().unwrap().packet } else if key_phase == self.key_phase || space != SpaceId::Data { &self.spaces[space].crypto.as_mut().unwrap().packet.remote } else if let Some(prev) = self.prev_crypto.as_ref().and_then(|crypto| { // If this packet comes prior to acknowledgment of the key update by the peer, if crypto.end_packet.map_or(true, |(pn, _)| number < pn) { // use the previous keys. Some(crypto) } else { // Otherwise, this must be a remotely-initiated key update, so fall through to the // final case. None } }) { &prev.crypto.remote } else { // We're in the Data space with a key phase mismatch and either there is no locally // initiated key update or the locally initiated key update was acknowledged by a // lower-numbered packet. The key phase mismatch must therefore represent a new // remotely-initiated key update. crypto_update = true; &self.next_crypto.as_ref().unwrap().remote }; crypto .decrypt(number, &packet.header_data, &mut packet.payload) .map_err(|_| { trace!("decryption failed with packet number {}", number); None })?; if let Some(ref mut prev) = self.prev_crypto { if prev.end_packet.is_none() && key_phase == self.key_phase { // Outgoing key update newly acknowledged prev.end_packet = Some((number, now)); self.set_key_discard_timer(now, space); } } if !packet.reserved_bits_valid() { return Err(Some(TransportError::PROTOCOL_VIOLATION( "reserved bits set", ))); } if crypto_update { // Validate and commit incoming key update if number <= rx_packet || self .prev_crypto .as_ref() .map_or(false, |x| x.update_unacked) { return Err(Some(TransportError::KEY_UPDATE_ERROR(""))); } trace!("key update authenticated"); self.update_keys(Some((number, now)), true); self.set_key_discard_timer(now, space); } Ok(Some(number)) } fn update_keys(&mut self, end_packet: Option<(u64, Instant)>, remote: bool) { // Generate keys for the key phase after the one we're switching to, store them in // `next_crypto`, make the contents of `next_crypto` current, and move the current keys into // `prev_crypto`. let new = self .crypto .next_1rtt_keys() .expect("only called for `Data` packets"); let old = mem::replace( &mut self.spaces[SpaceId::Data] .crypto .as_mut() .unwrap() // safe because update_keys() can only be triggered by short packets .packet, mem::replace(self.next_crypto.as_mut().unwrap(), new), ); self.spaces[SpaceId::Data].sent_with_keys = 0; self.prev_crypto = Some(PrevCrypto { crypto: old, end_packet, update_unacked: remote, }); self.key_phase = !self.key_phase; } /// The number of bytes of packets containing retransmittable frames that have not been /// acknowledged or declared lost. #[cfg(test)] pub(crate) fn bytes_in_flight(&self) -> u64 { self.in_flight.bytes } /// Number of bytes worth of non-ack-only packets that may be sent #[cfg(test)] pub(crate) fn congestion_window(&self) -> u64 { self.path .congestion .window() .saturating_sub(self.in_flight.bytes) } /// Whether no timers but keepalive, idle and pushnewcid are running #[cfg(test)] pub(crate) fn is_idle(&self) -> bool { Timer::VALUES .iter() .filter(|&&t| t != Timer::KeepAlive && t != Timer::PushNewCid) .filter_map(|&t| Some((t, self.timers.get(t)?))) .min_by_key(|&(_, time)| time) .map_or(true, |(timer, _)| timer == Timer::Idle) } /// Total number of outgoing packets that have been deemed lost #[cfg(test)] pub(crate) fn lost_packets(&self) -> u64 { self.lost_packets } /// Whether explicit congestion notification is in use on outgoing packets. #[cfg(test)] pub(crate) fn using_ecn(&self) -> bool { self.path.sending_ecn } /// The number of received bytes in the current path #[cfg(test)] pub(crate) fn total_recvd(&self) -> u64 { self.path.total_recvd } #[cfg(test)] pub(crate) fn active_local_cid_seq(&self) -> (u64, u64) { self.local_cid_state.active_seq() } /// Instruct the peer to replace previously issued CIDs by sending a NEW_CONNECTION_ID frame /// with updated `retire_prior_to` field set to `v` #[cfg(test)] pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) { let n = self.local_cid_state.assign_retire_seq(v); self.endpoint_events .push_back(EndpointEventInner::NeedIdentifiers(now, n)); } /// Check the current active remote CID sequence #[cfg(test)] pub(crate) fn active_rem_cid_seq(&self) -> u64 { self.rem_cids.active_seq() } /// Returns the detected maximum udp payload size for the current path #[cfg(test)] pub(crate) fn path_mtu(&self) -> u16 { self.path.current_mtu() } fn max_ack_delay(&self) -> Duration { Duration::from_micros(self.peer_params.max_ack_delay.0 * 1000) } /// Whether we have 1-RTT data to send /// /// See also `self.space(SpaceId::Data).can_send()` fn can_send_1rtt(&self) -> bool { self.streams.can_send_stream_data() || self.path.challenge_pending || self .prev_path .as_ref() .map_or(false, |x| x.challenge_pending) || self.path_response.is_some() || !self.datagrams.outgoing.is_empty() } /// Update counters to account for a packet becoming acknowledged, lost, or abandoned fn remove_in_flight(&mut self, space: SpaceId, packet: &SentPacket) { self.in_flight.bytes -= u64::from(packet.size); self.in_flight.ack_eliciting -= u64::from(packet.ack_eliciting); self.spaces[space].in_flight -= u64::from(packet.size); } /// Terminate the connection instantly, without sending a close packet fn kill(&mut self, reason: ConnectionError) { self.close_common(); self.error = Some(reason); self.state = State::Drained; self.endpoint_events.push_back(EndpointEventInner::Drained); } } impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Connection") .field("handshake_cid", &self.handshake_cid) .finish() } } /// Reasons why a connection might be lost #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum ConnectionError { /// The peer doesn't implement any supported version #[error("peer doesn't implement any supported version")] VersionMismatch, /// The peer violated the QUIC specification as understood by this implementation #[error(transparent)] TransportError(#[from] TransportError), /// The peer's QUIC stack aborted the connection automatically #[error("aborted by peer: {0}")] ConnectionClosed(frame::ConnectionClose), /// The peer closed the connection #[error("closed by peer: {0}")] ApplicationClosed(frame::ApplicationClose), /// The peer is unable to continue processing this connection, usually due to having restarted #[error("reset by peer")] Reset, /// Communication with the peer has lapsed for longer than the negotiated idle timeout /// /// If neither side is sending keep-alives, a connection will time out after a long enough idle /// period even if the peer is still reachable. See also [`TransportConfig::max_idle_timeout()`] /// and [`TransportConfig::keep_alive_interval()`]. #[error("timed out")] TimedOut, /// The local application closed the connection #[error("closed")] LocallyClosed, } impl From for ConnectionError { fn from(x: Close) -> Self { match x { Close::Connection(reason) => Self::ConnectionClosed(reason), Close::Application(reason) => Self::ApplicationClosed(reason), } } } // For compatibility with API consumers impl From for io::Error { fn from(x: ConnectionError) -> Self { use self::ConnectionError::*; let kind = match x { TimedOut => io::ErrorKind::TimedOut, Reset => io::ErrorKind::ConnectionReset, ApplicationClosed(_) | ConnectionClosed(_) => io::ErrorKind::ConnectionAborted, TransportError(_) | VersionMismatch | LocallyClosed => io::ErrorKind::Other, }; Self::new(kind, x) } } #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub enum State { Handshake(state::Handshake), Established, Closed(state::Closed), Draining, /// Waiting for application to call close so we can dispose of the resources Drained, } impl State { fn closed>(reason: R) -> Self { Self::Closed(state::Closed { reason: reason.into(), }) } fn is_handshake(&self) -> bool { matches!(*self, Self::Handshake(_)) } fn is_established(&self) -> bool { matches!(*self, Self::Established) } fn is_closed(&self) -> bool { matches!(*self, Self::Closed(_) | Self::Draining | Self::Drained) } fn is_drained(&self) -> bool { matches!(*self, Self::Drained) } } mod state { use super::*; #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub struct Handshake { /// Whether the remote CID has been set by the peer yet /// /// Always set for servers pub(super) rem_cid_set: bool, /// Stateless retry token received in the first Initial by a server. /// /// Must be present in every Initial. Always empty for clients. pub(super) expected_token: Bytes, /// First cryptographic message /// /// Only set for clients pub(super) client_hello: Option, } #[allow(unreachable_pub)] // fuzzing only #[derive(Clone)] pub struct Closed { pub(super) reason: Close, } } struct PrevCrypto { /// The keys used for the previous key phase, temporarily retained to decrypt packets sent by /// the peer prior to its own key update. crypto: KeyPair>, /// The incoming packet that ends the interval for which these keys are applicable, and the time /// of its receipt. /// /// Incoming packets should be decrypted using these keys iff this is `None` or their packet /// number is lower. `None` indicates that we have not yet received a packet using newer keys, /// which implies that the update was locally initiated. end_packet: Option<(u64, Instant)>, /// Whether the following key phase is from a remotely initiated update that we haven't acked update_unacked: bool, } struct InFlight { /// Sum of the sizes of all sent packets considered "in flight" by congestion control /// /// The size does not include IP or UDP overhead. Packets only containing ACK frames do not /// count towards this to ensure congestion control does not impede congestion feedback. bytes: u64, /// Number of packets in flight containing frames other than ACK and PADDING /// /// This can be 0 even when bytes is not 0 because PADDING frames cause a packet to be /// considered "in flight" by congestion control. However, if this is nonzero, bytes will always /// also be nonzero. ack_eliciting: u64, } impl InFlight { fn new() -> Self { Self { bytes: 0, ack_eliciting: 0, } } fn insert(&mut self, packet: &SentPacket) { self.bytes += u64::from(packet.size); self.ack_eliciting += u64::from(packet.ack_eliciting); } } /// Events of interest to the application #[derive(Debug)] pub enum Event { /// The connection's handshake data is ready HandshakeDataReady, /// The connection was successfully established Connected, /// The connection was lost /// /// Emitted if the peer closes the connection or an error is encountered. ConnectionLost { /// Reason that the connection was closed reason: ConnectionError, }, /// Stream events Stream(StreamEvent), /// One or more application datagrams have been received DatagramReceived, } struct PathResponse { /// The packet number the corresponding PATH_CHALLENGE was received in packet: u64, token: u64, } fn instant_saturating_sub(x: Instant, y: Instant) -> Duration { if x > y { x - y } else { Duration::new(0, 0) } } // Prevents overflow and improves behavior in extreme circumstances const MAX_BACKOFF_EXPONENT: u32 = 16; // Minimal remaining size to allow packet coalescing const MIN_PACKET_SPACE: usize = 40; /// The maximum amount of datagrams that are sent in a single transmit /// /// This can be lower than the maximum platform capabilities, to avoid excessive /// memory allocations when calling `poll_transmit()`. Benchmarks have shown /// that numbers around 10 are a good compromise. const MAX_TRANSMIT_SEGMENTS: usize = 10; struct ZeroRttCrypto { header: Box, packet: Box, } #[derive(Default)] struct SentFrames { retransmits: ThinRetransmits, largest_acked: Option, stream_frames: StreamMetaVec, /// Whether the packet contains non-retransmittable frames (like datagrams) non_retransmits: bool, requires_padding: bool, } impl SentFrames { /// Returns whether the packet contains only ACKs fn is_ack_only(&self, streams: &StreamsState) -> bool { self.largest_acked.is_some() && !self.non_retransmits && self.stream_frames.is_empty() && self.retransmits.is_empty(streams) } } quinn-proto-0.10.6/src/connection/mtud.rs000064400000000000000000000707321046102023000164610ustar 00000000000000use crate::{packet::SpaceId, MtuDiscoveryConfig, MAX_UDP_PAYLOAD}; use std::time::Instant; use tracing::trace; /// Implements Datagram Packetization Layer Path Maximum Transmission Unit Discovery /// /// See [`MtuDiscoveryConfig`] for details #[derive(Clone)] pub(crate) struct MtuDiscovery { /// Detected MTU for the path current_mtu: u16, /// The state of the MTU discovery, if enabled state: Option, /// The state of the black hole detector black_hole_detector: BlackHoleDetector, } impl MtuDiscovery { pub(crate) fn new( initial_plpmtu: u16, min_mtu: u16, peer_max_udp_payload_size: Option, config: MtuDiscoveryConfig, ) -> Self { debug_assert!( initial_plpmtu >= min_mtu, "initial_max_udp_payload_size must be at least {min_mtu}" ); let mut mtud = Self::with_state( initial_plpmtu, min_mtu, Some(EnabledMtuDiscovery::new(config)), ); // We might be migrating an existing connection to a new path, in which case the transport // parameters have already been transmitted, and we already know the value of // `peer_max_udp_payload_size` if let Some(peer_max_udp_payload_size) = peer_max_udp_payload_size { mtud.on_peer_max_udp_payload_size_received(peer_max_udp_payload_size); } mtud } /// MTU discovery will be disabled and the current MTU will be fixed to the provided value pub(crate) fn disabled(plpmtu: u16, min_mtu: u16) -> Self { Self::with_state(plpmtu, min_mtu, None) } fn with_state(current_mtu: u16, min_mtu: u16, state: Option) -> Self { Self { current_mtu, state, black_hole_detector: BlackHoleDetector::new(min_mtu), } } /// Returns the current MTU pub(crate) fn current_mtu(&self) -> u16 { self.current_mtu } /// Returns the amount of bytes that should be sent as an MTU probe, if any pub(crate) fn poll_transmit(&mut self, now: Instant, next_packet_number: u64) -> Option { self.state .as_mut() .and_then(|state| state.poll_transmit(now, self.current_mtu, next_packet_number)) } /// Notifies the [`MtuDiscovery`] that the peer's `max_udp_payload_size` transport parameter has /// been received pub(crate) fn on_peer_max_udp_payload_size_received(&mut self, peer_max_udp_payload_size: u16) { self.current_mtu = self.current_mtu.min(peer_max_udp_payload_size); if let Some(state) = self.state.as_mut() { // MTUD is only active after the connection has been fully established, so it is // guaranteed we will receive the peer's transport parameters before we start probing debug_assert!(matches!(state.phase, Phase::Initial)); state.peer_max_udp_payload_size = peer_max_udp_payload_size; } } /// Notifies the [`MtuDiscovery`] that a packet has been ACKed /// /// Returns true if the packet was an MTU probe pub(crate) fn on_acked( &mut self, space: SpaceId, packet_number: u64, packet_bytes: u16, ) -> bool { // MTU probes are only sent in application data space if space != SpaceId::Data { return false; } // Update the state of the MTU search if let Some(new_mtu) = self .state .as_mut() .and_then(|state| state.on_probe_acked(packet_number)) { self.current_mtu = new_mtu; trace!(current_mtu = self.current_mtu, "new MTU detected"); self.black_hole_detector.on_probe_acked(); true } else { self.black_hole_detector.on_non_probe_acked( self.current_mtu, packet_number, packet_bytes, ); false } } /// Returns the packet number of the in-flight MTU probe, if any pub(crate) fn in_flight_mtu_probe(&self) -> Option { match &self.state { Some(EnabledMtuDiscovery { phase: Phase::Searching(search_state), .. }) => search_state.in_flight_probe, _ => None, } } /// Notifies the [`MtuDiscovery`] that the in-flight MTU probe was lost pub(crate) fn on_probe_lost(&mut self) { if let Some(state) = &mut self.state { state.on_probe_lost(); } } /// Notifies the [`MtuDiscovery`] that a non-probe packet was lost /// /// When done notifying of lost packets, [`MtuDiscovery::black_hole_detected`] must be called, to /// ensure the last loss burst is properly processed and to trigger black hole recovery logic if /// necessary. pub(crate) fn on_non_probe_lost(&mut self, packet_number: u64, packet_bytes: u16) { self.black_hole_detector .on_non_probe_lost(packet_number, packet_bytes); } /// Returns true if a black hole was detected /// /// Calling this function will close the previous loss burst. If a black hole is detected, the /// current MTU will be reset to `min_mtu`. pub(crate) fn black_hole_detected(&mut self, now: Instant) -> bool { if !self.black_hole_detector.black_hole_detected() { return false; } self.current_mtu = self.black_hole_detector.min_mtu; if let Some(state) = &mut self.state { state.on_black_hole_detected(now); } true } } /// Additional state for enabled MTU discovery #[derive(Debug, Clone)] struct EnabledMtuDiscovery { phase: Phase, peer_max_udp_payload_size: u16, config: MtuDiscoveryConfig, } impl EnabledMtuDiscovery { fn new(config: MtuDiscoveryConfig) -> Self { Self { phase: Phase::Initial, peer_max_udp_payload_size: MAX_UDP_PAYLOAD, config, } } /// Returns the amount of bytes that should be sent as an MTU probe, if any fn poll_transmit( &mut self, now: Instant, current_mtu: u16, next_packet_number: u64, ) -> Option { if let Phase::Initial = &self.phase { // Start the first search self.phase = Phase::Searching(SearchState::new( current_mtu, self.peer_max_udp_payload_size, &self.config, )); } else if let Phase::Complete(next_mtud_activation) = &self.phase { if now < *next_mtud_activation { return None; } // Start a new search (we have reached the next activation time) self.phase = Phase::Searching(SearchState::new( current_mtu, self.peer_max_udp_payload_size, &self.config, )); } if let Phase::Searching(state) = &mut self.phase { // Nothing to do while there is a probe in flight if state.in_flight_probe.is_some() { return None; } // Retransmit lost probes, if any if 0 < state.lost_probe_count && state.lost_probe_count < MAX_PROBE_RETRANSMITS { state.in_flight_probe = Some(next_packet_number); return Some(state.last_probed_mtu); } let last_probe_succeeded = state.lost_probe_count == 0; // The probe is definitely lost (we reached the MAX_PROBE_RETRANSMITS threshold) if !last_probe_succeeded { state.lost_probe_count = 0; state.in_flight_probe = None; } if let Some(probe_udp_payload_size) = state.next_mtu_to_probe(last_probe_succeeded) { state.in_flight_probe = Some(next_packet_number); state.last_probed_mtu = probe_udp_payload_size; return Some(probe_udp_payload_size); } else { let next_mtud_activation = now + self.config.interval; self.phase = Phase::Complete(next_mtud_activation); return None; } } None } /// Called when a packet is acknowledged in [`SpaceId::Data`] /// /// Returns the new `current_mtu` if the packet number corresponds to the in-flight MTU probe fn on_probe_acked(&mut self, packet_number: u64) -> Option { match &mut self.phase { Phase::Searching(state) if state.in_flight_probe == Some(packet_number) => { state.in_flight_probe = None; state.lost_probe_count = 0; Some(state.last_probed_mtu) } _ => None, } } /// Called when the in-flight MTU probe was lost fn on_probe_lost(&mut self) { // We might no longer be searching, e.g. if a black hole was detected if let Phase::Searching(state) = &mut self.phase { state.in_flight_probe = None; state.lost_probe_count += 1; } } /// Called when a black hole is detected fn on_black_hole_detected(&mut self, now: Instant) { // Stop searching, if applicable, and reset the timer let next_mtud_activation = now + self.config.black_hole_cooldown; self.phase = Phase::Complete(next_mtud_activation); } } #[derive(Debug, Clone, Copy)] enum Phase { /// We haven't started polling yet Initial, /// We are currently searching for a higher PMTU Searching(SearchState), /// Searching has completed and will be triggered again at the provided instant Complete(Instant), } #[derive(Debug, Clone, Copy)] struct SearchState { /// The lower bound for the current binary search lower_bound: u16, /// The upper bound for the current binary search upper_bound: u16, /// The UDP payload size we last sent a probe for last_probed_mtu: u16, /// Packet number of an in-flight probe (if any) in_flight_probe: Option, /// Lost probes at the current probe size lost_probe_count: usize, } impl SearchState { /// Creates a new search state, with the specified lower bound (the upper bound is derived from /// the config and the peer's `max_udp_payload_size` transport parameter) fn new( mut lower_bound: u16, peer_max_udp_payload_size: u16, config: &MtuDiscoveryConfig, ) -> Self { lower_bound = lower_bound.min(peer_max_udp_payload_size); let upper_bound = config .upper_bound .clamp(lower_bound, peer_max_udp_payload_size); Self { in_flight_probe: None, lost_probe_count: 0, lower_bound, upper_bound, // During initialization, we consider the lower bound to have already been // successfully probed last_probed_mtu: lower_bound, } } /// Determines the next MTU to probe using binary search fn next_mtu_to_probe(&mut self, last_probe_succeeded: bool) -> Option { debug_assert_eq!(self.in_flight_probe, None); if last_probe_succeeded { self.lower_bound = self.last_probed_mtu; } else { self.upper_bound = self.last_probed_mtu - 1; } let next_mtu = (self.lower_bound as i32 + self.upper_bound as i32) / 2; // Binary search stopping condition if ((next_mtu - self.last_probed_mtu as i32).unsigned_abs() as u16) < BINARY_SEARCH_MINIMUM_CHANGE { // Special case: if the upper bound is far enough, we want to probe it as a last // step (otherwise we will never achieve the upper bound) if self.upper_bound.saturating_sub(self.last_probed_mtu) >= BINARY_SEARCH_MINIMUM_CHANGE { return Some(self.upper_bound); } return None; } Some(next_mtu as u16) } } #[derive(Clone)] struct BlackHoleDetector { /// Counts suspicious packet loss bursts since a packet with size equal to the current MTU was /// acknowledged (or since a black hole was detected) /// /// A packet loss burst is a group of contiguous packets that are deemed lost at the same time /// (see usages of [`MtuDiscovery::on_non_probe_lost`] for details on how loss detection is /// implemented) /// /// A packet loss burst is considered suspicious when it contains only suspicious packets and no /// MTU-sized packet has been acknowledged since the group's packets were sent suspicious_loss_bursts: u8, /// Indicates whether the current loss burst has any non-suspicious packets /// /// Non-suspicious packets are non-probe packets of size <= `min_mtu` loss_burst_has_non_suspicious_packets: bool, /// The largest suspicious packet that was lost in the current burst /// /// Suspicious packets are non-probe packets of size > `min_mtu` largest_suspicious_packet_lost: Option, /// The largest non-probe packet that was lost (used to keep track of loss bursts) largest_non_probe_lost: Option, /// The largest acked packet of size `current_mtu` largest_acked_mtu_sized_packet: Option, /// The UDP payload size guaranteed to be supported by the network min_mtu: u16, } impl BlackHoleDetector { fn new(min_mtu: u16) -> Self { Self { suspicious_loss_bursts: 0, largest_non_probe_lost: None, loss_burst_has_non_suspicious_packets: false, largest_suspicious_packet_lost: None, largest_acked_mtu_sized_packet: None, min_mtu, } } fn on_probe_acked(&mut self) { // We know for sure the path supports the current MTU self.suspicious_loss_bursts = 0; } fn on_non_probe_acked(&mut self, current_mtu: u16, packet_number: u64, packet_bytes: u16) { // Reset the black hole counter if a packet the size of the current MTU or larger // has been acknowledged if packet_bytes >= current_mtu && self .largest_acked_mtu_sized_packet .map_or(true, |pn| packet_number > pn) { self.suspicious_loss_bursts = 0; self.largest_acked_mtu_sized_packet = Some(packet_number); } } fn on_non_probe_lost(&mut self, packet_number: u64, packet_bytes: u16) { // A loss burst is a group of consecutive packets that are declared lost, so a distance // greater than 1 indicates a new burst let new_loss_burst = self .largest_non_probe_lost .map_or(true, |prev| packet_number - prev != 1); if new_loss_burst { self.finish_loss_burst(); } if packet_bytes <= self.min_mtu { self.loss_burst_has_non_suspicious_packets = true; } else { self.largest_suspicious_packet_lost = Some(packet_number); } self.largest_non_probe_lost = Some(packet_number); } fn black_hole_detected(&mut self) -> bool { self.finish_loss_burst(); if self.suspicious_loss_bursts <= BLACK_HOLE_THRESHOLD { return false; } self.suspicious_loss_bursts = 0; self.largest_acked_mtu_sized_packet = None; true } /// Marks the end of the current loss burst, checking whether it was suspicious fn finish_loss_burst(&mut self) { if self.last_burst_was_suspicious() { self.suspicious_loss_bursts = self.suspicious_loss_bursts.saturating_add(1); } self.loss_burst_has_non_suspicious_packets = false; self.largest_suspicious_packet_lost = None; self.largest_non_probe_lost = None; } /// Returns true if the burst was suspicious and should count towards black hole detection fn last_burst_was_suspicious(&self) -> bool { // Ignore burst if it contains any non-suspicious packets, because in that case packet loss // was likely caused by congestion (instead of a sudden decrease in the path's MTU) if self.loss_burst_has_non_suspicious_packets { return false; } // Ignore burst if we have received an ACK for a more recent MTU-sized packet, because that // proves the network still supports the current MTU let largest_acked = self.largest_acked_mtu_sized_packet.unwrap_or(0); if self .largest_suspicious_packet_lost .map_or(true, |largest_lost| largest_lost < largest_acked) { return false; } true } } // Corresponds to the RFC's `MAX_PROBES` constant (see // https://www.rfc-editor.org/rfc/rfc8899#section-5.1.2) const MAX_PROBE_RETRANSMITS: usize = 3; const BLACK_HOLE_THRESHOLD: u8 = 3; const BINARY_SEARCH_MINIMUM_CHANGE: u16 = 20; #[cfg(test)] mod tests { use super::*; use crate::packet::SpaceId; use crate::MAX_UDP_PAYLOAD; use assert_matches::assert_matches; use std::time::Duration; fn default_mtud() -> MtuDiscovery { let config = MtuDiscoveryConfig::default(); MtuDiscovery::new(1_200, 1_200, None, config) } fn completed(mtud: &MtuDiscovery) -> bool { matches!(mtud.state.as_ref().unwrap().phase, Phase::Complete(_)) } /// Drives mtud until it reaches `Phase::Completed` fn drive_to_completion( mtud: &mut MtuDiscovery, now: Instant, link_payload_size_limit: u16, ) -> Vec { let mut probed_sizes = Vec::new(); for probe_packet_number in 1..100 { let result = mtud.poll_transmit(now, probe_packet_number); if completed(mtud) { break; } // "Send" next probe assert!(result.is_some()); let probe_size = result.unwrap(); probed_sizes.push(probe_size); if probe_size <= link_payload_size_limit { mtud.on_acked(SpaceId::Data, probe_packet_number, probe_size); } else { mtud.on_probe_lost(); } } probed_sizes } #[test] fn black_hole_detector_ignores_burst_containing_non_suspicious_packet() { let mut mtud = default_mtud(); mtud.on_non_probe_lost(2, 1300); mtud.on_non_probe_lost(3, 1300); assert_eq!( mtud.black_hole_detector.largest_suspicious_packet_lost, Some(3) ); assert_eq!(mtud.black_hole_detector.suspicious_loss_bursts, 0); mtud.on_non_probe_lost(4, 800); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!( mtud.black_hole_detector.largest_suspicious_packet_lost, None ); assert_eq!(mtud.black_hole_detector.suspicious_loss_bursts, 0); } #[test] fn black_hole_detector_counts_burst_containing_only_suspicious_packets() { let mut mtud = default_mtud(); mtud.on_non_probe_lost(2, 1300); mtud.on_non_probe_lost(3, 1300); assert_eq!( mtud.black_hole_detector.largest_suspicious_packet_lost, Some(3) ); assert_eq!(mtud.black_hole_detector.suspicious_loss_bursts, 0); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!( mtud.black_hole_detector.largest_suspicious_packet_lost, None ); assert_eq!(mtud.black_hole_detector.suspicious_loss_bursts, 1); } #[test] fn black_hole_detector_ignores_empty_burst() { let mut mtud = default_mtud(); assert!(!mtud.black_hole_detected(Instant::now())); assert_eq!(mtud.black_hole_detector.suspicious_loss_bursts, 0); } #[test] fn mtu_discovery_disabled_does_nothing() { let mut mtud = MtuDiscovery::disabled(1_200, 1_200); let probe_size = mtud.poll_transmit(Instant::now(), 0); assert_eq!(probe_size, None); } #[test] fn mtu_discovery_disabled_lost_four_packet_bursts_triggers_black_hole_detection() { let mut mtud = MtuDiscovery::disabled(1_400, 1_250); let now = Instant::now(); for i in 0..4 { // The packets are never contiguous, so each one has its own burst mtud.on_non_probe_lost(i * 2, 1300); } assert!(mtud.black_hole_detected(now)); assert_eq!(mtud.current_mtu, 1250); assert_matches!(mtud.state, None); } #[test] fn mtu_discovery_lost_two_packet_bursts_does_not_trigger_black_hole_detection() { let mut mtud = default_mtud(); let now = Instant::now(); for i in 0..2 { mtud.on_non_probe_lost(i, 1300); assert!(!mtud.black_hole_detected(now)); } } #[test] fn mtu_discovery_lost_four_packet_bursts_triggers_black_hole_detection_and_resets_timer() { let mut mtud = default_mtud(); let now = Instant::now(); for i in 0..4 { // The packets are never contiguous, so each one has its own burst mtud.on_non_probe_lost(i * 2, 1300); } assert!(mtud.black_hole_detected(now)); assert_eq!(mtud.current_mtu, 1200); if let Phase::Complete(next_mtud_activation) = mtud.state.unwrap().phase { assert_eq!(next_mtud_activation, now + Duration::from_secs(60)); } else { panic!("Unexpected MTUD phase!"); } } #[test] fn mtu_discovery_after_complete_reactivates_when_interval_elapsed() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9_000); let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config); let now = Instant::now(); drive_to_completion(&mut mtud, now, 1_500); // Polling right after completion does not cause new packets to be sent assert_eq!(mtud.poll_transmit(now, 42), None); assert!(completed(&mtud)); assert_eq!(mtud.current_mtu, 1_471); // Polling after the interval has passed does (taking the current mtu as lower bound) assert_eq!( mtud.poll_transmit(now + Duration::from_secs(600), 43), Some(5235) ); match mtud.state.unwrap().phase { Phase::Searching(state) => { assert_eq!(state.lower_bound, 1_471); assert_eq!(state.upper_bound, 9_000); } _ => { panic!("Unexpected MTUD phase!") } } } #[test] fn mtu_discovery_lost_three_probes_lowers_probe_size() { let mut mtud = default_mtud(); let mut probe_sizes = (0..4).map(|i| { let probe_size = mtud.poll_transmit(Instant::now(), i); assert!(probe_size.is_some(), "no probe returned for packet {i}"); mtud.on_probe_lost(); probe_size.unwrap() }); // After the first probe is lost, it gets retransmitted twice let first_probe_size = probe_sizes.next().unwrap(); for _ in 0..2 { assert_eq!(probe_sizes.next().unwrap(), first_probe_size) } // After the third probe is lost, we decrement our probe size let fourth_probe_size = probe_sizes.next().unwrap(); assert!(fourth_probe_size < first_probe_size); assert_eq!( fourth_probe_size, first_probe_size - (first_probe_size - 1_200) / 2 - 1 ); } #[test] fn mtu_discovery_with_peer_max_udp_payload_size_clamps_upper_bound() { let mut mtud = default_mtud(); mtud.on_peer_max_udp_payload_size_received(1300); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1300); assert_eq!(mtud.current_mtu, 1300); let expected_probed_sizes = &[1250, 1275, 1300]; assert_eq!(probed_sizes, expected_probed_sizes); assert!(completed(&mtud)); } #[test] fn mtu_discovery_with_previous_peer_max_udp_payload_size_clamps_upper_bound() { let mut mtud = MtuDiscovery::new(1500, 1_200, Some(1400), MtuDiscoveryConfig::default()); assert_eq!(mtud.current_mtu, 1400); assert_eq!(mtud.state.as_ref().unwrap().peer_max_udp_payload_size, 1400); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); assert_eq!(mtud.current_mtu, 1400); assert!(probed_sizes.is_empty()); assert!(completed(&mtud)); } #[test] #[should_panic] fn mtu_discovery_with_peer_max_udp_payload_size_after_search_panics() { let mut mtud = default_mtud(); drive_to_completion(&mut mtud, Instant::now(), 1500); mtud.on_peer_max_udp_payload_size_received(1300); } #[test] fn mtu_discovery_with_1500_limit() { let mut mtud = default_mtud(); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); let expected_probed_sizes = &[1326, 1389, 1420, 1452]; assert_eq!(probed_sizes, expected_probed_sizes); assert_eq!(mtud.current_mtu, 1452); assert!(completed(&mtud)); } #[test] fn mtu_discovery_with_1500_limit_and_10000_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(10_000); let mut mtud = MtuDiscovery::new(1_200, 1_200, None, config); let probed_sizes = drive_to_completion(&mut mtud, Instant::now(), 1500); let expected_probed_sizes = &[ 5600, 5600, 5600, 3399, 3399, 3399, 2299, 2299, 2299, 1749, 1749, 1749, 1474, 1611, 1611, 1611, 1542, 1542, 1542, 1507, 1507, 1507, ]; assert_eq!(probed_sizes, expected_probed_sizes); assert_eq!(mtud.current_mtu, 1474); assert!(completed(&mtud)); } #[test] fn mtu_discovery_no_lost_probes_finds_maximum_udp_payload() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(MAX_UDP_PAYLOAD); let mut mtud = MtuDiscovery::new(1200, 1200, None, config); drive_to_completion(&mut mtud, Instant::now(), u16::MAX); assert_eq!(mtud.current_mtu, 65527); assert!(completed(&mtud)); } #[test] fn mtu_discovery_lost_half_of_probes_finds_maximum_udp_payload() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(MAX_UDP_PAYLOAD); let mut mtud = MtuDiscovery::new(1200, 1200, None, config); let now = Instant::now(); let mut iterations = 0; for i in 1..100 { iterations += 1; let probe_packet_number = i * 2 - 1; let other_packet_number = i * 2; let result = mtud.poll_transmit(Instant::now(), probe_packet_number); if completed(&mtud) { break; } // "Send" next probe assert!(result.is_some()); assert!(mtud.in_flight_mtu_probe().is_some()); // Nothing else to send while the probe is in-flight assert_matches!(mtud.poll_transmit(now, other_packet_number), None); if i % 2 == 0 { // ACK probe and ensure it results in an increase of current_mtu let previous_max_size = mtud.current_mtu; mtud.on_acked(SpaceId::Data, probe_packet_number, result.unwrap()); println!( "ACK packet {}. Previous MTU = {previous_max_size}. New MTU = {}", result.unwrap(), mtud.current_mtu ); // assert!(mtud.current_mtu > previous_max_size); } else { mtud.on_probe_lost(); } } assert_eq!(iterations, 25); assert_eq!(mtud.current_mtu, 65527); assert!(completed(&mtud)); } #[test] fn search_state_lower_bound_higher_than_upper_bound_clamps_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(1400); let state = SearchState::new(1500, u16::MAX, &config); assert_eq!(state.lower_bound, 1500); assert_eq!(state.upper_bound, 1500); } #[test] fn search_state_lower_bound_higher_than_peer_max_udp_payload_size_clamps_lower_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9000); let state = SearchState::new(1500, 1300, &config); assert_eq!(state.lower_bound, 1300); assert_eq!(state.upper_bound, 1300); } #[test] fn search_state_upper_bound_higher_than_peer_max_udp_payload_size_clamps_upper_bound() { let mut config = MtuDiscoveryConfig::default(); config.upper_bound(9000); let state = SearchState::new(1200, 1450, &config); assert_eq!(state.lower_bound, 1200); assert_eq!(state.upper_bound, 1450); } } quinn-proto-0.10.6/src/connection/pacing.rs000064400000000000000000000235001046102023000167400ustar 00000000000000//! Pacing of packet transmissions. use std::time::{Duration, Instant}; use tracing::warn; /// A simple token-bucket pacer /// /// The pacer's capacity is derived on a fraction of the congestion window /// which can be sent in regular intervals /// Once the bucket is empty, further transmission is blocked. /// The bucket refills at a rate slightly faster /// than one congestion window per RTT, as recommended in /// pub(super) struct Pacer { capacity: u64, last_window: u64, last_mtu: u16, tokens: u64, prev: Instant, } impl Pacer { /// Obtains a new [`Pacer`]. pub(super) fn new(smoothed_rtt: Duration, window: u64, mtu: u16, now: Instant) -> Self { let capacity = optimal_capacity(smoothed_rtt, window, mtu); Self { capacity, last_window: window, last_mtu: mtu, tokens: capacity, prev: now, } } /// Record that a packet has been transmitted. pub(super) fn on_transmit(&mut self, packet_length: u16) { self.tokens = self.tokens.saturating_sub(packet_length.into()) } /// Return how long we need to wait before sending `bytes_to_send` /// /// If we can send a packet right away, this returns `None`. Otherwise, returns `Some(d)`, /// where `d` is the time before this function should be called again. /// /// The 5/4 ratio used here comes from the suggestion that N = 1.25 in the draft IETF RFC for /// QUIC. pub(super) fn delay( &mut self, smoothed_rtt: Duration, bytes_to_send: u64, mtu: u16, window: u64, now: Instant, ) -> Option { debug_assert_ne!( window, 0, "zero-sized congestion control window is nonsense" ); if window != self.last_window || mtu != self.last_mtu { self.capacity = optimal_capacity(smoothed_rtt, window, mtu); // Clamp the tokens self.tokens = self.capacity.min(self.tokens); self.last_window = window; self.last_mtu = mtu; } // if we can already send a packet, there is no need for delay if self.tokens >= bytes_to_send { return None; } // we disable pacing for extremely large windows if window > u32::max_value().into() { return None; } let window = window as u32; let time_elapsed = now.checked_duration_since(self.prev).unwrap_or_else(|| { warn!("received a timestamp early than a previous recorded time, ignoring"); Default::default() }); if smoothed_rtt.as_nanos() == 0 { return None; } let elapsed_rtts = time_elapsed.as_secs_f64() / smoothed_rtt.as_secs_f64(); let new_tokens = window as f64 * 1.25 * elapsed_rtts; self.tokens = self .tokens .saturating_add(new_tokens as _) .min(self.capacity); self.prev = now; // if we can already send a packet, there is no need for delay if self.tokens >= bytes_to_send { return None; } let unscaled_delay = smoothed_rtt .checked_mul((bytes_to_send.max(self.capacity) - self.tokens) as _) .unwrap_or_else(|| Duration::new(u64::max_value(), 999_999_999)) / window; // divisions come before multiplications to prevent overflow // this is the time at which the pacing window becomes empty Some(self.prev + (unscaled_delay / 5) * 4) } } /// Calculates a pacer capacity for a certain window and RTT /// /// The goal is to emit a burst (of size `capacity`) in timer intervals /// which compromise between /// - ideally distributing datagrams over time /// - constantly waking up the connection to produce additional datagrams /// /// Too short burst intervals means we will never meet them since the timer /// accuracy in user-space is not high enough. If we miss the interval by more /// than 25%, we will lose that part of the congestion window since no additional /// tokens for the extra-elapsed time can be stored. /// /// Too long burst intervals make pacing less effective. fn optimal_capacity(smoothed_rtt: Duration, window: u64, mtu: u16) -> u64 { let rtt = smoothed_rtt.as_nanos().max(1); let capacity = ((window as u128 * BURST_INTERVAL_NANOS) / rtt) as u64; // Small bursts are less efficient (no GSO), could increase latency and don't effectively // use the channel's buffer capacity. Large bursts might block the connection on sending. capacity.clamp(MIN_BURST_SIZE * mtu as u64, MAX_BURST_SIZE * mtu as u64) } /// The burst interval /// /// The capacity will we refilled in 4/5 of that time. /// 2ms is chosen here since framework timers might have 1ms precision. /// If kernel-level pacing is supported later a higher time here might be /// more applicable. const BURST_INTERVAL_NANOS: u128 = 2_000_000; // 2ms /// Allows some usage of GSO, and doesn't slow down the handshake. const MIN_BURST_SIZE: u64 = 10; /// Creating 256 packets took 1ms in a benchmark, so larger bursts don't make sense. const MAX_BURST_SIZE: u64 = 256; #[cfg(test)] mod tests { use super::*; #[test] fn does_not_panic_on_bad_instant() { let old_instant = Instant::now(); let new_instant = old_instant + Duration::from_micros(15); let rtt = Duration::from_micros(400); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 0, 1500, 1, old_instant) .is_none()); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 1600, 1500, 1, old_instant) .is_none()); assert!(Pacer::new(rtt, 30000, 1500, new_instant) .delay(Duration::from_micros(0), 1500, 1500, 3000, old_instant) .is_none()); } #[test] fn derives_initial_capacity() { let window = 2_000_000; let mtu = 1500; let rtt = Duration::from_millis(50); let now = Instant::now(); let pacer = Pacer::new(rtt, window, mtu, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, pacer.capacity); let pacer = Pacer::new(Duration::from_millis(0), window, mtu, now); assert_eq!(pacer.capacity, MAX_BURST_SIZE * mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); let pacer = Pacer::new(rtt, 1, mtu, now); assert_eq!(pacer.capacity, MIN_BURST_SIZE * mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); } #[test] fn adjusts_capacity() { let window = 2_000_000; let mtu = 1500; let rtt = Duration::from_millis(50); let now = Instant::now(); let mut pacer = Pacer::new(rtt, window, mtu, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, pacer.capacity); let initial_tokens = pacer.tokens; pacer.delay(rtt, mtu as u64, mtu, window * 2, now); assert_eq!( pacer.capacity, (2 * window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, initial_tokens); pacer.delay(rtt, mtu as u64, mtu, window / 2, now); assert_eq!( pacer.capacity, (window as u128 / 2 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, initial_tokens / 2); pacer.delay(rtt, mtu as u64, mtu * 2, window, now); assert_eq!( pacer.capacity, (window as u128 * BURST_INTERVAL_NANOS / rtt.as_nanos()) as u64 ); pacer.delay(rtt, mtu as u64, 20_000, window, now); assert_eq!(pacer.capacity, 20_000_u64 * MIN_BURST_SIZE); } #[test] fn computes_pause_correctly() { let window = 2_000_000u64; let mtu = 1000; let rtt = Duration::from_millis(50); let old_instant = Instant::now(); let mut pacer = Pacer::new(rtt, window, mtu, old_instant); let packet_capacity = pacer.capacity / mtu as u64; for _ in 0..packet_capacity { assert_eq!( pacer.delay(rtt, mtu as u64, mtu, window, old_instant), None, "When capacity is available packets should be sent immediately" ); pacer.on_transmit(mtu); } let pace_duration = Duration::from_nanos((BURST_INTERVAL_NANOS * 4 / 5) as u64); assert_eq!( pacer .delay(rtt, mtu as u64, mtu, window, old_instant) .expect("Send must be delayed") .duration_since(old_instant), pace_duration ); // Refill half of the tokens assert_eq!( pacer.delay( rtt, mtu as u64, mtu, window, old_instant + pace_duration / 2 ), None ); assert_eq!(pacer.tokens, pacer.capacity / 2); for _ in 0..packet_capacity / 2 { assert_eq!( pacer.delay(rtt, mtu as u64, mtu, window, old_instant), None, "When capacity is available packets should be sent immediately" ); pacer.on_transmit(mtu); } // Refill all capacity by waiting more than the expected duration assert_eq!( pacer.delay( rtt, mtu as u64, mtu, window, old_instant + pace_duration * 3 / 2 ), None ); assert_eq!(pacer.tokens, pacer.capacity); } } quinn-proto-0.10.6/src/connection/packet_builder.rs000064400000000000000000000222531046102023000204600ustar 00000000000000use std::time::Instant; use bytes::{Bytes, BytesMut}; use rand::Rng; use tracing::{trace, trace_span}; use super::{spaces::SentPacket, Connection, SentFrames}; use crate::{ frame::{self, Close}, packet::{Header, LongType, PacketNumber, PartialEncode, SpaceId, FIXED_BIT}, TransportError, TransportErrorCode, }; pub(super) struct PacketBuilder { pub(super) datagram_start: usize, pub(super) space: SpaceId, pub(super) partial_encode: PartialEncode, pub(super) ack_eliciting: bool, pub(super) exact_number: u64, pub(super) short_header: bool, pub(super) min_size: usize, pub(super) max_size: usize, pub(super) tag_len: usize, pub(super) span: tracing::Span, } impl PacketBuilder { /// Write a new packet header to `buffer` and determine the packet's properties /// /// Marks the connection drained and returns `None` if the confidentiality limit would be /// violated. pub(super) fn new( now: Instant, space_id: SpaceId, buffer: &mut BytesMut, buffer_capacity: usize, datagram_start: usize, ack_eliciting: bool, conn: &mut Connection, version: u32, ) -> Option { // Initiate key update if we're approaching the confidentiality limit let confidentiality_limit = conn.spaces[space_id] .crypto .as_ref() .map_or_else( || &conn.zero_rtt_crypto.as_ref().unwrap().packet, |keys| &keys.packet.local, ) .confidentiality_limit(); let sent_with_keys = conn.spaces[space_id].sent_with_keys; if space_id == SpaceId::Data { if sent_with_keys.saturating_add(KEY_UPDATE_MARGIN) >= confidentiality_limit { conn.initiate_key_update(); } } else if sent_with_keys.saturating_add(1) == confidentiality_limit { // We still have time to attempt a graceful close conn.close_inner( now, Close::Connection(frame::ConnectionClose { error_code: TransportErrorCode::AEAD_LIMIT_REACHED, frame_type: None, reason: Bytes::from_static(b"confidentiality limit reached"), }), ) } else if sent_with_keys > confidentiality_limit { // Confidentiality limited violated and there's nothing we can do conn.kill(TransportError::AEAD_LIMIT_REACHED("confidentiality limit reached").into()); return None; } let space = &mut conn.spaces[space_id]; space.loss_probes = space.loss_probes.saturating_sub(1); let exact_number = space.get_tx_number(); let span = trace_span!("send", space = ?space_id, pn = exact_number); span.with_subscriber(|(id, dispatch)| dispatch.enter(id)); let number = PacketNumber::new(exact_number, space.largest_acked_packet.unwrap_or(0)); let header = match space_id { SpaceId::Data if space.crypto.is_some() => Header::Short { dst_cid: conn.rem_cids.active(), number, spin: if conn.spin_enabled { conn.spin } else { conn.rng.gen() }, key_phase: conn.key_phase, }, SpaceId::Data => Header::Long { ty: LongType::ZeroRtt, src_cid: conn.handshake_cid, dst_cid: conn.rem_cids.active(), number, version, }, SpaceId::Handshake => Header::Long { ty: LongType::Handshake, src_cid: conn.handshake_cid, dst_cid: conn.rem_cids.active(), number, version, }, SpaceId::Initial => Header::Initial { src_cid: conn.handshake_cid, dst_cid: conn.rem_cids.active(), token: conn.retry_token.clone(), number, version, }, }; let partial_encode = header.encode(buffer); if conn.peer_params.grease_quic_bit && conn.rng.gen() { buffer[partial_encode.start] ^= FIXED_BIT; } let (sample_size, tag_len) = if let Some(ref crypto) = space.crypto { ( crypto.header.local.sample_size(), crypto.packet.local.tag_len(), ) } else if space_id == SpaceId::Data { let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap(); (zero_rtt.header.sample_size(), zero_rtt.packet.tag_len()) } else { unreachable!("tried to send {:?} packet without keys", space_id); }; // Each packet must be large enough for header protection sampling, i.e. the combined // lengths of the encoded packet number and protected payload must be at least 4 bytes // longer than the sample required for header protection. Further, each packet should be at // least tag_len + 6 bytes larger than the destination CID on incoming packets so that the // peer may send stateless resets that are indistinguishable from regular traffic. // pn_len + payload_len + tag_len >= sample_size + 4 // payload_len >= sample_size + 4 - pn_len - tag_len let min_size = Ord::max( buffer.len() + (sample_size + 4).saturating_sub(number.len() + tag_len), partial_encode.start + conn.rem_cids.active().len() + 6, ); let max_size = buffer_capacity - partial_encode.start - partial_encode.header_len - tag_len; Some(Self { datagram_start, space: space_id, partial_encode, exact_number, short_header: header.is_short(), min_size, max_size, span, tag_len, ack_eliciting, }) } pub(super) fn pad_to(&mut self, min_size: u16) { let prev = self.min_size; self.min_size = self.datagram_start + (min_size as usize) - self.tag_len; debug_assert!(self.min_size >= prev, "padding must not shrink datagram"); } pub(super) fn finish_and_track( self, now: Instant, conn: &mut Connection, sent: Option, buffer: &mut BytesMut, ) { let ack_eliciting = self.ack_eliciting; let exact_number = self.exact_number; let space_id = self.space; let (size, padded) = self.finish(conn, buffer); let sent = match sent { Some(sent) => sent, None => return, }; let size = match padded || ack_eliciting { true => size as u16, false => 0, }; let packet = SentPacket { largest_acked: sent.largest_acked, time_sent: now, size, ack_eliciting, retransmits: sent.retransmits, stream_frames: sent.stream_frames, }; conn.in_flight.insert(&packet); conn.spaces[space_id].sent(exact_number, packet); conn.stats.path.sent_packets += 1; conn.reset_keep_alive(now); if size != 0 { if ack_eliciting { conn.spaces[space_id].time_of_last_ack_eliciting_packet = Some(now); if conn.permit_idle_reset { conn.reset_idle_timeout(now, space_id); } conn.permit_idle_reset = false; } conn.set_loss_detection_timer(now); conn.path.pacing.on_transmit(size); } } /// Encrypt packet, returning the length of the packet and whether padding was added pub(super) fn finish(self, conn: &mut Connection, buffer: &mut BytesMut) -> (usize, bool) { let pad = buffer.len() < self.min_size; if pad { trace!("PADDING * {}", self.min_size - buffer.len()); buffer.resize(self.min_size, 0); } let space = &conn.spaces[self.space]; let (header_crypto, packet_crypto) = if let Some(ref crypto) = space.crypto { (&*crypto.header.local, &*crypto.packet.local) } else if self.space == SpaceId::Data { let zero_rtt = conn.zero_rtt_crypto.as_ref().unwrap(); (&*zero_rtt.header, &*zero_rtt.packet) } else { unreachable!("tried to send {:?} packet without keys", self.space); }; debug_assert_eq!( packet_crypto.tag_len(), self.tag_len, "Mismatching crypto tag len" ); buffer.resize(buffer.len() + packet_crypto.tag_len(), 0); let encode_start = self.partial_encode.start; let packet_buf = &mut buffer[encode_start..]; self.partial_encode.finish( packet_buf, header_crypto, Some((self.exact_number, packet_crypto)), ); self.span .with_subscriber(|(id, dispatch)| dispatch.exit(id)); (buffer.len() - encode_start, pad) } } /// Perform key updates this many packets before the AEAD confidentiality limit. /// /// Chosen arbitrarily, intended to be large enough to prevent spurious connection loss. const KEY_UPDATE_MARGIN: u64 = 10000; quinn-proto-0.10.6/src/connection/paths.rs000064400000000000000000000133531046102023000166230ustar 00000000000000use std::{cmp, net::SocketAddr, time::Duration, time::Instant}; use super::{mtud::MtuDiscovery, pacing::Pacer}; use crate::{config::MtuDiscoveryConfig, congestion, packet::SpaceId, TIMER_GRANULARITY}; /// Description of a particular network path pub(super) struct PathData { pub(super) remote: SocketAddr, pub(super) rtt: RttEstimator, /// Whether we're enabling ECN on outgoing packets pub(super) sending_ecn: bool, /// Congestion controller state pub(super) congestion: Box, /// Pacing state pub(super) pacing: Pacer, pub(super) challenge: Option, pub(super) challenge_pending: bool, /// Whether we're certain the peer can both send and receive on this address /// /// Initially equal to `use_stateless_retry` for servers, and becomes false again on every /// migration. Always true for clients. pub(super) validated: bool, /// Total size of all UDP datagrams sent on this path pub(super) total_sent: u64, /// Total size of all UDP datagrams received on this path pub(super) total_recvd: u64, /// The state of the MTU discovery process pub(super) mtud: MtuDiscovery, /// Packet number of the first packet sent after an RTT sample was collected on this path /// /// Used in persistent congestion determination. pub(super) first_packet_after_rtt_sample: Option<(SpaceId, u64)>, } impl PathData { pub(super) fn new( remote: SocketAddr, initial_rtt: Duration, congestion: Box, initial_mtu: u16, min_mtu: u16, peer_max_udp_payload_size: Option, mtud_config: Option, now: Instant, validated: bool, ) -> Self { Self { remote, rtt: RttEstimator::new(initial_rtt), sending_ecn: true, pacing: Pacer::new(initial_rtt, congestion.initial_window(), initial_mtu, now), congestion, challenge: None, challenge_pending: false, validated, total_sent: 0, total_recvd: 0, mtud: mtud_config.map_or(MtuDiscovery::disabled(initial_mtu, min_mtu), |config| { MtuDiscovery::new(initial_mtu, min_mtu, peer_max_udp_payload_size, config) }), first_packet_after_rtt_sample: None, } } pub(super) fn from_previous(remote: SocketAddr, prev: &Self, now: Instant) -> Self { let congestion = prev.congestion.clone_box(); let smoothed_rtt = prev.rtt.get(); Self { remote, rtt: prev.rtt, pacing: Pacer::new(smoothed_rtt, congestion.window(), prev.current_mtu(), now), sending_ecn: true, congestion, challenge: None, challenge_pending: false, validated: false, total_sent: 0, total_recvd: 0, mtud: prev.mtud.clone(), first_packet_after_rtt_sample: prev.first_packet_after_rtt_sample, } } /// Indicates whether we're a server that hasn't validated the peer's address and hasn't /// received enough data from the peer to permit sending `bytes_to_send` additional bytes pub(super) fn anti_amplification_blocked(&self, bytes_to_send: u64) -> bool { !self.validated && self.total_recvd * 3 < self.total_sent + bytes_to_send } /// Returns the path's current MTU pub(super) fn current_mtu(&self) -> u16 { self.mtud.current_mtu() } } /// RTT estimation for a particular network path #[derive(Copy, Clone)] pub struct RttEstimator { /// The most recent RTT measurement made when receiving an ack for a previously unacked packet latest: Duration, /// The smoothed RTT of the connection, computed as described in RFC6298 smoothed: Option, /// The RTT variance, computed as described in RFC6298 var: Duration, /// The minimum RTT seen in the connection, ignoring ack delay. min: Duration, } impl RttEstimator { fn new(initial_rtt: Duration) -> Self { Self { latest: initial_rtt, smoothed: None, var: initial_rtt / 2, min: initial_rtt, } } /// The current best RTT estimation. pub fn get(&self) -> Duration { self.smoothed.unwrap_or(self.latest) } /// Conservative estimate of RTT /// /// Takes the maximum of smoothed and latest RTT, as recommended /// in 6.1.2 of the recovery spec (draft 29). pub fn conservative(&self) -> Duration { self.get().max(self.latest) } /// Minimum RTT registered so far for this estimator. pub fn min(&self) -> Duration { self.min } // PTO computed as described in RFC9002#6.2.1 pub(crate) fn pto_base(&self) -> Duration { self.get() + cmp::max(4 * self.var, TIMER_GRANULARITY) } pub(crate) fn update(&mut self, ack_delay: Duration, rtt: Duration) { self.latest = rtt; // min_rtt ignores ack delay. self.min = cmp::min(self.min, self.latest); // Based on RFC6298. if let Some(smoothed) = self.smoothed { let adjusted_rtt = if self.min + ack_delay <= self.latest { self.latest - ack_delay } else { self.latest }; let var_sample = if smoothed > adjusted_rtt { smoothed - adjusted_rtt } else { adjusted_rtt - smoothed }; self.var = (3 * self.var + var_sample) / 4; self.smoothed = Some((7 * smoothed + adjusted_rtt) / 8); } else { self.smoothed = Some(self.latest); self.var = self.latest / 2; self.min = self.latest; } } } quinn-proto-0.10.6/src/connection/send_buffer.rs000064400000000000000000000327331046102023000177710ustar 00000000000000use std::{collections::VecDeque, ops::Range}; use bytes::{Buf, Bytes}; use crate::{range_set::RangeSet, VarInt}; /// Buffer of outgoing retransmittable stream data #[derive(Default, Debug)] pub(super) struct SendBuffer { /// Data queued by the application but not yet acknowledged. May or may not have been sent. unacked_segments: VecDeque, /// Total size of `unacked_segments` unacked_len: usize, /// The first offset that hasn't been written by the application, i.e. the offset past the end of `unacked` offset: u64, /// The first offset that hasn't been sent /// /// Always lies in (offset - unacked.len())..offset unsent: u64, /// Acknowledged ranges which couldn't be discarded yet as they don't include the earliest /// offset in `unacked` // TODO: Recover storage from these by compacting (#700) acks: RangeSet, /// Previously transmitted ranges deemed lost retransmits: RangeSet, } impl SendBuffer { /// Construct an empty buffer at the initial offset pub(super) fn new() -> Self { Self::default() } /// Append application data to the end of the stream pub(super) fn write(&mut self, data: Bytes) { self.unacked_len += data.len(); self.offset += data.len() as u64; self.unacked_segments.push_back(data); } /// Discard a range of acknowledged stream data pub(super) fn ack(&mut self, mut range: Range) { // Clamp the range to data which is still tracked let base_offset = self.offset - self.unacked_len as u64; range.start = base_offset.max(range.start); range.end = base_offset.max(range.end); self.acks.insert(range); while self.acks.min() == Some(self.offset - self.unacked_len as u64) { let prefix = self.acks.pop_min().unwrap(); let mut to_advance = (prefix.end - prefix.start) as usize; self.unacked_len -= to_advance; while to_advance > 0 { let front = self .unacked_segments .front_mut() .expect("Expected buffered data"); if front.len() <= to_advance { to_advance -= front.len(); self.unacked_segments.pop_front(); if self.unacked_segments.len() * 4 < self.unacked_segments.capacity() { self.unacked_segments.shrink_to_fit(); } } else { front.advance(to_advance); to_advance = 0; } } } } /// Compute the next range to transmit on this stream and update state to account for that /// transmission. /// /// `max_len` here includes the space which is available to transmit the /// offset and length of the data to send. The caller has to guarantee that /// there is at least enough space available to write maximum-sized metadata /// (8 byte offset + 8 byte length). /// /// The method returns a tuple: /// - The first return value indicates the range of data to send /// - The second return value indicates whether the length needs to be encoded /// in the STREAM frames metadata (`true`), or whether it can be omitted /// since the selected range will fill the whole packet. pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range, bool) { debug_assert!(max_len >= 8 + 8); let mut encode_length = false; if let Some(range) = self.retransmits.pop_min() { // Retransmit sent data // When the offset is known, we know how many bytes are required to encode it. // Offset 0 requires no space if range.start != 0 { max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) }); } if range.end - range.start < max_len as u64 { encode_length = true; max_len -= 8; } let end = range.end.min((max_len as u64).saturating_add(range.start)); if end != range.end { self.retransmits.insert(end..range.end); } return (range.start..end, encode_length); } // Transmit new data // When the offset is known, we know how many bytes are required to encode it. // Offset 0 requires no space if self.unsent != 0 { max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) }); } if self.offset - self.unsent < max_len as u64 { encode_length = true; max_len -= 8; } let end = self .offset .min((max_len as u64).saturating_add(self.unsent)); let result = self.unsent..end; self.unsent = end; (result, encode_length) } /// Returns data which is associated with a range /// /// This function can return a subset of the range, if the data is stored /// in noncontiguous fashion in the send buffer. In this case callers /// should call the function again with an incremented start offset to /// retrieve more data. pub(super) fn get(&self, offsets: Range) -> &[u8] { let base_offset = self.offset - self.unacked_len as u64; let mut segment_offset = base_offset; for segment in self.unacked_segments.iter() { if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() as u64 { let start = (offsets.start - segment_offset) as usize; let end = (offsets.end - segment_offset) as usize; return &segment[start..end.min(segment.len())]; } segment_offset += segment.len() as u64; } &[] } /// Queue a range of sent but unacknowledged data to be retransmitted pub(super) fn retransmit(&mut self, range: Range) { debug_assert!(range.end <= self.unsent, "unsent data can't be lost"); self.retransmits.insert(range); } pub(super) fn retransmit_all_for_0rtt(&mut self) { debug_assert_eq!(self.offset, self.unacked_len as u64); self.unsent = 0; } /// First stream offset unwritten by the application, i.e. the offset that the next write will /// begin at pub(super) fn offset(&self) -> u64 { self.offset } /// Whether all sent data has been acknowledged pub(super) fn is_fully_acked(&self) -> bool { self.unacked_len == 0 } /// Whether there's data to send /// /// There may be sent unacknowledged data even when this is false. pub(super) fn has_unsent_data(&self) -> bool { self.unsent != self.offset || !self.retransmits.is_empty() } /// Compute the amount of data that hasn't been acknowledged pub(super) fn unacked(&self) -> u64 { self.unacked_len as u64 - self.acks.iter().map(|x| x.end - x.start).sum::() } } #[cfg(test)] mod tests { use super::*; #[test] fn fragment_with_length() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; buf.write(MSG.into()); // 0 byte offset => 19 bytes left => 13 byte data isn't enough // with 8 bytes reserved for length 11 payload bytes will fit assert_eq!(buf.poll_transmit(19), (0..11, true)); assert_eq!( buf.poll_transmit(MSG.len() + 16 - 11), (11..MSG.len() as u64, true) ); assert_eq!( buf.poll_transmit(58), (MSG.len() as u64..MSG.len() as u64, true) ); } #[test] fn fragment_without_length() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with some extra data!"; buf.write(MSG.into()); // 0 byte offset => 19 bytes left => can be filled by 34 bytes payload assert_eq!(buf.poll_transmit(19), (0..19, false)); assert_eq!( buf.poll_transmit(MSG.len() - 19 + 1), (19..MSG.len() as u64, false) ); assert_eq!( buf.poll_transmit(58), (MSG.len() as u64..MSG.len() as u64, true) ); } #[test] fn reserves_encoded_offset() { let mut buf = SendBuffer::new(); // Pretend we have more than 1 GB of data in the buffer let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]); for _ in 0..1025 { buf.write(chunk.clone()); } const SIZE1: u64 = 64; const SIZE2: u64 = 16 * 1024; const SIZE3: u64 = 1024 * 1024 * 1024; // Offset 0 requires no space assert_eq!(buf.poll_transmit(16), (0..16, false)); buf.retransmit(0..16); assert_eq!(buf.poll_transmit(16), (0..16, false)); let mut transmitted = 16u64; // Offset 16 requires 1 byte assert_eq!( buf.poll_transmit((SIZE1 - transmitted + 1) as usize), (transmitted..SIZE1, false) ); buf.retransmit(transmitted..SIZE1); assert_eq!( buf.poll_transmit((SIZE1 - transmitted + 1) as usize), (transmitted..SIZE1, false) ); transmitted = SIZE1; // Offset 64 requires 2 bytes assert_eq!( buf.poll_transmit((SIZE2 - transmitted + 2) as usize), (transmitted..SIZE2, false) ); buf.retransmit(transmitted..SIZE2); assert_eq!( buf.poll_transmit((SIZE2 - transmitted + 2) as usize), (transmitted..SIZE2, false) ); transmitted = SIZE2; // Offset 16384 requires requires 4 bytes assert_eq!( buf.poll_transmit((SIZE3 - transmitted + 4) as usize), (transmitted..SIZE3, false) ); buf.retransmit(transmitted..SIZE3); assert_eq!( buf.poll_transmit((SIZE3 - transmitted + 4) as usize), (transmitted..SIZE3, false) ); transmitted = SIZE3; // Offset 1GB requires 8 bytes assert_eq!( buf.poll_transmit(chunk.len() + 8), (transmitted..transmitted + chunk.len() as u64, false) ); buf.retransmit(transmitted..transmitted + chunk.len() as u64); assert_eq!( buf.poll_transmit(chunk.len() + 8), (transmitted..transmitted + chunk.len() as u64, false) ); } #[test] fn multiple_segments() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; const MSG_LEN: u64 = MSG.len() as u64; const SEG1: &[u8] = b"He"; buf.write(SEG1.into()); const SEG2: &[u8] = b"llo,"; buf.write(SEG2.into()); const SEG3: &[u8] = b" w"; buf.write(SEG3.into()); const SEG4: &[u8] = b"o"; buf.write(SEG4.into()); const SEG5: &[u8] = b"rld!"; buf.write(SEG5.into()); assert_eq!(aggregate_unacked(&buf), MSG); assert_eq!(buf.poll_transmit(16), (0..8, true)); assert_eq!(buf.get(0..5), SEG1); assert_eq!(buf.get(2..8), SEG2); assert_eq!(buf.get(6..8), SEG3); assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true)); assert_eq!(buf.get(8..MSG_LEN), SEG4); assert_eq!(buf.get(9..MSG_LEN), SEG5); assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true)); // Now drain the segments buf.ack(0..1); assert_eq!(aggregate_unacked(&buf), &MSG[1..]); buf.ack(0..3); assert_eq!(aggregate_unacked(&buf), &MSG[3..]); buf.ack(3..5); assert_eq!(aggregate_unacked(&buf), &MSG[5..]); buf.ack(7..9); assert_eq!(aggregate_unacked(&buf), &MSG[5..]); buf.ack(4..7); assert_eq!(aggregate_unacked(&buf), &MSG[9..]); buf.ack(0..MSG_LEN); assert_eq!(aggregate_unacked(&buf), &[]); } #[test] fn retransmit() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with extra data!"; buf.write(MSG.into()); // Transmit two frames assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (16..23, true)); // Lose the first, but not the second buf.retransmit(0..16); // Ensure we only retransmit the lost frame, then continue sending fresh data assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true)); // Lose the second frame buf.retransmit(16..23); assert_eq!(buf.poll_transmit(16), (16..23, true)); } #[test] fn ack() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world!"; buf.write(MSG.into()); assert_eq!(buf.poll_transmit(16), (0..8, true)); buf.ack(0..8); assert_eq!(aggregate_unacked(&buf), &MSG[8..]); } #[test] fn reordered_ack() { let mut buf = SendBuffer::new(); const MSG: &[u8] = b"Hello, world with extra data!"; buf.write(MSG.into()); assert_eq!(buf.poll_transmit(16), (0..16, false)); assert_eq!(buf.poll_transmit(16), (16..23, true)); buf.ack(16..23); assert_eq!(aggregate_unacked(&buf), MSG); buf.ack(0..16); assert_eq!(aggregate_unacked(&buf), &MSG[23..]); assert!(buf.acks.is_empty()); } fn aggregate_unacked(buf: &SendBuffer) -> Vec { let mut result = Vec::new(); for segment in buf.unacked_segments.iter() { result.extend_from_slice(&segment[..]); } result } } quinn-proto-0.10.6/src/connection/spaces.rs000064400000000000000000000457501046102023000167700ustar 00000000000000use std::{ cmp, collections::{BTreeMap, VecDeque}, mem, ops::{Index, IndexMut}, time::{Duration, Instant}, }; use rustc_hash::FxHashSet; use super::assembler::Assembler; use crate::{ connection::StreamsState, crypto::Keys, frame, packet::SpaceId, range_set::ArrayRangeSet, shared::IssuedCid, Dir, StreamId, VarInt, }; pub(super) struct PacketSpace { pub(super) crypto: Option, pub(super) dedup: Dedup, /// Highest received packet number pub(super) rx_packet: u64, /// Data to send pub(super) pending: Retransmits, /// Packet numbers to acknowledge pub(super) pending_acks: PendingAcks, /// The packet number of the next packet that will be sent, if any. pub(super) next_packet_number: u64, /// The largest packet number the remote peer acknowledged in an ACK frame. pub(super) largest_acked_packet: Option, pub(super) largest_acked_packet_sent: Instant, /// Transmitted but not acked // We use a BTreeMap here so we can efficiently query by range on ACK and for loss detection pub(super) sent_packets: BTreeMap, /// Number of explicit congestion notification codepoints seen on incoming packets pub(super) ecn_counters: frame::EcnCounts, /// Recent ECN counters sent by the peer in ACK frames /// /// Updated (and inspected) whenever we receive an ACK with a new highest acked packet /// number. Stored per-space to simplify verification, which would otherwise have difficulty /// distinguishing between ECN bleaching and counts having been updated by a near-simultaneous /// ACK already processed in another space. pub(super) ecn_feedback: frame::EcnCounts, /// Incoming cryptographic handshake stream pub(super) crypto_stream: Assembler, /// Current offset of outgoing cryptographic handshake stream pub(super) crypto_offset: u64, /// The time the most recently sent retransmittable packet was sent. pub(super) time_of_last_ack_eliciting_packet: Option, /// The time at which the earliest sent packet in this space will be considered lost based on /// exceeding the reordering window in time. Only set for packets numbered prior to a packet /// that has been acknowledged. pub(super) loss_time: Option, /// Number of tail loss probes to send pub(super) loss_probes: u32, pub(super) ping_pending: bool, /// Number of congestion control "in flight" bytes pub(super) in_flight: u64, /// Number of packets sent in the current key phase pub(super) sent_with_keys: u64, } impl PacketSpace { pub(super) fn new(now: Instant) -> Self { Self { crypto: None, dedup: Dedup::new(), rx_packet: 0, pending: Retransmits::default(), pending_acks: PendingAcks::default(), next_packet_number: 0, largest_acked_packet: None, largest_acked_packet_sent: now, sent_packets: BTreeMap::new(), ecn_counters: frame::EcnCounts::ZERO, ecn_feedback: frame::EcnCounts::ZERO, crypto_stream: Assembler::new(), crypto_offset: 0, time_of_last_ack_eliciting_packet: None, loss_time: None, loss_probes: 0, ping_pending: false, in_flight: 0, sent_with_keys: 0, } } /// Queue data for a tail loss probe (or anti-amplification deadlock prevention) packet /// /// Probes are sent similarly to normal packets when an expect ACK has not arrived. We never /// deem a packet lost until we receive an ACK that should have included it, but if a trailing /// run of packets (or their ACKs) are lost, this might not happen in a timely fashion. We send /// probe packets to force an ACK, and exempt them from congestion control to prevent a deadlock /// when the congestion window is filled with lost tail packets. /// /// We prefer to send new data, to make the most efficient use of bandwidth. If there's no data /// waiting to be sent, then we retransmit in-flight data to reduce odds of loss. If there's no /// in-flight data either, we're probably a client guarding against a handshake /// anti-amplification deadlock and we just make something up. pub(super) fn maybe_queue_probe(&mut self, streams: &StreamsState) { if self.loss_probes == 0 { return; } // Retransmit the data of the oldest in-flight packet if !self.pending.is_empty(streams) { // There's real data to send here, no need to make something up return; } for packet in self.sent_packets.values_mut() { if !packet.retransmits.is_empty(streams) { // Remove retransmitted data from the old packet so we don't end up retransmitting // it *again* even if the copy we're sending now gets acknowledged. self.pending |= mem::take(&mut packet.retransmits); return; } } // Nothing new to send and nothing to retransmit, so fall back on a ping. This should only // happen in rare cases during the handshake when the server becomes blocked by // anti-amplification. self.ping_pending = true; } pub(super) fn get_tx_number(&mut self) -> u64 { // TODO: Handle packet number overflow gracefully assert!(self.next_packet_number < 2u64.pow(62)); let x = self.next_packet_number; self.next_packet_number += 1; self.sent_with_keys += 1; x } pub(super) fn can_send(&self, streams: &StreamsState) -> SendableFrames { let acks = self.pending_acks.can_send(); let other = !self.pending.is_empty(streams) || self.ping_pending; SendableFrames { acks, other } } /// Verifies sanity of an ECN block and returns whether congestion was encountered. pub(super) fn detect_ecn( &mut self, newly_acked: u64, ecn: frame::EcnCounts, ) -> Result { let ect0_increase = ecn .ect0 .checked_sub(self.ecn_feedback.ect0) .ok_or("peer ECT(0) count regression")?; let ect1_increase = ecn .ect1 .checked_sub(self.ecn_feedback.ect1) .ok_or("peer ECT(1) count regression")?; let ce_increase = ecn .ce .checked_sub(self.ecn_feedback.ce) .ok_or("peer CE count regression")?; let total_increase = ect0_increase + ect1_increase + ce_increase; if total_increase < newly_acked { return Err("ECN bleaching"); } if (ect0_increase + ce_increase) < newly_acked || ect1_increase != 0 { return Err("ECN corruption"); } // If total_increase > newly_acked (which happens when ACKs are lost), this is required by // the draft so that long-term drift does not occur. If =, then the only question is whether // to count CE packets as CE or ECT0. Recording them as CE is more consistent and keeps the // congestion check obvious. self.ecn_feedback = ecn; Ok(ce_increase != 0) } pub(super) fn sent(&mut self, number: u64, packet: SentPacket) { self.in_flight += u64::from(packet.size); self.sent_packets.insert(number, packet); } } impl Index for [PacketSpace; 3] { type Output = PacketSpace; fn index(&self, space: SpaceId) -> &PacketSpace { &self.as_ref()[space as usize] } } impl IndexMut for [PacketSpace; 3] { fn index_mut(&mut self, space: SpaceId) -> &mut PacketSpace { &mut self.as_mut()[space as usize] } } /// Represents one or more packets subject to retransmission #[derive(Debug, Clone)] pub(super) struct SentPacket { /// The time the packet was sent. pub(super) time_sent: Instant, /// The number of bytes sent in the packet, not including UDP or IP overhead, but including QUIC /// framing overhead. Zero if this packet is not counted towards congestion control, i.e. not an /// "in flight" packet. pub(super) size: u16, /// Whether an acknowledgement is expected directly in response to this packet. pub(super) ack_eliciting: bool, /// The largest packet number acknowledged by this packet pub(super) largest_acked: Option, /// Data which needs to be retransmitted in case the packet is lost. /// The data is boxed to minimize `SentPacket` size for the typical case of /// packets only containing ACKs and STREAM frames. pub(super) retransmits: ThinRetransmits, /// Metadata for stream frames in a packet /// /// The actual application data is stored with the stream state. pub(super) stream_frames: frame::StreamMetaVec, } /// Retransmittable data queue #[allow(unreachable_pub)] // fuzzing only #[derive(Debug, Default, Clone)] pub struct Retransmits { pub(super) max_data: bool, pub(super) max_stream_id: [bool; 2], pub(super) reset_stream: Vec<(StreamId, VarInt)>, pub(super) stop_sending: Vec, pub(super) max_stream_data: FxHashSet, pub(super) crypto: VecDeque, pub(super) new_cids: Vec, pub(super) retire_cids: Vec, pub(super) handshake_done: bool, } impl Retransmits { pub(super) fn is_empty(&self, streams: &StreamsState) -> bool { !self.max_data && !self.max_stream_id.into_iter().any(|x| x) && self.reset_stream.is_empty() && self.stop_sending.is_empty() && self .max_stream_data .iter() .all(|&id| !streams.can_send_flow_control(id)) && self.crypto.is_empty() && self.new_cids.is_empty() && self.retire_cids.is_empty() && !self.handshake_done } } impl ::std::ops::BitOrAssign for Retransmits { fn bitor_assign(&mut self, rhs: Self) { // We reduce in-stream head-of-line blocking by queueing retransmits before other data for // STREAM and CRYPTO frames. self.max_data |= rhs.max_data; for dir in Dir::iter() { self.max_stream_id[dir as usize] |= rhs.max_stream_id[dir as usize]; } self.reset_stream.extend_from_slice(&rhs.reset_stream); self.stop_sending.extend_from_slice(&rhs.stop_sending); self.max_stream_data.extend(&rhs.max_stream_data); for crypto in rhs.crypto.into_iter().rev() { self.crypto.push_front(crypto); } self.new_cids.extend(&rhs.new_cids); self.retire_cids.extend(rhs.retire_cids); self.handshake_done |= rhs.handshake_done; } } impl ::std::ops::BitOrAssign for Retransmits { fn bitor_assign(&mut self, rhs: ThinRetransmits) { if let Some(retransmits) = rhs.retransmits { self.bitor_assign(*retransmits) } } } impl ::std::iter::FromIterator for Retransmits { fn from_iter(iter: T) -> Self where T: IntoIterator, { let mut result = Self::default(); for packet in iter { result |= packet; } result } } /// A variant of `Retransmits` which only allocates storage when required #[derive(Debug, Default, Clone)] pub(super) struct ThinRetransmits { retransmits: Option>, } impl ThinRetransmits { /// Returns `true` if no retransmits are necessary pub(super) fn is_empty(&self, streams: &StreamsState) -> bool { match &self.retransmits { Some(retransmits) => retransmits.is_empty(streams), None => true, } } /// Returns a reference to the retransmits stored in this box pub(super) fn get(&self) -> Option<&Retransmits> { self.retransmits.as_deref() } /// Returns a mutable reference to the stored retransmits /// /// This function will allocate a backing storage if required. pub(super) fn get_or_create(&mut self) -> &mut Retransmits { if self.retransmits.is_none() { self.retransmits = Some(Box::default()); } self.retransmits.as_deref_mut().unwrap() } } /// RFC4303-style sliding window packet number deduplicator. /// /// A contiguous bitfield, where each bit corresponds to a packet number and the rightmost bit is /// always set. A set bit represents a packet that has been successfully authenticated. Bits left of /// the window are assumed to be set. /// /// ```text /// ...xxxxxxxxx 1 0 /// ^ ^ ^ /// window highest next /// ``` pub(super) struct Dedup { window: Window, /// Lowest packet number higher than all yet authenticated. next: u64, } /// Inner bitfield type. /// /// Because QUIC never reuses packet numbers, this only needs to be large enough to deal with /// packets that are reordered but still delivered in a timely manner. type Window = u128; /// Number of packets tracked by `Dedup`. const WINDOW_SIZE: u64 = 1 + mem::size_of::() as u64 * 8; impl Dedup { /// Construct an empty window positioned at the start. pub(super) fn new() -> Self { Self { window: 0, next: 0 } } /// Highest packet number authenticated. fn highest(&self) -> u64 { self.next - 1 } /// Record a newly authenticated packet number. /// /// Returns whether the packet might be a duplicate. pub(super) fn insert(&mut self, packet: u64) -> bool { if let Some(diff) = packet.checked_sub(self.next) { // Right of window self.window = (self.window << 1 | 1) .checked_shl(cmp::min(diff, u64::from(u32::max_value())) as u32) .unwrap_or(0); self.next = packet + 1; false } else if self.highest() - packet < WINDOW_SIZE { // Within window if let Some(bit) = (self.highest() - packet).checked_sub(1) { // < highest let mask = 1 << bit; let duplicate = self.window & mask != 0; self.window |= mask; duplicate } else { // == highest true } } else { // Left of window true } } } /// Indicates which data is available for sending #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub(super) struct SendableFrames { pub(super) acks: bool, pub(super) other: bool, } impl SendableFrames { /// Returns that no data is available for sending pub(super) fn empty() -> Self { Self { acks: false, other: false, } } /// Whether no data is sendable pub(super) fn is_empty(&self) -> bool { !self.acks && !self.other } } #[derive(Debug, Default)] pub(super) struct PendingAcks { permit_ack_only: bool, ranges: ArrayRangeSet, /// This value will be used for calculating ACK delay once it is implemented /// /// ACK delay will be the delay between when a packet arrived (`latest_incoming`) /// and between it will be allowed to be acknowledged (`can_send() == true`). latest_incoming: Option, ack_delay: Duration, } impl PendingAcks { /// Whether any ACK frames can be sent pub(super) fn can_send(&self) -> bool { self.permit_ack_only && !self.ranges.is_empty() } /// Returns the duration the acknowledgement of the latest incoming packet has been delayed pub(super) fn ack_delay(&self) -> Duration { self.ack_delay } /// Handle receipt of a new packet pub(super) fn packet_received(&mut self, ack_eliciting: bool) { self.permit_ack_only |= ack_eliciting; } /// Should be called whenever ACKs have been sent /// /// This will suppress sending further ACKs until additional ACK eliciting frames arrive pub(super) fn acks_sent(&mut self) { // If we sent any acks, don't immediately resend them. Setting this even if ack_only is // false needlessly prevents us from ACKing the next packet if it's ACK-only, but saves // the need for subtler logic to avoid double-transmitting acks all the time. // This reset needs to happen before we check whether more data // is available in this space - because otherwise it would return // `true` purely due to the ACKs self.permit_ack_only = false; } /// Insert one packet that needs to be acknowledged pub(super) fn insert_one(&mut self, packet: u64, now: Instant) { self.ranges.insert_one(packet); self.latest_incoming = Some(now); if self.ranges.len() > MAX_ACK_BLOCKS { self.ranges.pop_min(); } } /// Remove ACKs of packets numbered at or below `max` from the set of pending ACKs pub(super) fn subtract_below(&mut self, max: u64) { self.ranges.remove(0..(max + 1)); } /// Returns the set of currently pending ACK ranges pub(super) fn ranges(&self) -> &ArrayRangeSet { &self.ranges } } /// Ensures we can always fit all our ACKs in a single minimum-MTU packet with room to spare const MAX_ACK_BLOCKS: usize = 64; #[cfg(test)] mod test { use super::*; #[test] fn sanity() { let mut dedup = Dedup::new(); assert!(!dedup.insert(0)); assert_eq!(dedup.next, 1); assert_eq!(dedup.window, 0b1); assert!(dedup.insert(0)); assert_eq!(dedup.next, 1); assert_eq!(dedup.window, 0b1); assert!(!dedup.insert(1)); assert_eq!(dedup.next, 2); assert_eq!(dedup.window, 0b11); assert!(!dedup.insert(2)); assert_eq!(dedup.next, 3); assert_eq!(dedup.window, 0b111); assert!(!dedup.insert(4)); assert_eq!(dedup.next, 5); assert_eq!(dedup.window, 0b11110); assert!(!dedup.insert(7)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_0100); assert!(dedup.insert(4)); assert!(!dedup.insert(3)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1100); assert!(!dedup.insert(6)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1101); assert!(!dedup.insert(5)); assert_eq!(dedup.next, 8); assert_eq!(dedup.window, 0b1111_1111); } #[test] fn happypath() { let mut dedup = Dedup::new(); for i in 0..(2 * WINDOW_SIZE) { assert!(!dedup.insert(i)); for j in 0..=i { assert!(dedup.insert(j)); } } } #[test] fn jump() { let mut dedup = Dedup::new(); dedup.insert(2 * WINDOW_SIZE); assert!(dedup.insert(WINDOW_SIZE)); assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1); assert_eq!(dedup.window, 0); assert!(!dedup.insert(WINDOW_SIZE + 1)); assert_eq!(dedup.next, 2 * WINDOW_SIZE + 1); assert_eq!(dedup.window, 1 << (WINDOW_SIZE - 2)); } #[test] fn sent_packet_size() { // The tracking state of sent packets should be minimal, and not grow // over time. assert!(std::mem::size_of::() <= 128); } } quinn-proto-0.10.6/src/connection/stats.rs000064400000000000000000000136341046102023000166440ustar 00000000000000//! Connection statistics use crate::{frame::Frame, Dir}; use std::time::Duration; /// Statistics about UDP datagrams transmitted or received on a connection #[derive(Default, Debug, Copy, Clone)] #[non_exhaustive] pub struct UdpStats { /// The amount of UDP datagrams observed pub datagrams: u64, /// The total amount of bytes which have been transferred inside UDP datagrams pub bytes: u64, /// The amount of transmit calls which have been performed /// /// This can mismatch the amount of datagrams in case GSO is utilized for /// transmitting data. pub transmits: u64, } /// Number of frames transmitted of each frame type #[derive(Default, Copy, Clone)] #[non_exhaustive] #[allow(missing_docs)] pub struct FrameStats { pub acks: u64, pub crypto: u64, pub connection_close: u64, pub data_blocked: u64, pub datagram: u64, pub handshake_done: u8, pub max_data: u64, pub max_stream_data: u64, pub max_streams_bidi: u64, pub max_streams_uni: u64, pub new_connection_id: u64, pub new_token: u64, pub path_challenge: u64, pub path_response: u64, pub ping: u64, pub reset_stream: u64, pub retire_connection_id: u64, pub stream_data_blocked: u64, pub streams_blocked_bidi: u64, pub streams_blocked_uni: u64, pub stop_sending: u64, pub stream: u64, } impl FrameStats { pub(crate) fn record(&mut self, frame: &Frame) { match frame { Frame::Padding => {} Frame::Ping => self.ping += 1, Frame::Ack(_) => self.acks += 1, Frame::ResetStream(_) => self.reset_stream += 1, Frame::StopSending(_) => self.stop_sending += 1, Frame::Crypto(_) => self.crypto += 1, Frame::Datagram(_) => self.datagram += 1, Frame::NewToken { .. } => self.new_token += 1, Frame::MaxData(_) => self.max_data += 1, Frame::MaxStreamData { .. } => self.max_stream_data += 1, Frame::MaxStreams { dir, .. } => { if *dir == Dir::Bi { self.max_streams_bidi += 1; } else { self.max_streams_uni += 1; } } Frame::DataBlocked { .. } => self.data_blocked += 1, Frame::Stream(_) => self.stream += 1, Frame::StreamDataBlocked { .. } => self.stream_data_blocked += 1, Frame::StreamsBlocked { dir, .. } => { if *dir == Dir::Bi { self.streams_blocked_bidi += 1; } else { self.streams_blocked_uni += 1; } } Frame::NewConnectionId(_) => self.new_connection_id += 1, Frame::RetireConnectionId { .. } => self.retire_connection_id += 1, Frame::PathChallenge(_) => self.path_challenge += 1, Frame::PathResponse(_) => self.path_response += 1, Frame::Close(_) => self.connection_close += 1, Frame::HandshakeDone => self.handshake_done += 1, } } } impl std::fmt::Debug for FrameStats { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("FrameStats") .field("ACK", &self.acks) .field("CONNECTION_CLOSE", &self.connection_close) .field("CRYPTO", &self.crypto) .field("DATA_BLOCKED", &self.data_blocked) .field("DATAGRAM", &self.datagram) .field("HANDSHAKE_DONE", &self.handshake_done) .field("MAX_DATA", &self.max_data) .field("MAX_STREAM_DATA", &self.max_stream_data) .field("MAX_STREAMS_BIDI", &self.max_streams_bidi) .field("MAX_STREAMS_UNI", &self.max_streams_uni) .field("NEW_CONNECTION_ID", &self.new_connection_id) .field("NEW_TOKEN", &self.new_token) .field("PATH_CHALLENGE", &self.path_challenge) .field("PATH_RESPONSE", &self.path_response) .field("PING", &self.ping) .field("RESET_STREAM", &self.reset_stream) .field("RETIRE_CONNECTION_ID", &self.retire_connection_id) .field("STREAM_DATA_BLOCKED", &self.stream_data_blocked) .field("STREAMS_BLOCKED_BIDI", &self.streams_blocked_bidi) .field("STREAMS_BLOCKED_UNI", &self.streams_blocked_uni) .field("STOP_SENDING", &self.stop_sending) .field("STREAM", &self.stream) .finish() } } /// Statistics related to a transmission path #[derive(Debug, Default, Copy, Clone)] #[non_exhaustive] pub struct PathStats { /// Current best estimate of this connection's latency (round-trip-time) pub rtt: Duration, /// Current congestion window of the connection pub cwnd: u64, /// Congestion events on the connection pub congestion_events: u64, /// The amount of packets lost on this path pub lost_packets: u64, /// The amount of bytes lost on this path pub lost_bytes: u64, /// The amount of packets sent on this path pub sent_packets: u64, /// The amount of PLPMTUD probe packets sent on this path (also counted by `sent_packets`) pub sent_plpmtud_probes: u64, /// The amount of PLPMTUD probe packets lost on this path (ignored by `lost_packets` and /// `lost_bytes`) pub lost_plpmtud_probes: u64, /// The number of times a black hole was detected in the path pub black_holes_detected: u64, } /// Connection statistics #[derive(Debug, Default, Copy, Clone)] #[non_exhaustive] pub struct ConnectionStats { /// Statistics about UDP datagrams transmitted on a connection pub udp_tx: UdpStats, /// Statistics about UDP datagrams received on a connection pub udp_rx: UdpStats, /// Statistics about frames transmitted on a connection pub frame_tx: FrameStats, /// Statistics about frames received on a connection pub frame_rx: FrameStats, /// Statistics related to the current transmission path pub path: PathStats, } quinn-proto-0.10.6/src/connection/streams/mod.rs000064400000000000000000000361651046102023000177470ustar 00000000000000use std::{ cell::RefCell, collections::{hash_map, BinaryHeap, VecDeque}, }; use bytes::Bytes; use thiserror::Error; use tracing::trace; use self::state::get_or_insert_recv; use super::spaces::{Retransmits, ThinRetransmits}; use crate::{connection::streams::state::get_or_insert_send, frame, Dir, StreamId, VarInt}; mod recv; use recv::Recv; pub use recv::{Chunks, ReadError, ReadableError}; mod send; pub(crate) use send::{ByteSlice, BytesArray}; pub use send::{BytesSource, FinishError, WriteError, Written}; use send::{Send, SendState}; mod state; #[allow(unreachable_pub)] // fuzzing only pub use state::StreamsState; /// Access to streams pub struct Streams<'a> { pub(super) state: &'a mut StreamsState, pub(super) conn_state: &'a super::State, } impl<'a> Streams<'a> { #[cfg(fuzzing)] pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self { Self { state, conn_state } } /// Open a single stream if possible /// /// Returns `None` if the streams in the given direction are currently exhausted. pub fn open(&mut self, dir: Dir) -> Option { if self.conn_state.is_closed() { return None; } // TODO: Queue STREAM_ID_BLOCKED if this fails if self.state.next[dir as usize] >= self.state.max[dir as usize] { return None; } self.state.next[dir as usize] += 1; let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1); self.state.insert(false, id); self.state.send_streams += 1; Some(id) } /// Accept a remotely initiated stream of a certain directionality, if possible /// /// Returns `None` if there are no new incoming streams for this connection. /// Has no impact on the data flow-control or stream concurrency limits. pub fn accept(&mut self, dir: Dir) -> Option { if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] { return None; } let x = self.state.next_reported_remote[dir as usize]; self.state.next_reported_remote[dir as usize] = x + 1; if dir == Dir::Bi { self.state.send_streams += 1; } Some(StreamId::new(!self.state.side, dir, x)) } #[cfg(fuzzing)] pub fn state(&mut self) -> &mut StreamsState { self.state } /// The number of streams that may have unacknowledged data. pub fn send_streams(&self) -> usize { self.state.send_streams } /// The number of remotely initiated open streams of a certain directionality. /// /// Includes remotely initiated streams, which have not been accepted via [`accept`](Self::accept). /// These streams count against the respective concurrency limit reported by /// [`Connection::max_concurrent_streams`](super::Connection::max_concurrent_streams). pub fn remote_open_streams(&self, dir: Dir) -> u64 { // total opened - total closed = total opened - ( total permitted - total permitted unclosed ) self.state.next_remote[dir as usize] - (self.state.max_remote[dir as usize] - self.state.allocated_remote_count[dir as usize]) } } /// Access to streams pub struct RecvStream<'a> { pub(super) id: StreamId, pub(super) state: &'a mut StreamsState, pub(super) pending: &'a mut Retransmits, } impl<'a> RecvStream<'a> { /// Read from the given recv stream /// /// `max_length` limits the maximum size of the returned `Bytes` value; passing `usize::MAX` /// will yield the best performance. `ordered` will make sure the returned chunk's offset will /// have an offset exactly equal to the previously returned offset plus the previously returned /// bytes' length. /// /// Yields `Ok(None)` if the stream was finished. Otherwise, yields a segment of data and its /// offset in the stream. If `ordered` is `false`, segments may be received in any order, and /// the `Chunk`'s `offset` field can be used to determine ordering in the caller. /// /// While most applications will prefer to consume stream data in order, unordered reads can /// improve performance when packet loss occurs and data cannot be retransmitted before the flow /// control window is filled. On any given stream, you can switch from ordered to unordered /// reads, but ordered reads on streams that have seen previous unordered reads will return /// `ReadError::IllegalOrderedRead`. pub fn read(&mut self, ordered: bool) -> Result { Chunks::new(self.id, ordered, self.state, self.pending) } /// Stop accepting data on the given receive stream /// /// Discards unread data and notifies the peer to stop transmitting. Once stopped, further /// attempts to operate on a stream will yield `UnknownStream` errors. pub fn stop(&mut self, error_code: VarInt) -> Result<(), UnknownStream> { let mut entry = match self.state.recv.entry(self.id) { hash_map::Entry::Occupied(s) => s, hash_map::Entry::Vacant(_) => return Err(UnknownStream { _private: () }), }; let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut()); let (read_credits, stop_sending) = stream.stop()?; if stop_sending.should_transmit() { self.pending.stop_sending.push(frame::StopSending { id: self.id, error_code, }); } // We need to keep stopped streams around until they're finished or reset so we can update // connection-level flow control to account for discarded data. Otherwise, we can discard // state immediately. if !stream.receiving_unknown_size() { entry.remove(); self.state.stream_freed(self.id, StreamHalf::Recv); } if self.state.add_read_credits(read_credits).should_transmit() { self.pending.max_data = true; } Ok(()) } } /// Access to streams pub struct SendStream<'a> { pub(super) id: StreamId, pub(super) state: &'a mut StreamsState, pub(super) pending: &'a mut Retransmits, pub(super) conn_state: &'a super::State, } impl<'a> SendStream<'a> { #[cfg(fuzzing)] pub fn new( id: StreamId, state: &'a mut StreamsState, pending: &'a mut Retransmits, conn_state: &'a super::State, ) -> Self { Self { id, state, pending, conn_state, } } /// Send data on the given stream /// /// Returns the number of bytes successfully written. pub fn write(&mut self, data: &[u8]) -> Result { Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes) } /// Send data on the given stream /// /// Returns the number of bytes and chunks successfully written. /// Note that this method might also write a partial chunk. In this case /// [`Written::chunks`] will not count this chunk as fully written. However /// the chunk will be advanced and contain only non-written data after the call. pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result { self.write_source(&mut BytesArray::from_chunks(data)) } fn write_source(&mut self, source: &mut B) -> Result { if self.conn_state.is_closed() { trace!(%self.id, "write blocked; connection draining"); return Err(WriteError::Blocked); } let limit = self.state.write_limit(); let max_send_data = self.state.initial_max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(WriteError::UnknownStream)?; if limit == 0 { trace!( stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent, "write blocked by connection-level flow control or send window" ); if !stream.connection_blocked { stream.connection_blocked = true; self.state.connection_blocked.push(self.id); } return Err(WriteError::Blocked); } let was_pending = stream.is_pending(); let written = stream.write(source, limit)?; self.state.data_sent += written.bytes as u64; self.state.unacked_data += written.bytes as u64; trace!(stream = %self.id, "wrote {} bytes", written.bytes); if !was_pending { push_pending(&mut self.state.pending, self.id, stream.priority); } Ok(written) } /// Check if this stream was stopped, get the reason if it was pub fn stopped(&mut self) -> Result, UnknownStream> { match self.state.send.get(&self.id).as_ref() { Some(Some(s)) => Ok(s.stop_reason), Some(None) => Ok(None), None => Err(UnknownStream { _private: () }), } } /// Finish a send stream, signalling that no more data will be sent. /// /// If this fails, no [`StreamEvent::Finished`] will be generated. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished pub fn finish(&mut self) -> Result<(), FinishError> { let max_send_data = self.state.initial_max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(FinishError::UnknownStream)?; let was_pending = stream.is_pending(); stream.finish()?; if !was_pending { push_pending(&mut self.state.pending, self.id, stream.priority); } Ok(()) } /// Abandon transmitting data on a stream /// /// # Panics /// - when applied to a receive stream pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> { let max_send_data = self.state.initial_max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(UnknownStream { _private: () })?; if matches!(stream.state, SendState::ResetSent) { // Redundant reset call return Err(UnknownStream { _private: () }); } // Restore the portion of the send window consumed by the data that we aren't about to // send. We leave flow control alone because the peer's responsible for issuing additional // credit based on the final offset communicated in the RESET_STREAM frame we send. self.state.unacked_data -= stream.pending.unacked(); stream.reset(); self.pending.reset_stream.push((self.id, error_code)); // Don't reopen an already-closed stream we haven't forgotten yet Ok(()) } /// Set the priority of a stream /// /// # Panics /// - when applied to a receive stream pub fn set_priority(&mut self, priority: i32) -> Result<(), UnknownStream> { let max_send_data = self.state.initial_max_send_data(self.id); let stream = self .state .send .get_mut(&self.id) .map(get_or_insert_send(max_send_data)) .ok_or(UnknownStream { _private: () })?; stream.priority = priority; Ok(()) } /// Get the priority of a stream /// /// # Panics /// - when applied to a receive stream pub fn priority(&self) -> Result { let stream = self .state .send .get(&self.id) .ok_or(UnknownStream { _private: () })?; Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default()) } } fn push_pending(pending: &mut BinaryHeap, id: StreamId, priority: i32) { for level in pending.iter() { if priority == level.priority { level.queue.borrow_mut().push_back(id); return; } } // If there is only a single level and it's empty, repurpose it for the // required priority if pending.len() == 1 { if let Some(mut first) = pending.peek_mut() { let mut queue = first.queue.borrow_mut(); if queue.is_empty() { queue.push_back(id); drop(queue); first.priority = priority; return; } } } let mut queue = VecDeque::new(); queue.push_back(id); pending.push(PendingLevel { queue: RefCell::new(queue), priority, }); } struct PendingLevel { // RefCell is needed because BinaryHeap doesn't have an iter_mut() queue: RefCell>, priority: i32, } impl PartialEq for PendingLevel { fn eq(&self, other: &Self) -> bool { self.priority.eq(&other.priority) } } impl PartialOrd for PendingLevel { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Eq for PendingLevel {} impl Ord for PendingLevel { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.priority.cmp(&other.priority) } } /// Application events about streams #[derive(Debug, PartialEq, Eq)] pub enum StreamEvent { /// One or more new streams has been opened and might be readable Opened { /// Directionality for which streams have been opened dir: Dir, }, /// A currently open stream likely has data or errors waiting to be read Readable { /// Which stream is now readable id: StreamId, }, /// A formerly write-blocked stream might be ready for a write or have been stopped /// /// Only generated for streams that are currently open. Writable { /// Which stream is now writable id: StreamId, }, /// A finished stream has been fully acknowledged or stopped Finished { /// Which stream has been finished id: StreamId, }, /// The peer asked us to stop sending on an outgoing stream Stopped { /// Which stream has been stopped id: StreamId, /// Error code supplied by the peer error_code: VarInt, }, /// At least one new stream of a certain directionality may be opened Available { /// Directionality for which streams are newly available dir: Dir, }, } /// Indicates whether a frame needs to be transmitted /// /// This type wraps around bool and uses the `#[must_use]` attribute in order /// to prevent accidental loss of the frame transmission requirement. #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] #[must_use = "A frame might need to be enqueued"] pub struct ShouldTransmit(bool); impl ShouldTransmit { /// Returns whether a frame should be transmitted pub fn should_transmit(self) -> bool { self.0 } } /// Error indicating that a stream has not been opened or has already been finished or reset #[derive(Debug, Error, Clone, PartialEq, Eq)] #[error("unknown stream")] pub struct UnknownStream { _private: (), } #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum StreamHalf { Send, Recv, } quinn-proto-0.10.6/src/connection/streams/recv.rs000064400000000000000000000333551046102023000201250ustar 00000000000000use std::collections::hash_map::Entry; use std::mem; use thiserror::Error; use tracing::debug; use super::state::get_or_insert_recv; use super::{Retransmits, ShouldTransmit, StreamHalf, StreamId, StreamsState, UnknownStream}; use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead}; use crate::{frame, TransportError, VarInt}; #[derive(Debug, Default)] pub(super) struct Recv { state: RecvState, pub(super) assembler: Assembler, sent_max_stream_data: u64, pub(super) end: u64, pub(super) stopped: bool, } impl Recv { pub(super) fn new(initial_max_data: u64) -> Box { Box::new(Self { state: RecvState::default(), assembler: Assembler::new(), sent_max_stream_data: initial_max_data, end: 0, stopped: false, }) } /// Process a STREAM frame /// /// Return value is `(number_of_new_bytes_ingested, stream_is_closed)` pub(super) fn ingest( &mut self, frame: frame::Stream, payload_len: usize, received: u64, max_data: u64, ) -> Result<(u64, bool), TransportError> { let end = frame.offset + frame.data.len() as u64; if end >= 2u64.pow(62) { return Err(TransportError::FLOW_CONTROL_ERROR( "maximum stream offset too large", )); } if let Some(final_offset) = self.final_offset() { if end > final_offset || (frame.fin && end != final_offset) { debug!(end, final_offset, "final size error"); return Err(TransportError::FINAL_SIZE_ERROR("")); } } let new_bytes = self.credit_consumed_by(end, received, max_data)?; // Stopped streams don't need to wait for the actual data, they just need to know // how much there was. if frame.fin && !self.stopped { if let RecvState::Recv { ref mut size } = self.state { *size = Some(end); } } self.end = self.end.max(end); if !self.stopped { self.assembler.insert(frame.offset, frame.data, payload_len); } else { self.assembler.set_bytes_read(end); } Ok((new_bytes, frame.fin && self.stopped)) } pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), UnknownStream> { if self.stopped { return Err(UnknownStream { _private: () }); } self.stopped = true; self.assembler.clear(); // Issue flow control credit for unread data let read_credits = self.end - self.assembler.bytes_read(); // This may send a spurious STOP_SENDING if we've already received all data, but it's a bit // fiddly to distinguish that from the case where we've received a FIN but are missing some // data that the peer might still be trying to retransmit, in which case a STOP_SENDING is // still useful. Ok((read_credits, ShouldTransmit(self.is_receiving()))) } /// Returns the window that should be advertised in a `MAX_STREAM_DATA` frame /// /// The method returns a tuple which consists of the window that should be /// announced, as well as a boolean parameter which indicates if a new /// transmission of the value is recommended. If the boolean value is /// `false` the new window should only be transmitted if a previous transmission /// had failed. pub(super) fn max_stream_data(&mut self, stream_receive_window: u64) -> (u64, ShouldTransmit) { let max_stream_data = self.assembler.bytes_read() + stream_receive_window; // Only announce a window update if it's significant enough // to make it worthwhile sending a MAX_STREAM_DATA frame. // We use here a fraction of the configured stream receive window to make // the decision, and accommodate for streams using bigger windows requiring // less updates. A fixed size would also work - but it would need to be // smaller than `stream_receive_window` in order to make sure the stream // does not get stuck. let diff = max_stream_data - self.sent_max_stream_data; let transmit = self.receiving_unknown_size() && diff >= (stream_receive_window / 8); (max_stream_data, ShouldTransmit(transmit)) } /// Records that a `MAX_STREAM_DATA` announcing a certain window was sent /// /// This will suppress enqueuing further `MAX_STREAM_DATA` frames unless /// either the previous transmission was not acknowledged or the window /// further increased. pub(super) fn record_sent_max_stream_data(&mut self, sent_value: u64) { if sent_value > self.sent_max_stream_data { self.sent_max_stream_data = sent_value; } } pub(super) fn receiving_unknown_size(&self) -> bool { matches!(self.state, RecvState::Recv { size: None }) } /// Whether data is still being accepted from the peer pub(super) fn is_receiving(&self) -> bool { matches!(self.state, RecvState::Recv { .. }) } fn final_offset(&self) -> Option { match self.state { RecvState::Recv { size } => size, RecvState::ResetRecvd { size, .. } => Some(size), } } /// Returns `false` iff the reset was redundant pub(super) fn reset( &mut self, error_code: VarInt, final_offset: VarInt, received: u64, max_data: u64, ) -> Result { // Validate final_offset if let Some(offset) = self.final_offset() { if offset != final_offset.into_inner() { return Err(TransportError::FINAL_SIZE_ERROR("inconsistent value")); } } else if self.end > final_offset.into() { return Err(TransportError::FINAL_SIZE_ERROR( "lower than high water mark", )); } self.credit_consumed_by(final_offset.into(), received, max_data)?; if matches!(self.state, RecvState::ResetRecvd { .. }) { return Ok(false); } self.state = RecvState::ResetRecvd { size: final_offset.into(), error_code, }; // Nuke buffers so that future reads fail immediately, which ensures future reads don't // issue flow control credit redundant to that already issued. We could instead special-case // reset streams during read, but it's unclear if there's any benefit to retaining data for // reset streams. self.assembler.clear(); Ok(true) } /// Compute the amount of flow control credit consumed, or return an error if more was consumed /// than issued fn credit_consumed_by( &self, offset: u64, received: u64, max_data: u64, ) -> Result { let prev_end = self.end; let new_bytes = offset.saturating_sub(prev_end); if offset > self.sent_max_stream_data || received + new_bytes > max_data { debug!( received, new_bytes, max_data, offset, stream_max_data = self.sent_max_stream_data, "flow control error" ); return Err(TransportError::FLOW_CONTROL_ERROR("")); } Ok(new_bytes) } } /// Chunks pub struct Chunks<'a> { id: StreamId, ordered: bool, streams: &'a mut StreamsState, pending: &'a mut Retransmits, state: ChunksState, read: u64, } impl<'a> Chunks<'a> { pub(super) fn new( id: StreamId, ordered: bool, streams: &'a mut StreamsState, pending: &'a mut Retransmits, ) -> Result { let mut entry = match streams.recv.entry(id) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => return Err(ReadableError::UnknownStream), }; let mut recv = match get_or_insert_recv(streams.stream_receive_window)(entry.get_mut()).stopped { true => return Err(ReadableError::UnknownStream), false => entry.remove().unwrap(), // this can't fail due to the previous get_or_insert_with }; recv.assembler.ensure_ordering(ordered)?; Ok(Self { id, ordered, streams, pending, state: ChunksState::Readable(recv), read: 0, }) } /// Next /// /// Should call finalize() when done calling this. pub fn next(&mut self, max_length: usize) -> Result, ReadError> { let rs = match self.state { ChunksState::Readable(ref mut rs) => rs, ChunksState::Reset(error_code) => { return Err(ReadError::Reset(error_code)); } ChunksState::Finished => { return Ok(None); } ChunksState::Finalized => panic!("must not call next() after finalize()"), }; if let Some(chunk) = rs.assembler.read(max_length, self.ordered) { self.read += chunk.bytes.len() as u64; return Ok(Some(chunk)); } match rs.state { RecvState::ResetRecvd { error_code, .. } => { debug_assert_eq!(self.read, 0, "reset streams have empty buffers"); self.streams.stream_freed(self.id, StreamHalf::Recv); self.state = ChunksState::Reset(error_code); Err(ReadError::Reset(error_code)) } RecvState::Recv { size } => { if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end { self.streams.stream_freed(self.id, StreamHalf::Recv); self.state = ChunksState::Finished; Ok(None) } else { // We don't need a distinct `ChunksState` variant for a blocked stream because // retrying a read harmlessly re-traces our steps back to returning // `Err(Blocked)` again. The buffers can't refill and the stream's own state // can't change so long as this `Chunks` exists. Err(ReadError::Blocked) } } } } /// Finalize pub fn finalize(mut self) -> ShouldTransmit { self.finalize_inner(false) } fn finalize_inner(&mut self, drop: bool) -> ShouldTransmit { let state = mem::replace(&mut self.state, ChunksState::Finalized); debug_assert!( !drop || matches!(state, ChunksState::Finalized), "finalize must be called before drop" ); if let ChunksState::Finalized = state { // Noop on repeated calls return ShouldTransmit(false); } let mut should_transmit = false; // We issue additional stream ID credit after the application is notified that a previously // open stream has finished or been reset and we've therefore disposed of its state. if matches!(state, ChunksState::Finished | ChunksState::Reset(_)) && self.streams.side != self.id.initiator() { self.pending.max_stream_id[self.id.dir() as usize] = true; should_transmit = true; } // If the stream hasn't finished, we may need to issue stream-level flow control credit if let ChunksState::Readable(mut rs) = state { let (_, max_stream_data) = rs.max_stream_data(self.streams.stream_receive_window); should_transmit |= max_stream_data.0; if max_stream_data.0 { self.pending.max_stream_data.insert(self.id); } // Return the stream to storage for future use self.streams.recv.insert(self.id, Some(rs)); } // Issue connection-level flow control credit for any data we read regardless of state let max_data = self.streams.add_read_credits(self.read); self.pending.max_data |= max_data.0; should_transmit |= max_data.0; ShouldTransmit(should_transmit) } } impl<'a> Drop for Chunks<'a> { fn drop(&mut self) { let _ = self.finalize_inner(true); } } enum ChunksState { Readable(Box), Reset(VarInt), Finished, Finalized, } /// Errors triggered when reading from a recv stream #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ReadError { /// No more data is currently available on this stream. /// /// If more data on this stream is received from the peer, an `Event::StreamReadable` will be /// generated for this stream, indicating that retrying the read might succeed. #[error("blocked")] Blocked, /// The peer abandoned transmitting data on this stream. /// /// Carries an application-defined error code. #[error("reset by peer: code {0}")] Reset(VarInt), } /// Errors triggered when opening a recv stream for reading #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ReadableError { /// The stream has not been opened or was already stopped, finished, or reset #[error("unknown stream")] UnknownStream, /// Attempted an ordered read following an unordered read /// /// Performing an unordered read allows discontinuities to arise in the receive buffer of a /// stream which cannot be recovered, making further ordered reads impossible. #[error("ordered read after unordered read")] IllegalOrderedRead, } impl From for ReadableError { fn from(_: IllegalOrderedRead) -> Self { Self::IllegalOrderedRead } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum RecvState { Recv { size: Option }, ResetRecvd { size: u64, error_code: VarInt }, } impl Default for RecvState { fn default() -> Self { Self::Recv { size: None } } } quinn-proto-0.10.6/src/connection/streams/send.rs000064400000000000000000000316441046102023000201160ustar 00000000000000use bytes::Bytes; use thiserror::Error; use crate::{connection::send_buffer::SendBuffer, frame, VarInt}; #[derive(Debug)] pub(super) struct Send { pub(super) max_data: u64, pub(super) state: SendState, pub(super) pending: SendBuffer, pub(super) priority: i32, /// Whether a frame containing a FIN bit must be transmitted, even if we don't have any new data pub(super) fin_pending: bool, /// Whether this stream is in the `connection_blocked` list of `Streams` pub(super) connection_blocked: bool, /// The reason the peer wants us to stop, if `STOP_SENDING` was received pub(super) stop_reason: Option, } impl Send { pub(super) fn new(max_data: VarInt) -> Box { Box::new(Self { max_data: max_data.into(), state: SendState::Ready, pending: SendBuffer::new(), priority: 0, fin_pending: false, connection_blocked: false, stop_reason: None, }) } /// Whether the stream has been reset pub(super) fn is_reset(&self) -> bool { matches!(self.state, SendState::ResetSent { .. }) } pub(super) fn finish(&mut self) -> Result<(), FinishError> { if let Some(error_code) = self.stop_reason { Err(FinishError::Stopped(error_code)) } else if self.state == SendState::Ready { self.state = SendState::DataSent { finish_acked: false, }; self.fin_pending = true; Ok(()) } else { Err(FinishError::UnknownStream) } } pub(super) fn write( &mut self, source: &mut S, limit: u64, ) -> Result { if !self.is_writable() { return Err(WriteError::UnknownStream); } if let Some(error_code) = self.stop_reason { return Err(WriteError::Stopped(error_code)); } let budget = self.max_data - self.pending.offset(); if budget == 0 { return Err(WriteError::Blocked); } let mut limit = limit.min(budget) as usize; let mut result = Written::default(); loop { let (chunk, chunks_consumed) = source.pop_chunk(limit); result.chunks += chunks_consumed; result.bytes += chunk.len(); if chunk.is_empty() { break; } limit -= chunk.len(); self.pending.write(chunk); } Ok(result) } /// Update stream state due to a reset sent by the local application pub(super) fn reset(&mut self) { use SendState::*; if let DataSent { .. } | Ready = self.state { self.state = ResetSent; } } /// Handle STOP_SENDING /// /// Returns true if the stream was stopped due to this frame, and false /// if it had been stopped before pub(super) fn try_stop(&mut self, error_code: VarInt) -> bool { if self.stop_reason.is_none() { self.stop_reason = Some(error_code); true } else { false } } /// Returns whether the stream has been finished and all data has been acknowledged by the peer pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool { self.pending.ack(frame.offsets); match self.state { SendState::DataSent { ref mut finish_acked, } => { *finish_acked |= frame.fin; *finish_acked && self.pending.is_fully_acked() } _ => false, } } /// Handle increase to stream-level flow control limit /// /// Returns whether the stream was unblocked pub(super) fn increase_max_data(&mut self, offset: u64) -> bool { if offset <= self.max_data || self.state != SendState::Ready { return false; } let was_blocked = self.pending.offset() == self.max_data; self.max_data = offset; was_blocked } pub(super) fn offset(&self) -> u64 { self.pending.offset() } pub(super) fn is_pending(&self) -> bool { self.pending.has_unsent_data() || self.fin_pending } pub(super) fn is_writable(&self) -> bool { matches!(self.state, SendState::Ready) } } /// A [`BytesSource`] implementation for `&'a mut [Bytes]` /// /// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to /// a configured limit. pub(crate) struct BytesArray<'a> { /// The wrapped slice of `Bytes` chunks: &'a mut [Bytes], /// The amount of chunks consumed from this source consumed: usize, } impl<'a> BytesArray<'a> { pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self { Self { chunks, consumed: 0, } } } impl<'a> BytesSource for BytesArray<'a> { fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { // The loop exists to skip empty chunks while still marking them as // consumed let mut chunks_consumed = 0; while self.consumed < self.chunks.len() { let chunk = &mut self.chunks[self.consumed]; if chunk.len() <= limit { let chunk = std::mem::take(chunk); self.consumed += 1; chunks_consumed += 1; if chunk.is_empty() { continue; } return (chunk, chunks_consumed); } else if limit > 0 { let chunk = chunk.split_to(limit); return (chunk, chunks_consumed); } else { break; } } (Bytes::new(), chunks_consumed) } } /// A [`BytesSource`] implementation for `&[u8]` /// /// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily /// created from a reference. This allows to defer the allocation until it is /// known how much data needs to be copied. pub(crate) struct ByteSlice<'a> { /// The wrapped byte slice data: &'a [u8], } impl<'a> ByteSlice<'a> { pub(crate) fn from_slice(data: &'a [u8]) -> Self { Self { data } } } impl<'a> BytesSource for ByteSlice<'a> { fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) { let limit = limit.min(self.data.len()); if limit == 0 { return (Bytes::new(), 0); } let chunk = Bytes::from(self.data[..limit].to_owned()); self.data = &self.data[chunk.len()..]; let chunks_consumed = usize::from(self.data.is_empty()); (chunk, chunks_consumed) } } /// A source of one or more buffers which can be converted into `Bytes` buffers on demand /// /// The purpose of this data type is to defer conversion as long as possible, /// so that no heap allocation is required in case no data is writable. pub trait BytesSource { /// Returns the next chunk from the source of owned chunks. /// /// This method will consume parts of the source. /// Calling it will yield `Bytes` elements up to the configured `limit`. /// /// The method returns a tuple: /// - The first item is the yielded `Bytes` element. The element will be /// empty if the limit is zero or no more data is available. /// - The second item returns how many complete chunks inside the source had /// had been consumed. This can be less than 1, if a chunk inside the /// source had been truncated in order to adhere to the limit. It can also /// be more than 1, if zero-length chunks had been skipped. fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize); } /// Indicates how many bytes and chunks had been transferred in a write operation #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Written { /// The amount of bytes which had been written pub bytes: usize, /// The amount of full chunks which had been written /// /// If a chunk was only partially written, it will not be counted by this field. pub chunks: usize, } /// Errors triggered while writing to a send stream #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum WriteError { /// The peer is not able to accept additional data, or the connection is congested. /// /// If the peer issues additional flow control credit, a [`StreamEvent::Writable`] event will /// be generated, indicating that retrying the write might succeed. /// /// [`StreamEvent::Writable`]: crate::StreamEvent::Writable #[error("unable to accept further writes")] Blocked, /// The peer is no longer accepting data on this stream, and it has been implicitly reset. The /// stream cannot be finished or further written to. /// /// Carries an application-defined error code. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished #[error("stopped by peer: code {0}")] Stopped(VarInt), /// The stream has not been opened or has already been finished or reset #[error("unknown stream")] UnknownStream, } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(super) enum SendState { /// Sending new data Ready, /// Stream was finished; now sending retransmits only DataSent { finish_acked: bool }, /// Sent RESET ResetSent, } /// Reasons why attempting to finish a stream might fail #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum FinishError { /// The peer is no longer accepting data on this stream. No /// [`StreamEvent::Finished`] event will be emitted for this stream. /// /// Carries an application-defined error code. /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished #[error("stopped by peer: code {0}")] Stopped(VarInt), /// The stream has not been opened or was already finished or reset #[error("unknown stream")] UnknownStream, } #[cfg(test)] mod tests { use super::*; #[test] fn bytes_array() { let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); for limit in 0..full.len() { let mut chunks = [ Bytes::from_static(b""), Bytes::from_static(b"Hello "), Bytes::from_static(b"Wo"), Bytes::from_static(b""), Bytes::from_static(b"r"), Bytes::from_static(b"ld"), Bytes::from_static(b""), Bytes::from_static(b" 12345678"), Bytes::from_static(b"9 ABCDE"), Bytes::from_static(b"F"), Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"), ]; let num_chunks = chunks.len(); let last_chunk_len = chunks[chunks.len() - 1].len(); let mut array = BytesArray::from_chunks(&mut chunks); let mut buf = Vec::new(); let mut chunks_popped = 0; let mut chunks_consumed = 0; let mut remaining = limit; loop { let (chunk, consumed) = array.pop_chunk(remaining); chunks_consumed += consumed; if !chunk.is_empty() { buf.extend_from_slice(&chunk); remaining -= chunk.len(); chunks_popped += 1; } else { break; } } assert_eq!(&buf[..], &full[..limit]); if limit == full.len() { // Full consumption of the last chunk assert_eq!(chunks_consumed, num_chunks); // Since there are empty chunks, we consume more than there are popped assert_eq!(chunks_consumed, chunks_popped + 3); } else if limit > full.len() - last_chunk_len { // Partial consumption of the last chunk assert_eq!(chunks_consumed, num_chunks - 1); assert_eq!(chunks_consumed, chunks_popped + 2); } } } #[test] fn byte_slice() { let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned(); for limit in 0..full.len() { let mut array = ByteSlice::from_slice(&full[..]); let mut buf = Vec::new(); let mut chunks_popped = 0; let mut chunks_consumed = 0; let mut remaining = limit; loop { let (chunk, consumed) = array.pop_chunk(remaining); chunks_consumed += consumed; if !chunk.is_empty() { buf.extend_from_slice(&chunk); remaining -= chunk.len(); chunks_popped += 1; } else { break; } } assert_eq!(&buf[..], &full[..limit]); if limit != 0 { assert_eq!(chunks_popped, 1); } else { assert_eq!(chunks_popped, 0); } if limit == full.len() { assert_eq!(chunks_consumed, 1); } else { assert_eq!(chunks_consumed, 0); } } } } quinn-proto-0.10.6/src/connection/streams/state.rs000064400000000000000000001711211046102023000203000ustar 00000000000000use std::{ collections::{binary_heap::PeekMut, hash_map, BinaryHeap, VecDeque}, convert::TryFrom, mem, }; use bytes::{BufMut, BytesMut}; use rustc_hash::FxHashMap; use tracing::{debug, trace}; use super::{ push_pending, PendingLevel, Recv, Retransmits, Send, SendState, ShouldTransmit, StreamEvent, StreamHalf, ThinRetransmits, }; use crate::{ coding::BufMutExt, connection::stats::FrameStats, frame::{self, FrameStruct, StreamMetaVec}, transport_parameters::TransportParameters, Dir, Side, StreamId, TransportError, VarInt, MAX_STREAM_COUNT, }; #[allow(unreachable_pub)] // fuzzing only pub struct StreamsState { pub(super) side: Side, // Set of streams that are currently open, or could be immediately opened by the peer pub(super) send: FxHashMap>>, pub(super) recv: FxHashMap>>, pub(super) next: [u64; 2], /// Maximum number of locally-initiated streams that may be opened over the lifetime of the /// connection so far, per direction pub(super) max: [u64; 2], /// Maximum number of remotely-initiated streams that may be opened over the lifetime of the /// connection so far, per direction pub(super) max_remote: [u64; 2], /// Number of streams that we've given the peer permission to open and which aren't fully closed pub(super) allocated_remote_count: [u64; 2], /// Size of the desired stream flow control window. May be smaller than `allocated_remote_count` /// due to `set_max_concurrent` calls. max_concurrent_remote_count: [u64; 2], /// Whether `max_concurrent_remote_count` has ever changed flow_control_adjusted: bool, /// Lowest remotely-initiated stream index that haven't actually been opened by the peer pub(super) next_remote: [u64; 2], /// Whether the remote endpoint has opened any streams the application doesn't know about yet, /// per directionality opened: [bool; 2], // Next to report to the application, once opened pub(super) next_reported_remote: [u64; 2], /// Number of outbound streams /// /// This differs from `self.send.len()` in that it does not include streams that the peer is /// permitted to open but which have not yet been opened. pub(super) send_streams: usize, /// Streams with outgoing data queued pub(super) pending: BinaryHeap, events: VecDeque, /// Streams blocked on connection-level flow control or stream window space /// /// Streams are only added to this list when a write fails. pub(super) connection_blocked: Vec, /// Connection-level flow control budget dictated by the peer pub(super) max_data: u64, /// The initial receive window receive_window: u64, /// Limit on incoming data, which is transmitted through `MAX_DATA` frames local_max_data: u64, /// The last value of `MAX_DATA` which had been queued for transmission in /// an outgoing `MAX_DATA` frame sent_max_data: VarInt, /// Sum of current offsets of all send streams. pub(super) data_sent: u64, /// Sum of end offsets of all receive streams. Includes gaps, so it's an upper bound. data_recvd: u64, /// Total quantity of unacknowledged outgoing data pub(super) unacked_data: u64, /// Configured upper bound for `unacked_data` pub(super) send_window: u64, /// Configured upper bound for how much unacked data the peer can send us per stream pub(super) stream_receive_window: u64, /// Whether the corresponding `max_remote` has increased max_streams_dirty: [bool; 2], // Pertinent state from the TransportParameters supplied by the peer initial_max_stream_data_uni: VarInt, initial_max_stream_data_bidi_local: VarInt, initial_max_stream_data_bidi_remote: VarInt, /// The shrink to be applied to local_max_data when receive_window is shrunk receive_window_shrink_debt: u64, } impl StreamsState { #[allow(unreachable_pub)] // fuzzing only pub fn new( side: Side, max_remote_uni: VarInt, max_remote_bi: VarInt, send_window: u64, receive_window: VarInt, stream_receive_window: VarInt, ) -> Self { let mut this = Self { side, send: FxHashMap::default(), recv: FxHashMap::default(), next: [0, 0], max: [0, 0], max_remote: [max_remote_bi.into(), max_remote_uni.into()], allocated_remote_count: [max_remote_bi.into(), max_remote_uni.into()], max_concurrent_remote_count: [max_remote_bi.into(), max_remote_uni.into()], flow_control_adjusted: false, next_remote: [0, 0], opened: [false, false], next_reported_remote: [0, 0], send_streams: 0, pending: BinaryHeap::new(), events: VecDeque::new(), connection_blocked: Vec::new(), max_data: 0, receive_window: receive_window.into(), local_max_data: receive_window.into(), sent_max_data: receive_window, data_sent: 0, data_recvd: 0, unacked_data: 0, send_window, stream_receive_window: stream_receive_window.into(), max_streams_dirty: [false, false], initial_max_stream_data_uni: 0u32.into(), initial_max_stream_data_bidi_local: 0u32.into(), initial_max_stream_data_bidi_remote: 0u32.into(), receive_window_shrink_debt: 0, }; for dir in Dir::iter() { for i in 0..this.max_remote[dir as usize] { this.insert(true, StreamId::new(!side, dir, i)); } } this } pub(crate) fn set_params(&mut self, params: &TransportParameters) { self.initial_max_stream_data_uni = params.initial_max_stream_data_uni; self.initial_max_stream_data_bidi_local = params.initial_max_stream_data_bidi_local; self.initial_max_stream_data_bidi_remote = params.initial_max_stream_data_bidi_remote; self.max[Dir::Bi as usize] = params.initial_max_streams_bidi.into(); self.max[Dir::Uni as usize] = params.initial_max_streams_uni.into(); self.received_max_data(params.initial_max_data); for i in 0..self.max_remote[Dir::Bi as usize] { let id = StreamId::new(!self.side, Dir::Bi, i); if let Some(s) = self.send.get_mut(&id).and_then(|s| s.as_mut()) { s.max_data = params.initial_max_stream_data_bidi_local.into(); } } } /// Ensure we have space for at least a full flow control window of remotely-initiated streams /// to be open, and notify the peer if the window has moved fn ensure_remote_streams(&mut self, dir: Dir) { let new_count = self.max_concurrent_remote_count[dir as usize] .saturating_sub(self.allocated_remote_count[dir as usize]); for i in 0..new_count { let id = StreamId::new(!self.side, dir, self.max_remote[dir as usize] + i); self.insert(true, id); } self.allocated_remote_count[dir as usize] += new_count; self.max_remote[dir as usize] += new_count; self.max_streams_dirty[dir as usize] = new_count != 0; } pub(crate) fn zero_rtt_rejected(&mut self) { // Revert to initial state for outgoing streams for dir in Dir::iter() { for i in 0..self.next[dir as usize] { // We don't bother calling `stream_freed` here because we explicitly reset affected // counters below. let id = StreamId::new(self.side, dir, i); self.send.remove(&id).unwrap(); if let Dir::Bi = dir { self.recv.remove(&id).unwrap(); } } self.next[dir as usize] = 0; // If 0-RTT was rejected, any flow control frames we sent were lost. if self.flow_control_adjusted { self.max_streams_dirty[dir as usize] = true; } } self.pending.clear(); self.send_streams = 0; self.data_sent = 0; self.connection_blocked.clear(); } /// Process incoming stream frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted pub(crate) fn received( &mut self, frame: frame::Stream, payload_len: usize, ) -> Result { let id = frame.id; self.validate_receive_id(id).map_err(|e| { debug!("received illegal STREAM frame"); e })?; let rs = match self .recv .get_mut(&id) .map(get_or_insert_recv(self.stream_receive_window)) { Some(rs) => rs, None => { trace!("dropping frame for closed stream"); return Ok(ShouldTransmit(false)); } }; if !rs.is_receiving() { trace!("dropping frame for finished stream"); return Ok(ShouldTransmit(false)); } let (new_bytes, closed) = rs.ingest(frame, payload_len, self.data_recvd, self.local_max_data)?; self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { self.on_stream_frame(true, id); return Ok(ShouldTransmit(false)); } // Stopped streams become closed instantly on FIN, so check whether we need to clean up if closed { self.recv.remove(&id); self.stream_freed(id, StreamHalf::Recv); } // We don't buffer data on stopped streams, so issue flow control credit immediately Ok(self.add_read_credits(new_bytes)) } /// Process incoming RESET_STREAM frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted #[allow(unreachable_pub)] // fuzzing only pub fn received_reset( &mut self, frame: frame::ResetStream, ) -> Result { let frame::ResetStream { id, error_code, final_offset, } = frame; self.validate_receive_id(id).map_err(|e| { debug!("received illegal RESET_STREAM frame"); e })?; let rs = match self .recv .get_mut(&id) .map(get_or_insert_recv(self.stream_receive_window)) { Some(stream) => stream, None => { trace!("received RESET_STREAM on closed stream"); return Ok(ShouldTransmit(false)); } }; // State transition if !rs.reset( error_code, final_offset, self.data_recvd, self.local_max_data, )? { // Redundant reset return Ok(ShouldTransmit(false)); } let bytes_read = rs.assembler.bytes_read(); let stopped = rs.stopped; let end = rs.end; if stopped { // Stopped streams should be disposed immediately on reset self.recv.remove(&id); self.stream_freed(id, StreamHalf::Recv); } self.on_stream_frame(!stopped, id); // Update flow control Ok(if bytes_read != final_offset.into_inner() { // bytes_read is always <= end, so this won't underflow. self.data_recvd = self .data_recvd .saturating_add(u64::from(final_offset) - end); self.add_read_credits(u64::from(final_offset) - bytes_read) } else { ShouldTransmit(false) }) } /// Process incoming `STOP_SENDING` frame #[allow(unreachable_pub)] // fuzzing only pub fn received_stop_sending(&mut self, id: StreamId, error_code: VarInt) { let max_send_data = self.initial_max_send_data(id); let stream = match self .send .get_mut(&id) .map(get_or_insert_send(max_send_data)) { Some(ss) => ss, None => return, }; if stream.try_stop(error_code) { self.events .push_back(StreamEvent::Stopped { id, error_code }); self.on_stream_frame(false, id); } } pub(crate) fn reset_acked(&mut self, id: StreamId) { match self.send.entry(id) { hash_map::Entry::Vacant(_) => {} hash_map::Entry::Occupied(e) => { if let Some(SendState::ResetSent) = e.get().as_ref().map(|s| s.state) { e.remove_entry(); self.stream_freed(id, StreamHalf::Send); } } } } /// Whether any stream data is queued, regardless of control frames pub(crate) fn can_send_stream_data(&self) -> bool { // Reset streams may linger in the pending stream list, but will never produce stream frames self.pending.iter().any(|level| { level.queue.borrow().iter().any(|id| { self.send .get(id) .and_then(|s| s.as_ref()) .map_or(false, |s| !s.is_reset()) }) }) } /// Whether MAX_STREAM_DATA frames could be sent for stream `id` pub(crate) fn can_send_flow_control(&self, id: StreamId) -> bool { self.recv .get(&id) .and_then(|s| s.as_ref()) .map_or(false, |s| s.receiving_unknown_size()) } pub(in crate::connection) fn write_control_frames( &mut self, buf: &mut BytesMut, pending: &mut Retransmits, retransmits: &mut ThinRetransmits, stats: &mut FrameStats, max_size: usize, ) { // RESET_STREAM while buf.len() + frame::ResetStream::SIZE_BOUND < max_size { let (id, error_code) = match pending.reset_stream.pop() { Some(x) => x, None => break, }; let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; trace!(stream = %id, "RESET_STREAM"); retransmits .get_or_create() .reset_stream .push((id, error_code)); frame::ResetStream { id, error_code, final_offset: VarInt::try_from(stream.offset()).expect("impossibly large offset"), } .encode(buf); stats.reset_stream += 1; } // STOP_SENDING while buf.len() + frame::StopSending::SIZE_BOUND < max_size { let frame = match pending.stop_sending.pop() { Some(x) => x, None => break, }; // We may need to transmit STOP_SENDING even for streams whose state we have discarded, // because we are able to discard local state for stopped streams immediately upon // receiving FIN, even if the peer still has arbitrarily large amounts of data to // (re)transmit due to loss or unconventional sending strategy. We could fine-tune this // a little by dropping the frame if we specifically know the stream's been reset by the // peer, but we discard that information as soon as the application consumes it, so it // can't be relied upon regardless. trace!(stream = %frame.id, "STOP_SENDING"); frame.encode(buf); retransmits.get_or_create().stop_sending.push(frame); stats.stop_sending += 1; } // MAX_DATA if pending.max_data && buf.len() + 9 < max_size { pending.max_data = false; // `local_max_data` can grow bigger than `VarInt`. // For transmission inside QUIC frames we need to clamp it to the // maximum allowed `VarInt` size. let max = VarInt::try_from(self.local_max_data).unwrap_or(VarInt::MAX); trace!(value = max.into_inner(), "MAX_DATA"); if max > self.sent_max_data { // Record that a `MAX_DATA` announcing a certain window was sent. This will // suppress enqueuing further `MAX_DATA` frames unless either the previous // transmission was not acknowledged or the window further increased. self.sent_max_data = max; } retransmits.get_or_create().max_data = true; buf.write(frame::Type::MAX_DATA); buf.write(max); stats.max_data += 1; } // MAX_STREAM_DATA while buf.len() + 17 < max_size { let id = match pending.max_stream_data.iter().next() { Some(x) => *x, None => break, }; pending.max_stream_data.remove(&id); let rs = match self.recv.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; if !rs.receiving_unknown_size() { continue; } retransmits.get_or_create().max_stream_data.insert(id); let (max, _) = rs.max_stream_data(self.stream_receive_window); rs.record_sent_max_stream_data(max); trace!(stream = %id, max = max, "MAX_STREAM_DATA"); buf.write(frame::Type::MAX_STREAM_DATA); buf.write(id); buf.write_var(max); stats.max_stream_data += 1; } // MAX_STREAMS for dir in Dir::iter() { if !pending.max_stream_id[dir as usize] || buf.len() + 9 >= max_size { continue; } pending.max_stream_id[dir as usize] = false; retransmits.get_or_create().max_stream_id[dir as usize] = true; self.max_streams_dirty[dir as usize] = false; trace!( value = self.max_remote[dir as usize], "MAX_STREAMS ({:?})", dir ); buf.write(match dir { Dir::Uni => frame::Type::MAX_STREAMS_UNI, Dir::Bi => frame::Type::MAX_STREAMS_BIDI, }); buf.write_var(self.max_remote[dir as usize]); match dir { Dir::Uni => stats.max_streams_uni += 1, Dir::Bi => stats.max_streams_bidi += 1, } } } pub(crate) fn write_stream_frames( &mut self, buf: &mut BytesMut, max_buf_size: usize, ) -> StreamMetaVec { let mut stream_frames = StreamMetaVec::new(); while buf.len() + frame::Stream::SIZE_BOUND < max_buf_size { if max_buf_size .checked_sub(buf.len() + frame::Stream::SIZE_BOUND) .is_none() { break; } let num_levels = self.pending.len(); let mut level = match self.pending.peek_mut() { Some(x) => x, None => break, }; // Poppping data from the front of the queue, storing as much data // as possible in a single frame, and enqueing sending further // remaining data at the end of the queue helps with fairness. // Other streams will have a chance to write data before we touch // this stream again. let id = match level.queue.get_mut().pop_front() { Some(x) => x, None => { debug_assert!( num_levels == 1, "An empty queue is only allowed for a single level" ); break; } }; let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(s) => s, // Stream was reset with pending data and the reset was acknowledged None => continue, }; // Reset streams aren't removed from the pending list and still exist while the peer // hasn't acknowledged the reset, but should not generate STREAM frames, so we need to // check for them explicitly. if stream.is_reset() { continue; } // Now that we know the `StreamId`, we can better account for how many bytes // are required to encode it. let max_buf_size = max_buf_size - buf.len() - 1 - VarInt::size(id.into()); let (offsets, encode_length) = stream.pending.poll_transmit(max_buf_size); let fin = offsets.end == stream.pending.offset() && matches!(stream.state, SendState::DataSent { .. }); if fin { stream.fin_pending = false; } if stream.is_pending() { if level.priority == stream.priority { // Enqueue for the same level level.queue.get_mut().push_back(id); } else { // Enqueue for a different level. If the current level is empty, drop it if level.queue.borrow().is_empty() && num_levels != 1 { // We keep the last level around even in empty form so that // the next insert doesn't have to reallocate the queue PeekMut::pop(level); } else { drop(level); } push_pending(&mut self.pending, id, stream.priority); } } else if level.queue.borrow().is_empty() && num_levels != 1 { // We keep the last level around even in empty form so that // the next insert doesn't have to reallocate the queue PeekMut::pop(level); } let meta = frame::StreamMeta { id, offsets, fin }; trace!(id = %meta.id, off = meta.offsets.start, len = meta.offsets.end - meta.offsets.start, fin = meta.fin, "STREAM"); meta.encode(encode_length, buf); // The range might not be retrievable in a single `get` if it is // stored in noncontiguous fashion. Therefore this loop iterates // until the range is fully copied into the frame. let mut offsets = meta.offsets.clone(); while offsets.start != offsets.end { let data = stream.pending.get(offsets.clone()); offsets.start += data.len() as u64; buf.put_slice(data); } stream_frames.push(meta); } stream_frames } /// Notify the application that new streams were opened or a stream became readable. fn on_stream_frame(&mut self, notify_readable: bool, stream: StreamId) { if stream.initiator() == self.side { // Notifying about the opening of locally-initiated streams would be redundant. if notify_readable { self.events.push_back(StreamEvent::Readable { id: stream }); } return; } let next = &mut self.next_remote[stream.dir() as usize]; if stream.index() >= *next { *next = stream.index() + 1; self.opened[stream.dir() as usize] = true; } else if notify_readable { self.events.push_back(StreamEvent::Readable { id: stream }); } } pub(crate) fn received_ack_of(&mut self, frame: frame::StreamMeta) { let mut entry = match self.send.entry(frame.id) { hash_map::Entry::Vacant(_) => return, hash_map::Entry::Occupied(e) => e, }; let stream = match entry.get_mut().as_mut() { Some(s) => s, None => { // Because we only call this after sending data on this stream, // this closure should be unreachable. If we did somehow screw that up, // then we might hit an underflow below with unpredictable effects down // the line. Best to short-circuit. return; } }; if stream.is_reset() { // We account for outstanding data on reset streams at time of reset return; } let id = frame.id; self.unacked_data -= frame.offsets.end - frame.offsets.start; if !stream.ack(frame) { // The stream is unfinished or may still need retransmits return; } entry.remove_entry(); self.stream_freed(id, StreamHalf::Send); self.events.push_back(StreamEvent::Finished { id }); } pub(crate) fn retransmit(&mut self, frame: frame::StreamMeta) { let stream = match self.send.get_mut(&frame.id).and_then(|s| s.as_mut()) { // Loss of data on a closed stream is a noop None => return, Some(x) => x, }; if !stream.is_pending() { push_pending(&mut self.pending, frame.id, stream.priority); } stream.fin_pending |= frame.fin; stream.pending.retransmit(frame.offsets); } pub(crate) fn retransmit_all_for_0rtt(&mut self) { for dir in Dir::iter() { for index in 0..self.next[dir as usize] { let id = StreamId::new(Side::Client, dir, index); let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(stream) => stream, None => continue, }; if stream.pending.is_fully_acked() && !stream.fin_pending { // Stream data can't be acked in 0-RTT, so we must not have sent anything on // this stream continue; } if !stream.is_pending() { push_pending(&mut self.pending, id, stream.priority); } stream.pending.retransmit_all_for_0rtt(); } } } pub(crate) fn received_max_streams( &mut self, dir: Dir, count: u64, ) -> Result<(), TransportError> { if count > MAX_STREAM_COUNT { return Err(TransportError::FRAME_ENCODING_ERROR( "unrepresentable stream limit", )); } let current = &mut self.max[dir as usize]; if count > *current { *current = count; self.events.push_back(StreamEvent::Available { dir }); } Ok(()) } /// Handle increase to connection-level flow control limit pub(crate) fn received_max_data(&mut self, n: VarInt) { self.max_data = self.max_data.max(n.into()); } pub(crate) fn received_max_stream_data( &mut self, id: StreamId, offset: u64, ) -> Result<(), TransportError> { if id.initiator() != self.side && id.dir() == Dir::Uni { debug!("got MAX_STREAM_DATA on recv-only {}", id); return Err(TransportError::STREAM_STATE_ERROR( "MAX_STREAM_DATA on recv-only stream", )); } let write_limit = self.write_limit(); let max_send_data = self.initial_max_send_data(id); if let Some(ss) = self .send .get_mut(&id) .map(get_or_insert_send(max_send_data)) { if ss.increase_max_data(offset) { if write_limit > 0 { self.events.push_back(StreamEvent::Writable { id }); } else if !ss.connection_blocked { // The stream is still blocked on the connection flow control // window. In order to get unblocked when the window relaxes // it needs to be in the connection blocked list. ss.connection_blocked = true; self.connection_blocked.push(id); } } } else if id.initiator() == self.side && self.is_local_unopened(id) { debug!("got MAX_STREAM_DATA on unopened {}", id); return Err(TransportError::STREAM_STATE_ERROR( "MAX_STREAM_DATA on unopened stream", )); } self.on_stream_frame(false, id); Ok(()) } /// Returns the maximum amount of data this is allowed to be written on the connection pub(crate) fn write_limit(&self) -> u64 { (self.max_data - self.data_sent).min(self.send_window - self.unacked_data) } /// Yield stream events pub(crate) fn poll(&mut self) -> Option { if let Some(dir) = Dir::iter().find(|&i| mem::replace(&mut self.opened[i as usize], false)) { return Some(StreamEvent::Opened { dir }); } if self.write_limit() > 0 { while let Some(id) = self.connection_blocked.pop() { let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { None => continue, Some(s) => s, }; debug_assert!(stream.connection_blocked); stream.connection_blocked = false; // If it's no longer sensible to write to a stream (even to detect an error) then don't // report it. if stream.is_writable() && stream.max_data > stream.offset() { return Some(StreamEvent::Writable { id }); } } } self.events.pop_front() } pub(crate) fn take_max_streams_dirty(&mut self, dir: Dir) -> bool { mem::replace(&mut self.max_streams_dirty[dir as usize], false) } /// Check for errors entailed by the peer's use of `id` as a send stream fn validate_receive_id(&mut self, id: StreamId) -> Result<(), TransportError> { if self.side == id.initiator() { match id.dir() { Dir::Uni => { return Err(TransportError::STREAM_STATE_ERROR( "illegal operation on send-only stream", )); } Dir::Bi if id.index() >= self.next[Dir::Bi as usize] => { return Err(TransportError::STREAM_STATE_ERROR( "operation on unopened stream", )); } Dir::Bi => {} }; } else { let limit = self.max_remote[id.dir() as usize]; if id.index() >= limit { return Err(TransportError::STREAM_LIMIT_ERROR("")); } } Ok(()) } /// Whether a locally initiated stream has never been open pub(crate) fn is_local_unopened(&self, id: StreamId) -> bool { id.index() >= self.next[id.dir() as usize] } pub(crate) fn set_max_concurrent(&mut self, dir: Dir, count: VarInt) { self.flow_control_adjusted = true; self.max_concurrent_remote_count[dir as usize] = count.into(); self.ensure_remote_streams(dir); } pub(crate) fn max_concurrent(&self, dir: Dir) -> u64 { self.allocated_remote_count[dir as usize] } /// Set the receive_window and returns wether the receive_window has been /// expanded or shrunk: true if expanded, false if shrunk. pub(crate) fn set_receive_window(&mut self, receive_window: VarInt) -> bool { let receive_window = receive_window.into(); let mut expanded = false; if receive_window > self.receive_window { self.local_max_data = self .local_max_data .saturating_add(receive_window - self.receive_window); expanded = true; } else { let diff = self.receive_window - receive_window; self.receive_window_shrink_debt = self.receive_window_shrink_debt.saturating_add(diff); } self.receive_window = receive_window; expanded } pub(super) fn insert(&mut self, remote: bool, id: StreamId) { let bi = id.dir() == Dir::Bi; // bidirectional OR (unidirectional AND NOT remote) if bi || !remote { assert!(self.send.insert(id, None).is_none()); } // bidirectional OR (unidirectional AND remote) if bi || remote { assert!(self.recv.insert(id, None).is_none()); } } /// Adds credits to the connection flow control window /// /// Returns whether a `MAX_DATA` frame should be enqueued as soon as possible. /// This will only be the case if the window update would is significant /// enough. As soon as a window update with a `MAX_DATA` frame has been /// queued, the [`Recv::record_sent_max_stream_data`] function should be called to /// suppress sending further updates until the window increases significantly /// again. pub(super) fn add_read_credits(&mut self, credits: u64) -> ShouldTransmit { if credits > self.receive_window_shrink_debt { let net_credits = credits - self.receive_window_shrink_debt; self.local_max_data = self.local_max_data.saturating_add(net_credits); self.receive_window_shrink_debt = 0; } else { self.receive_window_shrink_debt -= credits; } if self.local_max_data > VarInt::MAX.into_inner() { return ShouldTransmit(false); } // Only announce a window update if it's significant enough // to make it worthwhile sending a MAX_DATA frame. // We use a fraction of the configured connection receive window to make // the decision, to accomodate for connection using bigger windows requring // less updates. let diff = self.local_max_data - self.sent_max_data.into_inner(); ShouldTransmit(diff >= (self.receive_window / 8)) } /// Update counters for removal of a stream pub(super) fn stream_freed(&mut self, id: StreamId, half: StreamHalf) { if id.initiator() != self.side { let fully_free = id.dir() == Dir::Uni || match half { StreamHalf::Send => !self.recv.contains_key(&id), StreamHalf::Recv => !self.send.contains_key(&id), }; if fully_free { self.allocated_remote_count[id.dir() as usize] -= 1; self.ensure_remote_streams(id.dir()); } } if half == StreamHalf::Send { self.send_streams -= 1; } } pub(super) fn initial_max_send_data(&self, id: StreamId) -> VarInt { let remote = self.side != id.initiator(); match id.dir() { Dir::Uni => self.initial_max_stream_data_uni, // Remote/local appear reversed here because the transport parameters are named from // the perspective of the peer. Dir::Bi if remote => self.initial_max_stream_data_bidi_local, Dir::Bi => self.initial_max_stream_data_bidi_remote, } } } #[inline] pub(super) fn get_or_insert_send( max_data: VarInt, ) -> impl Fn(&mut Option>) -> &mut Box { move |opt| opt.get_or_insert_with(|| Send::new(max_data)) } #[inline] pub(super) fn get_or_insert_recv( initial_max_data: u64, ) -> impl Fn(&mut Option>) -> &mut Box { move |opt| opt.get_or_insert_with(|| Recv::new(initial_max_data)) } #[cfg(test)] mod tests { use super::*; use crate::{ connection::State as ConnState, connection::Streams, ReadableError, RecvStream, SendStream, TransportErrorCode, WriteError, }; use bytes::{Bytes, BytesMut}; fn make(side: Side) -> StreamsState { StreamsState::new( side, 128u32.into(), 128u32.into(), 1024 * 1024, (1024 * 1024u32).into(), (1024 * 1024u32).into(), ) } #[test] fn trivial_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; const MESSAGE_SIZE: usize = 2048; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: true, data: Bytes::from_static(&[0; MESSAGE_SIZE]), }, 2048 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 2048); assert_eq!(client.local_max_data - initial_max, 0); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(MESSAGE_SIZE).unwrap().unwrap().bytes.len(), MESSAGE_SIZE ); assert!(chunks.next(0).unwrap().is_none()); let should_transmit = chunks.finalize(); assert!(should_transmit.0); assert!(pending.max_stream_id[Dir::Uni as usize]); assert_eq!(client.local_max_data - initial_max, MESSAGE_SIZE as u64); } #[test] fn reset_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 2048]), }, 2048 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 2048); assert_eq!(client.local_max_data - initial_max, 0); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); chunks.next(1024).unwrap(); let _ = chunks.finalize(); assert_eq!(client.local_max_data - initial_max, 1024); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); // Ensure reading after a reset doesn't issue redundant credit let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(1024).unwrap_err(), crate::ReadError::Reset(0u32.into()) ); let _ = chunks.finalize(); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); } #[test] fn reset_after_empty_frame_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 4096, fin: false, data: Bytes::from_static(&[0; 0]), }, 0 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 0); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!(client.local_max_data - initial_max, 4096); } #[test] fn duplicate_reset_flow_control() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 4096u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.data_recvd, 4096); } #[test] fn recv_stopped() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 32]), }, 32 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.local_max_data, initial_max); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; recv.stop(0u32.into()).unwrap(); assert_eq!(recv.pending.stop_sending.len(), 1); assert!(!recv.pending.max_data); assert!(recv.stop(0u32.into()).is_err()); assert_eq!(recv.read(true).err(), Some(ReadableError::UnknownStream)); assert_eq!(recv.read(false).err(), Some(ReadableError::UnknownStream)); assert_eq!(client.local_max_data - initial_max, 32); assert_eq!( client .received( frame::Stream { id, offset: 32, fin: true, data: Bytes::from_static(&[0; 16]), }, 16 ) .unwrap(), ShouldTransmit(false) ); assert_eq!(client.local_max_data - initial_max, 48); assert!(!client.recv.contains_key(&id)); } #[test] fn stopped_reset() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); // Server opens stream assert_eq!( client .received( frame::Stream { id, offset: 0, fin: false, data: Bytes::from_static(&[0; 32]) }, 32 ) .unwrap(), ShouldTransmit(false) ); let mut pending = Retransmits::default(); let mut recv = RecvStream { id, state: &mut client, pending: &mut pending, }; recv.stop(0u32.into()).unwrap(); assert_eq!(pending.stop_sending.len(), 1); assert!(!pending.max_data); // Server complies let prev_max = client.max_remote[Dir::Uni as usize]; assert_eq!( client .received_reset(frame::ResetStream { id, error_code: 0u32.into(), final_offset: 32u32.into(), }) .unwrap(), ShouldTransmit(false) ); assert!(!client.recv.contains_key(&id), "stream state is freed"); assert!( client.max_streams_dirty[Dir::Uni as usize], "stream credit is issued" ); assert_eq!(client.max_remote[Dir::Uni as usize], prev_max + 1); } #[test] fn send_stopped() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_uni: 1u32.into(), initial_max_data: 42u32.into(), initial_max_stream_data_uni: 42u32.into(), ..Default::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let id = Streams { state: &mut server, conn_state: &state, } .open(Dir::Uni) .unwrap(); let mut stream = SendStream { id, state: &mut server, pending: &mut pending, conn_state: &state, }; let error_code = 0u32.into(); stream.state.received_stop_sending(id, error_code); assert!(stream .state .events .contains(&StreamEvent::Stopped { id, error_code })); stream.state.events.clear(); assert_eq!(stream.write(&[]), Err(WriteError::Stopped(error_code))); stream.reset(0u32.into()).unwrap(); assert_eq!(stream.write(&[]), Err(WriteError::UnknownStream)); // A duplicate frame is a no-op stream.state.received_stop_sending(id, error_code); assert!(stream.state.events.is_empty()); } #[test] fn final_offset_flow_control() { let mut client = make(Side::Client); assert_eq!( client .received_reset(frame::ResetStream { id: StreamId::new(Side::Server, Dir::Uni, 0), error_code: 0u32.into(), final_offset: VarInt::MAX, }) .unwrap_err() .code, TransportErrorCode::FLOW_CONTROL_ERROR ); } #[test] fn stream_priority() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 10u32.into(), initial_max_stream_data_bidi_remote: 10u32.into(), ..Default::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id_high = streams.open(Dir::Bi).unwrap(); let id_mid = streams.open(Dir::Bi).unwrap(); let id_low = streams.open(Dir::Bi).unwrap(); let mut mid = SendStream { id: id_mid, state: &mut server, pending: &mut pending, conn_state: &state, }; mid.write(b"mid").unwrap(); let mut low = SendStream { id: id_low, state: &mut server, pending: &mut pending, conn_state: &state, }; low.set_priority(-1).unwrap(); low.write(b"low").unwrap(); let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(1).unwrap(); high.write(b"high").unwrap(); let mut buf = BytesMut::with_capacity(40); let meta = server.write_stream_frames(&mut buf, 40); assert_eq!(meta[0].id, id_high); assert_eq!(meta[1].id, id_mid); assert_eq!(meta[2].id, id_low); assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 1); } #[test] fn requeue_stream_priority() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_bidi: 3u32.into(), initial_max_data: 1000u32.into(), initial_max_stream_data_bidi_remote: 1000u32.into(), ..Default::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id_high = streams.open(Dir::Bi).unwrap(); let id_mid = streams.open(Dir::Bi).unwrap(); let mut mid = SendStream { id: id_mid, state: &mut server, pending: &mut pending, conn_state: &state, }; assert_eq!(mid.write(b"mid").unwrap(), 3); assert_eq!(server.pending.len(), 1); let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(1).unwrap(); assert_eq!(high.write(&[0; 200]).unwrap(), 200); assert_eq!(server.pending.len(), 2); // Requeue the high priority stream to lowest priority. The initial send // still uses high priority since it's queued that way. After that it will // switch to low priority let mut high = SendStream { id: id_high, state: &mut server, pending: &mut pending, conn_state: &state, }; high.set_priority(-1).unwrap(); let mut buf = BytesMut::with_capacity(1000); let meta = server.write_stream_frames(&mut buf, 40); assert_eq!(meta.len(), 1); assert_eq!(meta[0].id, id_high); // After requeuing we should end up with 2 priorities - not 3 assert_eq!(server.pending.len(), 2); // Send the remaining data. The initial mid priority one should go first now let meta = server.write_stream_frames(&mut buf, 1000); assert_eq!(meta.len(), 2); assert_eq!(meta[0].id, id_mid); assert_eq!(meta[1].id, id_high); assert!(!server.can_send_stream_data()); assert_eq!(server.pending.len(), 1); } #[test] fn stop_finished() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); // Server finishes stream let _ = client .received( frame::Stream { id, offset: 0, fin: true, data: Bytes::from_static(&[0; 32]), }, 32, ) .unwrap(); let mut pending = Retransmits::default(); let mut stream = RecvStream { id, state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); assert!(client.recv.get_mut(&id).is_none(), "stream is freed"); } // Verify that a stream that's been reset doesn't cause the appearance of pending data #[test] fn reset_stream_cannot_send() { let mut server = make(Side::Server); server.set_params(&TransportParameters { initial_max_streams_uni: 1u32.into(), initial_max_data: 42u32.into(), initial_max_stream_data_uni: 42u32.into(), ..Default::default() }); let (mut pending, state) = (Retransmits::default(), ConnState::Established); let mut streams = Streams { state: &mut server, conn_state: &state, }; let id = streams.open(Dir::Uni).unwrap(); let mut stream = SendStream { id, state: &mut server, pending: &mut pending, conn_state: &state, }; stream.write(b"hello").unwrap(); stream.reset(0u32.into()).unwrap(); assert_eq!(pending.reset_stream, &[(id, 0u32.into())]); assert!(!server.can_send_stream_data()); } #[test] fn stream_limit_fixed() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Try to open stream 128, exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Free stream 127 let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 127), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); assert!(client.max_streams_dirty[Dir::Uni as usize]); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn stream_limit_grows() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Try to open stream 128, exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Relax limit by one client.set_max_concurrent(Dir::Uni, 129u32.into()); assert!(client.max_streams_dirty[Dir::Uni as usize]); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn stream_limit_shrinks() { let mut client = make(Side::Client); // Open streams 0-127 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 127), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); // Tighten limit by one client.set_max_concurrent(Dir::Uni, 127u32.into()); // Free stream 127 let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 127), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); assert!(!client.max_streams_dirty[Dir::Uni as usize]); // Try to open stream 128, still exceeding limit assert_eq!( client .received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ) .unwrap_err() .code, TransportErrorCode::STREAM_LIMIT_ERROR ); // Free stream 126 assert_eq!( client.received_reset(frame::ResetStream { id: StreamId::new(Side::Server, Dir::Uni, 126), error_code: 0u32.into(), final_offset: 0u32.into(), }), Ok(ShouldTransmit(false)) ); let mut pending = Retransmits::default(); let mut stream = RecvStream { id: StreamId::new(Side::Server, Dir::Uni, 126), state: &mut client, pending: &mut pending, }; stream.stop(0u32.into()).unwrap(); assert!(client.max_streams_dirty[Dir::Uni as usize]); // Open stream 128 assert_eq!( client.received( frame::Stream { id: StreamId::new(Side::Server, Dir::Uni, 128), offset: 0, fin: true, data: Bytes::from_static(&[]), }, 0 ), Ok(ShouldTransmit(false)) ); } #[test] fn remote_stream_capacity() { let mut client = make(Side::Client); for _ in 0..2 { client.set_max_concurrent(Dir::Uni, 200u32.into()); client.set_max_concurrent(Dir::Bi, 201u32.into()); assert_eq!(client.recv.len(), 200 + 201); assert_eq!(client.max_remote[Dir::Uni as usize], 200); assert_eq!(client.max_remote[Dir::Bi as usize], 201); } } #[test] fn expand_receive_window() { let mut server = make(Side::Server); let new_receive_window = 2 * server.receive_window as u32; let expanded = server.set_receive_window(new_receive_window.into()); assert!(expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, new_receive_window as u64); assert_eq!(server.receive_window_shrink_debt, 0); let prev_local_max_data = server.local_max_data; // credit, expecting all of them added to local_max_data let credits = 1024u64; let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, prev_local_max_data + credits); assert!(should_transmit.should_transmit()); } #[test] fn shrink_receive_window() { let mut server = make(Side::Server); let new_receive_window = server.receive_window as u32 / 2; let prev_local_max_data = server.local_max_data; // shrink the receive_winbow, local_max_data is not expected to be changed let shrink_diff = server.receive_window - new_receive_window as u64; let expanded = server.set_receive_window(new_receive_window.into()); assert!(!expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, prev_local_max_data); assert_eq!(server.receive_window_shrink_debt, shrink_diff); let prev_local_max_data = server.local_max_data; // credit twice, local_max_data does not change as it is absorbed by receive_window_shrink_debt let credits = 1024u64; for _ in 0..2 { let expected_receive_window_shrink_debt = server.receive_window_shrink_debt - credits; let should_transmit = server.add_read_credits(credits); assert_eq!( server.receive_window_shrink_debt, expected_receive_window_shrink_debt ); assert_eq!(server.local_max_data, prev_local_max_data); assert!(!should_transmit.should_transmit()); } // credit again which exceeds all remaining expected_receive_window_shrink_debt let credits = 1024 * 512; let prev_local_max_data = server.local_max_data; let expected_local_max_data = server.local_max_data + (credits - server.receive_window_shrink_debt); let _should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); assert!(server.local_max_data > prev_local_max_data); // credit again, all should be added to local_max_data let credits = 1024 * 512; let expected_local_max_data = server.local_max_data + credits; let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); assert!(should_transmit.should_transmit()); } } quinn-proto-0.10.6/src/connection/timer.rs000064400000000000000000000035201046102023000166170ustar 00000000000000use std::time::Instant; #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) enum Timer { /// When to send an ack-eliciting probe packet or declare unacked packets lost LossDetection = 0, /// When to close the connection after no activity Idle = 1, /// When the close timer expires, the connection has been gracefully terminated. Close = 2, /// When keys are discarded because they should not be needed anymore KeyDiscard = 3, /// When to give up on validating a new path to the peer PathValidation = 4, /// When to send a `PING` frame to keep the connection alive KeepAlive = 5, /// When pacing will allow us to send a packet Pacing = 6, /// When to invalidate old CID and proactively push new one via NEW_CONNECTION_ID frame PushNewCid = 7, } impl Timer { pub(crate) const VALUES: [Self; 8] = [ Self::LossDetection, Self::Idle, Self::Close, Self::KeyDiscard, Self::PathValidation, Self::KeepAlive, Self::Pacing, Self::PushNewCid, ]; } /// A table of data associated with each distinct kind of `Timer` #[derive(Debug, Copy, Clone, Default)] pub(crate) struct TimerTable { data: [Option; 8], } impl TimerTable { pub(super) fn set(&mut self, timer: Timer, time: Instant) { self.data[timer as usize] = Some(time); } pub(super) fn get(&self, timer: Timer) -> Option { self.data[timer as usize] } pub(super) fn stop(&mut self, timer: Timer) { self.data[timer as usize] = None; } pub(super) fn next_timeout(&self) -> Option { self.data.iter().filter_map(|&x| x).min() } pub(super) fn is_expired(&self, timer: Timer, after: Instant) -> bool { self.data[timer as usize].map_or(false, |x| x <= after) } } quinn-proto-0.10.6/src/constant_time.rs000064400000000000000000000012601046102023000162060ustar 00000000000000// This function is non-inline to prevent the optimizer from looking inside it. #[inline(never)] fn constant_time_ne(a: &[u8], b: &[u8]) -> u8 { assert!(a.len() == b.len()); // These useless slices make the optimizer elide the bounds checks. // See the comment in clone_from_slice() added on Rust commit 6a7bc47. let len = a.len(); let a = &a[..len]; let b = &b[..len]; let mut tmp = 0; for i in 0..len { tmp |= a[i] ^ b[i]; } tmp // The compare with 0 must happen outside this function. } /// Compares byte strings in constant time. pub(crate) fn eq(a: &[u8], b: &[u8]) -> bool { a.len() == b.len() && constant_time_ne(a, b) == 0 } quinn-proto-0.10.6/src/crypto/ring.rs000064400000000000000000000032251046102023000156210ustar 00000000000000use ring::{aead, hkdf, hmac}; use crate::crypto::{self, CryptoError}; impl crypto::HmacKey for hmac::Key { fn sign(&self, data: &[u8], out: &mut [u8]) { out.copy_from_slice(hmac::sign(self, data).as_ref()); } fn signature_len(&self) -> usize { 32 } fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError> { Ok(hmac::verify(self, data, signature)?) } } impl crypto::HandshakeTokenKey for hkdf::Prk { fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box { let mut key_buffer = [0u8; 32]; let info = [random_bytes]; let okm = self.expand(&info, hkdf::HKDF_SHA256).unwrap(); okm.fill(&mut key_buffer).unwrap(); let key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_buffer).unwrap(); Box::new(aead::LessSafeKey::new(key)) } } impl crypto::AeadKey for aead::LessSafeKey { fn seal(&self, data: &mut Vec, additional_data: &[u8]) -> Result<(), CryptoError> { let aad = ring::aead::Aad::from(additional_data); let zero_nonce = ring::aead::Nonce::assume_unique_for_key([0u8; 12]); Ok(self.seal_in_place_append_tag(zero_nonce, aad, data)?) } fn open<'a>( &self, data: &'a mut [u8], additional_data: &[u8], ) -> Result<&'a mut [u8], CryptoError> { let aad = ring::aead::Aad::from(additional_data); let zero_nonce = ring::aead::Nonce::assume_unique_for_key([0u8; 12]); Ok(self.open_in_place(zero_nonce, aad, data)?) } } impl From for CryptoError { fn from(_: ring::error::Unspecified) -> Self { Self } } quinn-proto-0.10.6/src/crypto/rustls.rs000064400000000000000000000337171046102023000162270ustar 00000000000000use std::{any::Any, convert::TryInto, io, str, sync::Arc}; use bytes::BytesMut; use ring::aead; pub use rustls::Error; use rustls::{ self, quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Version}, }; use crate::{ crypto::{ self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion, }, transport_parameters::TransportParameters, ConnectError, ConnectionId, Side, TransportError, TransportErrorCode, }; impl From for rustls::Side { fn from(s: Side) -> Self { match s { Side::Client => Self::Client, Side::Server => Self::Server, } } } /// A rustls TLS session pub struct TlsSession { version: Version, got_handshake_data: bool, next_secrets: Option, inner: Connection, } impl TlsSession { fn side(&self) -> Side { match self.inner { Connection::Client(_) => Side::Client, Connection::Server(_) => Side::Server, } } } impl crypto::Session for TlsSession { fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys { initial_keys(self.version, dst_cid, side) } fn handshake_data(&self) -> Option> { if !self.got_handshake_data { return None; } Some(Box::new(HandshakeData { protocol: self.inner.alpn_protocol().map(|x| x.into()), server_name: match self.inner { Connection::Client(_) => None, Connection::Server(ref session) => session.server_name().map(|x| x.into()), }, })) } fn peer_identity(&self) -> Option> { self.inner .peer_certificates() .map(|v| -> Box { Box::new(v.to_vec()) }) } fn early_crypto(&self) -> Option<(Box, Box)> { let keys = self.inner.zero_rtt_keys()?; Some((Box::new(keys.header), Box::new(keys.packet))) } fn early_data_accepted(&self) -> Option { match self.inner { Connection::Client(ref session) => Some(session.is_early_data_accepted()), _ => None, } } fn is_handshaking(&self) -> bool { self.inner.is_handshaking() } fn read_handshake(&mut self, buf: &[u8]) -> Result { self.inner.read_hs(buf).map_err(|e| { if let Some(alert) = self.inner.alert() { TransportError { code: TransportErrorCode::crypto(alert.get_u8()), frame: None, reason: e.to_string(), } } else { TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}")) } })?; if !self.got_handshake_data { // Hack around the lack of an explicit signal from rustls to reflect ClientHello being // ready on incoming connections, or ALPN negotiation completing on outgoing // connections. let have_server_name = match self.inner { Connection::Client(_) => false, Connection::Server(ref session) => session.server_name().is_some(), }; if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() { self.got_handshake_data = true; return Ok(true); } } Ok(false) } fn transport_parameters(&self) -> Result, TransportError> { match self.inner.quic_transport_parameters() { None => Ok(None), Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) { Ok(params) => Ok(Some(params)), Err(e) => Err(e.into()), }, } } fn write_handshake(&mut self, buf: &mut Vec) -> Option { let keys = match self.inner.write_hs(buf)? { KeyChange::Handshake { keys } => keys, KeyChange::OneRtt { keys, next } => { self.next_secrets = Some(next); keys } }; Some(Keys { header: KeyPair { local: Box::new(keys.local.header), remote: Box::new(keys.remote.header), }, packet: KeyPair { local: Box::new(keys.local.packet), remote: Box::new(keys.remote.packet), }, }) } fn next_1rtt_keys(&mut self) -> Option>> { let secrets = self.next_secrets.as_mut()?; let keys = secrets.next_packet_keys(); Some(KeyPair { local: Box::new(keys.local), remote: Box::new(keys.remote), }) } fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool { let tag_start = match payload.len().checked_sub(16) { Some(x) => x, None => return false, }; let mut pseudo_packet = Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1); pseudo_packet.push(orig_dst_cid.len() as u8); pseudo_packet.extend_from_slice(orig_dst_cid); pseudo_packet.extend_from_slice(header); let tag_start = tag_start + pseudo_packet.len(); pseudo_packet.extend_from_slice(payload); let (nonce, key) = match self.version { Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1), Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT), _ => unreachable!(), }; let nonce = aead::Nonce::assume_unique_for_key(nonce); let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap()); let (aad, tag) = pseudo_packet.split_at_mut(tag_start); key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok() } fn export_keying_material( &self, output: &mut [u8], label: &[u8], context: &[u8], ) -> Result<(), ExportKeyingMaterialError> { self.inner .export_keying_material(output, label, Some(context)) .map_err(|_| ExportKeyingMaterialError)?; Ok(()) } } const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [ 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1, ]; const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [ 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c, ]; const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [ 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e, ]; const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [ 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb, ]; impl crypto::HeaderKey for HeaderProtectionKey { fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) { let (header, sample) = packet.split_at_mut(pn_offset + 4); let (first, rest) = header.split_at_mut(1); let pn_end = Ord::min(pn_offset + 3, rest.len()); self.decrypt_in_place( &sample[..self.sample_size()], &mut first[0], &mut rest[pn_offset - 1..pn_end], ) .unwrap(); } fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) { let (header, sample) = packet.split_at_mut(pn_offset + 4); let (first, rest) = header.split_at_mut(1); let pn_end = Ord::min(pn_offset + 3, rest.len()); self.encrypt_in_place( &sample[..self.sample_size()], &mut first[0], &mut rest[pn_offset - 1..pn_end], ) .unwrap(); } fn sample_size(&self) -> usize { self.sample_len() } } /// Authentication data for (rustls) TLS session pub struct HandshakeData { /// The negotiated application protocol, if ALPN is in use /// /// Guaranteed to be set if a nonempty list of protocols was specified for this connection. pub protocol: Option>, /// The server name specified by the client, if any /// /// Always `None` for outgoing connections pub server_name: Option, } impl crypto::ClientConfig for rustls::ClientConfig { fn start_session( self: Arc, version: u32, server_name: &str, params: &TransportParameters, ) -> Result, ConnectError> { let version = interpret_version(version)?; Ok(Box::new(TlsSession { version, got_handshake_data: false, next_secrets: None, inner: rustls::quic::Connection::Client( rustls::quic::ClientConnection::new( self, version, server_name .try_into() .map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?, to_vec(params), ) .unwrap(), ), })) } } impl crypto::ServerConfig for rustls::ServerConfig { fn start_session( self: Arc, version: u32, params: &TransportParameters, ) -> Box { let version = interpret_version(version).unwrap(); Box::new(TlsSession { version, got_handshake_data: false, next_secrets: None, inner: rustls::quic::Connection::Server( rustls::quic::ServerConnection::new(self, version, to_vec(params)).unwrap(), ), }) } fn initial_keys( &self, version: u32, dst_cid: &ConnectionId, side: Side, ) -> Result { let version = interpret_version(version)?; Ok(initial_keys(version, dst_cid, side)) } fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] { let version = interpret_version(version).unwrap(); let (nonce, key) = match version { Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1), Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT), _ => unreachable!(), }; let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1); pseudo_packet.push(orig_dst_cid.len() as u8); pseudo_packet.extend_from_slice(orig_dst_cid); pseudo_packet.extend_from_slice(packet); let nonce = aead::Nonce::assume_unique_for_key(nonce); let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap()); let tag = key .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut []) .unwrap(); let mut result = [0; 16]; result.copy_from_slice(tag.as_ref()); result } } fn to_vec(params: &TransportParameters) -> Vec { let mut bytes = Vec::new(); params.write(&mut bytes); bytes } pub(crate) fn initial_keys(version: Version, dst_cid: &ConnectionId, side: Side) -> Keys { let keys = rustls::quic::Keys::initial(version, dst_cid, side.into()); Keys { header: KeyPair { local: Box::new(keys.local.header), remote: Box::new(keys.remote.header), }, packet: KeyPair { local: Box::new(keys.local.packet), remote: Box::new(keys.remote.packet), }, } } impl crypto::PacketKey for PacketKey { fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) { let (header, payload_tag) = buf.split_at_mut(header_len); let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len()); let tag = self.encrypt_in_place(packet, &*header, payload).unwrap(); tag_storage.copy_from_slice(tag.as_ref()); } fn decrypt( &self, packet: u64, header: &[u8], payload: &mut BytesMut, ) -> Result<(), CryptoError> { let plain = self .decrypt_in_place(packet, header, payload.as_mut()) .map_err(|_| CryptoError)?; let plain_len = plain.len(); payload.truncate(plain_len); Ok(()) } fn tag_len(&self) -> usize { self.tag_len() } fn confidentiality_limit(&self) -> u64 { self.confidentiality_limit() } fn integrity_limit(&self) -> u64 { self.integrity_limit() } } /// Initialize a sane QUIC-compatible TLS client configuration /// /// QUIC requires that TLS 1.3 be enabled. Advanced users can use any [`rustls::ClientConfig`] that /// satisfies this requirement. pub(crate) fn client_config(roots: rustls::RootCertStore) -> rustls::ClientConfig { let mut cfg = rustls::ClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() .with_root_certificates(roots) .with_no_client_auth(); cfg.enable_early_data = true; cfg } /// Initialize a sane QUIC-compatible TLS server configuration /// /// QUIC requires that TLS 1.3 be enabled, and that the maximum early data size is either 0 or /// `u32::MAX`. Advanced users can use any [`rustls::ServerConfig`] that satisfies these /// requirements. pub(crate) fn server_config( cert_chain: Vec, key: rustls::PrivateKey, ) -> Result { let mut cfg = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() .with_no_client_auth() .with_single_cert(cert_chain, key)?; cfg.max_early_data_size = u32::MAX; Ok(cfg) } fn interpret_version(version: u32) -> Result { match version { 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft), 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1), _ => Err(UnsupportedVersion), } } quinn-proto-0.10.6/src/crypto.rs000064400000000000000000000177051046102023000146720ustar 00000000000000//! Traits and implementations for the QUIC cryptography protocol //! //! The protocol logic in Quinn is contained in types that abstract over the actual //! cryptographic protocol used. This module contains the traits used for this //! abstraction layer as well as a single implementation of these traits that uses //! *ring* and rustls to implement the TLS protocol support. //! //! Note that usage of any protocol (version) other than TLS 1.3 does not conform to any //! published versions of the specification, and will not be supported in QUIC v1. use std::{any::Any, str, sync::Arc}; use bytes::BytesMut; use crate::{ shared::ConnectionId, transport_parameters::TransportParameters, ConnectError, Side, TransportError, }; /// Cryptography interface based on *ring* #[cfg(feature = "ring")] pub(crate) mod ring; /// TLS interface based on rustls #[cfg(feature = "rustls")] pub mod rustls; /// A cryptographic session (commonly TLS) pub trait Session: Send + 'static { /// Create the initial set of keys given the client's initial destination ConnectionId fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys; /// Get data negotiated during the handshake, if available /// /// Returns `None` until the connection emits `HandshakeDataReady`. fn handshake_data(&self) -> Option>; /// Get the peer's identity, if available fn peer_identity(&self) -> Option>; /// Get the 0-RTT keys if available (clients only) /// /// On the client side, this method can be used to see if 0-RTT key material is available /// to start sending data before the protocol handshake has completed. /// /// Returns `None` if the key material is not available. This might happen if you have /// not connected to this server before. fn early_crypto(&self) -> Option<(Box, Box)>; /// If the 0-RTT-encrypted data has been accepted by the peer fn early_data_accepted(&self) -> Option; /// Returns `true` until the connection is fully established. fn is_handshaking(&self) -> bool; /// Read bytes of handshake data /// /// This should be called with the contents of `CRYPTO` frames. If it returns `Ok`, the /// caller should call `write_handshake()` to check if the crypto protocol has anything /// to send to the peer. This method will only return `true` the first time that /// handshake data is available. Future calls will always return false. /// /// On success, returns `true` iff `self.handshake_data()` has been populated. fn read_handshake(&mut self, buf: &[u8]) -> Result; /// The peer's QUIC transport parameters /// /// These are only available after the first flight from the peer has been received. fn transport_parameters(&self) -> Result, TransportError>; /// Writes handshake bytes into the given buffer and optionally returns the negotiated keys /// /// When the handshake proceeds to the next phase, this method will return a new set of /// keys to encrypt data with. fn write_handshake(&mut self, buf: &mut Vec) -> Option; /// Compute keys for the next key update fn next_1rtt_keys(&mut self) -> Option>>; /// Verify the integrity of a retry packet fn is_valid_retry(&self, orig_dst_cid: &ConnectionId, header: &[u8], payload: &[u8]) -> bool; /// Fill `output` with `output.len()` bytes of keying material derived /// from the [Session]'s secrets, using `label` and `context` for domain /// separation. /// /// This function will fail, returning [ExportKeyingMaterialError], /// if the requested output length is too large. fn export_keying_material( &self, output: &mut [u8], label: &[u8], context: &[u8], ) -> Result<(), ExportKeyingMaterialError>; } /// A pair of keys for bidirectional communication pub struct KeyPair { /// Key for encrypting data pub local: T, /// Key for decrypting data pub remote: T, } /// A complete set of keys for a certain packet space pub struct Keys { /// Header protection keys pub header: KeyPair>, /// Packet protection keys pub packet: KeyPair>, } /// Client-side configuration for the crypto protocol pub trait ClientConfig: Send + Sync { /// Start a client session with this configuration fn start_session( self: Arc, version: u32, server_name: &str, params: &TransportParameters, ) -> Result, ConnectError>; } /// Server-side configuration for the crypto protocol pub trait ServerConfig: Send + Sync { /// Create the initial set of keys given the client's initial destination ConnectionId fn initial_keys( &self, version: u32, dst_cid: &ConnectionId, side: Side, ) -> Result; /// Generate the integrity tag for a retry packet /// /// Never called if `initial_keys` rejected `version`. fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16]; /// Start a server session with this configuration /// /// Never called if `initial_keys` rejected `version`. fn start_session( self: Arc, version: u32, params: &TransportParameters, ) -> Box; } /// Keys used to protect packet payloads pub trait PacketKey: Send { /// Encrypt the packet payload with the given packet number fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize); /// Decrypt the packet payload with the given packet number fn decrypt( &self, packet: u64, header: &[u8], payload: &mut BytesMut, ) -> Result<(), CryptoError>; /// The length of the AEAD tag appended to packets on encryption fn tag_len(&self) -> usize; /// Maximum number of packets that may be sent using a single key fn confidentiality_limit(&self) -> u64; /// Maximum number of incoming packets that may fail decryption before the connection must be /// abandoned fn integrity_limit(&self) -> u64; } /// Keys used to protect packet headers pub trait HeaderKey: Send { /// Decrypt the given packet's header fn decrypt(&self, pn_offset: usize, packet: &mut [u8]); /// Encrypt the given packet's header fn encrypt(&self, pn_offset: usize, packet: &mut [u8]); /// The sample size used for this key's algorithm fn sample_size(&self) -> usize; } /// A key for signing with HMAC-based algorithms pub trait HmacKey: Send + Sync { /// Method for signing a message fn sign(&self, data: &[u8], signature_out: &mut [u8]); /// Length of `sign`'s output fn signature_len(&self) -> usize; /// Method for verifying a message fn verify(&self, data: &[u8], signature: &[u8]) -> Result<(), CryptoError>; } /// Error returned by [Session::export_keying_material]. /// /// This error occurs if the requested output length is too large. #[derive(Debug, PartialEq, Eq)] pub struct ExportKeyingMaterialError; /// A pseudo random key for HKDF pub trait HandshakeTokenKey: Send + Sync { /// Derive AEAD using hkdf fn aead_from_hkdf(&self, random_bytes: &[u8]) -> Box; } /// A key for sealing data with AEAD-based algorithms pub trait AeadKey { /// Method for sealing message `data` fn seal(&self, data: &mut Vec, additional_data: &[u8]) -> Result<(), CryptoError>; /// Method for opening a sealed message `data` fn open<'a>( &self, data: &'a mut [u8], additional_data: &[u8], ) -> Result<&'a mut [u8], CryptoError>; } /// Generic crypto errors #[derive(Debug)] pub struct CryptoError; /// Error indicating that the specified QUIC version is not supported #[derive(Debug)] pub struct UnsupportedVersion; impl From for ConnectError { fn from(_: UnsupportedVersion) -> Self { Self::UnsupportedVersion } } quinn-proto-0.10.6/src/endpoint.rs000064400000000000000000001033241046102023000151630ustar 00000000000000use std::{ collections::{HashMap, VecDeque}, convert::TryFrom, fmt, iter, net::{IpAddr, SocketAddr}, ops::{Index, IndexMut}, sync::Arc, time::{Instant, SystemTime}, }; use bytes::{BufMut, Bytes, BytesMut}; use rand::{rngs::StdRng, Rng, RngCore, SeedableRng}; use rustc_hash::FxHashMap; use slab::Slab; use thiserror::Error; use tracing::{debug, trace, warn}; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, coding::BufMutExt, config::{ClientConfig, EndpointConfig, ServerConfig}, connection::{Connection, ConnectionError}, crypto::{self, Keys, UnsupportedVersion}, frame, packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode}, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent, EndpointEventInner, IssuedCid, }, transport_parameters::TransportParameters, ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, }; /// The main entry point to the library /// /// This object performs no I/O whatsoever. Instead, it generates a stream of packets to send via /// `poll_transmit`, and consumes incoming packets and connection-generated events via `handle` and /// `handle_event`. pub struct Endpoint { rng: StdRng, transmits: VecDeque, /// Identifies connections based on the initial DCID the peer utilized /// /// Uses a standard `HashMap` to protect against hash collision attacks. connection_ids_initial: HashMap, /// Identifies connections based on locally created CIDs /// /// Uses a cheaper hash function since keys are locally created connection_ids: FxHashMap, /// Identifies connections with zero-length CIDs /// /// Uses a standard `HashMap` to protect against hash collision attacks. connection_remotes: HashMap, /// Reset tokens provided by the peer for the CID each connection is currently sending to /// /// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct /// recipient, if any. connection_reset_tokens: ResetTokenTable, connections: Slab, local_cid_generator: Box, config: Arc, server_config: Option>, /// Whether the underlying UDP socket promises not to fragment packets allow_mtud: bool, /// The contents length for packets in the transmits queue transmit_queue_contents_len: usize, /// The socket buffer aggregated contents length /// `transmit_queue_contents_len` + `socket_buffer_fill` represents the total contents length /// of outstanding outgoing packets. socket_buffer_fill: usize, } /// The maximum size of content length of packets in the outgoing transmit queue. Transmit packets /// generated from the endpoint (retry, initial close, stateless reset and version negotiation) /// can be dropped when this limit is being execeeded. /// Chose to represent 100 MB of data. const MAX_TRANSMIT_QUEUE_CONTENTS_LEN: usize = 100_000_000; impl Endpoint { /// Create a new endpoint /// /// `allow_mtud` enables path MTU detection when requested by `Connection` configuration for /// better performance. This requires that outgoing packets are never fragmented, which can be /// achieved via e.g. the `IPV6_DONTFRAG` socket option. pub fn new( config: Arc, server_config: Option>, allow_mtud: bool, ) -> Self { Self { rng: StdRng::from_entropy(), transmits: VecDeque::new(), connection_ids_initial: HashMap::default(), connection_ids: FxHashMap::default(), connection_remotes: HashMap::default(), connection_reset_tokens: ResetTokenTable::default(), connections: Slab::new(), local_cid_generator: (config.connection_id_generator_factory.as_ref())(), config, server_config, allow_mtud, transmit_queue_contents_len: 0, socket_buffer_fill: 0, } } /// Get the next packet to transmit #[must_use] pub fn poll_transmit(&mut self) -> Option { let t = self.transmits.pop_front(); self.decrement_transmit_queue_contents_len(t.as_ref().map_or(0, |t| t.contents.len())); t } /// Replace the server configuration, affecting new incoming connections only pub fn set_server_config(&mut self, server_config: Option>) { self.server_config = server_config; } /// Process `EndpointEvent`s emitted from related `Connection`s /// /// In turn, processing this event may return a `ConnectionEvent` for the same `Connection`. pub fn handle_event( &mut self, ch: ConnectionHandle, event: EndpointEvent, ) -> Option { use EndpointEventInner::*; match event.0 { NeedIdentifiers(now, n) => { return Some(self.send_new_identifiers(now, ch, n)); } ResetToken(remote, token) => { if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) { self.connection_reset_tokens.remove(old.0, old.1); } if self.connection_reset_tokens.insert(remote, token, ch) { warn!("duplicate reset token"); } } RetireConnectionId(now, seq, allow_more_cids) => { if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) { trace!("peer retired CID {}: {}", seq, cid); self.connection_ids.remove(&cid); if allow_more_cids { return Some(self.send_new_identifiers(now, ch, 1)); } } } Drained => { let conn = self.connections.remove(ch.0); if conn.init_cid.len() > 0 { self.connection_ids_initial.remove(&conn.init_cid); } for cid in conn.loc_cids.values() { self.connection_ids.remove(cid); } self.connection_remotes.remove(&conn.addresses); if let Some((remote, token)) = conn.reset_token { self.connection_reset_tokens.remove(remote, token); } } } None } /// Process an incoming UDP datagram pub fn handle( &mut self, now: Instant, remote: SocketAddr, local_ip: Option, ecn: Option, data: BytesMut, ) -> Option<(ConnectionHandle, DatagramEvent)> { let datagram_len = data.len(); let (first_decode, remaining) = match PartialDecode::new( data, self.local_cid_generator.cid_len(), &self.config.supported_versions, self.config.grease_quic_bit, ) { Ok(x) => x, Err(PacketDecodeError::UnsupportedVersion { src_cid, dst_cid, version, }) => { if self.server_config.is_none() { debug!("dropping packet with unsupported version"); return None; } if self.stateless_packets_supressed() { return None; } trace!("sending version negotiation"); // Negotiate versions let mut buf = BytesMut::new(); Header::VersionNegotiate { random: self.rng.gen::() | 0x40, src_cid: dst_cid, dst_cid: src_cid, } .encode(&mut buf); // Grease with a reserved version if version != 0x0a1a_2a3a { buf.write::(0x0a1a_2a3a); } else { buf.write::(0x0a1a_2a4a); } for &version in &self.config.supported_versions { buf.write(version); } self.increment_transmit_queue_contents_len(buf.len()); self.transmits.push_back(Transmit { destination: remote, ecn: None, contents: buf.freeze(), segment_size: None, src_ip: local_ip, }); return None; } Err(e) => { trace!("malformed header: {}", e); return None; } }; // // Handle packet on existing connection, if any // let addresses = FourTuple { remote, local_ip }; let dst_cid = first_decode.dst_cid(); let known_ch = (self.local_cid_generator.cid_len() > 0) .then(|| self.connection_ids.get(&dst_cid)) .flatten() .or_else(|| { if first_decode.is_initial() || first_decode.is_0rtt() { self.connection_ids_initial.get(&dst_cid) } else { None } }) .or_else(|| { if self.local_cid_generator.cid_len() == 0 { self.connection_remotes.get(&addresses) } else { None } }) .or_else(|| { let data = first_decode.data(); if data.len() < RESET_TOKEN_SIZE { return None; } self.connection_reset_tokens .get(addresses.remote, &data[data.len() - RESET_TOKEN_SIZE..]) }) .cloned(); if let Some(ch) = known_ch { return Some(( ch, DatagramEvent::ConnectionEvent(ConnectionEvent(ConnectionEventInner::Datagram { now, remote: addresses.remote, ecn, first_decode, remaining, })), )); } // // Potentially create a new connection // let server_config = match &self.server_config { Some(config) => config, None => { debug!("packet for unrecognized connection {}", dst_cid); self.stateless_reset(datagram_len, addresses, &dst_cid); return None; } }; if let Some(version) = first_decode.initial_version() { if datagram_len < MIN_INITIAL_SIZE as usize { debug!("ignoring short initial for connection {}", dst_cid); return None; } let crypto = match server_config .crypto .initial_keys(version, &dst_cid, Side::Server) { Ok(keys) => keys, Err(UnsupportedVersion) => { // This probably indicates that the user set supported_versions incorrectly in // `EndpointConfig`. debug!( "ignoring initial packet version {:#x} unsupported by cryptographic layer", version ); return None; } }; return match first_decode.finish(Some(&*crypto.header.remote)) { Ok(packet) => self .handle_first_packet(now, addresses, ecn, packet, remaining, &crypto) .map(|(ch, conn)| (ch, DatagramEvent::NewConnection(conn))), Err(e) => { trace!("unable to decode initial packet: {}", e); None } }; } else if first_decode.has_long_header() { debug!( "ignoring non-initial packet for unknown connection {}", dst_cid ); return None; } // // If we got this far, we're a server receiving a seemingly valid packet for an unknown // connection. Send a stateless reset. // if !dst_cid.is_empty() { self.stateless_reset(datagram_len, addresses, &dst_cid); } else { trace!("dropping unrecognized short packet without ID"); } None } fn stateless_reset( &mut self, inciting_dgram_len: usize, addresses: FourTuple, dst_cid: &ConnectionId, ) { if self.stateless_packets_supressed() { return; } /// Minimum amount of padding for the stateless reset to look like a short-header packet const MIN_PADDING_LEN: usize = 5; // Prevent amplification attacks and reset loops by ensuring we pad to at most 1 byte // smaller than the inciting packet. let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) { Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1, _ => { debug!("ignoring unexpected {} byte packet: not larger than minimum stateless reset size", inciting_dgram_len); return; } }; debug!( "sending stateless reset for {} to {}", dst_cid, addresses.remote ); let mut buf = BytesMut::new(); // Resets with at least this much padding can't possibly be distinguished from real packets const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE; let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN { max_padding_len } else { self.rng.gen_range(IDEAL_MIN_PADDING_LEN..max_padding_len) }; buf.reserve(padding_len + RESET_TOKEN_SIZE); buf.resize(padding_len, 0); self.rng.fill_bytes(&mut buf[0..padding_len]); buf[0] = 0b0100_0000 | buf[0] >> 2; buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid)); debug_assert!(buf.len() < inciting_dgram_len); self.increment_transmit_queue_contents_len(buf.len()); self.transmits.push_back(Transmit { destination: addresses.remote, ecn: None, contents: buf.freeze(), segment_size: None, src_ip: addresses.local_ip, }); } /// Initiate a connection pub fn connect( &mut self, config: ClientConfig, remote: SocketAddr, server_name: &str, ) -> Result<(ConnectionHandle, Connection), ConnectError> { if self.is_full() { return Err(ConnectError::TooManyConnections); } if remote.port() == 0 || remote.ip().is_unspecified() { return Err(ConnectError::InvalidRemoteAddress(remote)); } if !self.config.supported_versions.contains(&config.version) { return Err(ConnectError::UnsupportedVersion); } let remote_id = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); trace!(initial_dcid = %remote_id); let loc_cid = self.new_cid(); let params = TransportParameters::new( &config.transport, &self.config, self.local_cid_generator.as_ref(), loc_cid, None, ); let tls = config .crypto .start_session(config.version, server_name, ¶ms)?; let (ch, conn) = self.add_connection( config.version, remote_id, loc_cid, remote_id, FourTuple { remote, local_ip: None, }, Instant::now(), tls, None, config.transport, ); Ok((ch, conn)) } fn send_new_identifiers( &mut self, now: Instant, ch: ConnectionHandle, num: u64, ) -> ConnectionEvent { let mut ids = vec![]; for _ in 0..num { let id = self.new_cid(); self.connection_ids.insert(id, ch); let meta = &mut self.connections[ch]; meta.cids_issued += 1; let sequence = meta.cids_issued; meta.loc_cids.insert(sequence, id); ids.push(IssuedCid { sequence, id, reset_token: ResetToken::new(&*self.config.reset_key, &id), }); } ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now)) } fn new_cid(&mut self) -> ConnectionId { loop { let cid = self.local_cid_generator.generate_cid(); if !self.connection_ids.contains_key(&cid) { break cid; } assert!(self.local_cid_generator.cid_len() > 0); } } /// Limiting the memory usage for items queued in the outgoing queue from endpoint /// generated packets. Otherwise, we may see a build-up of the queue under test with /// flood of initial packets against the endpoint. The sender with the sender-limiter /// may not keep up the pace of these packets queued into the queue. fn stateless_packets_supressed(&self) -> bool { self.transmit_queue_contents_len .saturating_add(self.socket_buffer_fill) >= MAX_TRANSMIT_QUEUE_CONTENTS_LEN } /// Increment the contents length in the transmit queue. fn increment_transmit_queue_contents_len(&mut self, contents_len: usize) { self.transmit_queue_contents_len = self .transmit_queue_contents_len .saturating_add(contents_len); } /// Decrement the contents length in the transmit queue. fn decrement_transmit_queue_contents_len(&mut self, contents_len: usize) { self.transmit_queue_contents_len = self .transmit_queue_contents_len .saturating_sub(contents_len); } /// Set the `socket_buffer_fill` to the input `len` pub fn set_socket_buffer_fill(&mut self, len: usize) { self.socket_buffer_fill = len; } fn handle_first_packet( &mut self, now: Instant, addresses: FourTuple, ecn: Option, mut packet: Packet, rest: Option, crypto: &Keys, ) -> Option<(ConnectionHandle, Connection)> { let (src_cid, dst_cid, token, packet_number, version) = match packet.header { Header::Initial { src_cid, dst_cid, ref token, number, version, .. } => (src_cid, dst_cid, token.clone(), number, version), _ => panic!("non-initial packet in handle_first_packet()"), }; let packet_number = packet_number.expand(0); if crypto .packet .remote .decrypt(packet_number, &packet.header_data, &mut packet.payload) .is_err() { debug!(packet_number, "failed to authenticate initial packet"); return None; }; if !packet.reserved_bits_valid() { debug!("dropping connection attempt with invalid reserved bits"); return None; } let loc_cid = self.new_cid(); let server_config = self.server_config.as_ref().unwrap(); if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full() { debug!("refusing connection"); self.initial_close( version, addresses, crypto, &src_cid, &loc_cid, TransportError::CONNECTION_REFUSED(""), ); return None; } if dst_cid.len() < 8 && (!server_config.use_retry || dst_cid.len() != self.local_cid_generator.cid_len()) { debug!( "rejecting connection due to invalid DCID length {}", dst_cid.len() ); self.initial_close( version, addresses, crypto, &src_cid, &loc_cid, TransportError::PROTOCOL_VIOLATION("invalid destination CID length"), ); return None; } let (retry_src_cid, orig_dst_cid) = if server_config.use_retry { if token.is_empty() { if self.stateless_packets_supressed() { return None; } // First Initial let mut random_bytes = vec![0u8; RetryToken::RANDOM_BYTES_LEN]; self.rng.fill_bytes(&mut random_bytes); let token = RetryToken { orig_dst_cid: dst_cid, issued: SystemTime::now(), random_bytes: &random_bytes, } .encode(&*server_config.token_key, &addresses.remote, &loc_cid); let header = Header::Retry { src_cid: loc_cid, dst_cid: src_cid, version, }; let mut buf = BytesMut::new(); let encode = header.encode(&mut buf); buf.put_slice(&token); buf.extend_from_slice(&server_config.crypto.retry_tag(version, &dst_cid, &buf)); encode.finish(&mut buf, &*crypto.header.local, None); self.increment_transmit_queue_contents_len(buf.len()); self.transmits.push_back(Transmit { destination: addresses.remote, ecn: None, contents: buf.freeze(), segment_size: None, src_ip: addresses.local_ip, }); return None; } match RetryToken::from_bytes( &*server_config.token_key, &addresses.remote, &dst_cid, &token, ) { Ok(token) if token.issued + server_config.retry_token_lifetime > SystemTime::now() => { (Some(dst_cid), token.orig_dst_cid) } _ => { debug!("rejecting invalid stateless retry token"); self.initial_close( version, addresses, crypto, &src_cid, &loc_cid, TransportError::INVALID_TOKEN(""), ); return None; } } } else { (None, dst_cid) }; let server_config = server_config.clone(); let mut params = TransportParameters::new( &server_config.transport, &self.config, self.local_cid_generator.as_ref(), loc_cid, Some(&server_config), ); params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid)); params.original_dst_cid = Some(orig_dst_cid); params.retry_src_cid = retry_src_cid; let tls = server_config.crypto.clone().start_session(version, ¶ms); let transport_config = server_config.transport.clone(); let (ch, mut conn) = self.add_connection( version, dst_cid, loc_cid, src_cid, addresses, now, tls, Some(server_config), transport_config, ); if dst_cid.len() != 0 { self.connection_ids_initial.insert(dst_cid, ch); } match conn.handle_first_packet(now, addresses.remote, ecn, packet_number, packet, rest) { Ok(()) => { trace!(id = ch.0, icid = %dst_cid, "connection incoming"); Some((ch, conn)) } Err(e) => { debug!("handshake failed: {}", e); self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained)); if let ConnectionError::TransportError(e) = e { self.initial_close(version, addresses, crypto, &src_cid, &loc_cid, e); } None } } } fn add_connection( &mut self, version: u32, init_cid: ConnectionId, loc_cid: ConnectionId, rem_cid: ConnectionId, addresses: FourTuple, now: Instant, tls: Box, server_config: Option>, transport_config: Arc, ) -> (ConnectionHandle, Connection) { let conn = Connection::new( self.config.clone(), server_config, transport_config, init_cid, loc_cid, rem_cid, addresses.remote, addresses.local_ip, tls, self.local_cid_generator.as_ref(), now, version, self.allow_mtud, ); let id = self.connections.insert(ConnectionMeta { init_cid, cids_issued: 0, loc_cids: iter::once((0, loc_cid)).collect(), addresses, reset_token: None, }); let ch = ConnectionHandle(id); match self.local_cid_generator.cid_len() { 0 => self.connection_remotes.insert(addresses, ch), _ => self.connection_ids.insert(loc_cid, ch), }; (ch, conn) } fn initial_close( &mut self, version: u32, addresses: FourTuple, crypto: &Keys, remote_id: &ConnectionId, local_id: &ConnectionId, reason: TransportError, ) { if self.stateless_packets_supressed() { return; } let number = PacketNumber::U8(0); let header = Header::Initial { dst_cid: *remote_id, src_cid: *local_id, number, token: Bytes::new(), version, }; let mut buf = BytesMut::new(); let partial_encode = header.encode(&mut buf); let max_len = INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len(); frame::Close::from(reason).encode(&mut buf, max_len); buf.resize(buf.len() + crypto.packet.local.tag_len(), 0); partial_encode.finish( &mut buf, &*crypto.header.local, Some((0, &*crypto.packet.local)), ); self.increment_transmit_queue_contents_len(buf.len()); self.transmits.push_back(Transmit { destination: addresses.remote, ecn: None, contents: buf.freeze(), segment_size: None, src_ip: addresses.local_ip, }) } /// Reject new incoming connections without affecting existing connections /// /// Convenience short-hand for using /// [`set_server_config`](Self::set_server_config) to update /// [`concurrent_connections`](ServerConfig::concurrent_connections) to /// zero. pub fn reject_new_connections(&mut self) { if let Some(config) = self.server_config.as_mut() { Arc::make_mut(config).concurrent_connections(0); } } /// Access the configuration used by this endpoint pub fn config(&self) -> &EndpointConfig { &self.config } #[cfg(test)] pub(crate) fn known_connections(&self) -> usize { let x = self.connections.len(); debug_assert_eq!(x, self.connection_ids_initial.len()); // Not all connections have known reset tokens debug_assert!(x >= self.connection_reset_tokens.0.len()); // Not all connections have unique remotes, and 0-length CIDs might not be in use. debug_assert!(x >= self.connection_remotes.len()); x } #[cfg(test)] pub(crate) fn known_cids(&self) -> usize { self.connection_ids.len() } /// Whether we've used up 3/4 of the available CID space /// /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot. fn is_full(&self) -> bool { self.local_cid_generator.cid_len() <= 4 && self.local_cid_generator.cid_len() != 0 && (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8) - self.connection_ids.len()) < 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2) } } impl fmt::Debug for Endpoint { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Endpoint") .field("rng", &self.rng) .field("transmits", &self.transmits) .field("connection_ids_initial", &self.connection_ids_initial) .field("connection_ids", &self.connection_ids) .field("connection_remotes", &self.connection_remotes) .field("connection_reset_tokens", &self.connection_reset_tokens) .field("connections", &self.connections) .field("config", &self.config) .field("server_config", &self.server_config) .finish() } } #[derive(Debug)] pub(crate) struct ConnectionMeta { init_cid: ConnectionId, /// Number of local connection IDs that have been issued in NEW_CONNECTION_ID frames. cids_issued: u64, loc_cids: FxHashMap, /// Remote/local addresses the connection began with /// /// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't /// bother keeping it up to date. addresses: FourTuple, /// Reset token provided by the peer for the CID we're currently sending to, and the address /// being sent to reset_token: Option<(SocketAddr, ResetToken)>, } /// Internal identifier for a `Connection` currently associated with an endpoint #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct ConnectionHandle(pub usize); impl From for usize { fn from(x: ConnectionHandle) -> Self { x.0 } } impl Index for Slab { type Output = ConnectionMeta; fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta { &self[ch.0] } } impl IndexMut for Slab { fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta { &mut self[ch.0] } } /// Event resulting from processing a single datagram #[allow(clippy::large_enum_variant)] // Not passed around extensively pub enum DatagramEvent { /// The datagram is redirected to its `Connection` ConnectionEvent(ConnectionEvent), /// The datagram has resulted in starting a new `Connection` NewConnection(Connection), } /// Errors in the parameters being used to create a new connection /// /// These arise before any I/O has been performed. #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum ConnectError { /// The endpoint can no longer create new connections /// /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled. #[error("endpoint stopping")] EndpointStopping, /// The number of active connections on the local endpoint is at the limit /// /// Try using longer connection IDs. #[error("too many connections")] TooManyConnections, /// The domain name supplied was malformed #[error("invalid DNS name: {0}")] InvalidDnsName(String), /// The remote [`SocketAddr`] supplied was malformed /// /// Examples include attempting to connect to port 0, or using an inappropriate address family. #[error("invalid remote address: {0}")] InvalidRemoteAddress(SocketAddr), /// No default client configuration was set up /// /// Use `Endpoint::connect_with` to specify a client configuration. #[error("no default client config")] NoDefaultClientConfig, /// The local endpoint does not support the QUIC version specified in the client configuration #[error("unsupported QUIC version")] UnsupportedVersion, } /// Reset Tokens which are associated with peer socket addresses /// /// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are /// peer generated and might be usable for hash collision attacks. #[derive(Default, Debug)] struct ResetTokenTable(HashMap>); impl ResetTokenTable { fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool { self.0 .entry(remote) .or_default() .insert(token, ch) .is_some() } fn remove(&mut self, remote: SocketAddr, token: ResetToken) { use std::collections::hash_map::Entry; match self.0.entry(remote) { Entry::Vacant(_) => {} Entry::Occupied(mut e) => { e.get_mut().remove(&token); if e.get().is_empty() { e.remove_entry(); } } } } fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> { let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?); self.0.get(&remote)?.get(&token) } } /// Identifies a connection by the combination of remote and local addresses /// /// Including the local ensures good behavior when the host has multiple IP addresses on the same /// subnet and zero-length connection IDs are in use. #[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)] struct FourTuple { remote: SocketAddr, // A single socket can only listen on a single port, so no need to store it explicitly local_ip: Option, } quinn-proto-0.10.6/src/frame.rs000064400000000000000000000651561046102023000144470ustar 00000000000000use std::{ fmt::{self, Write}, io, mem, ops::{Range, RangeInclusive}, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tinyvec::TinyVec; use crate::{ coding::{self, BufExt, BufMutExt, UnexpectedEnd}, range_set::ArrayRangeSet, shared::{ConnectionId, EcnCodepoint}, Dir, ResetToken, StreamId, TransportError, TransportErrorCode, VarInt, MAX_CID_SIZE, RESET_TOKEN_SIZE, }; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; #[derive(Copy, Clone, Eq, PartialEq)] pub struct Type(u64); impl Type { fn stream(self) -> Option { if STREAM_TYS.contains(&self.0) { Some(StreamInfo(self.0 as u8)) } else { None } } fn datagram(self) -> Option { if DATAGRAM_TYS.contains(&self.0) { Some(DatagramInfo(self.0 as u8)) } else { None } } } impl coding::Codec for Type { fn decode(buf: &mut B) -> coding::Result { Ok(Self(buf.get_var()?)) } fn encode(&self, buf: &mut B) { buf.write_var(self.0); } } pub(crate) trait FrameStruct { /// Smallest number of bytes this type of frame is guaranteed to fit within. const SIZE_BOUND: usize; } macro_rules! frame_types { {$($name:ident = $val:expr,)*} => { impl Type { $(pub const $name: Type = Type($val);)* } impl fmt::Debug for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* _ => write!(f, "Type({:02x})", self.0) } } } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* x if STREAM_TYS.contains(&x) => f.write_str("STREAM"), x if DATAGRAM_TYS.contains(&x) => f.write_str("DATAGRAM"), _ => write!(f, "", self.0), } } } } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] struct StreamInfo(u8); impl StreamInfo { fn fin(self) -> bool { self.0 & 0x01 != 0 } fn len(self) -> bool { self.0 & 0x02 != 0 } fn off(self) -> bool { self.0 & 0x04 != 0 } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] struct DatagramInfo(u8); impl DatagramInfo { fn len(self) -> bool { self.0 & 0x01 != 0 } } frame_types! { PADDING = 0x00, PING = 0x01, ACK = 0x02, ACK_ECN = 0x03, RESET_STREAM = 0x04, STOP_SENDING = 0x05, CRYPTO = 0x06, NEW_TOKEN = 0x07, // STREAM MAX_DATA = 0x10, MAX_STREAM_DATA = 0x11, MAX_STREAMS_BIDI = 0x12, MAX_STREAMS_UNI = 0x13, DATA_BLOCKED = 0x14, STREAM_DATA_BLOCKED = 0x15, STREAMS_BLOCKED_BIDI = 0x16, STREAMS_BLOCKED_UNI = 0x17, NEW_CONNECTION_ID = 0x18, RETIRE_CONNECTION_ID = 0x19, PATH_CHALLENGE = 0x1a, PATH_RESPONSE = 0x1b, CONNECTION_CLOSE = 0x1c, APPLICATION_CLOSE = 0x1d, HANDSHAKE_DONE = 0x1e, // DATAGRAM } const STREAM_TYS: RangeInclusive = RangeInclusive::new(0x08, 0x0f); const DATAGRAM_TYS: RangeInclusive = RangeInclusive::new(0x30, 0x31); #[derive(Debug)] pub(crate) enum Frame { Padding, Ping, Ack(Ack), ResetStream(ResetStream), StopSending(StopSending), Crypto(Crypto), NewToken { token: Bytes }, Stream(Stream), MaxData(VarInt), MaxStreamData { id: StreamId, offset: u64 }, MaxStreams { dir: Dir, count: u64 }, DataBlocked { offset: u64 }, StreamDataBlocked { id: StreamId, offset: u64 }, StreamsBlocked { dir: Dir, limit: u64 }, NewConnectionId(NewConnectionId), RetireConnectionId { sequence: u64 }, PathChallenge(u64), PathResponse(u64), Close(Close), Datagram(Datagram), HandshakeDone, } impl Frame { pub(crate) fn ty(&self) -> Type { use self::Frame::*; match *self { Padding => Type::PADDING, ResetStream(_) => Type::RESET_STREAM, Close(self::Close::Connection(_)) => Type::CONNECTION_CLOSE, Close(self::Close::Application(_)) => Type::APPLICATION_CLOSE, MaxData(_) => Type::MAX_DATA, MaxStreamData { .. } => Type::MAX_STREAM_DATA, MaxStreams { dir: Dir::Bi, .. } => Type::MAX_STREAMS_BIDI, MaxStreams { dir: Dir::Uni, .. } => Type::MAX_STREAMS_UNI, Ping => Type::PING, DataBlocked { .. } => Type::DATA_BLOCKED, StreamDataBlocked { .. } => Type::STREAM_DATA_BLOCKED, StreamsBlocked { dir: Dir::Bi, .. } => Type::STREAMS_BLOCKED_BIDI, StreamsBlocked { dir: Dir::Uni, .. } => Type::STREAMS_BLOCKED_UNI, StopSending { .. } => Type::STOP_SENDING, RetireConnectionId { .. } => Type::RETIRE_CONNECTION_ID, Ack(_) => Type::ACK, Stream(ref x) => { let mut ty = *STREAM_TYS.start(); if x.fin { ty |= 0x01; } if x.offset != 0 { ty |= 0x04; } Type(ty) } PathChallenge(_) => Type::PATH_CHALLENGE, PathResponse(_) => Type::PATH_RESPONSE, NewConnectionId { .. } => Type::NEW_CONNECTION_ID, Crypto(_) => Type::CRYPTO, NewToken { .. } => Type::NEW_TOKEN, Datagram(_) => Type(*DATAGRAM_TYS.start()), HandshakeDone => Type::HANDSHAKE_DONE, } } pub(crate) fn is_ack_eliciting(&self) -> bool { !matches!(*self, Self::Ack(_) | Self::Padding | Self::Close(_)) } } #[derive(Clone, Debug)] pub enum Close { Connection(ConnectionClose), Application(ApplicationClose), } impl Close { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { match *self { Self::Connection(ref x) => x.encode(out, max_len), Self::Application(ref x) => x.encode(out, max_len), } } } impl From for Close { fn from(x: TransportError) -> Self { Self::Connection(x.into()) } } impl From for Close { fn from(x: ConnectionClose) -> Self { Self::Connection(x) } } impl From for Close { fn from(x: ApplicationClose) -> Self { Self::Application(x) } } /// Reason given by the transport for closing the connection #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConnectionClose { /// Class of error as encoded in the specification pub error_code: TransportErrorCode, /// Type of frame that caused the close pub frame_type: Option, /// Human-readable reason for the close pub reason: Bytes, } impl fmt::Display for ConnectionClose { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.error_code.fmt(f)?; if !self.reason.as_ref().is_empty() { f.write_str(": ")?; f.write_str(&String::from_utf8_lossy(&self.reason))?; } Ok(()) } } impl From for ConnectionClose { fn from(x: TransportError) -> Self { Self { error_code: x.code, frame_type: x.frame, reason: x.reason.into(), } } } impl FrameStruct for ConnectionClose { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } impl ConnectionClose { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { out.write(Type::CONNECTION_CLOSE); // 1 byte out.write(self.error_code); // <= 8 bytes let ty = self.frame_type.map_or(0, |x| x.0); out.write_var(ty); // <= 8 bytes let max_len = max_len - 3 - VarInt::from_u64(ty).unwrap().size() - VarInt::from_u64(self.reason.len() as u64).unwrap().size(); let actual_len = self.reason.len().min(max_len); out.write_var(actual_len as u64); // <= 8 bytes out.put_slice(&self.reason[0..actual_len]); // whatever's left } } /// Reason given by an application for closing the connection #[derive(Debug, Clone, PartialEq, Eq)] pub struct ApplicationClose { /// Application-specific reason code pub error_code: VarInt, /// Human-readable reason for the close pub reason: Bytes, } impl fmt::Display for ApplicationClose { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if !self.reason.as_ref().is_empty() { f.write_str(&String::from_utf8_lossy(&self.reason))?; f.write_str(" (code ")?; self.error_code.fmt(f)?; f.write_str(")")?; } else { self.error_code.fmt(f)?; } Ok(()) } } impl FrameStruct for ApplicationClose { const SIZE_BOUND: usize = 1 + 8 + 8; } impl ApplicationClose { pub(crate) fn encode(&self, out: &mut W, max_len: usize) { out.write(Type::APPLICATION_CLOSE); // 1 byte out.write(self.error_code); // <= 8 bytes let max_len = max_len - 3 - VarInt::from_u64(self.reason.len() as u64).unwrap().size(); let actual_len = self.reason.len().min(max_len); out.write_var(actual_len as u64); // <= 8 bytes out.put_slice(&self.reason[0..actual_len]); // whatever's left } } #[derive(Clone, Eq, PartialEq)] pub struct Ack { pub largest: u64, pub delay: u64, pub additional: Bytes, pub ecn: Option, } impl fmt::Debug for Ack { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut ranges = "[".to_string(); let mut first = true; for range in self.iter() { if !first { ranges.push(','); } write!(ranges, "{range:?}").unwrap(); first = false; } ranges.push(']'); f.debug_struct("Ack") .field("largest", &self.largest) .field("delay", &self.delay) .field("ecn", &self.ecn) .field("ranges", &ranges) .finish() } } impl<'a> IntoIterator for &'a Ack { type Item = RangeInclusive; type IntoIter = AckIter<'a>; fn into_iter(self) -> AckIter<'a> { AckIter::new(self.largest, &self.additional[..]) } } impl Ack { pub fn encode( delay: u64, ranges: &ArrayRangeSet, ecn: Option<&EcnCounts>, buf: &mut W, ) { let mut rest = ranges.iter().rev(); let first = rest.next().unwrap(); let largest = first.end - 1; let first_size = first.end - first.start; buf.write(if ecn.is_some() { Type::ACK_ECN } else { Type::ACK }); buf.write_var(largest); buf.write_var(delay); buf.write_var(ranges.len() as u64 - 1); buf.write_var(first_size - 1); let mut prev = first.start; for block in rest { let size = block.end - block.start; buf.write_var(prev - block.end - 1); buf.write_var(size - 1); prev = block.start; } if let Some(x) = ecn { x.encode(buf) } } pub fn iter(&self) -> AckIter<'_> { self.into_iter() } } #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct EcnCounts { pub ect0: u64, pub ect1: u64, pub ce: u64, } impl std::ops::AddAssign for EcnCounts { fn add_assign(&mut self, rhs: EcnCodepoint) { match rhs { EcnCodepoint::Ect0 => { self.ect0 += 1; } EcnCodepoint::Ect1 => { self.ect1 += 1; } EcnCodepoint::Ce => { self.ce += 1; } } } } impl EcnCounts { pub const ZERO: Self = Self { ect0: 0, ect1: 0, ce: 0, }; pub fn encode(&self, out: &mut W) { out.write_var(self.ect0); out.write_var(self.ect1); out.write_var(self.ce); } } #[derive(Debug, Clone)] pub(crate) struct Stream { pub(crate) id: StreamId, pub(crate) offset: u64, pub(crate) fin: bool, pub(crate) data: Bytes, } impl FrameStruct for Stream { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } /// Metadata from a stream frame #[derive(Debug, Clone)] pub(crate) struct StreamMeta { pub(crate) id: StreamId, pub(crate) offsets: Range, pub(crate) fin: bool, } // This manual implementation exists because `Default` is not implemented for `StreamId` impl Default for StreamMeta { fn default() -> Self { Self { id: StreamId(0), offsets: 0..0, fin: false, } } } impl StreamMeta { pub(crate) fn encode(&self, length: bool, out: &mut W) { let mut ty = *STREAM_TYS.start(); if self.offsets.start != 0 { ty |= 0x04; } if length { ty |= 0x02; } if self.fin { ty |= 0x01; } out.write_var(ty); // 1 byte out.write(self.id); // <=8 bytes if self.offsets.start != 0 { out.write_var(self.offsets.start); // <=8 bytes } if length { out.write_var(self.offsets.end - self.offsets.start); // <=8 bytes } } } /// A vector of [`StreamMeta`] with optimization for the single element case pub(crate) type StreamMetaVec = TinyVec<[StreamMeta; 1]>; #[derive(Debug, Clone)] pub(crate) struct Crypto { pub(crate) offset: u64, pub(crate) data: Bytes, } impl Crypto { pub(crate) const SIZE_BOUND: usize = 17; pub(crate) fn encode(&self, out: &mut W) { out.write(Type::CRYPTO); out.write_var(self.offset); out.write_var(self.data.len() as u64); out.put_slice(&self.data); } } pub(crate) struct Iter { // TODO: ditch io::Cursor after bytes 0.5 bytes: io::Cursor, last_ty: Option, } enum IterErr { UnexpectedEnd, InvalidFrameId, Malformed, } impl IterErr { fn reason(&self) -> &'static str { use self::IterErr::*; match *self { UnexpectedEnd => "unexpected end", InvalidFrameId => "invalid frame ID", Malformed => "malformed", } } } impl From for IterErr { fn from(_: UnexpectedEnd) -> Self { Self::UnexpectedEnd } } impl Iter { pub(crate) fn new(payload: Bytes) -> Self { Self { bytes: io::Cursor::new(payload), last_ty: None, } } fn take_len(&mut self) -> Result { let len = self.bytes.get_var()?; if len > self.bytes.remaining() as u64 { return Err(UnexpectedEnd); } let start = self.bytes.position() as usize; self.bytes.advance(len as usize); Ok(self.bytes.get_ref().slice(start..(start + len as usize))) } fn try_next(&mut self) -> Result { let ty = self.bytes.get::()?; self.last_ty = Some(ty); Ok(match ty { Type::PADDING => Frame::Padding, Type::RESET_STREAM => Frame::ResetStream(ResetStream { id: self.bytes.get()?, error_code: self.bytes.get()?, final_offset: self.bytes.get()?, }), Type::CONNECTION_CLOSE => Frame::Close(Close::Connection(ConnectionClose { error_code: self.bytes.get()?, frame_type: { let x = self.bytes.get_var()?; if x == 0 { None } else { Some(Type(x)) } }, reason: self.take_len()?, })), Type::APPLICATION_CLOSE => Frame::Close(Close::Application(ApplicationClose { error_code: self.bytes.get()?, reason: self.take_len()?, })), Type::MAX_DATA => Frame::MaxData(self.bytes.get()?), Type::MAX_STREAM_DATA => Frame::MaxStreamData { id: self.bytes.get()?, offset: self.bytes.get_var()?, }, Type::MAX_STREAMS_BIDI => Frame::MaxStreams { dir: Dir::Bi, count: self.bytes.get_var()?, }, Type::MAX_STREAMS_UNI => Frame::MaxStreams { dir: Dir::Uni, count: self.bytes.get_var()?, }, Type::PING => Frame::Ping, Type::DATA_BLOCKED => Frame::DataBlocked { offset: self.bytes.get_var()?, }, Type::STREAM_DATA_BLOCKED => Frame::StreamDataBlocked { id: self.bytes.get()?, offset: self.bytes.get_var()?, }, Type::STREAMS_BLOCKED_BIDI => Frame::StreamsBlocked { dir: Dir::Bi, limit: self.bytes.get_var()?, }, Type::STREAMS_BLOCKED_UNI => Frame::StreamsBlocked { dir: Dir::Uni, limit: self.bytes.get_var()?, }, Type::STOP_SENDING => Frame::StopSending(StopSending { id: self.bytes.get()?, error_code: self.bytes.get()?, }), Type::RETIRE_CONNECTION_ID => Frame::RetireConnectionId { sequence: self.bytes.get_var()?, }, Type::ACK | Type::ACK_ECN => { let largest = self.bytes.get_var()?; let delay = self.bytes.get_var()?; let extra_blocks = self.bytes.get_var()? as usize; let start = self.bytes.position() as usize; scan_ack_blocks(&mut self.bytes, largest, extra_blocks)?; let end = self.bytes.position() as usize; Frame::Ack(Ack { delay, largest, additional: self.bytes.get_ref().slice(start..end), ecn: if ty != Type::ACK_ECN { None } else { Some(EcnCounts { ect0: self.bytes.get_var()?, ect1: self.bytes.get_var()?, ce: self.bytes.get_var()?, }) }, }) } Type::PATH_CHALLENGE => Frame::PathChallenge(self.bytes.get()?), Type::PATH_RESPONSE => Frame::PathResponse(self.bytes.get()?), Type::NEW_CONNECTION_ID => { let sequence = self.bytes.get_var()?; let retire_prior_to = self.bytes.get_var()?; if retire_prior_to > sequence { return Err(IterErr::Malformed); } let length = self.bytes.get::()? as usize; if length > MAX_CID_SIZE || length == 0 { return Err(IterErr::Malformed); } if length > self.bytes.remaining() { return Err(IterErr::UnexpectedEnd); } let mut stage = [0; MAX_CID_SIZE]; self.bytes.copy_to_slice(&mut stage[0..length]); let id = ConnectionId::new(&stage[..length]); if self.bytes.remaining() < 16 { return Err(IterErr::UnexpectedEnd); } let mut reset_token = [0; RESET_TOKEN_SIZE]; self.bytes.copy_to_slice(&mut reset_token); Frame::NewConnectionId(NewConnectionId { sequence, retire_prior_to, id, reset_token: reset_token.into(), }) } Type::CRYPTO => Frame::Crypto(Crypto { offset: self.bytes.get_var()?, data: self.take_len()?, }), Type::NEW_TOKEN => Frame::NewToken { token: self.take_len()?, }, Type::HANDSHAKE_DONE => Frame::HandshakeDone, _ => { if let Some(s) = ty.stream() { Frame::Stream(Stream { id: self.bytes.get()?, offset: if s.off() { self.bytes.get_var()? } else { 0 }, fin: s.fin(), data: if s.len() { self.take_len()? } else { self.take_remaining() }, }) } else if let Some(d) = ty.datagram() { Frame::Datagram(Datagram { data: if d.len() { self.take_len()? } else { self.take_remaining() }, }) } else { return Err(IterErr::InvalidFrameId); } } }) } fn take_remaining(&mut self) -> Bytes { let mut x = mem::replace(self.bytes.get_mut(), Bytes::new()); x.advance(self.bytes.position() as usize); self.bytes.set_position(0); x } } impl Iterator for Iter { type Item = Result; fn next(&mut self) -> Option { if !self.bytes.has_remaining() { return None; } match self.try_next() { Ok(x) => Some(Ok(x)), Err(e) => { // Corrupt frame, skip it and everything that follows self.bytes = io::Cursor::new(Bytes::new()); Some(Err(InvalidFrame { ty: self.last_ty, reason: e.reason(), })) } } } } #[derive(Debug)] pub(crate) struct InvalidFrame { pub(crate) ty: Option, pub(crate) reason: &'static str, } impl From for TransportError { fn from(err: InvalidFrame) -> Self { let mut te = Self::FRAME_ENCODING_ERROR(err.reason); te.frame = err.ty; te } } fn scan_ack_blocks(buf: &mut io::Cursor, largest: u64, n: usize) -> Result<(), IterErr> { let first_block = buf.get_var()?; let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?; for _ in 0..n { let gap = buf.get_var()?; smallest = smallest.checked_sub(gap + 2).ok_or(IterErr::Malformed)?; let block = buf.get_var()?; smallest = smallest.checked_sub(block).ok_or(IterErr::Malformed)?; } Ok(()) } #[derive(Debug, Clone)] pub struct AckIter<'a> { largest: u64, data: io::Cursor<&'a [u8]>, } impl<'a> AckIter<'a> { fn new(largest: u64, payload: &'a [u8]) -> Self { let data = io::Cursor::new(payload); Self { largest, data } } } impl<'a> Iterator for AckIter<'a> { type Item = RangeInclusive; fn next(&mut self) -> Option> { if !self.data.has_remaining() { return None; } let block = self.data.get_var().unwrap(); let largest = self.largest; if let Ok(gap) = self.data.get_var() { self.largest -= block + gap + 2; } Some(largest - block..=largest) } } #[allow(unreachable_pub)] // fuzzing only #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone)] pub struct ResetStream { pub(crate) id: StreamId, pub(crate) error_code: VarInt, pub(crate) final_offset: VarInt, } impl FrameStruct for ResetStream { const SIZE_BOUND: usize = 1 + 8 + 8 + 8; } impl ResetStream { pub(crate) fn encode(&self, out: &mut W) { out.write(Type::RESET_STREAM); // 1 byte out.write(self.id); // <= 8 bytes out.write(self.error_code); // <= 8 bytes out.write(self.final_offset); // <= 8 bytes } } #[derive(Debug, Copy, Clone)] pub(crate) struct StopSending { pub(crate) id: StreamId, pub(crate) error_code: VarInt, } impl FrameStruct for StopSending { const SIZE_BOUND: usize = 1 + 8 + 8; } impl StopSending { pub(crate) fn encode(&self, out: &mut W) { out.write(Type::STOP_SENDING); // 1 byte out.write(self.id); // <= 8 bytes out.write(self.error_code) // <= 8 bytes } } #[derive(Debug, Copy, Clone)] pub(crate) struct NewConnectionId { pub(crate) sequence: u64, pub(crate) retire_prior_to: u64, pub(crate) id: ConnectionId, pub(crate) reset_token: ResetToken, } impl NewConnectionId { pub(crate) fn encode(&self, out: &mut W) { out.write(Type::NEW_CONNECTION_ID); out.write_var(self.sequence); out.write_var(self.retire_prior_to); out.write(self.id.len() as u8); out.put_slice(&self.id); out.put_slice(&self.reset_token); } } /// Smallest number of bytes this type of frame is guaranteed to fit within. pub(crate) const RETIRE_CONNECTION_ID_SIZE_BOUND: usize = 9; /// An unreliable datagram #[derive(Debug, Clone)] pub struct Datagram { /// Payload pub data: Bytes, } impl FrameStruct for Datagram { const SIZE_BOUND: usize = 1 + 8; } impl Datagram { pub(crate) fn encode(&self, length: bool, out: &mut BytesMut) { out.write(Type(*DATAGRAM_TYS.start() | u64::from(length))); // 1 byte if length { // Safe to unwrap because we check length sanity before queueing datagrams out.write(VarInt::from_u64(self.data.len() as u64).unwrap()); // <= 8 bytes } out.extend_from_slice(&self.data); } pub(crate) fn size(&self, length: bool) -> usize { 1 + if length { VarInt::from_u64(self.data.len() as u64).unwrap().size() } else { 0 } + self.data.len() } } #[cfg(test)] mod test { use super::*; #[test] #[allow(clippy::range_plus_one)] fn ack_coding() { const PACKETS: &[u64] = &[1, 2, 3, 5, 10, 11, 14]; let mut ranges = ArrayRangeSet::new(); for &packet in PACKETS { ranges.insert(packet..packet + 1); } let mut buf = Vec::new(); const ECN: EcnCounts = EcnCounts { ect0: 42, ect1: 24, ce: 12, }; Ack::encode(42, &ranges, Some(&ECN), &mut buf); let frames = Iter::new(Bytes::from(buf)) .collect::, _>>() .unwrap(); assert_eq!(frames.len(), 1); match frames[0] { Frame::Ack(ref ack) => { let mut packets = ack.iter().flatten().collect::>(); packets.sort_unstable(); assert_eq!(&packets[..], PACKETS); assert_eq!(ack.ecn, Some(ECN)); } ref x => panic!("incorrect frame {x:?}"), } } } quinn-proto-0.10.6/src/lib.rs000064400000000000000000000212121046102023000141040ustar 00000000000000//! Low-level protocol logic for the QUIC protoocol //! //! quinn-proto contains a fully deterministic implementation of QUIC protocol logic. It contains //! no networking code and does not get any relevant timestamps from the operating system. Most //! users may want to use the futures-based quinn API instead. //! //! The quinn-proto API might be of interest if you want to use it from a C or C++ project //! through C bindings or if you want to use a different event loop than the one tokio provides. //! //! The most important types are `Endpoint`, which conceptually represents the protocol state for //! a single socket and mostly manages configuration and dispatches incoming datagrams to the //! related `Connection`. `Connection` types contain the bulk of the protocol logic related to //! managing a single connection and all the related state (such as streams). #![cfg_attr(not(fuzzing), warn(missing_docs))] #![cfg_attr(test, allow(dead_code))] // Fixes welcome: #![warn(unreachable_pub)] #![allow(clippy::cognitive_complexity)] #![allow(clippy::too_many_arguments)] #![warn(clippy::use_self)] use std::{ fmt, net::{IpAddr, SocketAddr}, ops, time::Duration, }; mod cid_queue; #[doc(hidden)] pub mod coding; mod constant_time; mod packet; mod range_set; #[cfg(all(test, feature = "rustls"))] mod tests; pub mod transport_parameters; mod varint; use bytes::Bytes; pub use varint::{VarInt, VarIntBoundsExceeded}; mod connection; pub use crate::connection::{ BytesSource, Chunk, Chunks, Connection, ConnectionError, ConnectionStats, Datagrams, Event, FinishError, FrameStats, PathStats, ReadError, ReadableError, RecvStream, RttEstimator, SendDatagramError, SendStream, StreamEvent, Streams, UdpStats, UnknownStream, WriteError, Written, }; mod config; pub use config::{ ClientConfig, ConfigError, EndpointConfig, IdleTimeout, MtuDiscoveryConfig, ServerConfig, TransportConfig, }; pub mod crypto; mod frame; use crate::frame::Frame; pub use crate::frame::{ApplicationClose, ConnectionClose, Datagram}; mod endpoint; pub use crate::endpoint::{ConnectError, ConnectionHandle, DatagramEvent, Endpoint}; mod shared; pub use crate::shared::{ConnectionEvent, ConnectionId, EcnCodepoint, EndpointEvent}; mod transport_error; pub use crate::transport_error::{Code as TransportErrorCode, Error as TransportError}; pub mod congestion; mod cid_generator; pub use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; mod token; use token::{ResetToken, RetryToken}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; #[doc(hidden)] #[cfg(fuzzing)] pub mod fuzzing { pub use crate::connection::{Retransmits, State as ConnectionState, StreamsState}; pub use crate::frame::ResetStream; pub use crate::packet::PartialDecode; pub use crate::transport_parameters::TransportParameters; use crate::MAX_CID_SIZE; use arbitrary::{Arbitrary, Result, Unstructured}; pub use bytes::{BufMut, BytesMut}; impl<'arbitrary> Arbitrary<'arbitrary> for TransportParameters { fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result { Ok(Self { initial_max_streams_bidi: u.arbitrary()?, initial_max_streams_uni: u.arbitrary()?, ack_delay_exponent: u.arbitrary()?, max_udp_payload_size: u.arbitrary()?, ..Self::default() }) } } #[derive(Debug)] pub struct PacketParams { pub local_cid_len: usize, pub buf: BytesMut, pub grease_quic_bit: bool, } impl<'arbitrary> Arbitrary<'arbitrary> for PacketParams { fn arbitrary(u: &mut Unstructured<'arbitrary>) -> Result { let local_cid_len: usize = u.int_in_range(0..=MAX_CID_SIZE)?; let bytes: Vec = Vec::arbitrary(u)?; let mut buf = BytesMut::new(); buf.put_slice(&bytes[..]); Ok(Self { local_cid_len, buf, grease_quic_bit: bool::arbitrary(u)?, }) } } } /// The QUIC protocol version implemented. pub const DEFAULT_SUPPORTED_VERSIONS: &[u32] = &[ 0x00000001, 0xff00_001d, 0xff00_001e, 0xff00_001f, 0xff00_0020, 0xff00_0021, 0xff00_0022, ]; /// Whether an endpoint was the initiator of a connection #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Side { /// The initiator of a connection Client = 0, /// The acceptor of a connection Server = 1, } impl Side { #[inline] /// Shorthand for `self == Side::Client` pub fn is_client(self) -> bool { self == Self::Client } #[inline] /// Shorthand for `self == Side::Server` pub fn is_server(self) -> bool { self == Self::Server } } impl ops::Not for Side { type Output = Self; fn not(self) -> Self { match self { Self::Client => Self::Server, Self::Server => Self::Client, } } } /// Whether a stream communicates data in both directions or only from the initiator #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum Dir { /// Data flows in both directions Bi = 0, /// Data flows only from the stream's initiator Uni = 1, } impl Dir { fn iter() -> impl Iterator { [Self::Bi, Self::Uni].iter().cloned() } } impl fmt::Display for Dir { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::Dir::*; f.pad(match *self { Bi => "bidirectional", Uni => "unidirectional", }) } } /// Identifier for a stream within a particular connection #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct StreamId(#[doc(hidden)] pub u64); impl fmt::Display for StreamId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let initiator = match self.initiator() { Side::Client => "client", Side::Server => "server", }; let dir = match self.dir() { Dir::Uni => "uni", Dir::Bi => "bi", }; write!( f, "{} {}directional stream {}", initiator, dir, self.index() ) } } impl StreamId { /// Create a new StreamId pub fn new(initiator: Side, dir: Dir, index: u64) -> Self { Self(index << 2 | (dir as u64) << 1 | initiator as u64) } /// Which side of a connection initiated the stream pub fn initiator(self) -> Side { if self.0 & 0x1 == 0 { Side::Client } else { Side::Server } } /// Which directions data flows in pub fn dir(self) -> Dir { if self.0 & 0x2 == 0 { Dir::Bi } else { Dir::Uni } } /// Distinguishes streams of the same initiator and directionality pub fn index(self) -> u64 { self.0 >> 2 } } impl From for VarInt { fn from(x: StreamId) -> Self { unsafe { Self::from_u64_unchecked(x.0) } } } impl From for StreamId { fn from(v: VarInt) -> Self { Self(v.0) } } impl coding::Codec for StreamId { fn decode(buf: &mut B) -> coding::Result { VarInt::decode(buf).map(|x| Self(x.into_inner())) } fn encode(&self, buf: &mut B) { VarInt::from_u64(self.0).unwrap().encode(buf); } } /// An outgoing packet #[derive(Debug)] pub struct Transmit { /// The socket this datagram should be sent to pub destination: SocketAddr, /// Explicit congestion notification bits to set on the packet pub ecn: Option, /// Contents of the datagram pub contents: Bytes, /// The segment size if this transmission contains multiple datagrams. /// This is `None` if the transmit only contains a single datagram pub segment_size: Option, /// Optional source IP address for the datagram pub src_ip: Option, } // // Useful internal constants // /// The maximum number of CIDs we bother to issue per connection const LOC_CID_COUNT: u64 = 8; const RESET_TOKEN_SIZE: usize = 16; const MAX_CID_SIZE: usize = 20; const MIN_INITIAL_SIZE: u16 = 1200; /// const INITIAL_MTU: u16 = 1200; const MAX_UDP_PAYLOAD: u16 = 65527; const TIMER_GRANULARITY: Duration = Duration::from_millis(1); /// Maximum number of streams that can be uniquely identified by a stream ID const MAX_STREAM_COUNT: u64 = 1 << 60; quinn-proto-0.10.6/src/packet.rs000064400000000000000000000647551046102023000146300ustar 00000000000000use std::{cmp::Ordering, io, ops::Range, str}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use thiserror::Error; use crate::{ coding::{self, BufExt, BufMutExt}, crypto, ConnectionId, }; // Due to packet number encryption, it is impossible to fully decode a header // (which includes a variable-length packet number) without crypto context. // The crypto context (represented by the `Crypto` type in Quinn) is usually // part of the `Connection`, or can be derived from the destination CID for // Initial packets. // // To cope with this, we decode the invariant header (which should be stable // across QUIC versions), which gives us the destination CID and allows us // to inspect the version and packet type (which depends on the version). // This information allows us to fully decode and decrypt the packet. #[allow(unreachable_pub)] // fuzzing only #[derive(Debug)] pub struct PartialDecode { plain_header: PlainHeader, buf: io::Cursor, } #[allow(clippy::len_without_is_empty)] impl PartialDecode { #[allow(unreachable_pub)] // fuzzing only pub fn new( bytes: BytesMut, local_cid_len: usize, supported_versions: &[u32], grease_quic_bit: bool, ) -> Result<(Self, Option), PacketDecodeError> { let mut buf = io::Cursor::new(bytes); let plain_header = PlainHeader::decode(&mut buf, local_cid_len, supported_versions, grease_quic_bit)?; let dgram_len = buf.get_ref().len(); let packet_len = plain_header .payload_len() .map(|len| (buf.position() + len) as usize) .unwrap_or(dgram_len); match dgram_len.cmp(&packet_len) { Ordering::Equal => Ok((Self { plain_header, buf }, None)), Ordering::Less => Err(PacketDecodeError::InvalidHeader( "packet too short to contain payload length", )), Ordering::Greater => { let rest = Some(buf.get_mut().split_off(packet_len)); Ok((Self { plain_header, buf }, rest)) } } } /// The underlying partially-decoded packet data pub(crate) fn data(&self) -> &[u8] { self.buf.get_ref() } pub(crate) fn initial_version(&self) -> Option { match self.plain_header { PlainHeader::Initial { version, .. } => Some(version), _ => None, } } pub(crate) fn has_long_header(&self) -> bool { !matches!(self.plain_header, PlainHeader::Short { .. }) } pub(crate) fn is_initial(&self) -> bool { self.space() == Some(SpaceId::Initial) } pub(crate) fn space(&self) -> Option { use self::PlainHeader::*; match self.plain_header { Initial { .. } => Some(SpaceId::Initial), Long { ty: LongType::Handshake, .. } => Some(SpaceId::Handshake), Long { ty: LongType::ZeroRtt, .. } => Some(SpaceId::Data), Short { .. } => Some(SpaceId::Data), _ => None, } } pub(crate) fn is_0rtt(&self) -> bool { match self.plain_header { PlainHeader::Long { ty, .. } => ty == LongType::ZeroRtt, _ => false, } } pub(crate) fn dst_cid(&self) -> ConnectionId { self.plain_header.dst_cid() } /// Length of QUIC packet being decoded #[allow(unreachable_pub)] // fuzzing only pub fn len(&self) -> usize { self.buf.get_ref().len() } pub(crate) fn finish( self, header_crypto: Option<&dyn crypto::HeaderKey>, ) -> Result { use self::PlainHeader::*; let Self { plain_header, mut buf, } = self; if let Initial { dst_cid, src_cid, token_pos, version, .. } = plain_header { let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?; let header_len = buf.position() as usize; let mut bytes = buf.into_inner(); let header_data = bytes.split_to(header_len).freeze(); let token = header_data.slice(token_pos.start..token_pos.end); return Ok(Packet { header: Header::Initial { dst_cid, src_cid, token, number, version, }, header_data, payload: bytes, }); } let header = match plain_header { Long { ty, dst_cid, src_cid, version, .. } => Header::Long { ty, dst_cid, src_cid, number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?, version, }, Retry { dst_cid, src_cid, version, } => Header::Retry { dst_cid, src_cid, version, }, Short { spin, dst_cid, .. } => { let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?; let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0; Header::Short { spin, key_phase, dst_cid, number, } } VersionNegotiate { random, dst_cid, src_cid, } => Header::VersionNegotiate { random, dst_cid, src_cid, }, Initial { .. } => unreachable!(), }; let header_len = buf.position() as usize; let mut bytes = buf.into_inner(); Ok(Packet { header, header_data: bytes.split_to(header_len).freeze(), payload: bytes, }) } fn decrypt_header( buf: &mut io::Cursor, header_crypto: &dyn crypto::HeaderKey, ) -> Result { let packet_length = buf.get_ref().len(); let pn_offset = buf.position() as usize; if packet_length < pn_offset + 4 + header_crypto.sample_size() { return Err(PacketDecodeError::InvalidHeader( "packet too short to extract header protection sample", )); } header_crypto.decrypt(pn_offset, buf.get_mut()); let len = PacketNumber::decode_len(buf.get_ref()[0]); PacketNumber::decode(len, buf) } } pub(crate) struct Packet { pub(crate) header: Header, pub(crate) header_data: Bytes, pub(crate) payload: BytesMut, } impl Packet { pub(crate) fn reserved_bits_valid(&self) -> bool { let mask = match self.header { Header::Short { .. } => SHORT_RESERVED_BITS, _ => LONG_RESERVED_BITS, }; self.header_data[0] & mask == 0 } } #[derive(Debug, Clone)] pub(crate) enum Header { Initial { dst_cid: ConnectionId, src_cid: ConnectionId, token: Bytes, number: PacketNumber, version: u32, }, Long { ty: LongType, dst_cid: ConnectionId, src_cid: ConnectionId, number: PacketNumber, version: u32, }, Retry { dst_cid: ConnectionId, src_cid: ConnectionId, version: u32, }, Short { spin: bool, key_phase: bool, dst_cid: ConnectionId, number: PacketNumber, }, VersionNegotiate { random: u8, src_cid: ConnectionId, dst_cid: ConnectionId, }, } impl Header { pub(crate) fn encode(&self, w: &mut BytesMut) -> PartialEncode { use self::Header::*; let start = w.len(); match *self { Initial { ref dst_cid, ref src_cid, ref token, number, version, } => { w.write(u8::from(LongHeaderType::Initial) | number.tag()); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); w.write_var(token.len() as u64); w.put_slice(token); w.write::(0); // Placeholder for payload length; see `set_payload_length` number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), true)), } } Long { ty, ref dst_cid, ref src_cid, number, version, } => { w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag()); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); w.write::(0); // Placeholder for payload length; see `set_payload_length` number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), true)), } } Retry { ref dst_cid, ref src_cid, version, } => { w.write(u8::from(LongHeaderType::Retry)); w.write(version); dst_cid.encode_long(w); src_cid.encode_long(w); PartialEncode { start, header_len: w.len() - start, pn: None, } } Short { spin, key_phase, ref dst_cid, number, } => { w.write( FIXED_BIT | if key_phase { KEY_PHASE_BIT } else { 0 } | if spin { SPIN_BIT } else { 0 } | number.tag(), ); w.put_slice(dst_cid); number.encode(w); PartialEncode { start, header_len: w.len() - start, pn: Some((number.len(), false)), } } VersionNegotiate { ref random, ref dst_cid, ref src_cid, } => { w.write(0x80u8 | random); w.write::(0); dst_cid.encode_long(w); src_cid.encode_long(w); PartialEncode { start, header_len: w.len() - start, pn: None, } } } } /// Whether the packet is encrypted on the wire pub(crate) fn is_protected(&self) -> bool { !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. }) } pub(crate) fn number(&self) -> Option { use self::Header::*; Some(match *self { Initial { number, .. } => number, Long { number, .. } => number, Short { number, .. } => number, _ => { return None; } }) } pub(crate) fn space(&self) -> SpaceId { use self::Header::*; match *self { Short { .. } => SpaceId::Data, Long { ty: LongType::ZeroRtt, .. } => SpaceId::Data, Long { ty: LongType::Handshake, .. } => SpaceId::Handshake, _ => SpaceId::Initial, } } pub(crate) fn key_phase(&self) -> bool { match *self { Self::Short { key_phase, .. } => key_phase, _ => false, } } pub(crate) fn is_short(&self) -> bool { matches!(*self, Self::Short { .. }) } pub(crate) fn is_1rtt(&self) -> bool { self.is_short() } pub(crate) fn is_0rtt(&self) -> bool { matches!( *self, Self::Long { ty: LongType::ZeroRtt, .. } ) } pub(crate) fn dst_cid(&self) -> &ConnectionId { use self::Header::*; match *self { Initial { ref dst_cid, .. } => dst_cid, Long { ref dst_cid, .. } => dst_cid, Retry { ref dst_cid, .. } => dst_cid, Short { ref dst_cid, .. } => dst_cid, VersionNegotiate { ref dst_cid, .. } => dst_cid, } } } pub(crate) struct PartialEncode { pub(crate) start: usize, pub(crate) header_len: usize, // Packet number length, payload length needed pn: Option<(usize, bool)>, } impl PartialEncode { pub(crate) fn finish( self, buf: &mut [u8], header_crypto: &dyn crypto::HeaderKey, crypto: Option<(u64, &dyn crypto::PacketKey)>, ) { let Self { header_len, pn, .. } = self; let (pn_len, write_len) = match pn { Some((pn_len, write_len)) => (pn_len, write_len), None => return, }; let pn_pos = header_len - pn_len; if write_len { let len = buf.len() - header_len + pn_len; assert!(len < 2usize.pow(14)); // Fits in reserved space let mut slice = &mut buf[pn_pos - 2..pn_pos]; slice.put_u16(len as u16 | 0b01 << 14); } if let Some((number, crypto)) = crypto { crypto.encrypt(number, buf, header_len); } debug_assert!( pn_pos + 4 + header_crypto.sample_size() <= buf.len(), "packet must be padded to at least {} bytes for header protection sampling", pn_pos + 4 + header_crypto.sample_size() ); header_crypto.encrypt(pn_pos, buf); } } #[derive(Debug)] pub(crate) enum PlainHeader { Initial { dst_cid: ConnectionId, src_cid: ConnectionId, token_pos: Range, len: u64, version: u32, }, Long { ty: LongType, dst_cid: ConnectionId, src_cid: ConnectionId, len: u64, version: u32, }, Retry { dst_cid: ConnectionId, src_cid: ConnectionId, version: u32, }, Short { spin: bool, dst_cid: ConnectionId, }, VersionNegotiate { random: u8, dst_cid: ConnectionId, src_cid: ConnectionId, }, } impl PlainHeader { fn dst_cid(&self) -> ConnectionId { use self::PlainHeader::*; match self { Initial { dst_cid, .. } => *dst_cid, Long { dst_cid, .. } => *dst_cid, Retry { dst_cid, .. } => *dst_cid, Short { dst_cid, .. } => *dst_cid, VersionNegotiate { dst_cid, .. } => *dst_cid, } } fn payload_len(&self) -> Option { use self::PlainHeader::*; match self { Initial { len, .. } | Long { len, .. } => Some(*len), _ => None, } } fn decode( buf: &mut io::Cursor, local_cid_len: usize, supported_versions: &[u32], grease_quic_bit: bool, ) -> Result { let first = buf.get::()?; if !grease_quic_bit && first & FIXED_BIT == 0 { return Err(PacketDecodeError::InvalidHeader("fixed bit unset")); } if first & LONG_HEADER_FORM == 0 { let spin = first & SPIN_BIT != 0; if buf.remaining() < local_cid_len { return Err(PacketDecodeError::InvalidHeader("cid out of bounds")); } Ok(Self::Short { spin, dst_cid: ConnectionId::from_buf(buf, local_cid_len), }) } else { let version = buf.get::()?; let dst_cid = ConnectionId::decode_long(buf) .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?; let src_cid = ConnectionId::decode_long(buf) .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?; // TODO: Support long CIDs for compatibility with future QUIC versions if version == 0 { let random = first & !LONG_HEADER_FORM; return Ok(Self::VersionNegotiate { random, dst_cid, src_cid, }); } if !supported_versions.contains(&version) { return Err(PacketDecodeError::UnsupportedVersion { src_cid, dst_cid, version, }); } match LongHeaderType::from_byte(first)? { LongHeaderType::Initial => { let token_len = buf.get_var()? as usize; let token_start = buf.position() as usize; if token_len > buf.remaining() { return Err(PacketDecodeError::InvalidHeader("token out of bounds")); } buf.advance(token_len); let len = buf.get_var()?; Ok(Self::Initial { dst_cid, src_cid, token_pos: token_start..token_start + token_len, len, version, }) } LongHeaderType::Retry => Ok(Self::Retry { dst_cid, src_cid, version, }), LongHeaderType::Standard(ty) => Ok(Self::Long { ty, dst_cid, src_cid, len: buf.get_var()?, version, }), } } } } // An encoded packet number #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum PacketNumber { U8(u8), U16(u16), U24(u32), U32(u32), } impl PacketNumber { pub(crate) fn new(n: u64, largest_acked: u64) -> Self { let range = (n - largest_acked) * 2; if range < 1 << 8 { Self::U8(n as u8) } else if range < 1 << 16 { Self::U16(n as u16) } else if range < 1 << 24 { Self::U24(n as u32) } else if range < 1 << 32 { Self::U32(n as u32) } else { panic!("packet number too large to encode") } } pub(crate) fn len(self) -> usize { use self::PacketNumber::*; match self { U8(_) => 1, U16(_) => 2, U24(_) => 3, U32(_) => 4, } } pub(crate) fn encode(self, w: &mut W) { use self::PacketNumber::*; match self { U8(x) => w.write(x), U16(x) => w.write(x), U24(x) => w.put_uint(u64::from(x), 3), U32(x) => w.write(x), } } pub(crate) fn decode(len: usize, r: &mut R) -> Result { use self::PacketNumber::*; let pn = match len { 1 => U8(r.get()?), 2 => U16(r.get()?), 3 => U24(r.get_uint(3) as u32), 4 => U32(r.get()?), _ => unreachable!(), }; Ok(pn) } pub(crate) fn decode_len(tag: u8) -> usize { 1 + (tag & 0x03) as usize } fn tag(self) -> u8 { use self::PacketNumber::*; match self { U8(_) => 0b00, U16(_) => 0b01, U24(_) => 0b10, U32(_) => 0b11, } } pub(crate) fn expand(self, expected: u64) -> u64 { // From Appendix A use self::PacketNumber::*; let truncated = match self { U8(x) => u64::from(x), U16(x) => u64::from(x), U24(x) => u64::from(x), U32(x) => u64::from(x), }; let nbits = self.len() * 8; let win = 1 << nbits; let hwin = win / 2; let mask = win - 1; // The incoming packet number should be greater than expected - hwin and less than or equal // to expected + hwin // // This means we can't just strip the trailing bits from expected and add the truncated // because that might yield a value outside the window. // // The following code calculates a candidate value and makes sure it's within the packet // number window. let candidate = (expected & !mask) | truncated; if expected.checked_sub(hwin).map_or(false, |x| candidate <= x) { candidate + win } else if candidate > expected + hwin && candidate > win { candidate - win } else { candidate } } } /// Long packet type including non-uniform cases #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongHeaderType { Initial, Retry, Standard(LongType), } impl LongHeaderType { fn from_byte(b: u8) -> Result { use self::{LongHeaderType::*, LongType::*}; debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet"); Ok(match (b & 0x30) >> 4 { 0x0 => Initial, 0x1 => Standard(ZeroRtt), 0x2 => Standard(Handshake), 0x3 => Retry, _ => unreachable!(), }) } } impl From for u8 { fn from(ty: LongHeaderType) -> Self { use self::{LongHeaderType::*, LongType::*}; match ty { Initial => LONG_HEADER_FORM | FIXED_BIT, Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4), Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4), Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4), } } } /// Long packet types with uniform header structure #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum LongType { Handshake, ZeroRtt, } #[allow(unreachable_pub)] // fuzzing only #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum PacketDecodeError { #[error("unsupported version {version:x}")] UnsupportedVersion { src_cid: ConnectionId, dst_cid: ConnectionId, version: u32, }, #[error("invalid header: {0}")] InvalidHeader(&'static str), } impl From for PacketDecodeError { fn from(_: coding::UnexpectedEnd) -> Self { Self::InvalidHeader("unexpected end of packet") } } pub(crate) const LONG_HEADER_FORM: u8 = 0x80; pub(crate) const FIXED_BIT: u8 = 0x40; pub(crate) const SPIN_BIT: u8 = 0x20; const SHORT_RESERVED_BITS: u8 = 0x18; const LONG_RESERVED_BITS: u8 = 0x0c; const KEY_PHASE_BIT: u8 = 0x04; /// Packet number space identifiers #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum SpaceId { /// Unprotected packets, used to bootstrap the handshake Initial = 0, Handshake = 1, /// Application data space, used for 0-RTT and post-handshake/1-RTT packets Data = 2, } impl SpaceId { pub fn iter() -> impl Iterator { [Self::Initial, Self::Handshake, Self::Data].iter().cloned() } } #[cfg(test)] mod tests { use super::*; use crate::DEFAULT_SUPPORTED_VERSIONS; use hex_literal::hex; use std::io; fn check_pn(typed: PacketNumber, encoded: &[u8]) { let mut buf = Vec::new(); typed.encode(&mut buf); assert_eq!(&buf[..], encoded); let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap(); assert_eq!(typed, decoded); } #[test] fn roundtrip_packet_numbers() { check_pn(PacketNumber::U8(0x7f), &hex!("7f")); check_pn(PacketNumber::U16(0x80), &hex!("0080")); check_pn(PacketNumber::U16(0x3fff), &hex!("3fff")); check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000")); check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff")); } #[test] fn pn_encode() { check_pn(PacketNumber::new(0x10, 0), &hex!("10")); check_pn(PacketNumber::new(0x100, 0), &hex!("0100")); check_pn(PacketNumber::new(0x10000, 0), &hex!("010000")); } #[test] fn pn_expand_roundtrip() { for expected in 0..1024 { for actual in expected..1024 { assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected)); } } } #[cfg(feature = "rustls")] #[test] fn header_encoding() { use crate::{crypto::rustls::initial_keys, Side}; use rustls::quic::Version; let dcid = ConnectionId::new(&hex!("06b858ec6f80452b")); let client = initial_keys(Version::V1, &dcid, Side::Client); let mut buf = BytesMut::new(); let header = Header::Initial { number: PacketNumber::U8(0), src_cid: ConnectionId::new(&[]), dst_cid: dcid, token: Bytes::new(), version: DEFAULT_SUPPORTED_VERSIONS[0], }; let encode = header.encode(&mut buf); let header_len = buf.len(); buf.resize(header_len + 16 + client.packet.local.tag_len(), 0); encode.finish( &mut buf, &*client.header.local, Some((0, &*client.packet.local)), ); for byte in &buf { print!("{byte:02x}"); } println!(); assert_eq!( buf[..], hex!( "c8000000010806b858ec6f80452b00004021be 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1" )[..] ); let server = initial_keys(Version::V1, &dcid, Side::Server); let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec(); let decode = PartialDecode::new(buf, 0, &supported_versions, false) .unwrap() .0; let mut packet = decode.finish(Some(&*server.header.remote)).unwrap(); assert_eq!( packet.header_data[..], hex!("c0000000010806b858ec6f80452b0000402100")[..] ); server .packet .remote .decrypt(0, &packet.header_data, &mut packet.payload) .unwrap(); assert_eq!(packet.payload[..], [0; 16]); match packet.header { Header::Initial { number: PacketNumber::U8(0), .. } => {} _ => { panic!("unexpected header {:?}", packet.header); } } } } quinn-proto-0.10.6/src/range_set/array_range_set.rs000064400000000000000000000147041046102023000204620ustar 00000000000000use std::ops::Range; use tinyvec::TinyVec; /// A set of u64 values optimized for long runs and random insert/delete/contains /// /// `ArrayRangeSet` uses an array representation, where each array entry represents /// a range. /// /// The array-based RangeSet provides 2 benefits: /// - There exists an inline representation, which avoids the need of heap /// allocating ACK ranges for SentFrames for small ranges. /// - Iterating over ranges should usually be faster since there is only /// a single cache-friendly contiguous range. /// /// `ArrayRangeSet` is especially useful for tracking ACK ranges where the amount /// of ranges is usually very low (since ACK numbers are in consecutive fashion /// unless reordering or packet loss occur). #[derive(Debug, Default)] pub struct ArrayRangeSet(TinyVec<[Range; ARRAY_RANGE_SET_INLINE_CAPACITY]>); /// The capacity of elements directly stored in [`ArrayRangeSet`] /// /// An inline capacity of 2 is chosen to keep `SentFrame` below 128 bytes. const ARRAY_RANGE_SET_INLINE_CAPACITY: usize = 2; impl Clone for ArrayRangeSet { fn clone(&self) -> Self { // tinyvec keeps the heap representation after clones. // We rather prefer the inline representation for clones if possible, // since clones (e.g. for storage in `SentFrames`) are rarely mutated if self.0.is_inline() || self.0.len() > ARRAY_RANGE_SET_INLINE_CAPACITY { return Self(self.0.clone()); } let mut vec = TinyVec::new(); vec.extend_from_slice(self.0.as_slice()); Self(vec) } } impl ArrayRangeSet { pub fn new() -> Self { Default::default() } pub fn iter(&self) -> impl DoubleEndedIterator> + '_ { self.0.iter().cloned() } pub fn elts(&self) -> impl Iterator + '_ { self.iter().flatten() } pub fn len(&self) -> usize { self.0.len() } pub fn contains(&self, x: u64) -> bool { for range in self.0.iter() { if range.start > x { // We only get here if there was no prior range that contained x return false; } else if range.contains(&x) { return true; } } false } pub fn subtract(&mut self, other: &Self) { // TODO: This can potentially be made more efficient, since the we know // individual ranges are not overlapping, and the next range must start // after the last one finished for range in &other.0 { self.remove(range.clone()); } } pub fn insert_one(&mut self, x: u64) -> bool { self.insert(x..x + 1) } pub fn insert(&mut self, x: Range) -> bool { let mut result = false; if x.is_empty() { // Don't try to deal with ranges where x.end <= x.start return false; } let mut idx = 0; while idx != self.0.len() { let range = &mut self.0[idx]; if range.start > x.end { // The range is fully before this range and therefore not extensible. // Add a new range to the left self.0.insert(idx, x); return true; } else if range.start > x.start { // The new range starts before this range but overlaps. // Extend the current range to the left // Note that we don't have to merge a potential left range, since // this case would have been captured by merging the right range // in the previous loop iteration result = true; range.start = x.start; } // At this point we have handled all parts of the new range which // are in front of the current range. Now we handle everything from // the start of the current range if x.end <= range.end { // Fully contained return result; } else if x.start <= range.end { // Extend the current range to the end of the new range. // Since it's not contained it must be bigger range.end = x.end; // Merge all follow-up ranges which overlap while idx != self.0.len() - 1 { let curr = self.0[idx].clone(); let next = self.0[idx + 1].clone(); if curr.end >= next.start { self.0[idx].end = next.end.max(curr.end); self.0.remove(idx + 1); } else { break; } } return true; } idx += 1; } // Insert a range at the end self.0.push(x); true } pub fn remove(&mut self, x: Range) -> bool { let mut result = false; if x.is_empty() { // Don't try to deal with ranges where x.end <= x.start return false; } let mut idx = 0; while idx != self.0.len() && x.start != x.end { let range = self.0[idx].clone(); if x.end <= range.start { // The range is fully before this range return result; } else if x.start >= range.end { // The range is fully after this range idx += 1; continue; } // The range overlaps with this range result = true; let left = range.start..x.start; let right = x.end..range.end; if left.is_empty() && right.is_empty() { self.0.remove(idx); } else if left.is_empty() { self.0[idx] = right; idx += 1; } else if right.is_empty() { self.0[idx] = left; idx += 1; } else { self.0[idx] = right; self.0.insert(idx, left); idx += 2; } } result } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn pop_min(&mut self) -> Option> { if !self.0.is_empty() { Some(self.0.remove(0)) } else { None } } pub fn min(&self) -> Option { self.iter().next().map(|x| x.start) } pub fn max(&self) -> Option { self.iter().next_back().map(|x| x.end - 1) } } quinn-proto-0.10.6/src/range_set/btree_range_set.rs000064400000000000000000000256151046102023000204500ustar 00000000000000use std::{ cmp, cmp::Ordering, collections::{btree_map, BTreeMap}, ops::{ Bound::{Excluded, Included}, Range, }, }; /// A set of u64 values optimized for long runs and random insert/delete/contains #[derive(Debug, Default, Clone)] pub struct RangeSet(BTreeMap); impl RangeSet { pub fn new() -> Self { Default::default() } pub fn contains(&self, x: u64) -> bool { self.pred(x).map_or(false, |(_, end)| end > x) } pub fn insert_one(&mut self, x: u64) -> bool { if let Some((start, end)) = self.pred(x) { match end.cmp(&x) { // Wholly contained Ordering::Greater => { return false; } Ordering::Equal => { // Extend existing self.0.remove(&start); let mut new_end = x + 1; if let Some((next_start, next_end)) = self.succ(x) { if next_start == new_end { self.0.remove(&next_start); new_end = next_end; } } self.0.insert(start, new_end); return true; } _ => {} } } let mut new_end = x + 1; if let Some((next_start, next_end)) = self.succ(x) { if next_start == new_end { self.0.remove(&next_start); new_end = next_end; } } self.0.insert(x, new_end); true } pub fn insert(&mut self, mut x: Range) -> bool { if x.is_empty() { return false; } if let Some((start, end)) = self.pred(x.start) { if end >= x.end { // Wholly contained return false; } else if end >= x.start { // Extend overlapping predecessor self.0.remove(&start); x.start = start; } } while let Some((next_start, next_end)) = self.succ(x.start) { if next_start > x.end { break; } // Overlaps with successor self.0.remove(&next_start); x.end = cmp::max(next_end, x.end); } self.0.insert(x.start, x.end); true } /// Find closest range to `x` that begins at or before it fn pred(&self, x: u64) -> Option<(u64, u64)> { self.0 .range((Included(0), Included(x))) .next_back() .map(|(&x, &y)| (x, y)) } /// Find the closest range to `x` that begins after it fn succ(&self, x: u64) -> Option<(u64, u64)> { self.0 .range((Excluded(x), Included(u64::max_value()))) .next() .map(|(&x, &y)| (x, y)) } pub fn remove(&mut self, x: Range) -> bool { if x.is_empty() { return false; } let before = match self.pred(x.start) { Some((start, end)) if end > x.start => { self.0.remove(&start); if start < x.start { self.0.insert(start, x.start); } if end > x.end { self.0.insert(x.end, end); } // Short-circuit if we cannot possibly overlap with another range if end >= x.end { return true; } true } Some(_) | None => false, }; let mut after = false; while let Some((start, end)) = self.succ(x.start) { if start >= x.end { break; } after = true; self.0.remove(&start); if end > x.end { self.0.insert(x.end, end); break; } } before || after } /// Add a range to the set, returning the intersection of current ranges with the new one pub fn replace(&mut self, mut range: Range) -> Replace<'_> { let pred = if let Some((prev_start, prev_end)) = self .pred(range.start) .filter(|&(_, end)| end >= range.start) { self.0.remove(&prev_start); let replaced_start = range.start; range.start = range.start.min(prev_start); let replaced_end = range.end.min(prev_end); range.end = range.end.max(prev_end); if replaced_start != replaced_end { Some(replaced_start..replaced_end) } else { None } } else { None }; Replace { set: self, range, pred, } } pub fn add(&mut self, other: &Self) { for (&start, &end) in &other.0 { self.insert(start..end); } } pub fn subtract(&mut self, other: &Self) { for (&start, &end) in &other.0 { self.remove(start..end); } } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn min(&self) -> Option { self.iter().next().map(|x| x.start) } pub fn max(&self) -> Option { self.iter().next_back().map(|x| x.end - 1) } pub fn len(&self) -> usize { self.0.len() } pub fn iter(&self) -> Iter<'_> { Iter(self.0.iter()) } pub fn elts(&self) -> EltIter<'_> { EltIter { inner: self.0.iter(), next: 0, end: 0, } } pub fn peek_min(&self) -> Option> { let (&start, &end) = self.0.iter().next()?; Some(start..end) } pub fn pop_min(&mut self) -> Option> { let result = self.peek_min()?; self.0.remove(&result.start); Some(result) } } pub struct Iter<'a>(btree_map::Iter<'a, u64, u64>); impl<'a> Iterator for Iter<'a> { type Item = Range; fn next(&mut self) -> Option> { let (&start, &end) = self.0.next()?; Some(start..end) } } impl<'a> DoubleEndedIterator for Iter<'a> { fn next_back(&mut self) -> Option> { let (&start, &end) = self.0.next_back()?; Some(start..end) } } impl<'a> IntoIterator for &'a RangeSet { type Item = Range; type IntoIter = Iter<'a>; fn into_iter(self) -> Iter<'a> { self.iter() } } pub struct EltIter<'a> { inner: btree_map::Iter<'a, u64, u64>, next: u64, end: u64, } impl<'a> Iterator for EltIter<'a> { type Item = u64; fn next(&mut self) -> Option { if self.next == self.end { let (&start, &end) = self.inner.next()?; self.next = start; self.end = end; } let x = self.next; self.next += 1; Some(x) } } impl<'a> DoubleEndedIterator for EltIter<'a> { fn next_back(&mut self) -> Option { if self.next == self.end { let (&start, &end) = self.inner.next_back()?; self.next = start; self.end = end; } self.end -= 1; Some(self.end) } } /// Iterator returned by `RangeSet::replace` pub struct Replace<'a> { set: &'a mut RangeSet, /// Portion of the intersection arising from a range beginning at or before the newly inserted /// range pred: Option>, /// Union of the input range and all ranges that have been visited by the iterator so far range: Range, } impl Iterator for Replace<'_> { type Item = Range; fn next(&mut self) -> Option> { if let Some(pred) = self.pred.take() { // If a range starting before the inserted range overlapped with it, return the // corresponding overlap first return Some(pred); } let (next_start, next_end) = self.set.succ(self.range.start)?; if next_start > self.range.end { // If the next successor range starts after the current range ends, there can be no more // overlaps. This is sound even when `self.range.end` is increased because `RangeSet` is // guaranteed not to contain pairs of ranges that could be simplified. return None; } // Remove the redundant range... self.set.0.remove(&next_start); // ...and handle the case where the redundant range ends later than the new range. let replaced_end = self.range.end.min(next_end); self.range.end = self.range.end.max(next_end); if next_start == replaced_end { // If the redundant range started exactly where the new range ended, there was no // overlap with it or any later range. None } else { Some(next_start..replaced_end) } } } impl Drop for Replace<'_> { fn drop(&mut self) { // Ensure we drain all remaining overlapping ranges for _ in &mut *self {} // Insert the final aggregate range self.set.0.insert(self.range.start, self.range.end); } } /// This module contains tests which only apply for this `RangeSet` implementation /// /// Tests which apply for all implementations can be found in the `tests.rs` module #[cfg(test)] mod tests { #![allow(clippy::single_range_in_vec_init)] // https://github.com/rust-lang/rust-clippy/issues/11086 use super::*; #[test] fn replace_contained() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(1..5).collect::>(), &[2..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..5); } #[test] fn replace_contains() { let mut set = RangeSet::new(); set.insert(1..5); assert_eq!(set.replace(2..4).collect::>(), &[2..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..5); } #[test] fn replace_pred() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(3..5).collect::>(), &[3..4]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 2..5); } #[test] fn replace_succ() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(1..3).collect::>(), &[2..3]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 1..4); } #[test] fn replace_exact_pred() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(4..6).collect::>(), &[]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 2..6); } #[test] fn replace_exact_succ() { let mut set = RangeSet::new(); set.insert(2..4); assert_eq!(set.replace(0..2).collect::>(), &[]); assert_eq!(set.len(), 1); assert_eq!(set.peek_min().unwrap(), 0..4); } } quinn-proto-0.10.6/src/range_set/mod.rs000064400000000000000000000002341046102023000160650ustar 00000000000000mod array_range_set; mod btree_range_set; #[cfg(test)] mod tests; pub(crate) use array_range_set::ArrayRangeSet; pub(crate) use btree_range_set::RangeSet; quinn-proto-0.10.6/src/range_set/tests.rs000064400000000000000000000201771046102023000164600ustar 00000000000000use std::ops::Range; use super::*; macro_rules! common_set_tests { ($set_name:ident, $set_type:ident) => { mod $set_name { use super::*; #[test] fn merge_and_split() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(2..4)); assert!(!set.insert(1..3)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3]); assert!(!set.contains(4)); assert!(set.remove(2..3)); assert_eq!(set.len(), 2); assert!(!set.contains(2)); assert_eq!(&set.elts().collect::>()[..], [0, 1, 3]); } #[test] fn double_merge_exact() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(2..4)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3, 4, 5]); } #[test] fn single_merge_low() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(2..3)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 4, 5]); } #[test] fn single_merge_high() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(3..4)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 1, 3, 4, 5]); } #[test] fn double_merge_wide() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert_eq!(set.len(), 2); assert!(set.insert(1..5)); assert_eq!(set.len(), 1); assert_eq!(&set.elts().collect::>()[..], [0, 1, 2, 3, 4, 5]); } #[test] fn double_remove() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(set.insert(4..6)); assert!(set.remove(1..5)); assert_eq!(set.len(), 2); assert_eq!(&set.elts().collect::>()[..], [0, 5]); } #[test] fn insert_multiple() { let mut set = $set_type::new(); assert!(set.insert(0..1)); assert!(set.insert(2..3)); assert!(set.insert(4..5)); assert!(set.insert(0..5)); assert_eq!(set.len(), 1); } #[test] fn remove_multiple() { let mut set = $set_type::new(); assert!(set.insert(0..1)); assert!(set.insert(2..3)); assert!(set.insert(4..5)); assert!(set.remove(0..5)); assert!(set.is_empty()); } #[test] fn double_insert() { let mut set = $set_type::new(); assert!(set.insert(0..2)); assert!(!set.insert(0..2)); assert!(set.insert(2..4)); assert!(!set.insert(2..4)); assert!(!set.insert(0..4)); assert!(!set.insert(1..2)); assert!(!set.insert(1..3)); assert!(!set.insert(1..4)); assert_eq!(set.len(), 1); } #[test] fn skip_empty_ranges() { let mut set = $set_type::new(); assert!(!set.insert(2..2)); assert_eq!(set.len(), 0); assert!(!set.insert(4..4)); assert_eq!(set.len(), 0); assert!(!set.insert(0..0)); assert_eq!(set.len(), 0); } #[test] fn compare_insert_to_reference() { const MAX_RANGE: u64 = 50; for start in 0..=MAX_RANGE { for end in 0..=MAX_RANGE { println!("insert({}..{})", start, end); let (mut set, mut reference) = create_initial_sets(MAX_RANGE); assert_eq!(set.insert(start..end), reference.insert(start..end)); assert_sets_equal(&set, &reference); } } } #[test] fn compare_remove_to_reference() { const MAX_RANGE: u64 = 50; for start in 0..=MAX_RANGE { for end in 0..=MAX_RANGE { println!("remove({}..{})", start, end); let (mut set, mut reference) = create_initial_sets(MAX_RANGE); assert_eq!(set.remove(start..end), reference.remove(start..end)); assert_sets_equal(&set, &reference); } } } fn create_initial_sets(max_range: u64) -> ($set_type, RefRangeSet) { let mut set = $set_type::new(); let mut reference = RefRangeSet::new(max_range as usize); assert_sets_equal(&set, &reference); assert_eq!(set.insert(2..6), reference.insert(2..6)); assert_eq!(set.insert(10..14), reference.insert(10..14)); assert_eq!(set.insert(14..14), reference.insert(14..14)); assert_eq!(set.insert(18..19), reference.insert(18..19)); assert_eq!(set.insert(20..21), reference.insert(20..21)); assert_eq!(set.insert(22..24), reference.insert(22..24)); assert_eq!(set.insert(26..30), reference.insert(26..30)); assert_eq!(set.insert(34..38), reference.insert(34..38)); assert_eq!(set.insert(42..44), reference.insert(42..44)); assert_sets_equal(&set, &reference); (set, reference) } fn assert_sets_equal(set: &$set_type, reference: &RefRangeSet) { assert_eq!(set.len(), reference.len()); assert_eq!(set.is_empty(), reference.is_empty()); assert_eq!(set.elts().collect::>()[..], reference.elts()[..]); } } }; } common_set_tests!(range_set, RangeSet); common_set_tests!(array_range_set, ArrayRangeSet); /// A very simple reference implementation of a RangeSet struct RefRangeSet { data: Vec, } impl RefRangeSet { fn new(capacity: usize) -> Self { Self { data: vec![false; capacity], } } fn len(&self) -> usize { let mut last = false; let mut count = 0; for v in self.data.iter() { if !last && *v { count += 1; } last = *v; } count } fn is_empty(&self) -> bool { self.len() == 0 } fn insert(&mut self, x: Range) -> bool { let mut result = false; assert!(x.end <= self.data.len() as u64); for i in x { let i = i as usize; if !self.data[i] { result = true; self.data[i] = true; } } result } fn remove(&mut self, x: Range) -> bool { let mut result = false; assert!(x.end <= self.data.len() as u64); for i in x { let i = i as usize; if self.data[i] { result = true; self.data[i] = false; } } result } fn elts(&self) -> Vec { self.data .iter() .enumerate() .filter_map(|(i, e)| if *e { Some(i as u64) } else { None }) .collect() } } quinn-proto-0.10.6/src/shared.rs000064400000000000000000000115351046102023000146130ustar 00000000000000use std::{fmt, net::SocketAddr, time::Instant}; use bytes::{Buf, BufMut, BytesMut}; use crate::{coding::BufExt, packet::PartialDecode, ResetToken, MAX_CID_SIZE}; /// Events sent from an Endpoint to a Connection #[derive(Debug)] pub struct ConnectionEvent(pub(crate) ConnectionEventInner); #[derive(Debug)] pub(crate) enum ConnectionEventInner { /// A datagram has been received for the Connection Datagram { now: Instant, remote: SocketAddr, ecn: Option, first_decode: PartialDecode, remaining: Option, }, /// New connection identifiers have been issued for the Connection NewIdentifiers(Vec, Instant), } /// Events sent from a Connection to an Endpoint #[derive(Debug)] pub struct EndpointEvent(pub(crate) EndpointEventInner); impl EndpointEvent { /// Construct an event that indicating that a `Connection` will no longer emit events /// /// Useful for notifying an `Endpoint` that a `Connection` has been destroyed outside of the /// usual state machine flow, e.g. when being dropped by the user. pub fn drained() -> Self { Self(EndpointEventInner::Drained) } /// Determine whether this is the last event a `Connection` will emit /// /// Useful for determining when connection-related event loop state can be freed. pub fn is_drained(&self) -> bool { self.0 == EndpointEventInner::Drained } } #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum EndpointEventInner { /// The connection has been drained Drained, /// The reset token and/or address eligible for generating resets has been updated ResetToken(SocketAddr, ResetToken), /// The connection needs connection identifiers NeedIdentifiers(Instant, u64), /// Stop routing connection ID for this sequence number to the connection /// When `bool == true`, a new connection ID will be issued to peer RetireConnectionId(Instant, u64, bool), } /// Protocol-level identifier for a connection. /// /// Mainly useful for identifying this connection's packets on the wire with tools like Wireshark. #[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct ConnectionId { /// length of CID len: u8, /// CID in byte array bytes: [u8; MAX_CID_SIZE], } impl ConnectionId { /// Construct cid from byte array pub fn new(bytes: &[u8]) -> Self { debug_assert!(bytes.len() <= MAX_CID_SIZE); let mut res = Self { len: bytes.len() as u8, bytes: [0; MAX_CID_SIZE], }; res.bytes[..bytes.len()].copy_from_slice(bytes); res } /// Constructs cid by reading `len` bytes from a `Buf` /// /// Callers need to assure that `buf.remaining() >= len` pub(crate) fn from_buf(buf: &mut impl Buf, len: usize) -> Self { debug_assert!(len <= MAX_CID_SIZE); let mut res = Self { len: len as u8, bytes: [0; MAX_CID_SIZE], }; buf.copy_to_slice(&mut res[..len]); res } /// Decode from long header format pub(crate) fn decode_long(buf: &mut impl Buf) -> Option { let len = buf.get::().ok()? as usize; match len > MAX_CID_SIZE || buf.remaining() < len { false => Some(Self::from_buf(buf, len)), true => None, } } /// Encode in long header format pub(crate) fn encode_long(&self, buf: &mut impl BufMut) { buf.put_u8(self.len() as u8); buf.put_slice(self); } } impl ::std::ops::Deref for ConnectionId { type Target = [u8]; fn deref(&self) -> &[u8] { &self.bytes[0..self.len as usize] } } impl ::std::ops::DerefMut for ConnectionId { fn deref_mut(&mut self) -> &mut [u8] { &mut self.bytes[0..self.len as usize] } } impl fmt::Debug for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.bytes[0..self.len as usize].fmt(f) } } impl fmt::Display for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for byte in self.iter() { write!(f, "{byte:02x}")?; } Ok(()) } } /// Explicit congestion notification codepoint #[repr(u8)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum EcnCodepoint { #[doc(hidden)] Ect0 = 0b10, #[doc(hidden)] Ect1 = 0b01, #[doc(hidden)] Ce = 0b11, } impl EcnCodepoint { /// Create new object from the given bits pub fn from_bits(x: u8) -> Option { use self::EcnCodepoint::*; Some(match x & 0b11 { 0b10 => Ect0, 0b01 => Ect1, 0b11 => Ce, _ => { return None; } }) } } #[derive(Debug, Copy, Clone)] pub(crate) struct IssuedCid { pub(crate) sequence: u64, pub(crate) id: ConnectionId, pub(crate) reset_token: ResetToken, } quinn-proto-0.10.6/src/tests/mod.rs000064400000000000000000002204251046102023000152660ustar 00000000000000use std::{ convert::TryInto, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, time::{Duration, Instant}, }; use assert_matches::assert_matches; use bytes::Bytes; use hex_literal::hex; use rand::RngCore; use ring::hmac; use rustls::AlertDescription; use tracing::info; use super::*; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, frame::FrameStruct, }; mod util; use util::*; #[test] fn version_negotiate_server() { let _guard = subscribe(); let client_addr = "[::2]:7890".parse().unwrap(); let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true); let now = Instant::now(); let event = server.handle( now, client_addr, None, None, // Long-header packet with reserved version number hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(), ); assert!(event.is_none()); let io = server.poll_transmit(); assert!(io.is_some()); if let Some(Transmit { contents, .. }) = io { assert_ne!(contents[0] & 0x80, 0); assert_eq!(&contents[1..15], hex!("00000000 04 00000000 04 00000000")); assert!(contents[15..].chunks(4).any(|x| { DEFAULT_SUPPORTED_VERSIONS.contains(&u32::from_be_bytes(x.try_into().unwrap())) })); } assert_matches!(server.poll_transmit(), None); } #[test] fn version_negotiate_client() { let _guard = subscribe(); let server_addr = "[::2]:7890".parse().unwrap(); let cid_generator_factory: fn() -> Box = || Box::new(RandomConnectionIdGenerator::new(0)); let mut client = Endpoint::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..Default::default() }), None, true, ); let (_, mut client_ch) = client .connect(client_config(), server_addr, "localhost") .unwrap(); let now = Instant::now(); let opt_event = client.handle( now, server_addr, None, None, // Version negotiation packet for reserved version hex!( "80 00000000 04 00000000 04 00000000 0a1a2a3a" )[..] .into(), ); if let Some((_, DatagramEvent::ConnectionEvent(event))) = opt_event { client_ch.handle_event(event); } assert_matches!( client_ch.poll(), Some(Event::ConnectionLost { reason: ConnectionError::VersionMismatch, }) ); } #[test] fn lifecycle() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); const REASON: &[u8] = b"whee"; info!("closing"); pair.client.connections.get_mut(&client_ch).unwrap().close( pair.time, VarInt(42), REASON.into(), ); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ApplicationClosed( ApplicationClose { error_code: VarInt(42), ref reason } )}) if reason == REASON); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn draft_version_compat() { let _guard = subscribe(); let mut client_config = client_config(); client_config.version(0xff00_0020); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect_with(client_config); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); const REASON: &[u8] = b"whee"; info!("closing"); pair.client.connections.get_mut(&client_ch).unwrap().close( pair.time, VarInt(42), REASON.into(), ); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ApplicationClosed( ApplicationClose { error_code: VarInt(42), ref reason } )}) if reason == REASON); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client.known_connections(), 0); assert_eq!(pair.client.known_cids(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn stateless_retry() { let _guard = subscribe(); let mut pair = Pair::new( Default::default(), ServerConfig { use_retry: true, ..server_config() }, ); pair.connect(); } #[test] fn server_stateless_reset() { let _guard = subscribe(); let mut reset_key = vec![0; 64]; let mut rng = rand::thread_rng(); rng.fill_bytes(&mut reset_key); let reset_key = hmac::Key::new(hmac::HMAC_SHA256, &reset_key); let endpoint_config = Arc::new(EndpointConfig::new(Arc::new(reset_key))); let mut pair = Pair::new(endpoint_config.clone(), server_config()); let (client_ch, _) = pair.connect(); pair.drive(); // Flush any post-handshake frames pair.server.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true); // Force the server to generate the smallest possible stateless reset pair.client.connections.get_mut(&client_ch).unwrap().ping(); info!("resetting"); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::Reset }) ); } #[test] fn client_stateless_reset() { let _guard = subscribe(); let mut reset_key = vec![0; 64]; let mut rng = rand::thread_rng(); rng.fill_bytes(&mut reset_key); let reset_key = hmac::Key::new(hmac::HMAC_SHA256, &reset_key); let endpoint_config = Arc::new(EndpointConfig::new(Arc::new(reset_key))); let mut pair = Pair::new(endpoint_config.clone(), server_config()); let (_, server_ch) = pair.connect(); pair.client.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true); // Send something big enough to allow room for a smaller stateless reset. pair.server.connections.get_mut(&server_ch).unwrap().close( pair.time, VarInt(42), (&[0xab; 128][..]).into(), ); info!("resetting"); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::Reset }) ); } #[test] fn export_keying_material() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); const LABEL: &[u8] = b"test_label"; const CONTEXT: &[u8] = b"test_context"; // client keying material let mut client_buf = [0u8; 64]; pair.client_conn_mut(client_ch) .crypto_session() .export_keying_material(&mut client_buf, LABEL, CONTEXT) .unwrap(); // server keying material let mut server_buf = [0u8; 64]; pair.server_conn_mut(server_ch) .crypto_session() .export_keying_material(&mut server_buf, LABEL, CONTEXT) .unwrap(); assert_eq!(&client_buf[..], &server_buf[..]); } #[test] fn finish_stream_simple() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); assert_eq!(pair.client_streams(client_ch).send_streams(), 1); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!(pair.client_streams(client_ch).send_streams(), 0); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); // Receive-only streams do not get `StreamFinished` events assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn reset_stream() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("resetting stream"); const ERROR: VarInt = VarInt(42); pair.client_send(client_ch, s).reset(ERROR).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Reset(ERROR))); let _ = chunks.finalize(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); } #[test] fn stop_stream() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("stopping stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!( pair.client_send(client_ch, s).write(b"foo"), Err(WriteError::Stopped(ERROR)) ); assert_matches!( pair.client_send(client_ch, s).finish(), Err(FinishError::Stopped(ERROR)) ); } #[test] fn reject_self_signed_server_cert() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); let client_ch = pair.begin_connect(client_config_with_certs(vec![])); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TransportError(ref error)}) if error.code == TransportErrorCode::crypto(AlertDescription::UnknownCA.get_u8())); } #[test] fn reject_missing_client_cert() { let _guard = subscribe(); let key = rustls::PrivateKey(CERTIFICATE.serialize_private_key_der()); let cert = util::CERTIFICATE.serialize_der().unwrap(); let config = rustls::ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&rustls::version::TLS13]) .unwrap() .with_client_cert_verifier(Arc::new(rustls::server::AllowAnyAuthenticatedClient::new( rustls::RootCertStore::empty(), ))) .with_single_cert(vec![rustls::Certificate(cert)], key) .unwrap(); let mut pair = Pair::new( Default::default(), ServerConfig::with_crypto(Arc::new(config)), ); info!("connecting"); let client_ch = pair.begin_connect(client_config()); pair.drive(); // The client completes the connection, but finds it immediately closed assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ref close)}) if close.error_code == TransportErrorCode::crypto(AlertDescription::CertificateRequired.get_u8())); // The server never completes the connection let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TransportError(ref error)}) if error.code == TransportErrorCode::crypto(AlertDescription::CertificateRequired.get_u8())); } #[test] fn congestion() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); const TARGET: u64 = 2048; assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); // Send data without receiving ACKs until the congestion state falls below target while pair.client_conn_mut(client_ch).congestion_window() > TARGET { let n = pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); assert_eq!(n, 1024); pair.drive_client(); } // Ensure that the congestion state recovers after receiving the ACKs pair.drive(); assert!(pair.client_conn_mut(client_ch).congestion_window() >= TARGET); pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); } #[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 #[test] fn high_latency_handshake() { let _guard = subscribe(); let mut pair = Pair::default(); pair.latency = Duration::from_micros(200 * 1000); let (client_ch, server_ch) = pair.connect(); assert_eq!(pair.client_conn_mut(client_ch).bytes_in_flight(), 0); assert_eq!(pair.server_conn_mut(server_ch).bytes_in_flight(), 0); assert!(pair.client_conn_mut(client_ch).using_ecn()); assert!(pair.server_conn_mut(server_ch).using_ecn()); } #[test] fn zero_rtt_happypath() { let _guard = subscribe(); let mut pair = Pair::new( Default::default(), ServerConfig { use_retry: true, ..server_config() }, ); let config = client_config(); // Establish normal connection let client_ch = pair.begin_connect(config.clone()); pair.drive(); pair.server.assert_accept(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), [][..].into()); pair.drive(); pair.client.addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); info!("resuming session"); let client_ch = pair.begin_connect(config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"Hello, 0-RTT!"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected) ); assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); // We don't currently preserve stream event order wrt. connection events assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } #[test] fn zero_rtt_rejection() { let _guard = subscribe(); let mut server_crypto = server_crypto(); server_crypto.alpn_protocols = vec!["foo".into(), "bar".into()]; let server_config = ServerConfig::with_crypto(Arc::new(server_crypto)); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_crypto = client_crypto(); client_crypto.alpn_protocols = vec!["foo".into()]; let client_config = ClientConfig::new(Arc::new(client_crypto.clone())); // Establish normal connection let client_ch = pair.begin_connect(client_config); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), [][..].into()); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { .. }) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); pair.client.connections.clear(); pair.server.connections.clear(); // Changing protocols invalidates 0-RTT client_crypto.alpn_protocols = vec!["bar".into()]; let client_config = ClientConfig::new(Arc::new(client_crypto)); info!("resuming session"); let client_ch = pair.begin_connect(client_config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"Hello, 0-RTT!"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert!(!pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let s2 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); assert_eq!(s, s2); let mut recv = pair.server_recv(server_ch, s2); let mut chunks = recv.read(false).unwrap(); assert_eq!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } #[test] fn alpn_success() { let _guard = subscribe(); let mut server_crypto = server_crypto(); server_crypto.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; let server_config = ServerConfig::with_crypto(Arc::new(server_crypto)); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_crypto = client_crypto(); client_crypto.alpn_protocols = vec!["bar".into(), "quux".into(), "corge".into()]; let client_config = ClientConfig::new(Arc::new(client_crypto)); // Establish normal connection let client_ch = pair.begin_connect(client_config); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected) ); let hd = pair .client_conn_mut(client_ch) .crypto_session() .handshake_data() .unwrap() .downcast::() .unwrap(); assert_eq!(hd.protocol.unwrap(), &b"bar"[..]); } #[test] fn server_alpn_unset() { let _guard = subscribe(); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config()); let mut client_crypto = client_crypto(); client_crypto.alpn_protocols = vec!["foo".into()]; let client_config = ClientConfig::new(Arc::new(client_crypto)); let client_ch = pair.begin_connect(client_config); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn client_alpn_unset() { let _guard = subscribe(); let mut server_crypto = server_crypto(); server_crypto.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; let server_config = ServerConfig::with_crypto(Arc::new(server_crypto)); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_ch = pair.begin_connect(client_config()); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn alpn_mismatch() { let _guard = subscribe(); let mut server_crypto = server_crypto(); server_crypto.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; let server_config = ServerConfig::with_crypto(Arc::new(server_crypto)); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_crypto = client_crypto(); client_crypto.alpn_protocols = vec!["quux".into(), "corge".into()]; let client_config = ClientConfig::new(Arc::new(client_crypto)); let client_ch = pair.begin_connect(client_config); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(err) }) if err.error_code == TransportErrorCode::crypto(0x78) ); } #[test] fn stream_id_limit() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { max_concurrent_uni_streams: 1u32.into(), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Uni) .expect("couldn't open first stream"); assert_eq!( pair.client_streams(client_ch).open(Dir::Uni), None, "only one stream is permitted at a time" ); // Generate some activity to allow the server to see the stream const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_eq!( pair.client_streams(client_ch).open(Dir::Uni), None, "server does not immediately grant additional credit" ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_eq!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); // Server will only send MAX_STREAM_ID now that the application's been notified pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Available { dir: Dir::Uni })) ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Try opening the second stream again, now that we've made room let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Uni) .expect("didn't get stream id budget"); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); // Make sure the server actually processes data on the newly-available stream assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn key_update_simple() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Bi) .expect("couldn't open first stream"); const MSG1: &[u8] = b"hello1"; pair.client_send(client_ch, s).write(MSG1).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG1 ); let _ = chunks.finalize(); info!("initiating key update"); pair.client_conn_mut(client_ch).initiate_key_update(); const MSG2: &[u8] = b"hello2"; pair.client_send(client_ch, s).write(MSG2).unwrap(); pair.drive(); assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Readable { id })) if id == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 6 && chunk.bytes == MSG2 ); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); } #[test] fn key_update_reordered() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair .client .connections .get_mut(&client_ch) .unwrap() .streams() .open(Dir::Bi) .expect("couldn't open first stream"); const MSG1: &[u8] = b"1"; pair.client_send(client_ch, s).write(MSG1).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert!(!pair.client.outbound.is_empty()); pair.client.delay_outbound(); pair.client_conn_mut(client_ch).initiate_key_update(); info!("updated keys"); const MSG2: &[u8] = b"two"; pair.client_send(client_ch, s).write(MSG2).unwrap(); pair.client.drive(pair.time, pair.server.addr); pair.client.finish_delay(); pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); let buf1 = chunks.next(usize::MAX).unwrap().unwrap(); assert_matches!(&*buf1.bytes, MSG1); let buf2 = chunks.next(usize::MAX).unwrap().unwrap(); assert_eq!(buf2.bytes, MSG2); let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); } #[test] fn initial_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.client.drive(pair.time, pair.server.addr); pair.client.outbound.clear(); // Drop initial pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } #[test] fn instant_close_1() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); let client_ch = pair.begin_connect(client_config()); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(0), Bytes::new()); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, .. }), }) ); } #[test] fn instant_close_2() { let _guard = subscribe(); let mut pair = Pair::default(); info!("connecting"); let client_ch = pair.begin_connect(client_config()); // Unlike `instant_close`, the server sees a valid Initial packet first. pair.drive_client(); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); let server_ch = pair.server.assert_accept(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(ConnectionClose { error_code: TransportErrorCode::APPLICATION_ERROR, .. }), }) ); } #[test] fn idle_timeout() { let _guard = subscribe(); const IDLE_TIMEOUT: u64 = 100; let server = ServerConfig { transport: Arc::new(TransportConfig { max_idle_timeout: Some(VarInt(IDLE_TIMEOUT)), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); pair.client_conn_mut(client_ch).ping(); let start = pair.time; while !pair.client_conn_mut(client_ch).is_closed() || !pair.server_conn_mut(server_ch).is_closed() { if !pair.step() { if let Some(t) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = t; } } pair.client.inbound.clear(); // Simulate total S->C packet loss } assert!(pair.time - start < Duration::from_millis(2 * IDLE_TIMEOUT)); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TimedOut, }) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::TimedOut, }) ); } #[test] fn connection_close_sends_acks() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _server_ch) = pair.connect(); let client_acks = pair.client_conn_mut(client_ch).stats().frame_rx.acks; pair.client_conn_mut(client_ch).ping(); pair.drive_client(); let time = pair.time; pair.server_conn_mut(client_ch) .close(time, VarInt(42), Bytes::new()); pair.drive(); let client_acks_2 = pair.client_conn_mut(client_ch).stats().frame_rx.acks; assert!( client_acks_2 > client_acks, "Connection close should send pending ACKs" ); } #[test] fn concurrent_connections_full() { let _guard = subscribe(); let mut pair = Pair::new( Default::default(), ServerConfig { concurrent_connections: 0, ..server_config() }, ); let client_ch = pair.begin_connect(client_config()); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::ConnectionLost { reason: ConnectionError::ConnectionClosed(frame::ConnectionClose { error_code: TransportErrorCode::CONNECTION_REFUSED, .. }), }) ); assert_eq!(pair.server.connections.len(), 0); assert_eq!(pair.server.known_connections(), 0); assert_eq!(pair.server.known_cids(), 0); } #[test] fn server_hs_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.step(); assert!(!pair.client.inbound.is_empty()); // Initial + Handshakes pair.client.inbound.clear(); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } #[test] fn migration() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); pair.client.addr = SocketAddr::new( Ipv4Addr::new(127, 0, 0, 1).into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); pair.client_conn_mut(client_ch).ping(); // Assert that just receiving the ping message is accounted into the servers // anti-amplification budget pair.drive_client(); pair.drive_server(); assert_ne!(pair.server_conn_mut(server_ch).total_recvd(), 0); pair.drive(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_eq!( pair.server_conn_mut(server_ch).remote_address(), pair.client.addr ); } fn test_flow_control(config: TransportConfig, window_size: usize) { let _guard = subscribe(); let mut pair = Pair::new( Default::default(), ServerConfig { transport: Arc::new(config), ..server_config() }, ); let (client_ch, server_ch) = pair.connect(); let msg = vec![0xAB; window_size + 10]; // Stream reset before read let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); info!("writing"); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); info!("resetting"); pair.client_send(client_ch, s).reset(VarInt(42)).unwrap(); pair.drive(); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); assert_eq!( chunks.next(usize::MAX).err(), Some(ReadError::Reset(VarInt(42))) ); let _ = chunks.finalize(); // Happy path info!("writing"); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); let mut cursor = 0; let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); loop { match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } Ok(None) => { panic!("end of stream"); } Err(ReadError::Blocked) => { break; } Err(e) => { panic!("{}", e); } } } let _ = chunks.finalize(); info!("finished reading"); assert_eq!(cursor, window_size); pair.drive(); info!("writing"); assert_eq!(pair.client_send(client_ch, s).write(&msg), Ok(window_size)); assert_eq!( pair.client_send(client_ch, s).write(&msg[window_size..]), Err(WriteError::Blocked) ); pair.drive(); let mut cursor = 0; let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(true).unwrap(); loop { match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } Ok(None) => { panic!("end of stream"); } Err(ReadError::Blocked) => { break; } Err(e) => { panic!("{}", e); } } } assert_eq!(cursor, window_size); let _ = chunks.finalize(); info!("finished reading"); } #[test] fn stream_flow_control() { test_flow_control( TransportConfig { stream_receive_window: 2000u32.into(), ..TransportConfig::default() }, 2000, ); } #[test] fn conn_flow_control() { test_flow_control( TransportConfig { receive_window: 2000u32.into(), ..TransportConfig::default() }, 2000, ); } #[test] fn stop_opens_bidi() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_eq!(pair.client_streams(client_ch).send_streams(), 0); let s = pair.client_streams(client_ch).open(Dir::Bi).unwrap(); assert_eq!(pair.client_streams(client_ch).send_streams(), 1); const ERROR: VarInt = VarInt(42); pair.client .connections .get_mut(&server_ch) .unwrap() .recv_stream(s) .stop(ERROR) .unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Bi })) ); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 0); assert_matches!(pair.server_streams(server_ch).accept(Dir::Bi), Some(stream) if stream == s); assert_eq!(pair.server_conn_mut(client_ch).streams().send_streams(), 1); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); assert_matches!( pair.server_send(server_ch, s).write(b"foo"), Err(WriteError::Stopped(ERROR)) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Stopped { id: _, error_code: ERROR })) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); } #[test] fn implicit_open() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s1 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); let s2 = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s2).write(b"hello").unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), Some(s1)); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), Some(s2)); assert_eq!(pair.server_streams(server_ch).accept(Dir::Uni), None); } #[test] fn zero_length_cid() { let _guard = subscribe(); let cid_generator_factory: fn() -> Box = || Box::new(RandomConnectionIdGenerator::new(0)); let mut pair = Pair::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..EndpointConfig::default() }), server_config(), ); let (client_ch, server_ch) = pair.connect(); // Ensure we can reconnect after a previous connection is cleaned up info!("closing"); pair.client .connections .get_mut(&client_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.drive(); pair.server .connections .get_mut(&server_ch) .unwrap() .close(pair.time, VarInt(42), Bytes::new()); pair.connect(); } #[test] fn keep_alive() { let _guard = subscribe(); const IDLE_TIMEOUT: u64 = 10; let server = ServerConfig { transport: Arc::new(TransportConfig { keep_alive_interval: Some(Duration::from_millis(IDLE_TIMEOUT / 2)), max_idle_timeout: Some(VarInt(IDLE_TIMEOUT)), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); // Run a good while longer than the idle timeout let end = pair.time + Duration::from_millis(20 * IDLE_TIMEOUT); while pair.time < end { if !pair.step() { if let Some(time) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = time; } } assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); } } #[test] fn cid_rotation() { let _guard = subscribe(); const CID_TIMEOUT: Duration = Duration::from_secs(2); let cid_generator_factory: fn() -> Box = || Box::new(*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT)); // Only test cid rotation on server side to have a clear output trace let server = Endpoint::new( Arc::new(EndpointConfig { connection_id_generator_factory: Arc::new(cid_generator_factory), ..EndpointConfig::default() }), Some(Arc::new(server_config())), true, ); let client = Endpoint::new(Arc::new(EndpointConfig::default()), None, true); let mut pair = Pair::new_from_endpoint(client, server); let (_, server_ch) = pair.connect(); let mut round: u64 = 1; let mut stop = pair.time; let end = pair.time + 5 * CID_TIMEOUT; use crate::cid_queue::CidQueue; use crate::LOC_CID_COUNT; let mut active_cid_num = CidQueue::LEN as u64 + 1; active_cid_num = active_cid_num.min(LOC_CID_COUNT); let mut left_bound = 0; let mut right_bound = active_cid_num - 1; while pair.time < end { stop += CID_TIMEOUT; // Run a while until PushNewCID timer fires while pair.time < stop { if !pair.step() { if let Some(time) = min_opt(pair.client.next_wakeup(), pair.server.next_wakeup()) { pair.time = time; } } } info!( "Checking active cid sequence range before {:?} seconds", round * CID_TIMEOUT.as_secs() ); let _bound = (left_bound, right_bound); assert_matches!( pair.server_conn_mut(server_ch).active_local_cid_seq(), _bound ); round += 1; left_bound += active_cid_num; right_bound += active_cid_num; pair.drive_server(); } } #[test] fn cid_retirement() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); // Server retires current active remote CIDs pair.server_conn_mut(server_ch) .rotate_local_cid(1, Instant::now()); pair.drive(); // Any unexpected behavior may trigger TransportError::CONNECTION_ID_LIMIT_ERROR assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); assert_matches!(pair.client_conn_mut(client_ch).active_rem_cid_seq(), 1); use crate::cid_queue::CidQueue; use crate::LOC_CID_COUNT; let mut active_cid_num = CidQueue::LEN as u64; active_cid_num = active_cid_num.min(LOC_CID_COUNT); let next_retire_prior_to = active_cid_num + 1; pair.client_conn_mut(client_ch).ping(); // Server retires all valid remote CIDs pair.server_conn_mut(server_ch) .rotate_local_cid(next_retire_prior_to, Instant::now()); pair.drive(); assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(server_ch).is_closed()); assert_matches!( pair.client_conn_mut(client_ch).active_rem_cid_seq(), _next_retire_prior_to ); } #[test] fn finish_stream_flow_control_reordered() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive_client(); // Send stream data pair.server.drive(pair.time, pair.client.addr); // Receive // Issue flow control credit let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); pair.server.drive(pair.time, pair.client.addr); pair.server.delay_outbound(); // Delay it pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); // Send FIN pair.server.drive(pair.time, pair.client.addr); // Acknowledge pair.server.finish_delay(); // Add flow control packets after pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] fn handshake_1rtt_handling() { let _guard = subscribe(); let mut pair = Pair::default(); let client_ch = pair.begin_connect(client_config()); pair.drive_client(); pair.drive_server(); let server_ch = pair.server.assert_accept(); // Server now has 1-RTT keys, but remains in Handshake state until the TLS CFIN has // authenticated the client. Delay the final client handshake flight so that doesn't happen yet. pair.client.drive(pair.time, pair.server.addr); pair.client.delay_outbound(); // Send some 1-RTT data which will be received first. let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.client.drive(pair.time, pair.server.addr); // Add the handshake flight back on. pair.client.finish_delay(); pair.drive(); assert!(pair.client_conn_mut(client_ch).lost_packets() != 0); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); let _ = chunks.finalize(); } #[test] fn stop_before_finish() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); info!("stopping stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive(); assert_matches!( pair.client_send(client_ch, s).finish(), Err(FinishError::Stopped(ERROR)) ); } #[test] fn stop_during_finish() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive(); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); info!("stopping and finishing stream"); const ERROR: VarInt = VarInt(42); pair.server_recv(server_ch, s).stop(ERROR).unwrap(); pair.drive_server(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Stopped { id, error_code: ERROR })) if id == s ); } // Ensure we can recover from loss of tail packets when the congestion window is full #[test] fn congested_tail_loss() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); const TARGET: u64 = 2048; assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); // Send data without receiving ACKs until the congestion state falls below target while pair.client_conn_mut(client_ch).congestion_window() > TARGET { let n = pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); assert_eq!(n, 1024); pair.drive_client(); } assert!(!pair.server.inbound.is_empty()); pair.server.inbound.clear(); // Ensure that the congestion state recovers after retransmits occur and are ACKed info!("recovering"); pair.drive(); assert!(pair.client_conn_mut(client_ch).congestion_window() > TARGET); pair.client_send(client_ch, s).write(&[42; 1024]).unwrap(); } #[test] fn datagram_send_recv() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.client_datagrams(client_ch).max_size(), Some(x) if x > 0); const DATA: &[u8] = b"whee"; pair.client_datagrams(client_ch).send(DATA.into()).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::DatagramReceived) ); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA); assert_matches!(pair.server_datagrams(server_ch).recv(), None); } #[test] fn datagram_recv_buffer_overflow() { let _guard = subscribe(); const WINDOW: usize = 100; let server = ServerConfig { transport: Arc::new(TransportConfig { datagram_receive_buffer_size: Some(WINDOW), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_eq!( pair.client_conn_mut(client_ch).datagrams().max_size(), Some(WINDOW - Datagram::SIZE_BOUND) ); const DATA1: &[u8] = &[0xAB; (WINDOW / 3) + 1]; const DATA2: &[u8] = &[0xBC; (WINDOW / 3) + 1]; const DATA3: &[u8] = &[0xCD; (WINDOW / 3) + 1]; pair.client_datagrams(client_ch).send(DATA1.into()).unwrap(); pair.client_datagrams(client_ch).send(DATA2.into()).unwrap(); pair.client_datagrams(client_ch).send(DATA3.into()).unwrap(); pair.drive(); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::DatagramReceived) ); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA2); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA3); assert_matches!(pair.server_datagrams(server_ch).recv(), None); pair.client_datagrams(client_ch).send(DATA1.into()).unwrap(); pair.drive(); assert_eq!(pair.server_datagrams(server_ch).recv().unwrap(), DATA1); assert_matches!(pair.server_datagrams(server_ch).recv(), None); } #[test] fn datagram_unsupported() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { datagram_receive_buffer_size: None, ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.client_datagrams(client_ch).max_size(), None); match pair.client_datagrams(client_ch).send(Bytes::new()) { Err(SendDatagramError::UnsupportedByPeer) => {} Err(e) => panic!("unexpected error: {e}"), Ok(_) => panic!("unexpected success"), } } #[test] fn large_initial() { let _guard = subscribe(); let mut server_crypto = server_crypto(); server_crypto.alpn_protocols = vec![vec![0, 0, 0, 42]]; let server_config = ServerConfig::with_crypto(Arc::new(server_crypto)); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_crypto = client_crypto(); let protocols = (0..1000u32) .map(|x| x.to_be_bytes().to_vec()) .collect::>(); client_crypto.alpn_protocols = protocols; let cfg = ClientConfig::new(Arc::new(client_crypto)); let client_ch = pair.begin_connect(cfg); pair.drive(); let server_ch = pair.server.assert_accept(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Connected { .. }) ); } #[test] /// Ensure that we don't yield a finish event before the actual FIN is acked so the peer isn't left /// hanging fn finish_acked() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); info!("client sends data to server"); pair.drive_client(); // send data to server info!("server acknowledges data"); pair.drive_server(); // process data and send data ack // Receive data assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); let _ = chunks.finalize(); // Finish before receiving data ack pair.client_send(client_ch, s).finish().unwrap(); // Send FIN, receive data ack info!("client receives ACK, sends FIN"); pair.drive_client(); // Check for premature finish from data ack assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Process FIN ack info!("server ACKs FIN"); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } #[test] /// Ensure that we don't yield a finish event while there's still unacknowledged data fn finish_retransmit() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); const MSG: &[u8] = b"hello"; pair.client_send(client_ch, s).write(MSG).unwrap(); pair.drive_client(); // send data to server pair.server.inbound.clear(); // Lose it // Send FIN pair.client_send(client_ch, s).finish().unwrap(); pair.drive_client(); // Process FIN pair.drive_server(); // Receive FIN ack, but no data ack pair.drive_client(); // Check for premature finish from FIN ack assert_matches!(pair.client_conn_mut(client_ch).poll(), None); // Recover pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); assert_matches!( pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_streams(server_ch).accept(Dir::Uni), Some(stream) if stream == s); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } /// Ensures that exchanging data on a client-initiated bidirectional stream works past the initial /// stream window. #[test] fn repeated_request_response() { let _guard = subscribe(); let server = ServerConfig { transport: Arc::new(TransportConfig { max_concurrent_bidi_streams: 1u32.into(), ..TransportConfig::default() }), ..server_config() }; let mut pair = Pair::new(Default::default(), server); let (client_ch, server_ch) = pair.connect(); const REQUEST: &[u8] = b"hello"; const RESPONSE: &[u8] = b"world"; for _ in 0..3 { let s = pair.client_streams(client_ch).open(Dir::Bi).unwrap(); pair.client_send(client_ch, s).write(REQUEST).unwrap(); pair.client_send(client_ch, s).finish().unwrap(); pair.drive(); assert_eq!(pair.server_streams(server_ch).accept(Dir::Bi), Some(s)); let mut recv = pair.server_recv(server_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == REQUEST ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); pair.server_send(server_ch, s).write(RESPONSE).unwrap(); pair.server_send(server_ch, s).finish().unwrap(); pair.drive(); let mut recv = pair.client_recv(client_ch, s); let mut chunks = recv.read(false).unwrap(); assert_matches!( chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == RESPONSE ); assert_matches!(chunks.next(usize::MAX), Ok(None)); let _ = chunks.finalize(); } } /// Ensures that the client sends an anti-deadlock probe after an incomplete server's first flight #[test] fn handshake_anti_deadlock_probe() { let _guard = subscribe(); let (cert, key) = big_cert_and_key(); let server = server_config_with_cert(cert.clone(), key); let client = client_config_with_certs(vec![cert]); let mut pair = Pair::new(Default::default(), server); let client_ch = pair.begin_connect(client); // Client sends initial pair.drive_client(); // Server sends first flight, gets blocked on anti-amplification pair.drive_server(); // Client acks... pair.drive_client(); // ...but it's lost, so the server doesn't get anti-amplification credit from it pair.server.inbound.clear(); // Client sends an anti-deadlock probe, and the handshake completes as usual. pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } /// Ensures that the server can respond with 3 initial packets during the handshake /// before the anti-amplification limit kicks in when MTUs are similar. #[test] fn server_can_send_3_inital_packets() { let _guard = subscribe(); let (cert, key) = big_cert_and_key(); let server = server_config_with_cert(cert.clone(), key); let client = client_config_with_certs(vec![cert]); let mut pair = Pair::new(Default::default(), server); let client_ch = pair.begin_connect(client); // Client sends initial pair.drive_client(); // Server sends first flight, gets blocked on anti-amplification pair.drive_server(); // Server should have queued 3 packets at this time assert_eq!(pair.client.inbound.len(), 3); pair.drive(); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( pair.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); } /// Generate a big fat certificate that can't fit inside the initial anti-amplification limit fn big_cert_and_key() -> (rustls::Certificate, rustls::PrivateKey) { let cert = rcgen::generate_simple_self_signed( Some("localhost".into()) .into_iter() .chain((0..1000).map(|x| format!("foo_{x}"))) .collect::>(), ) .unwrap(); let key = rustls::PrivateKey(cert.serialize_private_key_der()); let cert = rustls::Certificate(cert.serialize_der().unwrap()); (cert, key) } #[test] fn malformed_token_len() { let _guard = subscribe(); let client_addr = "[::2]:7890".parse().unwrap(); let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true); server.handle( Instant::now(), client_addr, None, None, hex!("8900 0000 0101 0000 1b1b 841b 0000 0000 3f00")[..].into(), ); } #[test] /// This is mostly a sanity check to ensure our testing code is correctly dropping packets above the /// pmtu fn connect_too_low_mtu() { let _guard = subscribe(); let mut pair = Pair::default(); // The maximum payload size is lower than 1200, so no packages will get through! pair.mtu = 1000; pair.begin_connect(client_config()); pair.drive(); pair.server.assert_no_accept() } #[test] fn connect_lost_mtu_probes_do_not_trigger_congestion_control() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); pair.drive(); let client_stats = pair.client_conn_mut(client_ch).stats(); let server_stats = pair.server_conn_mut(server_ch).stats(); // Sanity check (all MTU probes should have been lost) assert_eq!(client_stats.path.sent_plpmtud_probes, 9); assert_eq!(client_stats.path.lost_plpmtud_probes, 9); assert_eq!(server_stats.path.sent_plpmtud_probes, 9); assert_eq!(server_stats.path.lost_plpmtud_probes, 9); // No congestion events assert_eq!(client_stats.path.congestion_events, 0); assert_eq!(server_stats.path.congestion_events, 0); } #[test] fn connect_detects_mtu() { let _guard = subscribe(); let max_udp_payload_and_expected_mtu = &[(1200, 1200), (1400, 1389), (1500, 1452)]; for &(pair_max_udp, expected_mtu) in max_udp_payload_and_expected_mtu { println!("Trying {pair_max_udp}"); let mut pair = Pair::default(); pair.mtu = pair_max_udp; let (client_ch, server_ch) = pair.connect(); pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), expected_mtu); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), expected_mtu); } } #[test] fn migrate_detects_new_mtu_and_respects_original_peer_max_udp_payload_size() { let _guard = subscribe(); let client_max_udp_payload_size: u16 = 1400; // Set up a client with a max payload size of 1400 (and use the defaults for the server) let server_endpoint_config = EndpointConfig::default(); let server = Endpoint::new( Arc::new(server_endpoint_config), Some(Arc::new(server_config())), true, ); let client_endpoint_config = EndpointConfig { max_udp_payload_size: VarInt::from(client_max_udp_payload_size), ..EndpointConfig::default() }; let client = Endpoint::new(Arc::new(client_endpoint_config), None, true); let mut pair = Pair::new_from_endpoint(client, server); pair.mtu = 1300; // Connect let (client_ch, server_ch) = pair.connect(); pair.drive(); // Sanity check: MTUD ran to completion (the numbers differ because binary search stops when // changes are smaller than 20, otherwise both endpoints would converge at the same MTU of 1300) assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1293); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1300); // Migrate client to a different port (and simulate a higher path MTU) pair.mtu = 1500; pair.client.addr = SocketAddr::new( Ipv4Addr::new(127, 0, 0, 1).into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); pair.client_conn_mut(client_ch).ping(); pair.drive(); // Sanity check: the server saw that the client address was updated assert_eq!( pair.server_conn_mut(server_ch).remote_address(), pair.client.addr ); // MTU detection has successfully run after migrating assert_eq!( pair.server_conn_mut(server_ch).path_mtu(), client_max_udp_payload_size ); // Sanity check: the client keeps the old MTU, because migration is triggered by incoming // packets from a different address assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1293); } #[test] fn connect_runs_mtud_again_after_600_seconds() { let _guard = subscribe(); let mut server_config = server_config(); let mut client_config = client_config(); // Note: we use an infinite idle timeout to ensure we can wait 600 seconds without the // connection closing Arc::get_mut(&mut server_config.transport) .unwrap() .max_idle_timeout(None); Arc::get_mut(&mut client_config.transport) .unwrap() .max_idle_timeout(None); let mut pair = Pair::new(Default::default(), server_config); pair.mtu = 1400; let (client_ch, server_ch) = pair.connect_with(client_config); pair.drive(); // Sanity check: the mtu has been discovered let client_conn = pair.client_conn_mut(client_ch); assert_eq!(client_conn.path_mtu(), 1389); assert_eq!(client_conn.stats().path.sent_plpmtud_probes, 5); assert_eq!(client_conn.stats().path.lost_plpmtud_probes, 3); let server_conn = pair.server_conn_mut(server_ch); assert_eq!(server_conn.path_mtu(), 1389); assert_eq!(server_conn.stats().path.sent_plpmtud_probes, 5); assert_eq!(server_conn.stats().path.lost_plpmtud_probes, 3); // Sanity check: the mtu does not change after the fact, even though the link now supports a // higher udp payload size pair.mtu = 1500; pair.drive(); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1389); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1389); // The MTU changes after 600 seconds, because now MTUD runs for the second time pair.time += Duration::from_secs(600); pair.drive(); assert!(!pair.client_conn_mut(client_ch).is_closed()); assert!(!pair.server_conn_mut(client_ch).is_closed()); assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1452); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1452); } #[test] fn packet_loss_and_retry_too_low_mtu() { let _guard = subscribe(); let mut pair = Pair::default(); let (client_ch, server_ch) = pair.connect(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(b"hello").unwrap(); pair.drive(); // Nothing will get past this mtu pair.mtu = 10; pair.client_send(client_ch, s).write(b" world").unwrap(); pair.drive_client(); // The packet was dropped assert!(pair.client.outbound.is_empty()); assert!(pair.server.inbound.is_empty()); // Restore the default mtu, so future packets are properly transmitted pair.mtu = DEFAULT_MTU; // The lost packet is resent pair.drive(); assert!(pair.client.outbound.is_empty()); let recv = pair.server_recv(server_ch, s); let buf = stream_chunks(recv); assert_eq!(buf, b"hello world".as_slice()); } #[test] fn blackhole_after_mtu_change_repairs_itself() { let _guard = subscribe(); let mut pair = Pair::default(); pair.mtu = 1500; let (client_ch, server_ch) = pair.connect(); pair.drive(); // Sanity check assert_eq!(pair.client_conn_mut(client_ch).path_mtu(), 1452); assert_eq!(pair.server_conn_mut(server_ch).path_mtu(), 1452); // Back to the base MTU pair.mtu = 1200; // The payload will be sent in a single packet, because the detected MTU was 1444, but it will // be dropped because the link no longer supports that packet size! let payload = vec![42; 1300]; let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); let out_of_bounds = pair.drive_bounded(); if out_of_bounds { panic!("Connections never reached an idle state"); } let recv = pair.server_recv(server_ch, s); let buf = stream_chunks(recv); // The whole packet arrived in the end assert_eq!(buf.len(), 1300); // Sanity checks (black hole detected after 3 lost packets) let client_stats = pair.client_conn_mut(client_ch).stats(); assert!(client_stats.path.lost_packets >= 3); assert!(client_stats.path.congestion_events >= 3); assert_eq!(client_stats.path.black_holes_detected, 1); } #[test] fn packet_splitting_with_default_mtu() { let _guard = subscribe(); // The payload needs to be split in 2 in order to be sent, because it is higher than the max MTU let payload = vec![42; 1300]; let mut pair = Pair::default(); let (client_ch, _) = pair.connect(); pair.drive(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert_eq!(pair.client.outbound.len(), 2); pair.drive_client(); assert_eq!(pair.server.inbound.len(), 2); } #[test] fn packet_splitting_not_necessary_after_higher_mtu_discovered() { let _guard = subscribe(); let payload = vec![42; 1300]; let mut pair = Pair::default(); pair.mtu = 1500; let (client_ch, _) = pair.connect(); pair.drive(); let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); pair.client_send(client_ch, s).write(&payload).unwrap(); pair.client.drive(pair.time, pair.server.addr); assert_eq!(pair.client.outbound.len(), 1); pair.drive_client(); assert_eq!(pair.server.inbound.len(), 1); } fn stream_chunks(mut recv: RecvStream) -> Vec { let mut buf = Vec::new(); let mut chunks = recv.read(true).unwrap(); while let Ok(Some(chunk)) = chunks.next(usize::MAX) { buf.extend(chunk.bytes); } let _ = chunks.finalize(); buf } #[test] fn reject_new_connections() { let _guard = subscribe(); let mut pair = Pair::default(); pair.server.reject_new_connections(); // The server should now reject incoming connections. let client_ch = pair.begin_connect(client_config()); pair.drive(); pair.server.assert_no_accept(); assert!(pair.client.connections.get(&client_ch).unwrap().is_closed()); } quinn-proto-0.10.6/src/tests/util.rs000064400000000000000000000406421046102023000154650ustar 00000000000000use std::{ cmp, collections::{HashMap, VecDeque}, env, io::{self, Write}, mem, net::{Ipv6Addr, SocketAddr, UdpSocket}, ops::RangeFrom, str, sync::{Arc, Mutex}, time::{Duration, Instant}, }; use assert_matches::assert_matches; use bytes::BytesMut; use lazy_static::lazy_static; use rustls::{Certificate, KeyLogFile, PrivateKey}; use tracing::{info_span, trace}; use super::*; pub(super) const DEFAULT_MTU: usize = 1200; pub(super) struct Pair { pub(super) server: TestEndpoint, pub(super) client: TestEndpoint, pub(super) time: Instant, /// Simulates the maximum size allowed for UDP payloads by the link (packets exceeding this size will be dropped) pub(super) mtu: usize, // One-way pub(super) latency: Duration, /// Number of spin bit flips pub(super) spins: u64, last_spin: bool, } impl Pair { pub(super) fn new(endpoint_config: Arc, server_config: ServerConfig) -> Self { let server = Endpoint::new(endpoint_config.clone(), Some(Arc::new(server_config)), true); let client = Endpoint::new(endpoint_config, None, true); Self::new_from_endpoint(client, server) } pub(super) fn new_from_endpoint(client: Endpoint, server: Endpoint) -> Self { let server_addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), SERVER_PORTS.lock().unwrap().next().unwrap(), ); let client_addr = SocketAddr::new( Ipv6Addr::LOCALHOST.into(), CLIENT_PORTS.lock().unwrap().next().unwrap(), ); Self { server: TestEndpoint::new(server, server_addr), client: TestEndpoint::new(client, client_addr), time: Instant::now(), mtu: DEFAULT_MTU, latency: Duration::new(0, 0), spins: 0, last_spin: false, } } /// Returns whether the connection is not idle pub(super) fn step(&mut self) -> bool { self.drive_client(); self.drive_server(); if self.client.is_idle() && self.server.is_idle() { return false; } let client_t = self.client.next_wakeup(); let server_t = self.server.next_wakeup(); match min_opt(client_t, server_t) { Some(t) if Some(t) == client_t => { if t != self.time { self.time = self.time.max(t); trace!("advancing to {:?} for client", self.time); } true } Some(t) if Some(t) == server_t => { if t != self.time { self.time = self.time.max(t); trace!("advancing to {:?} for server", self.time); } true } Some(_) => unreachable!(), None => false, } } /// Advance time until both connections are idle pub(super) fn drive(&mut self) { while self.step() {} } /// Advance time until both connections are idle, or after 100 steps have been executed /// /// Returns true if the amount of steps exceeds the bounds, because the connections never became /// idle pub(super) fn drive_bounded(&mut self) -> bool { for _ in 0..100 { if !self.step() { return false; } } true } pub(super) fn drive_client(&mut self) { let span = info_span!("client"); let _guard = span.enter(); self.client.drive(self.time, self.server.addr); for x in self.client.outbound.drain(..) { if packet_size(&x) > self.mtu { info!( packet_size = packet_size(&x), "dropping packet (max size exceeded)" ); continue; } if x.contents[0] & packet::LONG_HEADER_FORM == 0 { let spin = x.contents[0] & packet::SPIN_BIT != 0; self.spins += (spin == self.last_spin) as u64; self.last_spin = spin; } if let Some(ref socket) = self.client.socket { socket.send_to(&x.contents, x.destination).unwrap(); } if self.server.addr == x.destination { self.server.inbound.push_back(( self.time + self.latency, x.ecn, x.contents.as_ref().into(), )); } } } pub(super) fn drive_server(&mut self) { let span = info_span!("server"); let _guard = span.enter(); self.server.drive(self.time, self.client.addr); for x in self.server.outbound.drain(..) { if packet_size(&x) > self.mtu { info!( packet_size = packet_size(&x), "dropping packet (max size exceeded)" ); continue; } if let Some(ref socket) = self.server.socket { socket.send_to(&x.contents, x.destination).unwrap(); } if self.client.addr == x.destination { self.client.inbound.push_back(( self.time + self.latency, x.ecn, x.contents.as_ref().into(), )); } } } pub(super) fn connect(&mut self) -> (ConnectionHandle, ConnectionHandle) { self.connect_with(client_config()) } pub(super) fn connect_with( &mut self, config: ClientConfig, ) -> (ConnectionHandle, ConnectionHandle) { info!("connecting"); let client_ch = self.begin_connect(config); self.drive(); let server_ch = self.server.assert_accept(); self.finish_connect(client_ch, server_ch); (client_ch, server_ch) } /// Just start connecting the client pub(super) fn begin_connect(&mut self, config: ClientConfig) -> ConnectionHandle { let span = info_span!("client"); let _guard = span.enter(); let (client_ch, client_conn) = self .client .connect(config, self.server.addr, "localhost") .unwrap(); self.client.connections.insert(client_ch, client_conn); client_ch } fn finish_connect(&mut self, client_ch: ConnectionHandle, server_ch: ConnectionHandle) { assert_matches!( self.client_conn_mut(client_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( self.client_conn_mut(client_ch).poll(), Some(Event::Connected { .. }) ); assert_matches!( self.server_conn_mut(server_ch).poll(), Some(Event::HandshakeDataReady) ); assert_matches!( self.server_conn_mut(server_ch).poll(), Some(Event::Connected { .. }) ); } pub(super) fn client_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection { self.client.connections.get_mut(&ch).unwrap() } pub(super) fn client_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> { self.client_conn_mut(ch).streams() } pub(super) fn client_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> { self.client_conn_mut(ch).send_stream(s) } pub(super) fn client_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> { self.client_conn_mut(ch).recv_stream(s) } pub(super) fn client_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> { self.client_conn_mut(ch).datagrams() } pub(super) fn server_conn_mut(&mut self, ch: ConnectionHandle) -> &mut Connection { self.server.connections.get_mut(&ch).unwrap() } pub(super) fn server_streams(&mut self, ch: ConnectionHandle) -> Streams<'_> { self.server_conn_mut(ch).streams() } pub(super) fn server_send(&mut self, ch: ConnectionHandle, s: StreamId) -> SendStream<'_> { self.server_conn_mut(ch).send_stream(s) } pub(super) fn server_recv(&mut self, ch: ConnectionHandle, s: StreamId) -> RecvStream<'_> { self.server_conn_mut(ch).recv_stream(s) } pub(super) fn server_datagrams(&mut self, ch: ConnectionHandle) -> Datagrams<'_> { self.server_conn_mut(ch).datagrams() } } impl Default for Pair { fn default() -> Self { Self::new(Default::default(), server_config()) } } pub(super) struct TestEndpoint { pub(super) endpoint: Endpoint, pub(super) addr: SocketAddr, socket: Option, timeout: Option, pub(super) outbound: VecDeque, delayed: VecDeque, pub(super) inbound: VecDeque<(Instant, Option, BytesMut)>, accepted: Option, pub(super) connections: HashMap, conn_events: HashMap>, } impl TestEndpoint { fn new(endpoint: Endpoint, addr: SocketAddr) -> Self { let socket = if env::var_os("SSLKEYLOGFILE").is_some() { let socket = UdpSocket::bind(addr).expect("failed to bind UDP socket"); socket .set_read_timeout(Some(Duration::new(0, 10_000_000))) .unwrap(); Some(socket) } else { None }; Self { endpoint, addr, socket, timeout: None, outbound: VecDeque::new(), delayed: VecDeque::new(), inbound: VecDeque::new(), accepted: None, connections: HashMap::default(), conn_events: HashMap::default(), } } pub(super) fn drive(&mut self, now: Instant, remote: SocketAddr) { if let Some(ref socket) = self.socket { loop { let mut buf = [0; 8192]; if socket.recv_from(&mut buf).is_err() { break; } } } while self.inbound.front().map_or(false, |x| x.0 <= now) { let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap(); if let Some((ch, event)) = self.endpoint.handle(recv_time, remote, None, ecn, packet) { match event { DatagramEvent::NewConnection(conn) => { self.connections.insert(ch, conn); self.accepted = Some(ch); } DatagramEvent::ConnectionEvent(event) => { self.conn_events.entry(ch).or_default().push_back(event); } } } } while let Some(x) = self.poll_transmit() { self.outbound.extend(split_transmit(x)); } loop { let mut endpoint_events: Vec<(ConnectionHandle, EndpointEvent)> = vec![]; for (ch, conn) in self.connections.iter_mut() { if self.timeout.map_or(false, |x| x <= now) { self.timeout = None; conn.handle_timeout(now); } for (_, mut events) in self.conn_events.drain() { for event in events.drain(..) { conn.handle_event(event); } } while let Some(event) = conn.poll_endpoint_events() { endpoint_events.push((*ch, event)); } while let Some(x) = conn.poll_transmit(now, MAX_DATAGRAMS) { self.outbound.extend(split_transmit(x)); } self.timeout = conn.poll_timeout(); } if endpoint_events.is_empty() { break; } for (ch, event) in endpoint_events { if let Some(event) = self.handle_event(ch, event) { if let Some(conn) = self.connections.get_mut(&ch) { conn.handle_event(event); } } } } } pub(super) fn next_wakeup(&self) -> Option { let next_inbound = self.inbound.front().map(|x| x.0); min_opt(self.timeout, next_inbound) } fn is_idle(&self) -> bool { self.connections.values().all(|x| x.is_idle()) } pub(super) fn delay_outbound(&mut self) { assert!(self.delayed.is_empty()); mem::swap(&mut self.delayed, &mut self.outbound); } pub(super) fn finish_delay(&mut self) { self.outbound.extend(self.delayed.drain(..)); } pub(super) fn assert_accept(&mut self) -> ConnectionHandle { self.accepted.take().expect("server didn't connect") } pub(super) fn assert_no_accept(&self) { assert!(self.accepted.is_none(), "server did unexpectedly connect") } } impl ::std::ops::Deref for TestEndpoint { type Target = Endpoint; fn deref(&self) -> &Endpoint { &self.endpoint } } impl ::std::ops::DerefMut for TestEndpoint { fn deref_mut(&mut self) -> &mut Endpoint { &mut self.endpoint } } pub(super) fn subscribe() -> tracing::subscriber::DefaultGuard { let sub = tracing_subscriber::FmtSubscriber::builder() .with_max_level(tracing::Level::TRACE) .with_writer(|| TestWriter) .finish(); tracing::subscriber::set_default(sub) } struct TestWriter; impl Write for TestWriter { fn write(&mut self, buf: &[u8]) -> io::Result { print!( "{}", str::from_utf8(buf).expect("tried to log invalid UTF-8") ); Ok(buf.len()) } fn flush(&mut self) -> io::Result<()> { io::stdout().flush() } } pub(super) fn server_config() -> ServerConfig { ServerConfig::with_crypto(Arc::new(server_crypto())) } pub(super) fn server_config_with_cert(cert: Certificate, key: PrivateKey) -> ServerConfig { ServerConfig::with_crypto(Arc::new(server_crypto_with_cert(cert, key))) } pub(super) fn server_crypto() -> rustls::ServerConfig { let cert = Certificate(CERTIFICATE.serialize_der().unwrap()); let key = PrivateKey(CERTIFICATE.serialize_private_key_der()); server_crypto_with_cert(cert, key) } pub(super) fn server_crypto_with_cert(cert: Certificate, key: PrivateKey) -> rustls::ServerConfig { crate::crypto::rustls::server_config(vec![cert], key).unwrap() } pub(super) fn client_config() -> ClientConfig { ClientConfig::new(Arc::new(client_crypto())) } pub(super) fn client_config_with_certs(certs: Vec) -> ClientConfig { ClientConfig::new(Arc::new(client_crypto_with_certs(certs))) } pub(super) fn client_crypto() -> rustls::ClientConfig { let cert = rustls::Certificate(CERTIFICATE.serialize_der().unwrap()); client_crypto_with_certs(vec![cert]) } pub(super) fn client_crypto_with_certs(certs: Vec) -> rustls::ClientConfig { let mut roots = rustls::RootCertStore::empty(); for cert in certs { roots.add(&cert).unwrap(); } let mut config = crate::crypto::rustls::client_config(roots); config.key_log = Arc::new(KeyLogFile::new()); config } pub(super) fn min_opt(x: Option, y: Option) -> Option { match (x, y) { (Some(x), Some(y)) => Some(cmp::min(x, y)), (Some(x), _) => Some(x), (_, Some(y)) => Some(y), _ => None, } } /// The maximum of datagrams TestEndpoint will produce via `poll_transmit` const MAX_DATAGRAMS: usize = 10; fn split_transmit(mut transmit: Transmit) -> Vec { let segment_size = match transmit.segment_size { Some(segment_size) => segment_size, _ => return vec![transmit], }; let mut transmits = Vec::new(); while !transmit.contents.is_empty() { let end = segment_size.min(transmit.contents.len()); let contents = transmit.contents.split_to(end); transmits.push(Transmit { destination: transmit.destination, ecn: transmit.ecn, contents, segment_size: None, src_ip: transmit.src_ip, }); } transmits } fn packet_size(transmit: &Transmit) -> usize { if transmit.segment_size.is_some() { panic!("This transmit is meant to be split into multiple packets!"); } transmit.contents.len() } lazy_static! { pub static ref SERVER_PORTS: Mutex> = Mutex::new(4433..); pub static ref CLIENT_PORTS: Mutex> = Mutex::new(44433..); pub(crate) static ref CERTIFICATE: rcgen::Certificate = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); } quinn-proto-0.10.6/src/token.rs000064400000000000000000000165611046102023000144710ustar 00000000000000use std::{ fmt, io, net::{IpAddr, SocketAddr}, time::{Duration, SystemTime, UNIX_EPOCH}, }; use bytes::BufMut; use crate::{ coding::{BufExt, BufMutExt}, crypto::{CryptoError, HandshakeTokenKey, HmacKey}, shared::ConnectionId, RESET_TOKEN_SIZE, }; pub(crate) struct RetryToken<'a> { /// The destination connection ID set in the very first packet from the client pub(crate) orig_dst_cid: ConnectionId, /// The time at which this token was issued pub(crate) issued: SystemTime, /// Random bytes for deriving AEAD key pub(crate) random_bytes: &'a [u8], } impl<'a> RetryToken<'a> { pub(crate) fn encode( &self, key: &dyn HandshakeTokenKey, address: &SocketAddr, retry_src_cid: &ConnectionId, ) -> Vec { let aead_key = key.aead_from_hkdf(self.random_bytes); let mut buf = Vec::new(); self.orig_dst_cid.encode_long(&mut buf); buf.write::( self.issued .duration_since(UNIX_EPOCH) .map(|x| x.as_secs()) .unwrap_or(0), ); let mut additional_data = [0u8; Self::MAX_ADDITIONAL_DATA_SIZE]; let additional_data = Self::put_additional_data(address, retry_src_cid, &mut additional_data); aead_key.seal(&mut buf, additional_data).unwrap(); let mut token = Vec::new(); token.put_slice(self.random_bytes); token.put_slice(&buf); token } pub(crate) fn from_bytes( key: &dyn HandshakeTokenKey, address: &SocketAddr, retry_src_cid: &ConnectionId, raw_token_bytes: &'a [u8], ) -> Result { if raw_token_bytes.len() < Self::RANDOM_BYTES_LEN { // Invalid length return Err(CryptoError); } let random_bytes = &raw_token_bytes[..Self::RANDOM_BYTES_LEN]; let aead_key = key.aead_from_hkdf(random_bytes); let mut sealed_token = raw_token_bytes[Self::RANDOM_BYTES_LEN..].to_vec(); let mut additional_data = [0u8; Self::MAX_ADDITIONAL_DATA_SIZE]; let additional_data = Self::put_additional_data(address, retry_src_cid, &mut additional_data); let data = aead_key.open(&mut sealed_token, additional_data)?; let mut reader = io::Cursor::new(data); let orig_dst_cid = ConnectionId::decode_long(&mut reader).ok_or(CryptoError)?; let issued = UNIX_EPOCH + Duration::new(reader.get::().map_err(|_| CryptoError)?, 0); Ok(Self { orig_dst_cid, issued, random_bytes, }) } fn put_additional_data<'b>( address: &SocketAddr, retry_src_cid: &ConnectionId, additional_data: &'b mut [u8], ) -> &'b [u8] { let mut cursor = &mut *additional_data; match address.ip() { IpAddr::V4(x) => cursor.put_slice(&x.octets()), IpAddr::V6(x) => cursor.put_slice(&x.octets()), } cursor.write(address.port()); retry_src_cid.encode_long(&mut cursor); let size = Self::MAX_ADDITIONAL_DATA_SIZE - cursor.len(); &additional_data[..size] } const MAX_ADDITIONAL_DATA_SIZE: usize = 39; // max(ipv4, ipv6) + port + retry_src_cid pub(crate) const RANDOM_BYTES_LEN: usize = 32; } /// Stateless reset token /// /// Used for an endpoint to securely communicate that it has lost state for a connection. #[allow(clippy::derived_hash_with_manual_eq)] // Custom PartialEq impl matches derived semantics #[derive(Debug, Copy, Clone, Hash)] pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]); impl ResetToken { pub(crate) fn new(key: &dyn HmacKey, id: &ConnectionId) -> Self { let mut signature = vec![0; key.signature_len()]; key.sign(id, &mut signature); // TODO: Server ID?? let mut result = [0; RESET_TOKEN_SIZE]; result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]); result.into() } } impl PartialEq for ResetToken { fn eq(&self, other: &Self) -> bool { crate::constant_time::eq(&self.0, &other.0) } } impl Eq for ResetToken {} impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken { fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self { Self(x) } } impl std::ops::Deref for ResetToken { type Target = [u8]; fn deref(&self) -> &[u8] { &self.0 } } impl fmt::Display for ResetToken { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for byte in self.iter() { write!(f, "{byte:02x}")?; } Ok(()) } } #[cfg(test)] mod test { #[cfg(feature = "ring")] #[test] fn token_sanity() { use super::*; use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; use crate::MAX_CID_SIZE; use rand::RngCore; use std::{ net::Ipv6Addr, time::{Duration, UNIX_EPOCH}, }; let rng = &mut rand::thread_rng(); let mut master_key = [0; 64]; rng.fill_bytes(&mut master_key); let mut random_bytes = [0; 32]; rng.fill_bytes(&mut random_bytes); let mut master_key = vec![0u8; 64]; rng.fill_bytes(&mut master_key); let prk = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, &[]).extract(&master_key); let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); let token = RetryToken { orig_dst_cid: RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(), issued: UNIX_EPOCH + Duration::new(42, 0), // Fractional seconds would be lost random_bytes: &random_bytes, }; let encoded = token.encode(&prk, &addr, &retry_src_cid); let decoded = RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &encoded) .expect("token didn't validate"); assert_eq!(token.orig_dst_cid, decoded.orig_dst_cid); assert_eq!(token.issued, decoded.issued); } #[cfg(feature = "ring")] #[test] fn invalid_token_returns_err() { use super::*; use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}; use crate::MAX_CID_SIZE; use rand::RngCore; use std::net::Ipv6Addr; let rng = &mut rand::thread_rng(); let mut master_key = [0; 64]; rng.fill_bytes(&mut master_key); let mut random_bytes = [0; 32]; rng.fill_bytes(&mut random_bytes); let prk = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, &[]).extract(&master_key); let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433); let retry_src_cid = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(); let mut invalid_token = Vec::new(); invalid_token.put_slice(&random_bytes); let mut random_data = [0; 32]; rand::thread_rng().fill_bytes(&mut random_data); invalid_token.put_slice(&random_data); // Assert: garbage sealed data with valid random bytes returns err assert!(RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &invalid_token).is_err()); let invalid_token = [0; 31]; rand::thread_rng().fill_bytes(&mut random_bytes); // Assert: completely invalid retry token returns error assert!(RetryToken::from_bytes(&prk, &addr, &retry_src_cid, &invalid_token).is_err()); } } quinn-proto-0.10.6/src/transport_error.rs000064400000000000000000000113761046102023000166150ustar 00000000000000use std::fmt; use bytes::{Buf, BufMut}; use crate::{ coding::{self, BufExt, BufMutExt}, frame, }; /// Transport-level errors occur when a peer violates the protocol specification #[derive(Debug, Clone, Eq, PartialEq)] pub struct Error { /// Type of error pub code: Code, /// Frame type that triggered the error pub frame: Option, /// Human-readable explanation of the reason pub reason: String, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.code.fmt(f)?; if let Some(frame) = self.frame { write!(f, " in {frame}")?; } if !self.reason.is_empty() { write!(f, ": {}", self.reason)?; } Ok(()) } } impl std::error::Error for Error {} impl From for Error { fn from(x: Code) -> Self { Self { code: x, frame: None, reason: "".to_string(), } } } /// Transport-level error code #[derive(Copy, Clone, Eq, PartialEq)] pub struct Code(u64); impl Code { /// Create QUIC error code from TLS alert code pub fn crypto(code: u8) -> Self { Self(0x100 | u64::from(code)) } } impl coding::Codec for Code { fn decode(buf: &mut B) -> coding::Result { Ok(Self(buf.get_var()?)) } fn encode(&self, buf: &mut B) { buf.write_var(self.0) } } impl From for u64 { fn from(x: Code) -> Self { x.0 } } macro_rules! errors { {$($name:ident($val:expr) $desc:expr;)*} => { #[allow(non_snake_case, unused)] impl Error { $( pub(crate) fn $name(reason: T) -> Self where T: Into { Self { code: Code::$name, frame: None, reason: reason.into(), } } )* } impl Code { $(#[doc = $desc] pub const $name: Self = Code($val);)* } impl fmt::Debug for Code { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str(stringify!($name)),)* x if (0x100..0x200).contains(&x) => write!(f, "Code::crypto({:02x})", self.0 as u8), _ => write!(f, "Code({:x})", self.0), } } } impl fmt::Display for Code { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { $($val => f.write_str($desc),)* // We're trying to be abstract over the crypto protocol, so human-readable descriptions here is tricky. _ if self.0 >= 0x100 && self.0 < 0x200 => write!(f, "the cryptographic handshake failed: error {}", self.0 & 0xFF), _ => f.write_str("unknown error"), } } } } } errors! { NO_ERROR(0x0) "the connection is being closed abruptly in the absence of any error"; INTERNAL_ERROR(0x1) "the endpoint encountered an internal error and cannot continue with the connection"; CONNECTION_REFUSED(0x2) "the server refused to accept a new connection"; FLOW_CONTROL_ERROR(0x3) "received more data than permitted in advertised data limits"; STREAM_LIMIT_ERROR(0x4) "received a frame for a stream identifier that exceeded advertised the stream limit for the corresponding stream type"; STREAM_STATE_ERROR(0x5) "received a frame for a stream that was not in a state that permitted that frame"; FINAL_SIZE_ERROR(0x6) "received a STREAM frame or a RESET_STREAM frame containing a different final size to the one already established"; FRAME_ENCODING_ERROR(0x7) "received a frame that was badly formatted"; TRANSPORT_PARAMETER_ERROR(0x8) "received transport parameters that were badly formatted, included an invalid value, was absent even though it is mandatory, was present though it is forbidden, or is otherwise in error"; CONNECTION_ID_LIMIT_ERROR(0x9) "the number of connection IDs provided by the peer exceeds the advertised active_connection_id_limit"; PROTOCOL_VIOLATION(0xA) "detected an error with protocol compliance that was not covered by more specific error codes"; INVALID_TOKEN(0xB) "received an invalid Retry Token in a client Initial"; APPLICATION_ERROR(0xC) "the application or application protocol caused the connection to be closed during the handshake"; CRYPTO_BUFFER_EXCEEDED(0xD) "received more data in CRYPTO frames than can be buffered"; KEY_UPDATE_ERROR(0xE) "key update error"; AEAD_LIMIT_REACHED(0xF) "the endpoint has reached the confidentiality or integrity limit for the AEAD algorithm"; NO_VIABLE_PATH(0x10) "no viable network path exists"; } quinn-proto-0.10.6/src/transport_parameters.rs000064400000000000000000000445621046102023000176320ustar 00000000000000//! QUIC connection transport parameters //! //! The `TransportParameters` type is used to represent the transport parameters //! negotiated by peers while establishing a QUIC connection. This process //! happens as part of the establishment of the TLS session. As such, the types //! contained in this modules should generally only be referred to by custom //! implementations of the `crypto::Session` trait. use std::{ convert::TryFrom, net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, }; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::{ cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, config::{EndpointConfig, ServerConfig, TransportConfig}, shared::ConnectionId, ResetToken, Side, TransportError, VarInt, LOC_CID_COUNT, MAX_CID_SIZE, MAX_STREAM_COUNT, RESET_TOKEN_SIZE, }; // Apply a given macro to a list of all the transport parameters having integer types, along with // their codes and default values. Using this helps us avoid error-prone duplication of the // contained information across decoding, encoding, and the `Default` impl. Whenever we want to do // something with transport parameters, we'll handle the bulk of cases by writing a macro that // takes a list of arguments in this form, then passing it to this macro. macro_rules! apply_params { ($macro:ident) => { $macro! { // #[doc] name (id) = default, /// Milliseconds, disabled if zero max_idle_timeout(0x0001) = 0, /// Limits the size of UDP payloads that the endpoint is willing to receive max_udp_payload_size(0x0003) = 65527, /// Initial value for the maximum amount of data that can be sent on the connection initial_max_data(0x0004) = 0, /// Initial flow control limit for locally-initiated bidirectional streams initial_max_stream_data_bidi_local(0x0005) = 0, /// Initial flow control limit for peer-initiated bidirectional streams initial_max_stream_data_bidi_remote(0x0006) = 0, /// Initial flow control limit for unidirectional streams initial_max_stream_data_uni(0x0007) = 0, /// Initial maximum number of bidirectional streams the peer may initiate initial_max_streams_bidi(0x0008) = 0, /// Initial maximum number of unidirectional streams the peer may initiate initial_max_streams_uni(0x0009) = 0, /// Exponent used to decode the ACK Delay field in the ACK frame ack_delay_exponent(0x000a) = 3, /// Maximum amount of time in milliseconds by which the endpoint will delay sending /// acknowledgments max_ack_delay(0x000b) = 25, /// Maximum number of connection IDs from the peer that an endpoint is willing to store active_connection_id_limit(0x000e) = 2, } }; } macro_rules! make_struct { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { /// Transport parameters used to negotiate connection-level preferences between peers #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct TransportParameters { $($(#[$doc])* pub(crate) $name : VarInt,)* /// Does the endpoint support active connection migration pub(crate) disable_active_migration: bool, /// Maximum size for datagram frames pub(crate) max_datagram_frame_size: Option, /// The value that the endpoint included in the Source Connection ID field of the first /// Initial packet it sends for the connection pub(crate) initial_src_cid: Option, /// The endpoint is willing to receive QUIC packets containing any value for the fixed /// bit pub(crate) grease_quic_bit: bool, // Server-only /// The value of the Destination Connection ID field from the first Initial packet sent /// by the client pub(crate) original_dst_cid: Option, /// The value that the server included in the Source Connection ID field of a Retry /// packet pub(crate) retry_src_cid: Option, /// Token used by the client to verify a stateless reset from the server pub(crate) stateless_reset_token: Option, /// The server's preferred address for communication after handshake completion pub(crate) preferred_address: Option, } impl Default for TransportParameters { /// Standard defaults, used if the peer does not supply a given parameter. fn default() -> Self { Self { $($name: VarInt::from_u32($default),)* disable_active_migration: false, max_datagram_frame_size: None, initial_src_cid: None, grease_quic_bit: false, original_dst_cid: None, retry_src_cid: None, stateless_reset_token: None, preferred_address: None, } } } } } apply_params!(make_struct); impl TransportParameters { pub(crate) fn new( config: &TransportConfig, endpoint_config: &EndpointConfig, cid_gen: &dyn ConnectionIdGenerator, initial_src_cid: ConnectionId, server_config: Option<&ServerConfig>, ) -> Self { Self { initial_src_cid: Some(initial_src_cid), initial_max_streams_bidi: config.max_concurrent_bidi_streams, initial_max_streams_uni: config.max_concurrent_uni_streams, initial_max_data: config.receive_window, initial_max_stream_data_bidi_local: config.stream_receive_window, initial_max_stream_data_bidi_remote: config.stream_receive_window, initial_max_stream_data_uni: config.stream_receive_window, max_udp_payload_size: endpoint_config.max_udp_payload_size, max_idle_timeout: config.max_idle_timeout.unwrap_or(VarInt(0)), disable_active_migration: server_config.map_or(false, |c| !c.migration), active_connection_id_limit: if cid_gen.cid_len() == 0 { 2 // i.e. default, i.e. unsent } else { CidQueue::LEN as u32 } .into(), max_datagram_frame_size: config .datagram_receive_buffer_size .map(|x| (x.min(u16::max_value().into()) as u16).into()), grease_quic_bit: endpoint_config.grease_quic_bit, ..Self::default() } } /// Check that these parameters are legal when resuming from /// certain cached parameters pub(crate) fn validate_resumption_from(&self, cached: &Self) -> Result<(), TransportError> { if cached.active_connection_id_limit > self.active_connection_id_limit || cached.initial_max_data > self.initial_max_data || cached.initial_max_stream_data_bidi_local > self.initial_max_stream_data_bidi_local || cached.initial_max_stream_data_bidi_remote > self.initial_max_stream_data_bidi_remote || cached.initial_max_stream_data_uni > self.initial_max_stream_data_uni || cached.initial_max_streams_bidi > self.initial_max_streams_bidi || cached.initial_max_streams_uni > self.initial_max_streams_uni || cached.max_datagram_frame_size > self.max_datagram_frame_size || cached.grease_quic_bit && !self.grease_quic_bit { return Err(TransportError::PROTOCOL_VIOLATION( "0-RTT accepted with incompatible transport parameters", )); } Ok(()) } /// Maximum number of CIDs to issue to this peer /// /// Consider both a) the active_connection_id_limit from the other end; and /// b) LOC_CID_COUNT used locally pub(crate) fn issue_cids_limit(&self) -> u64 { self.active_connection_id_limit.0.min(LOC_CID_COUNT) } } /// A server's preferred address /// /// This is communicated as a transport parameter during TLS session establishment. #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) struct PreferredAddress { pub(crate) address_v4: Option, pub(crate) address_v6: Option, pub(crate) connection_id: ConnectionId, pub(crate) stateless_reset_token: ResetToken, } impl PreferredAddress { fn wire_size(&self) -> u16 { 4 + 2 + 16 + 2 + 1 + self.connection_id.len() as u16 + 16 } fn write(&self, w: &mut W) { w.write(self.address_v4.map_or(Ipv4Addr::UNSPECIFIED, |x| *x.ip())); w.write::(self.address_v4.map_or(0, |x| x.port())); w.write(self.address_v6.map_or(Ipv6Addr::UNSPECIFIED, |x| *x.ip())); w.write::(self.address_v6.map_or(0, |x| x.port())); w.write::(self.connection_id.len() as u8); w.put_slice(&self.connection_id); w.put_slice(&self.stateless_reset_token); } fn read(r: &mut R) -> Result { let ip_v4 = r.get::()?; let port_v4 = r.get::()?; let ip_v6 = r.get::()?; let port_v6 = r.get::()?; let cid_len = r.get::()?; if r.remaining() < cid_len as usize || cid_len > MAX_CID_SIZE as u8 { return Err(Error::Malformed); } let mut stage = [0; MAX_CID_SIZE]; r.copy_to_slice(&mut stage[0..cid_len as usize]); let cid = ConnectionId::new(&stage[0..cid_len as usize]); if r.remaining() < 16 { return Err(Error::Malformed); } let mut token = [0; RESET_TOKEN_SIZE]; r.copy_to_slice(&mut token); let address_v4 = if ip_v4.is_unspecified() && port_v4 == 0 { None } else { Some(SocketAddrV4::new(ip_v4, port_v4)) }; let address_v6 = if ip_v6.is_unspecified() && port_v6 == 0 { None } else { Some(SocketAddrV6::new(ip_v6, port_v6, 0, 0)) }; if address_v4.is_none() && address_v6.is_none() { return Err(Error::IllegalValue); } Ok(Self { address_v4, address_v6, connection_id: cid, stateless_reset_token: token.into(), }) } } /// Errors encountered while decoding `TransportParameters` #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] pub enum Error { /// Parameters that are semantically invalid #[error("parameter had illegal value")] IllegalValue, /// Catch-all error for problems while decoding transport parameters #[error("parameters were malformed")] Malformed, } impl From for TransportError { fn from(e: Error) -> Self { match e { Error::IllegalValue => Self::TRANSPORT_PARAMETER_ERROR("illegal value"), Error::Malformed => Self::TRANSPORT_PARAMETER_ERROR("malformed"), } } } impl From for Error { fn from(_: UnexpectedEnd) -> Self { Self::Malformed } } impl TransportParameters { /// Encode `TransportParameters` into buffer pub fn write(&self, w: &mut W) { macro_rules! write_params { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { $( if self.$name.0 != $default { w.write_var($code); w.write(VarInt::try_from(self.$name.size()).unwrap()); w.write(self.$name); } )* } } apply_params!(write_params); // Add a reserved parameter to keep people on their toes w.write_var(31 * 5 + 27); w.write_var(0); if let Some(ref x) = self.stateless_reset_token { w.write_var(0x02); w.write_var(16); w.put_slice(x); } if self.disable_active_migration { w.write_var(0x0c); w.write_var(0); } if let Some(x) = self.max_datagram_frame_size { w.write_var(0x20); w.write_var(x.size() as u64); w.write(x); } if let Some(ref x) = self.preferred_address { w.write_var(0x000d); w.write_var(x.wire_size() as u64); x.write(w); } for &(tag, cid) in &[ (0x00, &self.original_dst_cid), (0x0f, &self.initial_src_cid), (0x10, &self.retry_src_cid), ] { if let Some(ref cid) = *cid { w.write_var(tag); w.write_var(cid.len() as u64); w.put_slice(cid); } } if self.grease_quic_bit { w.write_var(0x2ab2); w.write_var(0); } } /// Decode `TransportParameters` from buffer pub fn read(side: Side, r: &mut R) -> Result { // Initialize to protocol-specified defaults let mut params = Self::default(); // State to check for duplicate transport parameters. macro_rules! param_state { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => {{ struct ParamState { $($name: bool,)* } ParamState { $($name: false,)* } }} } let mut got = apply_params!(param_state); while r.has_remaining() { let id = r.get_var()?; let len = r.get_var()?; if (r.remaining() as u64) < len { return Err(Error::Malformed); } let len = len as usize; match id { 0x00 => decode_cid(len, &mut params.original_dst_cid, r)?, 0x02 => { if len != 16 || params.stateless_reset_token.is_some() { return Err(Error::Malformed); } let mut tok = [0; RESET_TOKEN_SIZE]; r.copy_to_slice(&mut tok); params.stateless_reset_token = Some(tok.into()); } 0x0c => { if len != 0 || params.disable_active_migration { return Err(Error::Malformed); } params.disable_active_migration = true; } 0x0d => { if params.preferred_address.is_some() { return Err(Error::Malformed); } params.preferred_address = Some(PreferredAddress::read(&mut r.take(len))?); } 0x0f => decode_cid(len, &mut params.initial_src_cid, r)?, 0x10 => decode_cid(len, &mut params.retry_src_cid, r)?, 0x20 => { if len > 8 || params.max_datagram_frame_size.is_some() { return Err(Error::Malformed); } params.max_datagram_frame_size = Some(r.get().unwrap()); } 0x2ab2 => match len { 0 => params.grease_quic_bit = true, _ => return Err(Error::Malformed), }, _ => { macro_rules! parse { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { match id { $($code => { let value = r.get::()?; if len != value.size() || got.$name { return Err(Error::Malformed); } params.$name = value.into(); got.$name = true; })* _ => r.advance(len as usize), } } } apply_params!(parse); } } } // Semantic validation if params.ack_delay_exponent.0 > 20 || params.max_ack_delay.0 >= 1 << 14 || params.active_connection_id_limit.0 < 2 || params.max_udp_payload_size.0 < 1200 || params.initial_max_streams_bidi.0 > MAX_STREAM_COUNT || params.initial_max_streams_uni.0 > MAX_STREAM_COUNT || (side.is_server() && (params.stateless_reset_token.is_some() || params.preferred_address.is_some())) { return Err(Error::IllegalValue); } Ok(params) } } fn decode_cid(len: usize, value: &mut Option, r: &mut impl Buf) -> Result<(), Error> { if len > MAX_CID_SIZE || value.is_some() || r.remaining() < len { return Err(Error::Malformed); } *value = Some(ConnectionId::from_buf(r, len)); Ok(()) } #[cfg(test)] mod test { use super::*; #[test] fn coding() { let mut buf = Vec::new(); let params = TransportParameters { initial_src_cid: Some(ConnectionId::new(&[])), original_dst_cid: Some(ConnectionId::new(&[])), initial_max_streams_bidi: 16u32.into(), initial_max_streams_uni: 16u32.into(), ack_delay_exponent: 2u32.into(), max_udp_payload_size: 1200u32.into(), preferred_address: Some(PreferredAddress { address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)), address_v6: None, connection_id: ConnectionId::new(&[]), stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(), }), grease_quic_bit: true, ..TransportParameters::default() }; params.write(&mut buf); assert_eq!( TransportParameters::read(Side::Client, &mut buf.as_slice()).unwrap(), params ); } #[test] fn resumption_params_validation() { let high_limit = TransportParameters { initial_max_streams_uni: 32u32.into(), ..Default::default() }; let low_limit = TransportParameters { initial_max_streams_uni: 16u32.into(), ..Default::default() }; high_limit.validate_resumption_from(&low_limit).unwrap(); low_limit.validate_resumption_from(&high_limit).unwrap_err(); } } quinn-proto-0.10.6/src/varint.rs000064400000000000000000000120151046102023000146420ustar 00000000000000use std::{convert::TryInto, fmt}; use bytes::{Buf, BufMut}; use thiserror::Error; use crate::coding::{self, Codec, UnexpectedEnd}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; /// An integer less than 2^62 /// /// Values of this type are suitable for encoding as QUIC variable-length integer. // It would be neat if we could express to Rust that the top two bits are available for use as enum // discriminants #[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(pub(crate) u64); impl VarInt { /// The largest representable value pub const MAX: Self = Self((1 << 62) - 1); /// The largest encoded value length pub const MAX_SIZE: usize = 8; /// Construct a `VarInt` infallibly pub const fn from_u32(x: u32) -> Self { Self(x as u64) } /// Succeeds iff `x` < 2^62 pub fn from_u64(x: u64) -> Result { if x < 2u64.pow(62) { Ok(Self(x)) } else { Err(VarIntBoundsExceeded) } } /// Create a VarInt without ensuring it's in range /// /// # Safety /// /// `x` must be less than 2^62. pub const unsafe fn from_u64_unchecked(x: u64) -> Self { Self(x) } /// Extract the integer value pub const fn into_inner(self) -> u64 { self.0 } /// Compute the number of bytes needed to encode this value pub(crate) fn size(self) -> usize { let x = self.0; if x < 2u64.pow(6) { 1 } else if x < 2u64.pow(14) { 2 } else if x < 2u64.pow(30) { 4 } else if x < 2u64.pow(62) { 8 } else { unreachable!("malformed VarInt"); } } } impl From for u64 { fn from(x: VarInt) -> Self { x.0 } } impl From for VarInt { fn from(x: u8) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u16) -> Self { Self(x.into()) } } impl From for VarInt { fn from(x: u32) -> Self { Self(x.into()) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: u64) -> Result { Self::from_u64(x) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: u128) -> Result { Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?) } } impl std::convert::TryFrom for VarInt { type Error = VarIntBoundsExceeded; /// Succeeds iff `x` < 2^62 fn try_from(x: usize) -> Result { Self::try_from(x as u64) } } impl fmt::Debug for VarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl fmt::Display for VarInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } #[cfg(feature = "arbitrary")] impl<'arbitrary> Arbitrary<'arbitrary> for VarInt { fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result { Ok(Self(u.int_in_range(0..=Self::MAX.0)?)) } } /// Error returned when constructing a `VarInt` from a value >= 2^62 #[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] #[error("value too large for varint encoding")] pub struct VarIntBoundsExceeded; impl Codec for VarInt { fn decode(r: &mut B) -> coding::Result { if !r.has_remaining() { return Err(UnexpectedEnd); } let mut buf = [0; 8]; buf[0] = r.get_u8(); let tag = buf[0] >> 6; buf[0] &= 0b0011_1111; let x = match tag { 0b00 => u64::from(buf[0]), 0b01 => { if r.remaining() < 1 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..2]); u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) } 0b10 => { if r.remaining() < 3 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..4]); u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) } 0b11 => { if r.remaining() < 7 { return Err(UnexpectedEnd); } r.copy_to_slice(&mut buf[1..8]); u64::from_be_bytes(buf) } _ => unreachable!(), }; Ok(Self(x)) } fn encode(&self, w: &mut B) { let x = self.0; if x < 2u64.pow(6) { w.put_u8(x as u8); } else if x < 2u64.pow(14) { w.put_u16(0b01 << 14 | x as u16); } else if x < 2u64.pow(30) { w.put_u32(0b10 << 30 | x as u32); } else if x < 2u64.pow(62) { w.put_u64(0b11 << 62 | x); } else { unreachable!("malformed VarInt") } } }