tokio-util-0.6.9/.cargo_vcs_info.json0000644000000001120000000000100131420ustar { "git": { "sha1": "d1a400912e82505c18c6c0c1f05cda06f334e201" } } tokio-util-0.6.9/CHANGELOG.md000064400000000000000000000117330072674642500136060ustar 00000000000000# 0.6.9 (October 29, 2021) ### Added - codec: implement `Clone` for `LengthDelimitedCodec` ([#4089]) - io: add `SyncIoBridge` ([#4146]) ### Fixed - time: update deadline on removal in `DelayQueue` ([#4178]) - codec: Update stream impl for Framed to return None after Err ([#4166]) [#4089]: https://github.com/tokio-rs/tokio/pull/4089 [#4146]: https://github.com/tokio-rs/tokio/pull/4146 [#4166]: https://github.com/tokio-rs/tokio/pull/4166 [#4178]: https://github.com/tokio-rs/tokio/pull/4178 # 0.6.8 (September 3, 2021) ### Added - sync: add drop guard for `CancellationToken` ([#3839]) - compact: added `AsyncSeek` compat ([#4078]) - time: expose `Key` used in `DelayQueue`'s `Expired` ([#4081]) - io: add `with_capacity` to `ReaderStream` ([#4086]) ### Fixed - codec: remove unnecessary `doc(cfg(...))` ([#3989]) [#3839]: https://github.com/tokio-rs/tokio/pull/3839 [#4078]: https://github.com/tokio-rs/tokio/pull/4078 [#4081]: https://github.com/tokio-rs/tokio/pull/4081 [#4086]: https://github.com/tokio-rs/tokio/pull/4086 [#3989]: https://github.com/tokio-rs/tokio/pull/3989 # 0.6.7 (May 14, 2021) ### Added - udp: make `UdpFramed` take `Borrow` ([#3451]) - compat: implement `AsRawFd`/`AsRawHandle` for `Compat` ([#3765]) [#3451]: https://github.com/tokio-rs/tokio/pull/3451 [#3765]: https://github.com/tokio-rs/tokio/pull/3765 # 0.6.6 (April 12, 2021) ### Added - util: makes `Framed` and `FramedStream` resumable after eof ([#3272]) - util: add `PollSemaphore::{add_permits, available_permits}` ([#3683]) ### Fixed - chore: avoid allocation if `PollSemaphore` is unused ([#3634]) [#3272]: https://github.com/tokio-rs/tokio/pull/3272 [#3634]: https://github.com/tokio-rs/tokio/pull/3634 [#3683]: https://github.com/tokio-rs/tokio/pull/3683 # 0.6.5 (March 20, 2021) ### Fixed - util: annotate time module as requiring `time` feature ([#3606]) [#3606]: https://github.com/tokio-rs/tokio/pull/3606 # 0.6.4 (March 9, 2021) ### Added - codec: `AnyDelimiter` codec ([#3406]) - sync: add pollable `mpsc::Sender` ([#3490]) ### Fixed - codec: `LinesCodec` should only return `MaxLineLengthExceeded` once per line ([#3556]) - sync: fuse PollSemaphore ([#3578]) [#3406]: https://github.com/tokio-rs/tokio/pull/3406 [#3490]: https://github.com/tokio-rs/tokio/pull/3490 [#3556]: https://github.com/tokio-rs/tokio/pull/3556 [#3578]: https://github.com/tokio-rs/tokio/pull/3578 # 0.6.3 (January 31, 2021) ### Added - sync: add `ReusableBoxFuture` utility ([#3464]) ### Changed - sync: use `ReusableBoxFuture` for `PollSemaphore` ([#3463]) - deps: remove `async-stream` dependency ([#3463]) - deps: remove `tokio-stream` dependency ([#3487]) # 0.6.2 (January 21, 2021) ### Added - sync: add pollable `Semaphore` ([#3444]) ### Fixed - time: fix panics on updating `DelayQueue` entries ([#3270]) # 0.6.1 (January 12, 2021) ### Added - codec: `get_ref()`, `get_mut()`, `get_pin_mut()` and `into_inner()` for `Framed`, `FramedRead`, `FramedWrite` and `StreamReader` ([#3364]). - codec: `write_buffer()` and `write_buffer_mut()` for `Framed` and `FramedWrite` ([#3387]). # 0.6.0 (December 23, 2020) ### Changed - depend on `tokio` 1.0. ### Added - rt: add constructors to `TokioContext` (#3221). # 0.5.1 (December 3, 2020) ### Added - io: `poll_read_buf` util fn (#2972). - io: `poll_write_buf` util fn with vectored write support (#3156). # 0.5.0 (October 30, 2020) ### Changed - io: update `bytes` to 0.6 (#3071). # 0.4.0 (October 15, 2020) ### Added - sync: `CancellationToken` for coordinating task cancellation (#2747). - rt: `TokioContext` sets the Tokio runtime for the duration of a future (#2791) - io: `StreamReader`/`ReaderStream` map between `AsyncRead` values and `Stream` of bytes (#2788). - time: `DelayQueue` to manage many delays (#2897). # 0.3.1 (March 18, 2020) ### Fixed - Adjust minimum-supported Tokio version to v0.2.5 to account for an internal dependency on features in that version of Tokio. ([#2326]) # 0.3.0 (March 4, 2020) ### Changed - **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746]) ### Added - Add futures-io/tokio::io compatibility layer. ([#2117]) - Add `Framed::with_capacity`. ([#2215]) ### Fixed - Use advance over split_to when data is not needed. ([#2198]) # 0.2.0 (November 26, 2019) - Initial release [#3487]: https://github.com/tokio-rs/tokio/pull/3487 [#3464]: https://github.com/tokio-rs/tokio/pull/3464 [#3463]: https://github.com/tokio-rs/tokio/pull/3463 [#3444]: https://github.com/tokio-rs/tokio/pull/3444 [#3387]: https://github.com/tokio-rs/tokio/pull/3387 [#3364]: https://github.com/tokio-rs/tokio/pull/3364 [#3270]: https://github.com/tokio-rs/tokio/pull/3270 [#2326]: https://github.com/tokio-rs/tokio/pull/2326 [#2215]: https://github.com/tokio-rs/tokio/pull/2215 [#2198]: https://github.com/tokio-rs/tokio/pull/2198 [#2117]: https://github.com/tokio-rs/tokio/pull/2117 [#1746]: https://github.com/tokio-rs/tokio/pull/1746 tokio-util-0.6.9/Cargo.toml0000644000000036370000000000100111570ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2018" name = "tokio-util" version = "0.6.9" authors = ["Tokio Contributors "] description = "Additional utilities for working with Tokio.\n" homepage = "https://tokio.rs" documentation = "https://docs.rs/tokio-util/0.6.9/tokio_util" categories = ["asynchronous"] license = "MIT" repository = "https://github.com/tokio-rs/tokio" [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies.bytes] version = "1.0.0" [dependencies.futures-core] version = "0.3.0" [dependencies.futures-io] version = "0.3.0" optional = true [dependencies.futures-sink] version = "0.3.0" [dependencies.futures-util] version = "0.3.0" optional = true [dependencies.log] version = "0.4" [dependencies.pin-project-lite] version = "0.2.0" [dependencies.slab] version = "0.4.1" optional = true [dependencies.tokio] version = "1.0.0" features = ["sync"] [dev-dependencies.async-stream] version = "0.3.0" [dev-dependencies.futures] version = "0.3.0" [dev-dependencies.futures-test] version = "0.3.5" [dev-dependencies.tokio] version = "1.0.0" features = ["full"] [dev-dependencies.tokio-stream] version = "0.1" [dev-dependencies.tokio-test] version = "0.4.0" [features] __docs_rs = ["futures-util"] codec = [] compat = ["futures-io"] default = [] full = ["codec", "compat", "io-util", "time", "net", "rt"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] net = ["tokio/net"] rt = ["tokio/rt"] time = ["tokio/time", "slab"] tokio-util-0.6.9/Cargo.toml.orig000064400000000000000000000030600072674642500146560ustar 00000000000000[package] name = "tokio-util" # When releasing to crates.io: # - Remove path dependencies # - Update doc url # - Cargo.toml # - Update CHANGELOG.md. # - Create "tokio-util-0.6.x" git tag. version = "0.6.9" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" documentation = "https://docs.rs/tokio-util/0.6.9/tokio_util" description = """ Additional utilities for working with Tokio. """ categories = ["asynchronous"] [features] # No features on by default default = [] # Shorthand for enabling everything full = ["codec", "compat", "io-util", "time", "net", "rt"] net = ["tokio/net"] compat = ["futures-io",] codec = [] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] rt = ["tokio/rt"] __docs_rs = ["futures-util"] [dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["sync"] } bytes = "1.0.0" futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } futures-util = { version = "0.3.0", optional = true } log = "0.4" pin-project-lite = "0.2.0" slab = { version = "0.4.1", optional = true } # Backs `DelayQueue` [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } async-stream = "0.3.0" futures = "0.3.0" futures-test = "0.3.5" [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] tokio-util-0.6.9/LICENSE000064400000000000000000000020460072674642500127770ustar 00000000000000Copyright (c) 2021 Tokio Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. tokio-util-0.6.9/README.md000064400000000000000000000005000072674642500132420ustar 00000000000000# tokio-util Utilities for working with Tokio. ## License This project is licensed under the [MIT license](LICENSE). ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in Tokio by you, shall be licensed as MIT, without any additional terms or conditions. tokio-util-0.6.9/src/cfg.rs000064400000000000000000000026760072674642500136770ustar 00000000000000macro_rules! cfg_codec { ($($item:item)*) => { $( #[cfg(feature = "codec")] #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] $item )* } } macro_rules! cfg_compat { ($($item:item)*) => { $( #[cfg(feature = "compat")] #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] $item )* } } macro_rules! cfg_net { ($($item:item)*) => { $( #[cfg(all(feature = "net", feature = "codec"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] $item )* } } macro_rules! cfg_io { ($($item:item)*) => { $( #[cfg(feature = "io")] #[cfg_attr(docsrs, doc(cfg(feature = "io")))] $item )* } } cfg_io! { macro_rules! cfg_io_util { ($($item:item)*) => { $( #[cfg(feature = "io-util")] #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] $item )* } } } macro_rules! cfg_rt { ($($item:item)*) => { $( #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] $item )* } } macro_rules! cfg_time { ($($item:item)*) => { $( #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] $item )* } } tokio-util-0.6.9/src/codec/any_delimiter_codec.rs000064400000000000000000000226520072674642500201730ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::{cmp, fmt, io, str, usize}; const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r"; const DEFAULT_SEQUENCE_WRITER: &[u8] = b","; /// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// /// # Example /// Decode string of bytes containing various different delimiters. /// /// [`BytesMut`]: bytes::BytesMut /// [`Error`]: std::io::Error /// /// ``` /// use tokio_util::codec::{AnyDelimiterCodec, Decoder}; /// use bytes::{BufMut, BytesMut}; /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> Result<(), std::io::Error> { /// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec()); /// let buf = &mut BytesMut::new(); /// buf.reserve(200); /// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); /// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("", codec.decode(buf).unwrap().unwrap()); /// assert_eq!(None, codec.decode(buf).unwrap()); /// # Ok(()) /// # } /// ``` /// #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct AnyDelimiterCodec { // Stored index of the next index to examine for the delimiter character. // This is used to optimize searching. // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`, // because that is the next index to examine. // The next time `decode` is called with `abcde}`, the method will // only look at `de}` before returning. next_index: usize, /// The maximum length for a given chunk. If `usize::MAX`, chunks will be /// read until a delimiter character is reached. max_length: usize, /// Are we currently discarding the remainder of a chunk which was over /// the length limit? is_discarding: bool, /// The bytes that are using for search during decode seek_delimiters: Vec, /// The bytes that are using for encoding sequence_writer: Vec, } impl AnyDelimiterCodec { /// Returns a `AnyDelimiterCodec` for splitting up data into chunks. /// /// # Note /// /// The returned `AnyDelimiterCodec` will not have an upper bound on the length /// of a buffered chunk. See the documentation for [`new_with_max_length`] /// for information on why this could be a potential security risk. /// /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length() pub fn new(seek_delimiters: Vec, sequence_writer: Vec) -> AnyDelimiterCodec { AnyDelimiterCodec { next_index: 0, max_length: usize::MAX, is_discarding: false, seek_delimiters, sequence_writer, } } /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit. /// /// If this is set, calls to `AnyDelimiterCodec::decode` will return a /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls /// will discard up to `limit` bytes from that chunk until a delimiter /// character is reached, returning `None` until the delimiter over the limit /// has been fully discarded. After that point, calls to `decode` will /// function as normal. /// /// # Note /// /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which /// will be exposed to untrusted input. Otherwise, the size of the buffer /// that holds the chunk currently being read is unbounded. An attacker could /// exploit this unbounded buffer by sending an unbounded amount of input /// without any delimiter characters, causing unbounded memory consumption. /// /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError pub fn new_with_max_length( seek_delimiters: Vec, sequence_writer: Vec, max_length: usize, ) -> Self { AnyDelimiterCodec { max_length, ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer) } } /// Returns the maximum chunk length when decoding. /// /// ``` /// use std::usize; /// use tokio_util::codec::AnyDelimiterCodec; /// /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec()); /// assert_eq!(codec.max_length(), usize::MAX); /// ``` /// ``` /// use tokio_util::codec::AnyDelimiterCodec; /// /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256); /// assert_eq!(codec.max_length(), 256); /// ``` pub fn max_length(&self) -> usize { self.max_length } } impl Decoder for AnyDelimiterCodec { type Item = Bytes; type Error = AnyDelimiterCodecError; fn decode(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { loop { // Determine how far into the buffer we'll search for a delimiter. If // there's no max_length set, we'll read to the end of the buffer. let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| { self.seek_delimiters .iter() .any(|delimiter| *b == *delimiter) }); match (self.is_discarding, new_chunk_offset) { (true, Some(offset)) => { // If we found a new chunk, discard up to that offset and // then stop discarding. On the next iteration, we'll try // to read a chunk normally. buf.advance(offset + self.next_index + 1); self.is_discarding = false; self.next_index = 0; } (true, None) => { // Otherwise, we didn't find a new chunk, so we'll discard // everything we read. On the next iteration, we'll continue // discarding up to max_len bytes unless we find a new chunk. buf.advance(read_to); self.next_index = 0; if buf.is_empty() { return Ok(None); } } (false, Some(offset)) => { // Found a chunk! let new_chunk_index = offset + self.next_index; self.next_index = 0; let mut chunk = buf.split_to(new_chunk_index + 1); chunk.truncate(chunk.len() - 1); let chunk = chunk.freeze(); return Ok(Some(chunk)); } (false, None) if buf.len() > self.max_length => { // Reached the maximum length without finding a // new chunk, return an error and start discarding on the // next call. self.is_discarding = true; return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded); } (false, None) => { // We didn't find a chunk or reach the length limit, so the next // call will resume searching at the current offset. self.next_index = read_to; return Ok(None); } } } } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { Ok(match self.decode(buf)? { Some(frame) => Some(frame), None => { // return remaining data, if any if buf.is_empty() { None } else { let chunk = buf.split_to(buf.len()); self.next_index = 0; Some(chunk.freeze()) } } }) } } impl Encoder for AnyDelimiterCodec where T: AsRef, { type Error = AnyDelimiterCodecError; fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> { let chunk = chunk.as_ref(); buf.reserve(chunk.len() + 1); buf.put(chunk.as_bytes()); buf.put(self.sequence_writer.as_ref()); Ok(()) } } impl Default for AnyDelimiterCodec { fn default() -> Self { Self::new( DEFAULT_SEEK_DELIMITERS.to_vec(), DEFAULT_SEQUENCE_WRITER.to_vec(), ) } } /// An error occurred while encoding or decoding a chunk. #[derive(Debug)] pub enum AnyDelimiterCodecError { /// The maximum chunk length was exceeded. MaxChunkLengthExceeded, /// An IO error occurred. Io(io::Error), } impl fmt::Display for AnyDelimiterCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AnyDelimiterCodecError::MaxChunkLengthExceeded => { write!(f, "max chunk length exceeded") } AnyDelimiterCodecError::Io(e) => write!(f, "{}", e), } } } impl From for AnyDelimiterCodecError { fn from(e: io::Error) -> AnyDelimiterCodecError { AnyDelimiterCodecError::Io(e) } } impl std::error::Error for AnyDelimiterCodecError {} tokio-util-0.6.9/src/codec/bytes_codec.rs000064400000000000000000000037300072674642500164700ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{BufMut, Bytes, BytesMut}; use std::io; /// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// /// # Example /// /// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`BytesMut`]: bytes::BytesMut /// [`Error`]: std::io::Error /// /// ``` /// # mod hidden { /// # #[allow(unused_imports)] /// use tokio::fs::File; /// # } /// use tokio::io::AsyncRead; /// use tokio_util::codec::{FramedRead, BytesCodec}; /// /// # enum File {} /// # impl File { /// # async fn open(_name: &str) -> Result { /// # use std::io::Cursor; /// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5])) /// # } /// # } /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> Result<(), std::io::Error> { /// let my_async_read = File::open("filename.txt").await?; /// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new()); /// # Ok(()) /// # } /// ``` /// #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] pub struct BytesCodec(()); impl BytesCodec { /// Creates a new `BytesCodec` for shipping around raw bytes. pub fn new() -> BytesCodec { BytesCodec(()) } } impl Decoder for BytesCodec { type Item = BytesMut; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { if !buf.is_empty() { let len = buf.len(); Ok(Some(buf.split_to(len))) } else { Ok(None) } } } impl Encoder for BytesCodec { type Error = io::Error; fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put(data); Ok(()) } } tokio-util-0.6.9/src/codec/decoder.rs000064400000000000000000000176360072674642500156240ustar 00000000000000use crate::codec::Framed; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use std::io; /// Decoding of frames via buffers. /// /// This trait is used when constructing an instance of [`Framed`] or /// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has /// already been buffered in `src` and decodes the data into a stream of /// `Self::Item` frames. /// /// Implementations are able to track state on `self`, which enables /// implementing stateful streaming parsers. In many cases, though, this type /// will simply be a unit struct (e.g. `struct HttpDecoder`). /// /// For some underlying data-sources, namely files and FIFOs, /// it's possible to temporarily read 0 bytes by reaching EOF. /// /// In these cases `decode_eof` will be called until it signals /// fullfillment of all closing frames by returning `Ok(None)`. /// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`] /// will not invoke `decode` or `decode_eof` again, until data can be read /// during a retry. /// /// It is up to the Decoder to keep track of a restart after an EOF, /// and to decide how to handle such an event by, for example, /// allowing frames to cross EOF boundaries, re-emitting opening frames, or /// resetting the entire internal state. /// /// [`Framed`]: crate::codec::Framed /// [`FramedRead`]: crate::codec::FramedRead pub trait Decoder { /// The type of decoded frames. type Item; /// The type of unrecoverable frame decoding errors. /// /// If an individual message is ill-formed but can be ignored without /// interfering with the processing of future messages, it may be more /// useful to report the failure as an `Item`. /// /// `From` is required in the interest of making `Error` suitable /// for returning directly from a [`FramedRead`], and to enable the default /// implementation of `decode_eof` to yield an `io::Error` when the decoder /// fails to consume all available data. /// /// Note that implementors of this trait can simply indicate `type Error = /// io::Error` to use I/O errors as this type. /// /// [`FramedRead`]: crate::codec::FramedRead type Error: From; /// Attempts to decode a frame from the provided buffer of bytes. /// /// This method is called by [`FramedRead`] whenever bytes are ready to be /// parsed. The provided buffer of bytes is what's been read so far, and /// this instance of `Decode` can determine whether an entire frame is in /// the buffer and is ready to be returned. /// /// If an entire frame is available, then this instance will remove those /// bytes from the buffer provided and return them as a decoded /// frame. Note that removing bytes from the provided buffer doesn't always /// necessarily copy the bytes, so this should be an efficient operation in /// most circumstances. /// /// If the bytes look valid, but a frame isn't fully available yet, then /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that /// it needs to read some more bytes before calling this method again. /// /// Note that the bytes provided may be empty. If a previous call to /// `decode` consumed all the bytes in the buffer then `decode` will be /// called again until it returns `Ok(None)`, indicating that more bytes need to /// be read. /// /// Finally, if the bytes in the buffer are malformed then an error is /// returned indicating why. This informs [`Framed`] that the stream is now /// corrupt and should be terminated. /// /// [`Framed`]: crate::codec::Framed /// [`FramedRead`]: crate::codec::FramedRead /// /// # Buffer management /// /// Before returning from the function, implementations should ensure that /// the buffer has appropriate capacity in anticipation of future calls to /// `decode`. Failing to do so leads to inefficiency. /// /// For example, if frames have a fixed length, or if the length of the /// current frame is known from a header, a possible buffer management /// strategy is: /// /// ```no_run /// # use std::io; /// # /// # use bytes::BytesMut; /// # use tokio_util::codec::Decoder; /// # /// # struct MyCodec; /// # /// impl Decoder for MyCodec { /// // ... /// # type Item = BytesMut; /// # type Error = io::Error; /// /// fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { /// // ... /// /// // Reserve enough to complete decoding of the current frame. /// let current_frame_len: usize = 1000; // Example. /// // And to start decoding the next frame. /// let next_frame_header_len: usize = 10; // Example. /// src.reserve(current_frame_len + next_frame_header_len); /// /// return Ok(None); /// } /// } /// ``` /// /// An optimal buffer management strategy minimizes reallocations and /// over-allocations. fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; /// A default method available to be called when there are no more bytes /// available to be read from the underlying I/O. /// /// This method defaults to calling `decode` and returns an error if /// `Ok(None)` is returned while there is unconsumed data in `buf`. /// Typically this doesn't need to be implemented unless the framing /// protocol differs near the end of the stream, or if you need to construct /// frames _across_ eof boundaries on sources that can be resumed. /// /// Note that the `buf` argument may be empty. If a previous call to /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be /// called again until it returns `None`, indicating that there are no more /// frames to yield. This behavior enables returning finalization frames /// that may not be based on inbound data. /// /// Once `None` has been returned, `decode_eof` won't be called again until /// an attempt to resume the stream has been made, where the underlying stream /// actually returned more data. fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { match self.decode(buf)? { Some(frame) => Ok(Some(frame)), None => { if buf.is_empty() { Ok(None) } else { Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) } } } } /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// `Io` object, using `Decode` and `Encode` to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the `Codec` /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both `Stream` and /// `Sink`; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling `split` on the [`Framed`] returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Framed`]: crate::codec::Framed fn framed(self, io: T) -> Framed where Self: Sized, { Framed::new(io, self) } } tokio-util-0.6.9/src/codec/encoder.rs000064400000000000000000000016020072674642500156200ustar 00000000000000use bytes::BytesMut; use std::io; /// Trait of helper objects to write out messages as bytes, for use with /// [`FramedWrite`]. /// /// [`FramedWrite`]: crate::codec::FramedWrite pub trait Encoder { /// The type of encoding errors. /// /// [`FramedWrite`] requires `Encoder`s errors to implement `From` /// in the interest letting it return `Error`s directly. /// /// [`FramedWrite`]: crate::codec::FramedWrite type Error: From; /// Encodes a frame into the buffer provided. /// /// This method will encode `item` into the byte buffer provided by `dst`. /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and /// will be written out when possible. /// /// [`FramedWrite`]: crate::codec::FramedWrite fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; } tokio-util-0.6.9/src/codec/framed.rs000064400000000000000000000304720072674642500154460ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; use futures_core::Stream; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using /// the `Encoder` and `Decoder` traits to encode and decode frames. /// /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or /// by using the `new` function seen below. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Decoder::framed`]: crate::codec::Decoder::framed() pub struct Framed { #[pin] inner: FramedImpl } } impl Framed where T: AsyncRead + AsyncWrite, { /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the codec /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// Note that, for some byte sources, the stream can be resumed after an EOF /// by reading from it, even after it has returned `None`. Repeated attempts /// to do so, without new data available, continue to return `None` without /// creating more (closing) frames. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decode`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn new(inner: T, codec: U) -> Framed { Framed { inner: FramedImpl { inner, codec, state: Default::default(), }, } } /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data, /// with a specific read buffer initial capacity. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the codec /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decode`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { Framed { inner: FramedImpl { inner, codec, state: RWFrames { read: ReadFrame { eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), has_errored: false, }, write: WriteFrame::default(), }, }, } } } impl Framed { /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the `Codec` /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// This objects takes a stream and a readbuffer and a writebuffer. These field /// can be obtained from an existing `Framed` with the [`into_parts`] method. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`into_parts`]: crate::codec::Framed::into_parts() /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn from_parts(parts: FramedParts) -> Framed { Framed { inner: FramedImpl { inner: parts.io, codec: parts.codec, state: RWFrames { read: parts.read_buf.into(), write: parts.write_buf.into(), }, }, } } /// Returns a reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &U { &self.inner.codec } /// Returns a mutable reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut U { &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.inner.state.read.buffer } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.read.buffer } /// Returns a reference to the write buffer. pub fn write_buffer(&self) -> &BytesMut { &self.inner.state.write.buffer } /// Returns a mutable reference to the write buffer. pub fn write_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.write.buffer } /// Consumes the `Framed`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Consumes the `Framed`, returning its underlying I/O stream, the buffer /// with unprocessed data, and the codec. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_parts(self) -> FramedParts { FramedParts { io: self.inner.inner, codec: self.inner.codec, read_buf: self.inner.state.read.buffer, write_buf: self.inner.state.write.buffer, _priv: (), } } } // This impl just defers to the underlying FramedImpl impl Stream for Framed where T: AsyncRead, U: Decoder, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } // This impl just defers to the underlying FramedImpl impl Sink for Framed where T: AsyncWrite, U: Encoder, U::Error: From, { type Error = U::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } impl fmt::Debug for Framed where T: fmt::Debug, U: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Framed") .field("io", self.get_ref()) .field("codec", self.codec()) .finish() } } /// `FramedParts` contains an export of the data of a Framed transport. /// It can be used to construct a new [`Framed`] with a different codec. /// It contains all current buffers and the inner transport. /// /// [`Framed`]: crate::codec::Framed #[derive(Debug)] #[allow(clippy::manual_non_exhaustive)] pub struct FramedParts { /// The inner transport used to read bytes to and write bytes to pub io: T, /// The codec pub codec: U, /// The buffer with read but unprocessed data. pub read_buf: BytesMut, /// A buffer with unprocessed data which are not written yet. pub write_buf: BytesMut, /// This private field allows us to add additional fields in the future in a /// backwards compatible way. _priv: (), } impl FramedParts { /// Create a new, default, `FramedParts` pub fn new(io: T, codec: U) -> FramedParts where U: Encoder, { FramedParts { io, codec, read_buf: BytesMut::new(), write_buf: BytesMut::new(), _priv: (), } } } tokio-util-0.6.9/src/codec/framed_impl.rs000064400000000000000000000305640072674642500164710ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use futures_core::Stream; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_core::ready; use futures_sink::Sink; use log::trace; use pin_project_lite::pin_project; use std::borrow::{Borrow, BorrowMut}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { #[derive(Debug)] pub(crate) struct FramedImpl { #[pin] pub(crate) inner: T, pub(crate) state: State, pub(crate) codec: U, } } const INITIAL_CAPACITY: usize = 8 * 1024; const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; #[derive(Debug)] pub(crate) struct ReadFrame { pub(crate) eof: bool, pub(crate) is_readable: bool, pub(crate) buffer: BytesMut, pub(crate) has_errored: bool, } pub(crate) struct WriteFrame { pub(crate) buffer: BytesMut, } #[derive(Default)] pub(crate) struct RWFrames { pub(crate) read: ReadFrame, pub(crate) write: WriteFrame, } impl Default for ReadFrame { fn default() -> Self { Self { eof: false, is_readable: false, buffer: BytesMut::with_capacity(INITIAL_CAPACITY), has_errored: false, } } } impl Default for WriteFrame { fn default() -> Self { Self { buffer: BytesMut::with_capacity(INITIAL_CAPACITY), } } } impl From for ReadFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer, is_readable: size > 0, eof: false, has_errored: false, } } } impl From for WriteFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer } } } impl Borrow for RWFrames { fn borrow(&self) -> &ReadFrame { &self.read } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut ReadFrame { &mut self.read } } impl Borrow for RWFrames { fn borrow(&self) -> &WriteFrame { &self.write } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut WriteFrame { &mut self.write } } impl Stream for FramedImpl where T: AsyncRead, U: Decoder, R: BorrowMut, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_read_buf; let mut pinned = self.project(); let state: &mut ReadFrame = pinned.state.borrow_mut(); // The following loops implements a state machine with each state corresponding // to a combination of the `is_readable` and `eof` flags. States persist across // loop entries and most state transitions occur with a return. // // The initial state is `reading`. // // | state | eof | is_readable | has_errored | // |---------|-------|-------------|-------------| // | reading | false | false | false | // | framing | false | true | false | // | pausing | true | true | false | // | paused | true | false | false | // | errored | | | true | // `decode_eof` returns Err // ┌────────────────────────────────────────────────────────┐ // `decode_eof` returns │ │ // `Ok(Some)` │ │ // ┌─────┐ │ `decode_eof` returns After returning │ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ // Pending read │ │ │ │ │ │ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ // │ │ │ ┌──────┐ │ Pending │ │ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ // └──┬─▲────┘ └─────┬──┬┘ │ │ // │ │ │ │ `decode` returns Err │ │ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ // │ read returns Err │ // └────────────────────────────────────────────────────────────────────────────────────────────┘ loop { // Return `None` if we have encountered an error from the underlying decoder // See: https://github.com/tokio-rs/tokio/issues/3976 if state.has_errored { // preparing has_errored -> paused trace!("Returning None and setting paused"); state.is_readable = false; state.has_errored = false; return Poll::Ready(None); } // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", // i.e. it _might_ contain data consumable as a frame or closing frame. // Both signal that there is no such data by returning `None`. // // If `decode` couldn't read a frame and the upstream source has returned eof, // `decode_eof` will attempt to decode the remaining bytes as closing frames. // // If the underlying AsyncRead is resumable, we may continue after an EOF, // but must finish emitting all of it's associated `decode_eof` frames. // Furthermore, we don't want to emit any `decode_eof` frames on retried // reads after an EOF unless we've actually read more data. if state.is_readable { // pausing or framing if state.eof { // pausing let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { trace!("Got an error, going to errored state"); state.has_errored = true; err })?; if frame.is_none() { state.is_readable = false; // prepare pausing -> paused } // implicit pausing -> pausing or pausing -> paused return Poll::Ready(frame.map(Ok)); } // framing trace!("attempting to decode a frame"); if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { trace!("Got an error, going to errored state"); state.has_errored = true; op })? { trace!("frame decoded from buffer"); // implicit framing -> framing return Poll::Ready(Some(Ok(frame))); } // framing -> reading state.is_readable = false; } // reading or paused // If we can't build a frame yet, try to read more data and try again. // Make sure we've got room for at least one byte to read to ensure // that we don't get a spurious 0 that looks like EOF. state.buffer.reserve(1); let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |err| { trace!("Got an error, going to errored state"); state.has_errored = true; err }, )? { Poll::Ready(ct) => ct, // implicit reading -> reading or implicit paused -> paused Poll::Pending => return Poll::Pending, }; if bytect == 0 { if state.eof { // We're already at an EOF, and since we've reached this path // we're also not readable. This implies that we've already finished // our `decode_eof` handling, so we can simply return `None`. // implicit paused -> paused return Poll::Ready(None); } // prepare reading -> paused state.eof = true; } else { // prepare paused -> framing or noop reading -> framing state.eof = false; } // paused -> framing or reading -> framing or reading -> pausing state.is_readable = true; } } } impl Sink for FramedImpl where T: AsyncWrite, U: Encoder, U::Error: From, W: BorrowMut, { type Error = U::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { self.as_mut().poll_flush(cx) } else { Poll::Ready(Ok(())) } } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { let pinned = self.project(); pinned .codec .encode(item, &mut pinned.state.borrow_mut().buffer)?; Ok(()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_write_buf; trace!("flushing framed transport"); let mut pinned = self.project(); while !pinned.state.borrow_mut().buffer.is_empty() { let WriteFrame { buffer } = pinned.state.borrow_mut(); trace!("writing; remaining={}", buffer.len()); let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; if n == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "failed to \ write frame to transport", ) .into())); } } // Try flushing the underlying IO ready!(pinned.inner.poll_flush(cx))?; trace!("framed transport flushed"); Poll::Ready(Ok(())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; ready!(self.project().inner.poll_shutdown(cx))?; Poll::Ready(Ok(())) } } tokio-util-0.6.9/src/codec/framed_read.rs000064400000000000000000000120420072674642500164320ustar 00000000000000use crate::codec::framed_impl::{FramedImpl, ReadFrame}; use crate::codec::Decoder; use futures_core::Stream; use tokio::io::AsyncRead; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A [`Stream`] of messages decoded from an [`AsyncRead`]. /// /// [`Stream`]: futures_core::Stream /// [`AsyncRead`]: tokio::io::AsyncRead pub struct FramedRead { #[pin] inner: FramedImpl, } } // ===== impl FramedRead ===== impl FramedRead where T: AsyncRead, D: Decoder, { /// Creates a new `FramedRead` with the given `decoder`. pub fn new(inner: T, decoder: D) -> FramedRead { FramedRead { inner: FramedImpl { inner, codec: decoder, state: Default::default(), }, } } /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity` /// initial size. pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead { FramedRead { inner: FramedImpl { inner, codec: decoder, state: ReadFrame { eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), has_errored: false, }, }, } } } impl FramedRead { /// Returns a reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Consumes the `FramedRead`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Returns a reference to the underlying decoder. pub fn decoder(&self) -> &D { &self.inner.codec } /// Returns a mutable reference to the underlying decoder. pub fn decoder_mut(&mut self) -> &mut D { &mut self.inner.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.inner.state.buffer } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.buffer } } // This impl just defers to the underlying FramedImpl impl Stream for FramedRead where T: AsyncRead, D: Decoder, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } // This impl just defers to the underlying T: Sink impl Sink for FramedRead where T: Sink, { type Error = T::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_close(cx) } } impl fmt::Debug for FramedRead where T: fmt::Debug, D: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedRead") .field("inner", &self.get_ref()) .field("decoder", &self.decoder()) .field("eof", &self.inner.state.eof) .field("is_readable", &self.inner.state.is_readable) .field("buffer", &self.read_buffer()) .finish() } } tokio-util-0.6.9/src/codec/framed_write.rs000064400000000000000000000104530072674642500166550ustar 00000000000000use crate::codec::encoder::Encoder; use crate::codec::framed_impl::{FramedImpl, WriteFrame}; use futures_core::Stream; use tokio::io::AsyncWrite; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A [`Sink`] of frames encoded to an `AsyncWrite`. /// /// [`Sink`]: futures_sink::Sink pub struct FramedWrite { #[pin] inner: FramedImpl, } } impl FramedWrite where T: AsyncWrite, { /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite { FramedWrite { inner: FramedImpl { inner, codec: encoder, state: WriteFrame::default(), }, } } } impl FramedWrite { /// Returns a reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Consumes the `FramedWrite`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Returns a reference to the underlying encoder. pub fn encoder(&self) -> &E { &self.inner.codec } /// Returns a mutable reference to the underlying encoder. pub fn encoder_mut(&mut self) -> &mut E { &mut self.inner.codec } /// Returns a reference to the write buffer. pub fn write_buffer(&self) -> &BytesMut { &self.inner.state.buffer } /// Returns a mutable reference to the write buffer. pub fn write_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.buffer } } // This impl just defers to the underlying FramedImpl impl Sink for FramedWrite where T: AsyncWrite, E: Encoder, E::Error: From, { type Error = E::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } // This impl just defers to the underlying T: Stream impl Stream for FramedWrite where T: Stream, { type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_next(cx) } } impl fmt::Debug for FramedWrite where T: fmt::Debug, U: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") .field("inner", &self.get_ref()) .field("encoder", &self.encoder()) .field("buffer", &self.inner.state.buffer) .finish() } } tokio-util-0.6.9/src/codec/length_delimited.rs000064400000000000000000000735000072674642500175100ustar 00000000000000//! Frame a stream of bytes based on a length prefix //! //! Many protocols delimit their frames by prefacing frame data with a //! frame head that specifies the length of the frame. The //! `length_delimited` module provides utilities for handling the length //! based framing. This allows the consumer to work with entire frames //! without having to worry about buffering or other framing logic. //! //! # Getting started //! //! If implementing a protocol from scratch, using length delimited framing //! is an easy way to get started. [`LengthDelimitedCodec::new()`] will //! return a length delimited codec using default configuration values. //! This can then be used to construct a framer to adapt a full-duplex //! byte stream into a stream of frames. //! //! ``` //! use tokio::io::{AsyncRead, AsyncWrite}; //! use tokio_util::codec::{Framed, LengthDelimitedCodec}; //! //! fn bind_transport(io: T) //! -> Framed //! { //! Framed::new(io, LengthDelimitedCodec::new()) //! } //! # pub fn main() {} //! ``` //! //! The returned transport implements `Sink + Stream` for `BytesMut`. It //! encodes the frame with a big-endian `u32` header denoting the frame //! payload length: //! //! ```text //! +----------+--------------------------------+ //! | len: u32 | frame payload | //! +----------+--------------------------------+ //! ``` //! //! Specifically, given the following: //! //! ``` //! use tokio::io::{AsyncRead, AsyncWrite}; //! use tokio_util::codec::{Framed, LengthDelimitedCodec}; //! //! use futures::SinkExt; //! use bytes::Bytes; //! //! async fn write_frame(io: T) -> Result<(), Box> //! where //! T: AsyncRead + AsyncWrite + Unpin, //! { //! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); //! let frame = Bytes::from("hello world"); //! //! transport.send(frame).await?; //! Ok(()) //! } //! ``` //! //! The encoded frame will look like this: //! //! ```text //! +---- len: u32 ----+---- data ----+ //! | \x00\x00\x00\x0b | hello world | //! +------------------+--------------+ //! ``` //! //! # Decoding //! //! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], //! such that each yielded [`BytesMut`] value contains the contents of an //! entire frame. There are many configuration parameters enabling //! [`FramedRead`] to handle a wide range of protocols. Here are some //! examples that will cover the various options at a high level. //! //! ## Example 1 //! //! The following will parse a `u16` length field at offset 0, including the //! frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(2) //! .length_adjustment(0) // default value //! .num_skip(0) // Do not strip frame header //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ //! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | //! +----------+---------------+ +----------+---------------+ //! ``` //! //! The value of the length field is 11 (`\x0B`) which represents the length //! of the payload, `hello world`. By default, [`FramedRead`] assumes that //! the length field represents the number of bytes that **follows** the //! length field. Thus, the entire frame has a length of 13: 2 bytes for the //! frame head + 11 bytes for the payload. //! //! ## Example 2 //! //! The following will parse a `u16` length field at offset 0, omitting the //! frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(2) //! .length_adjustment(0) // default value //! // `num_skip` is not needed, the default is to skip //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +--- Payload ---+ //! | \x00\x0B | Hello world | --> | Hello world | //! +----------+---------------+ +---------------+ //! ``` //! //! This is similar to the first example, the only difference is that the //! frame head is **not** included in the yielded `BytesMut` value. //! //! ## Example 3 //! //! The following will parse a `u16` length field at offset 0, including the //! frame head in the yielded `BytesMut`. In this case, the length field //! **includes** the frame head length. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(2) //! .length_adjustment(-2) // size of head //! .num_skip(0) //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ //! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | //! +----------+---------------+ +----------+---------------+ //! ``` //! //! In most cases, the length field represents the length of the payload //! only, as shown in the previous examples. However, in some protocols the //! length field represents the length of the whole frame, including the //! head. In such cases, we specify a negative `length_adjustment` to adjust //! the value provided in the frame head to represent the payload length. //! //! ## Example 4 //! //! The following will parse a 3 byte length field at offset 0 in a 5 byte //! frame head, including the frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(3) //! .length_adjustment(2) // remaining head //! .num_skip(0) //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +---- len -----+- head -+--- Payload ---+ //! | \x00\x00\x0B | \xCAFE | Hello world | //! +--------------+--------+---------------+ //! //! DECODED //! +---- len -----+- head -+--- Payload ---+ //! | \x00\x00\x0B | \xCAFE | Hello world | //! +--------------+--------+---------------+ //! ``` //! //! A more advanced example that shows a case where there is extra frame //! head data between the length field and the payload. In such cases, it is //! usually desirable to include the frame head as part of the yielded //! `BytesMut`. This lets consumers of the length delimited framer to //! process the frame head as needed. //! //! The positive `length_adjustment` value lets `FramedRead` factor in the //! additional head into the frame length calculation. //! //! ## Example 5 //! //! The following will parse a `u16` length field at offset 1 of a 4 byte //! frame head. The first byte and the length field will be omitted from the //! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be //! included. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 //! .length_field_length(2) //! .length_adjustment(1) // length of hdr2 //! .num_skip(3) // length of hdr1 + LEN //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ //! | \xCA | \x00\x0B | \xFE | Hello world | //! +--------+----------+--------+---------------+ //! //! DECODED //! +- hdr2 -+--- Payload ---+ //! | \xFE | Hello world | //! +--------+---------------+ //! ``` //! //! The length field is situated in the middle of the frame head. In this //! case, the first byte in the frame head could be a version or some other //! identifier that is not needed for processing. On the other hand, the //! second half of the head is needed. //! //! `length_field_offset` indicates how many bytes to skip before starting //! to read the length field. `length_adjustment` is the number of bytes to //! skip starting at the end of the length field. In this case, it is the //! second half of the head. //! //! ## Example 6 //! //! The following will parse a `u16` length field at offset 1 of a 4 byte //! frame head. The first byte and the length field will be omitted from the //! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be //! included. In this case, the length field **includes** the frame head //! length. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 //! .length_field_length(2) //! .length_adjustment(-3) // length of hdr1 + LEN, negative //! .num_skip(3) //! .new_read(io); //! # } //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ //! | \xCA | \x00\x0F | \xFE | Hello world | //! +--------+----------+--------+---------------+ //! //! DECODED //! +- hdr2 -+--- Payload ---+ //! | \xFE | Hello world | //! +--------+---------------+ //! ``` //! //! Similar to the example above, the difference is that the length field //! represents the length of the entire frame instead of just the payload. //! The length of `hdr1` and `len` must be counted in `length_adjustment`. //! Note that the length of `hdr2` does **not** need to be explicitly set //! anywhere because it already is factored into the total frame length that //! is read from the byte stream. //! //! ## Example 7 //! //! The following will parse a 3 byte length field at offset 0 in a 4 byte //! frame head, excluding the 4th byte from the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(3) //! .length_adjustment(0) // default value //! .num_skip(4) // skip the first 4 bytes //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +------- len ------+--- Payload ---+ +--- Payload ---+ //! | \x00\x00\x0B\xFF | Hello world | => | Hello world | //! +------------------+---------------+ +---------------+ //! ``` //! //! A simple example where there are unused bytes between the length field //! and the payload. //! //! # Encoding //! //! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], //! such that each submitted [`BytesMut`] is prefaced by a length field. //! There are fewer configuration options than [`FramedRead`]. Given //! protocols that have more complex frame heads, an encoder should probably //! be written by hand using [`Encoder`]. //! //! Here is a simple example, given a `FramedWrite` with the following //! configuration: //! //! ``` //! # use tokio::io::AsyncWrite; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn write_frame(io: T) { //! # let _ = //! LengthDelimitedCodec::builder() //! .length_field_length(2) //! .new_write(io); //! # } //! # pub fn main() {} //! ``` //! //! A payload of `hello world` will be encoded as: //! //! ```text //! +- len: u16 -+---- data ----+ //! | \x00\x0b | hello world | //! +------------+--------------+ //! ``` //! //! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new //! [`FramedRead`]: struct@FramedRead //! [`FramedWrite`]: struct@FramedWrite //! [`AsyncRead`]: trait@tokio::io::AsyncRead //! [`AsyncWrite`]: trait@tokio::io::AsyncWrite //! [`Encoder`]: trait@Encoder //! [`BytesMut`]: bytes::BytesMut use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::error::Error as StdError; use std::io::{self, Cursor}; use std::{cmp, fmt}; /// Configure length delimited `LengthDelimitedCodec`s. /// /// `Builder` enables constructing configured length delimited codecs. Note /// that not all configuration settings apply to both encoding and decoding. See /// the documentation for specific methods for more detail. #[derive(Debug, Clone, Copy)] pub struct Builder { // Maximum frame length max_frame_len: usize, // Number of bytes representing the field length length_field_len: usize, // Number of bytes in the header before the length field length_field_offset: usize, // Adjust the length specified in the header field by this amount length_adjustment: isize, // Total number of bytes to skip before reading the payload, if not set, // `length_field_len + length_field_offset` num_skip: Option, // Length field byte order (little or big endian) length_field_is_big_endian: bool, } /// An error when the number of bytes read is more than max frame length. pub struct LengthDelimitedCodecError { _priv: (), } /// A codec for frames delimited by a frame head specifying their lengths. /// /// This allows the consumer to work with entire frames without having to worry /// about buffering or other framing logic. /// /// See [module level] documentation for more detail. /// /// [module level]: index.html #[derive(Debug, Clone)] pub struct LengthDelimitedCodec { // Configuration values builder: Builder, // Read state state: DecodeState, } #[derive(Debug, Clone, Copy)] enum DecodeState { Head, Data(usize), } // ===== impl LengthDelimitedCodec ====== impl LengthDelimitedCodec { /// Creates a new `LengthDelimitedCodec` with the default configuration values. pub fn new() -> Self { Self { builder: Builder::new(), state: DecodeState::Head, } } /// Creates a new length delimited codec builder with default configuration /// values. pub fn builder() -> Builder { Builder::new() } /// Returns the current max frame setting /// /// This is the largest size this codec will accept from the wire. Larger /// frames will be rejected. pub fn max_frame_length(&self) -> usize { self.builder.max_frame_len } /// Updates the max frame setting. /// /// The change takes effect the next time a frame is decoded. In other /// words, if a frame is currently in process of being decoded with a frame /// size greater than `val` but less than the max frame length in effect /// before calling this function, then the frame will be allowed. pub fn set_max_frame_length(&mut self, val: usize) { self.builder.max_frame_length(val); } fn decode_head(&mut self, src: &mut BytesMut) -> io::Result> { let head_len = self.builder.num_head_bytes(); let field_len = self.builder.length_field_len; if src.len() < head_len { // Not enough data return Ok(None); } let n = { let mut src = Cursor::new(&mut *src); // Skip the required bytes src.advance(self.builder.length_field_offset); // match endianness let n = if self.builder.length_field_is_big_endian { src.get_uint(field_len) } else { src.get_uint_le(field_len) }; if n > self.builder.max_frame_len as u64 { return Err(io::Error::new( io::ErrorKind::InvalidData, LengthDelimitedCodecError { _priv: () }, )); } // The check above ensures there is no overflow let n = n as usize; // Adjust `n` with bounds checking let n = if self.builder.length_adjustment < 0 { n.checked_sub(-self.builder.length_adjustment as usize) } else { n.checked_add(self.builder.length_adjustment as usize) }; // Error handling match n { Some(n) => n, None => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "provided length would overflow after adjustment", )); } } }; let num_skip = self.builder.get_num_skip(); if num_skip > 0 { src.advance(num_skip); } // Ensure that the buffer has enough space to read the incoming // payload src.reserve(n); Ok(Some(n)) } fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option { // At this point, the buffer has already had the required capacity // reserved. All there is to do is read. if src.len() < n { return None; } Some(src.split_to(n)) } } impl Decoder for LengthDelimitedCodec { type Item = BytesMut; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> io::Result> { let n = match self.state { DecodeState::Head => match self.decode_head(src)? { Some(n) => { self.state = DecodeState::Data(n); n } None => return Ok(None), }, DecodeState::Data(n) => n, }; match self.decode_data(n, src) { Some(data) => { // Update the decode state self.state = DecodeState::Head; // Make sure the buffer has enough space to read the next head src.reserve(self.builder.num_head_bytes()); Ok(Some(data)) } None => Ok(None), } } } impl Encoder for LengthDelimitedCodec { type Error = io::Error; fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { let n = data.len(); if n > self.builder.max_frame_len { return Err(io::Error::new( io::ErrorKind::InvalidInput, LengthDelimitedCodecError { _priv: () }, )); } // Adjust `n` with bounds checking let n = if self.builder.length_adjustment < 0 { n.checked_add(-self.builder.length_adjustment as usize) } else { n.checked_sub(self.builder.length_adjustment as usize) }; let n = n.ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "provided length would overflow after adjustment", ) })?; // Reserve capacity in the destination buffer to fit the frame and // length field (plus adjustment). dst.reserve(self.builder.length_field_len + n); if self.builder.length_field_is_big_endian { dst.put_uint(n as u64, self.builder.length_field_len); } else { dst.put_uint_le(n as u64, self.builder.length_field_len); } // Write the frame to the buffer dst.extend_from_slice(&data[..]); Ok(()) } } impl Default for LengthDelimitedCodec { fn default() -> Self { Self::new() } } // ===== impl Builder ===== impl Builder { /// Creates a new length delimited codec builder with default configuration /// values. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_length(2) /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn new() -> Builder { Builder { // Default max frame length of 8MB max_frame_len: 8 * 1_024 * 1_024, // Default byte length of 4 length_field_len: 4, // Default to the header field being at the start of the header. length_field_offset: 0, length_adjustment: 0, // Total number of bytes to skip before reading the payload, if not set, // `length_field_len + length_field_offset` num_skip: None, // Default to reading the length field in network (big) endian. length_field_is_big_endian: true, } } /// Read the length field as a big endian integer /// /// This is the default setting. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .big_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn big_endian(&mut self) -> &mut Self { self.length_field_is_big_endian = true; self } /// Read the length field as a little endian integer /// /// The default setting is big endian. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .little_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn little_endian(&mut self) -> &mut Self { self.length_field_is_big_endian = false; self } /// Read the length field as a native endian integer /// /// The default setting is big endian. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .native_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn native_endian(&mut self) -> &mut Self { if cfg!(target_endian = "big") { self.big_endian() } else { self.little_endian() } } /// Sets the max frame length /// /// This configuration option applies to both encoding and decoding. The /// default value is 8MB. /// /// When decoding, the length field read from the byte stream is checked /// against this setting **before** any adjustments are applied. When /// encoding, the length of the submitted payload is checked against this /// setting. /// /// When frames exceed the max length, an `io::Error` with the custom value /// of the `LengthDelimitedCodecError` type will be returned. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .max_frame_length(8 * 1024) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn max_frame_length(&mut self, val: usize) -> &mut Self { self.max_frame_len = val; self } /// Sets the number of bytes used to represent the length field /// /// The default value is `4`. The max value is `8`. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_length(4) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_field_length(&mut self, val: usize) -> &mut Self { assert!(val > 0 && val <= 8, "invalid length field length"); self.length_field_len = val; self } /// Sets the number of bytes in the header before the length field /// /// This configuration option only applies to decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(1) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_field_offset(&mut self, val: usize) -> &mut Self { self.length_field_offset = val; self } /// Delta between the payload length specified in the header and the real /// payload length /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_adjustment(-2) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_adjustment(&mut self, val: isize) -> &mut Self { self.length_adjustment = val; self } /// Sets the number of bytes to skip before reading the payload /// /// Default value is `length_field_len + length_field_offset` /// /// This configuration option only applies to decoding /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .num_skip(4) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn num_skip(&mut self, val: usize) -> &mut Self { self.num_skip = Some(val); self } /// Create a configured length delimited `LengthDelimitedCodec` /// /// # Examples /// /// ``` /// use tokio_util::codec::LengthDelimitedCodec; /// # pub fn main() { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_length(2) /// .length_adjustment(0) /// .num_skip(0) /// .new_codec(); /// # } /// ``` pub fn new_codec(&self) -> LengthDelimitedCodec { LengthDelimitedCodec { builder: *self, state: DecodeState::Head, } } /// Create a configured length delimited `FramedRead` /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_length(2) /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn new_read(&self, upstream: T) -> FramedRead where T: AsyncRead, { FramedRead::new(upstream, self.new_codec()) } /// Create a configured length delimited `FramedWrite` /// /// # Examples /// /// ``` /// # use tokio::io::AsyncWrite; /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn write_frame(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_length(2) /// .new_write(io); /// # } /// # pub fn main() {} /// ``` pub fn new_write(&self, inner: T) -> FramedWrite where T: AsyncWrite, { FramedWrite::new(inner, self.new_codec()) } /// Create a configured length delimited `Framed` /// /// # Examples /// /// ``` /// # use tokio::io::{AsyncRead, AsyncWrite}; /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn write_frame(io: T) { /// # let _ = /// LengthDelimitedCodec::builder() /// .length_field_length(2) /// .new_framed(io); /// # } /// # pub fn main() {} /// ``` pub fn new_framed(&self, inner: T) -> Framed where T: AsyncRead + AsyncWrite, { Framed::new(inner, self.new_codec()) } fn num_head_bytes(&self) -> usize { let num = self.length_field_offset + self.length_field_len; cmp::max(num, self.num_skip.unwrap_or(0)) } fn get_num_skip(&self) -> usize { self.num_skip .unwrap_or(self.length_field_offset + self.length_field_len) } } impl Default for Builder { fn default() -> Self { Self::new() } } // ===== impl LengthDelimitedCodecError ===== impl fmt::Debug for LengthDelimitedCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LengthDelimitedCodecError").finish() } } impl fmt::Display for LengthDelimitedCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("frame size too big") } } impl StdError for LengthDelimitedCodecError {} tokio-util-0.6.9/src/codec/lines_codec.rs000064400000000000000000000176440072674642500164650ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{Buf, BufMut, BytesMut}; use std::{cmp, fmt, io, str, usize}; /// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct LinesCodec { // Stored index of the next index to examine for a `\n` character. // This is used to optimize searching. // For example, if `decode` was called with `abc`, it would hold `3`, // because that is the next index to examine. // The next time `decode` is called with `abcde\n`, the method will // only look at `de\n` before returning. next_index: usize, /// The maximum length for a given line. If `usize::MAX`, lines will be /// read until a `\n` character is reached. max_length: usize, /// Are we currently discarding the remainder of a line which was over /// the length limit? is_discarding: bool, } impl LinesCodec { /// Returns a `LinesCodec` for splitting up data into lines. /// /// # Note /// /// The returned `LinesCodec` will not have an upper bound on the length /// of a buffered line. See the documentation for [`new_with_max_length`] /// for information on why this could be a potential security risk. /// /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() pub fn new() -> LinesCodec { LinesCodec { next_index: 0, max_length: usize::MAX, is_discarding: false, } } /// Returns a `LinesCodec` with a maximum line length limit. /// /// If this is set, calls to `LinesCodec::decode` will return a /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls /// will discard up to `limit` bytes from that line until a newline /// character is reached, returning `None` until the line over the limit /// has been fully discarded. After that point, calls to `decode` will /// function as normal. /// /// # Note /// /// Setting a length limit is highly recommended for any `LinesCodec` which /// will be exposed to untrusted input. Otherwise, the size of the buffer /// that holds the line currently being read is unbounded. An attacker could /// exploit this unbounded buffer by sending an unbounded amount of input /// without any `\n` characters, causing unbounded memory consumption. /// /// [`LinesCodecError`]: crate::codec::LinesCodecError pub fn new_with_max_length(max_length: usize) -> Self { LinesCodec { max_length, ..LinesCodec::new() } } /// Returns the maximum line length when decoding. /// /// ``` /// use std::usize; /// use tokio_util::codec::LinesCodec; /// /// let codec = LinesCodec::new(); /// assert_eq!(codec.max_length(), usize::MAX); /// ``` /// ``` /// use tokio_util::codec::LinesCodec; /// /// let codec = LinesCodec::new_with_max_length(256); /// assert_eq!(codec.max_length(), 256); /// ``` pub fn max_length(&self) -> usize { self.max_length } } fn utf8(buf: &[u8]) -> Result<&str, io::Error> { str::from_utf8(buf) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) } fn without_carriage_return(s: &[u8]) -> &[u8] { if let Some(&b'\r') = s.last() { &s[..s.len() - 1] } else { s } } impl Decoder for LinesCodec { type Item = String; type Error = LinesCodecError; fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { loop { // Determine how far into the buffer we'll search for a newline. If // there's no max_length set, we'll read to the end of the buffer. let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); let newline_offset = buf[self.next_index..read_to] .iter() .position(|b| *b == b'\n'); match (self.is_discarding, newline_offset) { (true, Some(offset)) => { // If we found a newline, discard up to that offset and // then stop discarding. On the next iteration, we'll try // to read a line normally. buf.advance(offset + self.next_index + 1); self.is_discarding = false; self.next_index = 0; } (true, None) => { // Otherwise, we didn't find a newline, so we'll discard // everything we read. On the next iteration, we'll continue // discarding up to max_len bytes unless we find a newline. buf.advance(read_to); self.next_index = 0; if buf.is_empty() { return Ok(None); } } (false, Some(offset)) => { // Found a line! let newline_index = offset + self.next_index; self.next_index = 0; let line = buf.split_to(newline_index + 1); let line = &line[..line.len() - 1]; let line = without_carriage_return(line); let line = utf8(line)?; return Ok(Some(line.to_string())); } (false, None) if buf.len() > self.max_length => { // Reached the maximum length without finding a // newline, return an error and start discarding on the // next call. self.is_discarding = true; return Err(LinesCodecError::MaxLineLengthExceeded); } (false, None) => { // We didn't find a line or reach the length limit, so the next // call will resume searching at the current offset. self.next_index = read_to; return Ok(None); } } } } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { Ok(match self.decode(buf)? { Some(frame) => Some(frame), None => { // No terminating newline - return remaining data, if any if buf.is_empty() || buf == &b"\r"[..] { None } else { let line = buf.split_to(buf.len()); let line = without_carriage_return(&line); let line = utf8(line)?; self.next_index = 0; Some(line.to_string()) } } }) } } impl Encoder for LinesCodec where T: AsRef, { type Error = LinesCodecError; fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { let line = line.as_ref(); buf.reserve(line.len() + 1); buf.put(line.as_bytes()); buf.put_u8(b'\n'); Ok(()) } } impl Default for LinesCodec { fn default() -> Self { Self::new() } } /// An error occurred while encoding or decoding a line. #[derive(Debug)] pub enum LinesCodecError { /// The maximum line length was exceeded. MaxLineLengthExceeded, /// An IO error occurred. Io(io::Error), } impl fmt::Display for LinesCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), LinesCodecError::Io(e) => write!(f, "{}", e), } } } impl From for LinesCodecError { fn from(e: io::Error) -> LinesCodecError { LinesCodecError::Io(e) } } impl std::error::Error for LinesCodecError {} tokio-util-0.6.9/src/codec/mod.rs000064400000000000000000000253270072674642500147720ustar 00000000000000//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink //! //! Raw I/O objects work with byte sequences, but higher-level code usually //! wants to batch these into meaningful chunks, called "frames". //! //! This module contains adapters to go from streams of bytes, [`AsyncRead`] and //! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. //! Framed streams are also known as transports. //! //! # The Decoder trait //! //! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an //! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify //! how sequences of bytes are turned into a sequence of frames, and to //! determine where the boundaries between frames are. The job of the //! `FramedRead` is to repeatedly switch between reading more data from the IO //! resource, and asking the decoder whether we have received enough data to //! decode another frame of data. //! //! The main method on the `Decoder` trait is the [`decode`] method. This method //! takes as argument the data that has been read so far, and when it is called, //! it will be in one of the following situations: //! //! 1. The buffer contains less than a full frame. //! 2. The buffer contains exactly a full frame. //! 3. The buffer contains more than a full frame. //! //! In the first situation, the decoder should return `Ok(None)`. //! //! In the second situation, the decoder should clear the provided buffer and //! return `Ok(Some(the_decoded_frame))`. //! //! In the third situation, the decoder should use a method such as [`split_to`] //! or [`advance`] to modify the buffer such that the frame is removed from the //! buffer, but any data in the buffer after that frame should still remain in //! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in //! this case. //! //! Finally the decoder may return an error if the data is invalid in some way. //! The decoder should _not_ return an error just because it has yet to receive //! a full frame. //! //! It is guaranteed that, from one call to `decode` to another, the provided //! buffer will contain the exact same data as before, except that if more data //! has arrived through the IO resource, that data will have been appended to //! the buffer. This means that reading frames from a `FramedRead` is //! essentially equivalent to the following loop: //! //! ```no_run //! use tokio::io::AsyncReadExt; //! # // This uses async_stream to create an example that compiles. //! # fn foo() -> impl futures_core::Stream> { async_stream::try_stream! { //! # use tokio_util::codec::Decoder; //! # let mut decoder = tokio_util::codec::BytesCodec::new(); //! # let io_resource = &mut &[0u8, 1, 2, 3][..]; //! //! let mut buf = bytes::BytesMut::new(); //! loop { //! // The read_buf call will append to buf rather than overwrite existing data. //! let len = io_resource.read_buf(&mut buf).await?; //! //! if len == 0 { //! while let Some(frame) = decoder.decode_eof(&mut buf)? { //! yield frame; //! } //! break; //! } //! //! while let Some(frame) = decoder.decode(&mut buf)? { //! yield frame; //! } //! } //! # }} //! ``` //! The example above uses `yield` whenever the `Stream` produces an item. //! //! ## Example decoder //! //! As an example, consider a protocol that can be used to send strings where //! each frame is a four byte integer that contains the length of the frame, //! followed by that many bytes of string data. The decoder fails with an error //! if the string data is not valid utf-8 or too long. //! //! Such a decoder can be written like this: //! ``` //! use tokio_util::codec::Decoder; //! use bytes::{BytesMut, Buf}; //! //! struct MyStringDecoder {} //! //! const MAX: usize = 8 * 1024 * 1024; //! //! impl Decoder for MyStringDecoder { //! type Item = String; //! type Error = std::io::Error; //! //! fn decode( //! &mut self, //! src: &mut BytesMut //! ) -> Result, Self::Error> { //! if src.len() < 4 { //! // Not enough data to read length marker. //! return Ok(None); //! } //! //! // Read length marker. //! let mut length_bytes = [0u8; 4]; //! length_bytes.copy_from_slice(&src[..4]); //! let length = u32::from_le_bytes(length_bytes) as usize; //! //! // Check that the length is not too large to avoid a denial of //! // service attack where the server runs out of memory. //! if length > MAX { //! return Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! format!("Frame of length {} is too large.", length) //! )); //! } //! //! if src.len() < 4 + length { //! // The full string has not yet arrived. //! // //! // We reserve more space in the buffer. This is not strictly //! // necessary, but is a good idea performance-wise. //! src.reserve(4 + length - src.len()); //! //! // We inform the Framed that we need more bytes to form the next //! // frame. //! return Ok(None); //! } //! //! // Use advance to modify src such that it no longer contains //! // this frame. //! let data = src[4..4 + length].to_vec(); //! src.advance(4 + length); //! //! // Convert the data to a string, or fail if it is not valid utf-8. //! match String::from_utf8(data) { //! Ok(string) => Ok(Some(string)), //! Err(utf8_error) => { //! Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! utf8_error.utf8_error(), //! )) //! }, //! } //! } //! } //! ``` //! //! # The Encoder trait //! //! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn //! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to //! specify how frames are turned into a sequences of bytes. The job of the //! `FramedWrite` is to take the resulting sequence of bytes and write it to the //! IO resource. //! //! The main method on the `Encoder` trait is the [`encode`] method. This method //! takes an item that is being written, and a buffer to write the item to. The //! buffer may already contain data, and in this case, the encoder should append //! the new frame the to buffer rather than overwrite the existing data. //! //! It is guaranteed that, from one call to `encode` to another, the provided //! buffer will contain the exact same data as before, except that some of the //! data may have been removed from the front of the buffer. Writing to a //! `FramedWrite` is essentially equivalent to the following loop: //! //! ```no_run //! use tokio::io::AsyncWriteExt; //! use bytes::Buf; // for advance //! # use tokio_util::codec::Encoder; //! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() } //! # async fn no_more_frames() { } //! # #[tokio::main] async fn main() -> std::io::Result<()> { //! # let mut io_resource = tokio::io::sink(); //! # let mut encoder = tokio_util::codec::BytesCodec::new(); //! //! const MAX: usize = 8192; //! //! let mut buf = bytes::BytesMut::new(); //! loop { //! tokio::select! { //! num_written = io_resource.write(&buf), if !buf.is_empty() => { //! buf.advance(num_written?); //! }, //! frame = next_frame(), if buf.len() < MAX => { //! encoder.encode(frame, &mut buf)?; //! }, //! _ = no_more_frames() => { //! io_resource.write_all(&buf).await?; //! io_resource.shutdown().await?; //! return Ok(()); //! }, //! } //! } //! # } //! ``` //! Here the `next_frame` method corresponds to any frames you write to the //! `FramedWrite`. The `no_more_frames` method corresponds to closing the //! `FramedWrite` with [`SinkExt::close`]. //! //! ## Example encoder //! //! As an example, consider a protocol that can be used to send strings where //! each frame is a four byte integer that contains the length of the frame, //! followed by that many bytes of string data. The encoder will fail if the //! string is too long. //! //! Such an encoder can be written like this: //! ``` //! use tokio_util::codec::Encoder; //! use bytes::BytesMut; //! //! struct MyStringEncoder {} //! //! const MAX: usize = 8 * 1024 * 1024; //! //! impl Encoder for MyStringEncoder { //! type Error = std::io::Error; //! //! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> { //! // Don't send a string if it is longer than the other end will //! // accept. //! if item.len() > MAX { //! return Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! format!("Frame of length {} is too large.", item.len()) //! )); //! } //! //! // Convert the length into a byte array. //! // The cast to u32 cannot overflow due to the length check above. //! let len_slice = u32::to_le_bytes(item.len() as u32); //! //! // Reserve space in the buffer. //! dst.reserve(4 + item.len()); //! //! // Write the length and string to the buffer. //! dst.extend_from_slice(&len_slice); //! dst.extend_from_slice(item.as_bytes()); //! Ok(()) //! } //! } //! ``` //! //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite //! [`Stream`]: futures_core::Stream //! [`Sink`]: futures_sink::Sink //! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close //! [`FramedRead`]: struct@crate::codec::FramedRead //! [`FramedWrite`]: struct@crate::codec::FramedWrite //! [`Framed`]: struct@crate::codec::Framed //! [`Decoder`]: trait@crate::codec::Decoder //! [`decode`]: fn@crate::codec::Decoder::decode //! [`encode`]: fn@crate::codec::Encoder::encode //! [`split_to`]: fn@bytes::BytesMut::split_to //! [`advance`]: fn@bytes::Buf::advance mod bytes_codec; pub use self::bytes_codec::BytesCodec; mod decoder; pub use self::decoder::Decoder; mod encoder; pub use self::encoder::Encoder; mod framed_impl; #[allow(unused_imports)] pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; mod framed; pub use self::framed::{Framed, FramedParts}; mod framed_read; pub use self::framed_read::FramedRead; mod framed_write; pub use self::framed_write::FramedWrite; pub mod length_delimited; pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; mod lines_codec; pub use self::lines_codec::{LinesCodec, LinesCodecError}; mod any_delimiter_codec; pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError}; tokio-util-0.6.9/src/compat.rs000064400000000000000000000176110072674642500144160ustar 00000000000000//! Compatibility between the `tokio::io` and `futures-io` versions of the //! `AsyncRead` and `AsyncWrite` traits. use futures_core::ready; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A compatibility layer that allows conversion between the /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. #[derive(Copy, Clone, Debug)] pub struct Compat { #[pin] inner: T, seek_pos: Option, } } /// Extension trait that allows converting a type implementing /// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { /// Wraps `self` with a compatibility layer that implements /// `tokio_io::AsyncRead`. fn compat(self) -> Compat where Self: Sized, { Compat::new(self) } } impl FuturesAsyncReadCompatExt for T {} /// Extension trait that allows converting a type implementing /// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { /// Wraps `self` with a compatibility layer that implements /// `tokio::io::AsyncWrite`. fn compat_write(self) -> Compat where Self: Sized, { Compat::new(self) } } impl FuturesAsyncWriteCompatExt for T {} /// Extension trait that allows converting a type implementing /// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { /// Wraps `self` with a compatibility layer that implements /// `futures_io::AsyncRead`. fn compat(self) -> Compat where Self: Sized, { Compat::new(self) } } impl TokioAsyncReadCompatExt for T {} /// Extension trait that allows converting a type implementing /// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { /// Wraps `self` with a compatibility layer that implements /// `futures_io::AsyncWrite`. fn compat_write(self) -> Compat where Self: Sized, { Compat::new(self) } } impl TokioAsyncWriteCompatExt for T {} // === impl Compat === impl Compat { fn new(inner: T) -> Self { Self { inner, seek_pos: None, } } /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object /// contained within. pub fn get_ref(&self) -> &T { &self.inner } /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object /// contained within. pub fn get_mut(&mut self) -> &mut T { &mut self.inner } /// Returns the wrapped item. pub fn into_inner(self) -> T { self.inner } } impl tokio::io::AsyncRead for Compat where T: futures_io::AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { // We can't trust the inner type to not peak at the bytes, // so we must defensively initialize the buffer. let slice = buf.initialize_unfilled(); let n = ready!(futures_io::AsyncRead::poll_read( self.project().inner, cx, slice ))?; buf.advance(n); Poll::Ready(Ok(())) } } impl futures_io::AsyncRead for Compat where T: tokio::io::AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, slice: &mut [u8], ) -> Poll> { let mut buf = tokio::io::ReadBuf::new(slice); ready!(tokio::io::AsyncRead::poll_read( self.project().inner, cx, &mut buf ))?; Poll::Ready(Ok(buf.filled().len())) } } impl tokio::io::AsyncBufRead for Compat where T: futures_io::AsyncBufRead, { fn poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll> { futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) } fn consume(self: Pin<&mut Self>, amt: usize) { futures_io::AsyncBufRead::consume(self.project().inner, amt) } } impl futures_io::AsyncBufRead for Compat where T: tokio::io::AsyncBufRead, { fn poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll> { tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) } fn consume(self: Pin<&mut Self>, amt: usize) { tokio::io::AsyncBufRead::consume(self.project().inner, amt) } } impl tokio::io::AsyncWrite for Compat where T: futures_io::AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { futures_io::AsyncWrite::poll_flush(self.project().inner, cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { futures_io::AsyncWrite::poll_close(self.project().inner, cx) } } impl futures_io::AsyncWrite for Compat where T: tokio::io::AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) } } impl futures_io::AsyncSeek for Compat { fn poll_seek( mut self: Pin<&mut Self>, cx: &mut Context<'_>, pos: io::SeekFrom, ) -> Poll> { if self.seek_pos != Some(pos) { self.as_mut().project().inner.start_seek(pos)?; *self.as_mut().project().seek_pos = Some(pos); } let res = ready!(self.as_mut().project().inner.poll_complete(cx)); *self.as_mut().project().seek_pos = None; Poll::Ready(res.map(|p| p as u64)) } } impl tokio::io::AsyncSeek for Compat { fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { *self.as_mut().project().seek_pos = Some(pos); Ok(()) } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pos = match self.seek_pos { None => { // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. // We don't have to guarantee that the value returned by // poll_complete called without start_seek is correct, // so we'll return 0. return Poll::Ready(Ok(0)); } Some(pos) => pos, }; let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); *self.as_mut().project().seek_pos = None; Poll::Ready(res.map(|p| p as u64)) } } #[cfg(unix)] impl std::os::unix::io::AsRawFd for Compat { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { self.inner.as_raw_fd() } } #[cfg(windows)] impl std::os::windows::io::AsRawHandle for Compat { fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { self.inner.as_raw_handle() } } tokio-util-0.6.9/src/context.rs000064400000000000000000000143410072674642500146140ustar 00000000000000//! Tokio context aware futures utilities. //! //! This module includes utilities around integrating tokio with other runtimes //! by allowing the context to be attached to futures. This allows spawning //! futures on other executors while still using tokio to drive them. This //! can be useful if you need to use a tokio based library in an executor/runtime //! that does not provide a tokio context. use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::runtime::{Handle, Runtime}; pin_project! { /// `TokioContext` allows running futures that must be inside Tokio's /// context on a non-Tokio runtime. /// /// It contains a [`Handle`] to the runtime. A handle to the runtime can be /// obtain by calling the [`Runtime::handle()`] method. /// /// Note that the `TokioContext` wrapper only works if the `Runtime` it is /// connected to has not yet been destroyed. You must keep the `Runtime` /// alive until the future has finished executing. /// /// **Warning:** If `TokioContext` is used together with a [current thread] /// runtime, that runtime must be inside a call to `block_on` for the /// wrapped future to work. For this reason, it is recommended to use a /// [multi thread] runtime, even if you configure it to only spawn one /// worker thread. /// /// # Examples /// /// This example creates two runtimes, but only [enables time] on one of /// them. It then uses the context of the runtime with the timer enabled to /// execute a [`sleep`] future on the runtime with timing disabled. /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::RuntimeExt; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// // Wrap the sleep future in the context of rt. /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` /// /// [`Handle`]: struct@tokio::runtime::Handle /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle /// [`RuntimeExt`]: trait@crate::context::RuntimeExt /// [`new_static`]: fn@Self::new_static /// [`sleep`]: fn@tokio::time::sleep /// [current thread]: fn@tokio::runtime::Builder::new_current_thread /// [enables time]: fn@tokio::runtime::Builder::enable_time /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread pub struct TokioContext { #[pin] inner: F, handle: Handle, } } impl TokioContext { /// Associate the provided future with the context of the runtime behind /// the provided `Handle`. /// /// This constructor uses a `'static` lifetime to opt-out of checking that /// the runtime still exists. /// /// # Examples /// /// This is the same as the example above, but uses the `new` constructor /// rather than [`RuntimeExt::wrap`]. /// /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap /// /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::TokioContext; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// let fut = TokioContext::new( /// async { sleep(Duration::from_millis(2)).await }, /// rt.handle().clone(), /// ); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` pub fn new(future: F, handle: Handle) -> TokioContext { TokioContext { inner: future, handle, } } /// Obtain a reference to the handle inside this `TokioContext`. pub fn handle(&self) -> &Handle { &self.handle } /// Remove the association between the Tokio runtime and the wrapped future. pub fn into_inner(self) -> F { self.inner } } impl Future for TokioContext { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let me = self.project(); let handle = me.handle; let fut = me.inner; let _enter = handle.enter(); fut.poll(cx) } } /// Extension trait that simplifies bundling a `Handle` with a `Future`. pub trait RuntimeExt { /// Create a [`TokioContext`] that wraps the provided future and runs it in /// this runtime's context. /// /// # Examples /// /// This example creates two runtimes, but only [enables time] on one of /// them. It then uses the context of the runtime with the timer enabled to /// execute a [`sleep`] future on the runtime with timing disabled. /// /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::RuntimeExt; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// // Wrap the sleep future in the context of rt. /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` /// /// [`TokioContext`]: struct@crate::context::TokioContext /// [`sleep`]: fn@tokio::time::sleep /// [enables time]: fn@tokio::runtime::Builder::enable_time fn wrap(&self, fut: F) -> TokioContext; } impl RuntimeExt for Runtime { fn wrap(&self, fut: F) -> TokioContext { TokioContext { inner: fut, handle: self.handle().clone(), } } } tokio-util-0.6.9/src/either.rs000064400000000000000000000122440072674642500144100ustar 00000000000000//! Module defining an Either type. use std::{ future::Future, io::SeekFrom, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; /// Combines two different futures, streams, or sinks having the same associated types into a single type. /// /// This type implements common asynchronous traits such as [`Future`] and those in Tokio. /// /// [`Future`]: std::future::Future /// /// # Example /// /// The following code will not work: /// /// ```compile_fail /// # fn some_condition() -> bool { true } /// # async fn some_async_function() -> u32 { 10 } /// # async fn other_async_function() -> u32 { 20 } /// #[tokio::main] /// async fn main() { /// let result = if some_condition() { /// some_async_function() /// } else { /// other_async_function() // <- Will print: "`if` and `else` have incompatible types" /// }; /// /// println!("Result is {}", result.await); /// } /// ``` /// // This is because although the output types for both futures is the same, the exact future // types are different, but the compiler must be able to choose a single type for the // `result` variable. /// /// When the output type is the same, we can wrap each future in `Either` to avoid the /// issue: /// /// ``` /// use tokio_util::either::Either; /// # fn some_condition() -> bool { true } /// # async fn some_async_function() -> u32 { 10 } /// # async fn other_async_function() -> u32 { 20 } /// /// #[tokio::main] /// async fn main() { /// let result = if some_condition() { /// Either::Left(some_async_function()) /// } else { /// Either::Right(other_async_function()) /// }; /// /// let value = result.await; /// println!("Result is {}", value); /// # assert_eq!(value, 10); /// } /// ``` #[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. #[derive(Debug, Clone)] pub enum Either { Left(L), Right(R), } /// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. /// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either /// enum variant held in `self`. macro_rules! delegate_call { ($self:ident.$method:ident($($args:ident),+)) => { unsafe { match $self.get_unchecked_mut() { Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), } } } } impl Future for Either where L: Future, R: Future, { type Output = O; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { delegate_call!(self.poll(cx)) } } impl AsyncRead for Either where L: AsyncRead, R: AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { delegate_call!(self.poll_read(cx, buf)) } } impl AsyncBufRead for Either where L: AsyncBufRead, R: AsyncBufRead, { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_fill_buf(cx)) } fn consume(self: Pin<&mut Self>, amt: usize) { delegate_call!(self.consume(amt)) } } impl AsyncSeek for Either where L: AsyncSeek, R: AsyncSeek, { fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { delegate_call!(self.start_seek(position)) } fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_complete(cx)) } } impl AsyncWrite for Either where L: AsyncWrite, R: AsyncWrite, { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { delegate_call!(self.poll_write(cx, buf)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_flush(cx)) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_shutdown(cx)) } } impl futures_core::stream::Stream for Either where L: futures_core::stream::Stream, R: futures_core::stream::Stream, { type Item = L::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_next(cx)) } } #[cfg(test)] mod tests { use super::*; use tokio::io::{repeat, AsyncReadExt, Repeat}; use tokio_stream::{once, Once, StreamExt}; #[tokio::test] async fn either_is_stream() { let mut either: Either, Once> = Either::Left(once(1)); assert_eq!(Some(1u32), either.next().await); } #[tokio::test] async fn either_is_async_read() { let mut buffer = [0; 3]; let mut either: Either = Either::Right(repeat(0b101)); either.read_exact(&mut buffer).await.unwrap(); assert_eq!(buffer, [0b101, 0b101, 0b101]); } } tokio-util-0.6.9/src/io/mod.rs000064400000000000000000000014320072674642500143130ustar 00000000000000//! Helpers for IO related tasks. //! //! The stream types are often used in combination with hyper or reqwest, as they //! allow converting between a hyper [`Body`] and [`AsyncRead`]. //! //! The [`SyncIoBridge`] type converts from the world of async I/O //! to synchronous I/O; this may often come up when using synchronous APIs //! inside [`tokio::task::spawn_blocking`]. //! //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead mod read_buf; mod reader_stream; mod stream_reader; cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; pub use crate::util::{poll_read_buf, poll_write_buf}; tokio-util-0.6.9/src/io/read_buf.rs000064400000000000000000000031400072674642500153010ustar 00000000000000use bytes::BufMut; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; /// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// /// [`BufMut`]: bytes::BufMut /// /// # Example /// /// ``` /// use bytes::{Bytes, BytesMut}; /// use tokio_stream as stream; /// use tokio::io::Result; /// use tokio_util::io::{StreamReader, read_buf}; /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// /// // Create a reader from an iterator. This particular reader will always be /// // ready. /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); /// /// let mut buf = BytesMut::new(); /// let mut reads = 0; /// /// loop { /// reads += 1; /// let n = read_buf(&mut read, &mut buf).await?; /// /// if n == 0 { /// break; /// } /// } /// /// // one or more reads might be necessary. /// assert!(reads >= 1); /// assert_eq!(&buf[..], &[0, 1, 2, 3]); /// # Ok(()) /// # } /// ``` pub async fn read_buf(read: &mut R, buf: &mut B) -> io::Result where R: AsyncRead + Unpin, B: BufMut, { return ReadBufFn(read, buf).await; struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); impl<'a, R, B> Future for ReadBufFn<'a, R, B> where R: AsyncRead + Unpin, B: BufMut, { type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = &mut *self; crate::util::poll_read_buf(Pin::new(this.0), cx, this.1) } } } tokio-util-0.6.9/src/io/reader_stream.rs000064400000000000000000000067160072674642500163630ustar 00000000000000use bytes::{Bytes, BytesMut}; use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; const DEFAULT_CAPACITY: usize = 4096; pin_project! { /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. /// /// This stream is fused. It performs the inverse operation of /// [`StreamReader`]. /// /// # Example /// /// ``` /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// use tokio_stream::StreamExt; /// use tokio_util::io::ReaderStream; /// /// // Create a stream of data. /// let data = b"hello, world!"; /// let mut stream = ReaderStream::new(&data[..]); /// /// // Read all of the chunks into a vector. /// let mut stream_contents = Vec::new(); /// while let Some(chunk) = stream.next().await { /// stream_contents.extend_from_slice(&chunk?); /// } /// /// // Once the chunks are concatenated, we should have the /// // original data. /// assert_eq!(stream_contents, data); /// # Ok(()) /// # } /// ``` /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`StreamReader`]: crate::io::StreamReader /// [`Stream`]: futures_core::Stream #[derive(Debug)] pub struct ReaderStream { // Reader itself. // // This value is `None` if the stream has terminated. #[pin] reader: Option, // Working buffer, used to optimize allocations. buf: BytesMut, capacity: usize, } } impl ReaderStream { /// Convert an [`AsyncRead`] into a [`Stream`] with item type /// `Result`. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Stream`]: futures_core::Stream pub fn new(reader: R) -> Self { ReaderStream { reader: Some(reader), buf: BytesMut::new(), capacity: DEFAULT_CAPACITY, } } /// Convert an [`AsyncRead`] into a [`Stream`] with item type /// `Result`, /// with a specific read buffer initial capacity. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Stream`]: futures_core::Stream pub fn with_capacity(reader: R, capacity: usize) -> Self { ReaderStream { reader: Some(reader), buf: BytesMut::with_capacity(capacity), capacity, } } } impl Stream for ReaderStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_read_buf; let mut this = self.as_mut().project(); let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; if this.buf.capacity() == 0 { this.buf.reserve(*this.capacity); } match poll_read_buf(reader, cx, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { self.project().reader.set(None); Poll::Ready(Some(Err(err))) } Poll::Ready(Ok(0)) => { self.project().reader.set(None); Poll::Ready(None) } Poll::Ready(Ok(_)) => { let chunk = this.buf.split(); Poll::Ready(Some(Ok(chunk.freeze()))) } } } } tokio-util-0.6.9/src/io/stream_reader.rs000064400000000000000000000130670072674642500163600ustar 00000000000000use bytes::Buf; use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; pin_project! { /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. /// /// This type performs the inverse operation of [`ReaderStream`]. /// /// # Example /// /// ``` /// use bytes::Bytes; /// use tokio::io::{AsyncReadExt, Result}; /// use tokio_util::io::StreamReader; /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// /// // Create a stream from an iterator. /// let stream = tokio_stream::iter(vec![ /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), /// ]); /// /// // Convert it to an AsyncRead. /// let mut read = StreamReader::new(stream); /// /// // Read five bytes from the stream. /// let mut buf = [0; 5]; /// read.read_exact(&mut buf).await?; /// assert_eq!(buf, [0, 1, 2, 3, 4]); /// /// // Read the rest of the current chunk. /// assert_eq!(read.read(&mut buf).await?, 3); /// assert_eq!(&buf[..3], [5, 6, 7]); /// /// // Read the next chunk. /// assert_eq!(read.read(&mut buf).await?, 4); /// assert_eq!(&buf[..4], [8, 9, 10, 11]); /// /// // We have now reached the end. /// assert_eq!(read.read(&mut buf).await?, 0); /// /// # Ok(()) /// # } /// ``` /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Stream`]: futures_core::Stream /// [`ReaderStream`]: crate::io::ReaderStream #[derive(Debug)] pub struct StreamReader { #[pin] inner: S, chunk: Option, } } impl StreamReader where S: Stream>, B: Buf, E: Into, { /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). /// /// The item should be a [`Result`] with the ok variant being something that /// implements the [`Buf`] trait (e.g. `Vec` or `Bytes`). The error /// should be convertible into an [io error]. /// /// [`Result`]: std::result::Result /// [`Buf`]: bytes::Buf /// [io error]: std::io::Error pub fn new(stream: S) -> Self { Self { inner: stream, chunk: None, } } /// Do we have a chunk and is it non-empty? fn has_chunk(self: Pin<&mut Self>) -> bool { if let Some(chunk) = self.project().chunk { chunk.remaining() > 0 } else { false } } } impl StreamReader { /// Gets a reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_ref(&self) -> &S { &self.inner } /// Gets a mutable reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Gets a pinned mutable reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { self.project().inner } /// Consumes this `BufWriter`, returning the underlying stream. /// /// Note that any leftover data in the internal buffer is lost. pub fn into_inner(self) -> S { self.inner } } impl AsyncRead for StreamReader where S: Stream>, B: Buf, E: Into, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if buf.remaining() == 0 { return Poll::Ready(Ok(())); } let inner_buf = match self.as_mut().poll_fill_buf(cx) { Poll::Ready(Ok(buf)) => buf, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, }; let len = std::cmp::min(inner_buf.len(), buf.remaining()); buf.put_slice(&inner_buf[..len]); self.consume(len); Poll::Ready(Ok(())) } } impl AsyncBufRead for StreamReader where S: Stream>, B: Buf, E: Into, { fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if self.as_mut().has_chunk() { // This unwrap is very sad, but it can't be avoided. let buf = self.project().chunk.as_ref().unwrap().chunk(); return Poll::Ready(Ok(buf)); } else { match self.as_mut().project().inner.poll_next(cx) { Poll::Ready(Some(Ok(chunk))) => { // Go around the loop in case the chunk is empty. *self.as_mut().project().chunk = Some(chunk); } Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), Poll::Ready(None) => return Poll::Ready(Ok(&[])), Poll::Pending => return Poll::Pending, } } } } fn consume(self: Pin<&mut Self>, amt: usize) { if amt > 0 { self.project() .chunk .as_mut() .expect("No chunk present") .advance(amt); } } } tokio-util-0.6.9/src/io/sync_bridge.rs000064400000000000000000000075140072674642500160330ustar 00000000000000use std::io::{Read, Write}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. #[derive(Debug)] pub struct SyncIoBridge { src: T, rt: tokio::runtime::Handle, } impl Read for SyncIoBridge { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(AsyncReadExt::read(src, buf)) } fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.read_to_end(buf)) } fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.read_to_string(buf)) } fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { let src = &mut self.src; // The AsyncRead trait returns the count, synchronous doesn't. let _n = self.rt.block_on(src.read_exact(buf))?; Ok(()) } } impl Write for SyncIoBridge { fn write(&mut self, buf: &[u8]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.write(buf)) } fn flush(&mut self) -> std::io::Result<()> { let src = &mut self.src; self.rt.block_on(src.flush()) } fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { let src = &mut self.src; self.rt.block_on(src.write_all(buf)) } fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.write_vectored(bufs)) } } // Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time // of this writing still unstable, we expose this as part of a standalone method. impl SyncIoBridge { /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes. /// /// See [`tokio::io::AsyncWrite::is_write_vectored`]. pub fn is_write_vectored(&self) -> bool { self.src.is_write_vectored() } } impl SyncIoBridge { /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. /// /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`]. /// It is hence OK to move this struct into a separate thread outside the runtime, as created /// by e.g. [`tokio::task::spawn_blocking`]. /// /// Stated even more strongly: to make use of this bridge, you *must* move /// it into a separate thread outside the runtime. The synchronous I/O will use the /// underlying handle to block on the backing asynchronous source, via /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that /// function, an attempt to `block_on` from an asynchronous execution context /// will panic. /// /// # Wrapping `!Unpin` types /// /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. /// /// # Panic /// /// This will panic if called outside the context of a Tokio runtime. pub fn new(src: T) -> Self { Self::new_with_handle(src, tokio::runtime::Handle::current()) } /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. /// /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may /// be initially invoked outside of an asynchronous context. pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { Self { src, rt } } } tokio-util-0.6.9/src/lib.rs000064400000000000000000000124030072674642500136730ustar 00000000000000#![allow(clippy::needless_doctest_main)] #![warn( missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub )] #![cfg_attr(docsrs, deny(rustdoc::broken_intra_doc_links))] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] //! Utilities for working with Tokio. //! //! This crate is not versioned in lockstep with the core //! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's //! semantic versioning policy, especially with regard to breaking changes. //! //! [`tokio`]: https://docs.rs/tokio #[macro_use] mod cfg; mod loom; cfg_codec! { pub mod codec; } cfg_net! { pub mod udp; } cfg_compat! { pub mod compat; } cfg_io! { pub mod io; } cfg_rt! { pub mod context; } cfg_time! { pub mod time; } pub mod sync; pub mod either; #[cfg(any(feature = "io", feature = "codec"))] mod util { use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Buf, BufMut}; use futures_core::ready; use std::io::{self, IoSlice}; use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// /// [`BufMut`]: bytes::Buf /// /// # Example /// /// ``` /// use bytes::{Bytes, BytesMut}; /// use tokio_stream as stream; /// use tokio::io::Result; /// use tokio_util::io::{StreamReader, poll_read_buf}; /// use futures::future::poll_fn; /// use std::pin::Pin; /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// /// // Create a reader from an iterator. This particular reader will always be /// // ready. /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); /// /// let mut buf = BytesMut::new(); /// let mut reads = 0; /// /// loop { /// reads += 1; /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; /// /// if n == 0 { /// break; /// } /// } /// /// // one or more reads might be necessary. /// assert!(reads >= 1); /// assert_eq!(&buf[..], &[0, 1, 2, 3]); /// # Ok(()) /// # } /// ``` #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub fn poll_read_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { if !buf.has_remaining_mut() { return Poll::Ready(Ok(0)); } let n = { let dst = buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); let ptr = buf.filled().as_ptr(); ready!(io.poll_read(cx, &mut buf)?); // Ensure the pointer does not change from under us assert_eq!(ptr, buf.filled().as_ptr()); buf.filled().len() }; // Safety: This is guaranteed to be the number of initialized (and read) // bytes due to the invariants provided by `ReadBuf::filled`. unsafe { buf.advance_mut(n); } Poll::Ready(Ok(n)) } /// Try to write data from an implementer of the [`Buf`] trait to an /// [`AsyncWrite`], advancing the buffer's internal cursor. /// /// This function will use [vectored writes] when the [`AsyncWrite`] supports /// vectored writes. /// /// # Examples /// /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements /// [`Buf`]: /// /// ```no_run /// use tokio_util::io::poll_write_buf; /// use tokio::io; /// use tokio::fs::File; /// /// use bytes::Buf; /// use std::io::Cursor; /// use std::pin::Pin; /// use futures::future::poll_fn; /// /// #[tokio::main] /// async fn main() -> io::Result<()> { /// let mut file = File::create("foo.txt").await?; /// let mut buf = Cursor::new(b"data to write"); /// /// // Loop until the entire contents of the buffer are written to /// // the file. /// while buf.has_remaining() { /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; /// } /// /// Ok(()) /// } /// ``` /// /// [`Buf`]: bytes::Buf /// [`AsyncWrite`]: tokio::io::AsyncWrite /// [`File`]: tokio::fs::File /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub fn poll_write_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { const MAX_BUFS: usize = 64; if !buf.has_remaining() { return Poll::Ready(Ok(0)); } let n = if io.is_write_vectored() { let mut slices = [IoSlice::new(&[]); MAX_BUFS]; let cnt = buf.chunks_vectored(&mut slices); ready!(io.poll_write_vectored(cx, &slices[..cnt]))? } else { ready!(io.poll_write(cx, buf.chunk()))? }; buf.advance(n); Poll::Ready(Ok(n)) } } tokio-util-0.6.9/src/loom.rs000064400000000000000000000000320072674642500140660ustar 00000000000000pub(crate) use std::sync; tokio-util-0.6.9/src/sync/cancellation_token/guard.rs000064400000000000000000000014120072674642500210350ustar 00000000000000use crate::sync::CancellationToken; /// A wrapper for cancellation token which automatically cancels /// it on drop. It is created using `drop_guard` method on the `CancellationToken`. #[derive(Debug)] pub struct DropGuard { pub(super) inner: Option, } impl DropGuard { /// Returns stored cancellation token and removes this drop guard instance /// (i.e. it will no longer cancel token). Other guards for this token /// are not affected. pub fn disarm(mut self) -> CancellationToken { self.inner .take() .expect("`inner` can be only None in a destructor") } } impl Drop for DropGuard { fn drop(&mut self) { if let Some(inner) = &self.inner { inner.cancel(); } } } tokio-util-0.6.9/src/sync/cancellation_token.rs000064400000000000000000001011330072674642500177340ustar 00000000000000//! An asynchronously awaitable `CancellationToken`. //! The token allows to signal a cancellation request to one or more tasks. pub(crate) mod guard; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::sync::intrusive_double_linked_list::{LinkedList, ListNode}; use core::future::Future; use core::pin::Pin; use core::ptr::NonNull; use core::sync::atomic::Ordering; use core::task::{Context, Poll, Waker}; use guard::DropGuard; /// A token which can be used to signal a cancellation request to one or more /// tasks. /// /// Tasks can call [`CancellationToken::cancelled()`] in order to /// obtain a Future which will be resolved when cancellation is requested. /// /// Cancellation can be requested through the [`CancellationToken::cancel`] method. /// /// # Examples /// /// ```ignore /// use tokio::select; /// use tokio::scope::CancellationToken; /// /// #[tokio::main] /// async fn main() { /// let token = CancellationToken::new(); /// let cloned_token = token.clone(); /// /// let join_handle = tokio::spawn(async move { /// // Wait for either cancellation or a very long time /// select! { /// _ = cloned_token.cancelled() => { /// // The token was cancelled /// 5 /// } /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { /// 99 /// } /// } /// }); /// /// tokio::spawn(async move { /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; /// token.cancel(); /// }); /// /// assert_eq!(5, join_handle.await.unwrap()); /// } /// ``` pub struct CancellationToken { inner: NonNull, } // Safety: The CancellationToken is thread-safe and can be moved between threads, // since all methods are internally synchronized. unsafe impl Send for CancellationToken {} unsafe impl Sync for CancellationToken {} /// A Future that is resolved once the corresponding [`CancellationToken`] /// was cancelled #[must_use = "futures do nothing unless polled"] pub struct WaitForCancellationFuture<'a> { /// The CancellationToken that is associated with this WaitForCancellationFuture cancellation_token: Option<&'a CancellationToken>, /// Node for waiting at the cancellation_token wait_node: ListNode, /// Whether this future was registered at the token yet as a waiter is_registered: bool, } // Safety: Futures can be sent between threads as long as the underlying // cancellation_token is thread-safe (Sync), // which allows to poll/register/unregister from a different thread. unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} // ===== impl CancellationToken ===== impl core::fmt::Debug for CancellationToken { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("CancellationToken") .field("is_cancelled", &self.is_cancelled()) .finish() } } impl Clone for CancellationToken { fn clone(&self) -> Self { // Safety: The state inside a `CancellationToken` is always valid, since // is reference counted let inner = self.state(); // Tokens are cloned by increasing their refcount let current_state = inner.snapshot(); inner.increment_refcount(current_state); CancellationToken { inner: self.inner } } } impl Drop for CancellationToken { fn drop(&mut self) { let token_state_pointer = self.inner; // Safety: The state inside a `CancellationToken` is always valid, since // is reference counted let inner = unsafe { &mut *self.inner.as_ptr() }; let mut current_state = inner.snapshot(); // We need to safe the parent, since the state might be released by the // next call let parent = inner.parent; // Drop our own refcount current_state = inner.decrement_refcount(current_state); // If this was the last reference, unregister from the parent if current_state.refcount == 0 { if let Some(mut parent) = parent { // Safety: Since we still retain a reference on the parent, it must be valid. let parent = unsafe { parent.as_mut() }; parent.unregister_child(token_state_pointer, current_state); } } } } impl Default for CancellationToken { fn default() -> CancellationToken { CancellationToken::new() } } impl CancellationToken { /// Creates a new CancellationToken in the non-cancelled state. pub fn new() -> CancellationToken { let state = Box::new(CancellationTokenState::new( None, StateSnapshot { cancel_state: CancellationState::NotCancelled, has_parent_ref: false, refcount: 1, }, )); // Safety: We just created the Box. The pointer is guaranteed to be // not null CancellationToken { inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) }, } } /// Returns a reference to the utilized `CancellationTokenState`. fn state(&self) -> &CancellationTokenState { // Safety: The state inside a `CancellationToken` is always valid, since // is reference counted unsafe { &*self.inner.as_ptr() } } /// Creates a `CancellationToken` which will get cancelled whenever the /// current token gets cancelled. /// /// If the current token is already cancelled, the child token will get /// returned in cancelled state. /// /// # Examples /// /// ```ignore /// use tokio::select; /// use tokio::scope::CancellationToken; /// /// #[tokio::main] /// async fn main() { /// let token = CancellationToken::new(); /// let child_token = token.child_token(); /// /// let join_handle = tokio::spawn(async move { /// // Wait for either cancellation or a very long time /// select! { /// _ = child_token.cancelled() => { /// // The token was cancelled /// 5 /// } /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { /// 99 /// } /// } /// }); /// /// tokio::spawn(async move { /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; /// token.cancel(); /// }); /// /// assert_eq!(5, join_handle.await.unwrap()); /// } /// ``` pub fn child_token(&self) -> CancellationToken { let inner = self.state(); // Increment the refcount of this token. It will be referenced by the // child, independent of whether the child is immediately cancelled or // not. let _current_state = inner.increment_refcount(inner.snapshot()); let mut unpacked_child_state = StateSnapshot { has_parent_ref: true, refcount: 1, cancel_state: CancellationState::NotCancelled, }; let mut child_token_state = Box::new(CancellationTokenState::new( Some(self.inner), unpacked_child_state, )); { let mut guard = inner.synchronized.lock().unwrap(); if guard.is_cancelled { // This task was already cancelled. In this case we should not // insert the child into the list, since it would never get removed // from the list. (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true; unpacked_child_state.cancel_state = CancellationState::Cancelled; // Since it's not in the list, the parent doesn't need to retain // a reference to it. unpacked_child_state.has_parent_ref = false; child_token_state .state .store(unpacked_child_state.pack(), Ordering::SeqCst); } else { if let Some(mut first_child) = guard.first_child { child_token_state.from_parent.next_peer = Some(first_child); // Safety: We manipulate other child task inside the Mutex // and retain a parent reference on it. The child token can't // get invalidated while the Mutex is held. unsafe { first_child.as_mut().from_parent.prev_peer = Some((&mut *child_token_state).into()) }; } guard.first_child = Some((&mut *child_token_state).into()); } }; let child_token_ptr = Box::into_raw(child_token_state); // Safety: We just created the pointer from a `Box` CancellationToken { inner: unsafe { NonNull::new_unchecked(child_token_ptr) }, } } /// Cancel the [`CancellationToken`] and all child tokens which had been /// derived from it. /// /// This will wake up all tasks which are waiting for cancellation. pub fn cancel(&self) { self.state().cancel(); } /// Returns `true` if the `CancellationToken` had been cancelled pub fn is_cancelled(&self) -> bool { self.state().is_cancelled() } /// Returns a `Future` that gets fulfilled when cancellation is requested. pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { WaitForCancellationFuture { cancellation_token: Some(self), wait_node: ListNode::new(WaitQueueEntry::new()), is_registered: false, } } /// Creates a `DropGuard` for this token. /// /// Returned guard will cancel this token (and all its children) on drop /// unless disarmed. pub fn drop_guard(self) -> DropGuard { DropGuard { inner: Some(self) } } unsafe fn register( &self, wait_node: &mut ListNode, cx: &mut Context<'_>, ) -> Poll<()> { self.state().register(wait_node, cx) } fn check_for_cancellation( &self, wait_node: &mut ListNode, cx: &mut Context<'_>, ) -> Poll<()> { self.state().check_for_cancellation(wait_node, cx) } fn unregister(&self, wait_node: &mut ListNode) { self.state().unregister(wait_node) } } // ===== impl WaitForCancellationFuture ===== impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("WaitForCancellationFuture").finish() } } impl<'a> Future for WaitForCancellationFuture<'a> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { // Safety: We do not move anything out of `WaitForCancellationFuture` let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; let cancellation_token = mut_self .cancellation_token .expect("polled WaitForCancellationFuture after completion"); let poll_res = if !mut_self.is_registered { // Safety: The `ListNode` is pinned through the Future, // and we will unregister it in `WaitForCancellationFuture::drop` // before the Future is dropped and the memory reference is invalidated. unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) } } else { cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx) }; if let Poll::Ready(()) = poll_res { // The cancellation_token was signalled mut_self.cancellation_token = None; // A signalled Token means the Waker won't be enqueued anymore mut_self.is_registered = false; mut_self.wait_node.task = None; } else { // This `Future` and its stored `Waker` stay registered at the // `CancellationToken` mut_self.is_registered = true; } poll_res } } impl<'a> Drop for WaitForCancellationFuture<'a> { fn drop(&mut self) { // If this WaitForCancellationFuture has been polled and it was added to the // wait queue at the cancellation_token, it must be removed before dropping. // Otherwise the cancellation_token would access invalid memory. if let Some(token) = self.cancellation_token { if self.is_registered { token.unregister(&mut self.wait_node); } } } } /// Tracks how the future had interacted with the [`CancellationToken`] #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum PollState { /// The task has never interacted with the [`CancellationToken`]. New, /// The task was added to the wait queue at the [`CancellationToken`]. Waiting, /// The task has been polled to completion. Done, } /// Tracks the WaitForCancellationFuture waiting state. /// Access to this struct is synchronized through the mutex in the CancellationToken. struct WaitQueueEntry { /// The task handle of the waiting task task: Option, // Current polling state. This state is only updated inside the Mutex of // the CancellationToken. state: PollState, } impl WaitQueueEntry { /// Creates a new WaitQueueEntry fn new() -> WaitQueueEntry { WaitQueueEntry { task: None, state: PollState::New, } } } struct SynchronizedState { waiters: LinkedList, first_child: Option>, is_cancelled: bool, } impl SynchronizedState { fn new() -> Self { Self { waiters: LinkedList::new(), first_child: None, is_cancelled: false, } } } /// Information embedded in child tokens which is synchronized through the Mutex /// in their parent. struct SynchronizedThroughParent { next_peer: Option>, prev_peer: Option>, } /// Possible states of a `CancellationToken` #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum CancellationState { NotCancelled = 0, Cancelling = 1, Cancelled = 2, } impl CancellationState { fn pack(self) -> usize { self as usize } fn unpack(value: usize) -> Self { match value { 0 => CancellationState::NotCancelled, 1 => CancellationState::Cancelling, 2 => CancellationState::Cancelled, _ => unreachable!("Invalid value"), } } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] struct StateSnapshot { /// The amount of references to this particular CancellationToken. /// `CancellationToken` structs hold these references to a `CancellationTokenState`. /// Also the state is referenced by the state of each child. refcount: usize, /// Whether the state is still referenced by it's parent and can therefore /// not be freed. has_parent_ref: bool, /// Whether the token is cancelled cancel_state: CancellationState, } impl StateSnapshot { /// Packs the snapshot into a `usize` fn pack(self) -> usize { self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack() } /// Unpacks the snapshot from a `usize` fn unpack(value: usize) -> Self { let refcount = value >> 3; let has_parent_ref = value & 4 != 0; let cancel_state = CancellationState::unpack(value & 0x03); StateSnapshot { refcount, has_parent_ref, cancel_state, } } /// Whether this `CancellationTokenState` is still referenced by any /// `CancellationToken`. fn has_refs(&self) -> bool { self.refcount != 0 || self.has_parent_ref } } /// The maximum permitted amount of references to a CancellationToken. This /// is derived from the intent to never use more than 32bit in the `Snapshot`. const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3; /// Internal state of the `CancellationToken` pair above struct CancellationTokenState { state: AtomicUsize, parent: Option>, from_parent: SynchronizedThroughParent, synchronized: Mutex, } impl CancellationTokenState { fn new( parent: Option>, state: StateSnapshot, ) -> CancellationTokenState { CancellationTokenState { parent, from_parent: SynchronizedThroughParent { prev_peer: None, next_peer: None, }, state: AtomicUsize::new(state.pack()), synchronized: Mutex::new(SynchronizedState::new()), } } /// Returns a snapshot of the current atomic state of the token fn snapshot(&self) -> StateSnapshot { StateSnapshot::unpack(self.state.load(Ordering::SeqCst)) } fn atomic_update_state(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot where F: Fn(StateSnapshot) -> StateSnapshot, { let mut current_packed_state = current_state.pack(); loop { let next_state = func(current_state); match self.state.compare_exchange( current_packed_state, next_state.pack(), Ordering::SeqCst, Ordering::SeqCst, ) { Ok(_) => { return next_state; } Err(actual) => { current_packed_state = actual; current_state = StateSnapshot::unpack(actual); } } } } fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { self.atomic_update_state(current_state, |mut state: StateSnapshot| { if state.refcount >= MAX_REFS as usize { eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded"); std::process::abort(); } state.refcount += 1; state }) } fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { state.refcount -= 1; state }); // Drop the State if it is not referenced anymore if !current_state.has_refs() { // Safety: `CancellationTokenState` is always stored in refcounted // Boxes let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; } current_state } fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot { let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { state.has_parent_ref = false; state }); // Drop the State if it is not referenced anymore if !current_state.has_refs() { // Safety: `CancellationTokenState` is always stored in refcounted // Boxes let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; } current_state } /// Unregisters a child from the parent token. /// The child tokens state is not exactly known at this point in time. /// If the parent token is cancelled, the child token gets removed from the /// parents list, and might therefore already have been freed. If the parent /// token is not cancelled, the child token is still valid. fn unregister_child( &mut self, mut child_state: NonNull, current_child_state: StateSnapshot, ) { let removed_child = { // Remove the child toke from the parents linked list let mut guard = self.synchronized.lock().unwrap(); if !guard.is_cancelled { // Safety: Since the token was not cancelled, the child must // still be in the list and valid. let mut child_state = unsafe { child_state.as_mut() }; debug_assert!(child_state.snapshot().has_parent_ref); if guard.first_child == Some(child_state.into()) { guard.first_child = child_state.from_parent.next_peer; } // Safety: If peers wouldn't be valid anymore, they would try // to remove themselves from the list. This would require locking // the Mutex that we currently own. unsafe { if let Some(mut prev_peer) = child_state.from_parent.prev_peer { prev_peer.as_mut().from_parent.next_peer = child_state.from_parent.next_peer; } if let Some(mut next_peer) = child_state.from_parent.next_peer { next_peer.as_mut().from_parent.prev_peer = child_state.from_parent.prev_peer; } } child_state.from_parent.prev_peer = None; child_state.from_parent.next_peer = None; // The child is no longer referenced by the parent, since we were able // to remove its reference from the parents list. true } else { // Do not touch the linked list anymore. If the parent is cancelled // it will move all childs outside of the Mutex and manipulate // the pointers there. Manipulating the pointers here too could // lead to races. Therefore leave them just as as and let the // parent deal with it. The parent will make sure to retain a // reference to this state as long as it manipulates the list // pointers. Therefore the pointers are not dangling. false } }; if removed_child { // If the token removed itself from the parents list, it can reset // the parent ref status. If it is isn't able to do so, because the // parent removed it from the list, there is no need to do this. // The parent ref acts as as another reference count. Therefore // removing this reference can free the object. // Safety: The token was in the list. This means the parent wasn't // cancelled before, and the token must still be alive. unsafe { child_state.as_mut().remove_parent_ref(current_child_state) }; } // Decrement the refcount on the parent and free it if necessary self.decrement_refcount(self.snapshot()); } fn cancel(&self) { // Move the state of the CancellationToken from `NotCancelled` to `Cancelling` let mut current_state = self.snapshot(); let state_after_cancellation = loop { if current_state.cancel_state != CancellationState::NotCancelled { // Another task already initiated the cancellation return; } let mut next_state = current_state; next_state.cancel_state = CancellationState::Cancelling; match self.state.compare_exchange( current_state.pack(), next_state.pack(), Ordering::SeqCst, Ordering::SeqCst, ) { Ok(_) => break next_state, Err(actual) => current_state = StateSnapshot::unpack(actual), } }; // This task cancelled the token // Take the task list out of the Token // We do not want to cancel child token inside this lock. If one of the // child tasks would have additional child tokens, we would recursively // take locks. // Doing this action has an impact if the child token is dropped concurrently: // It will try to deregister itself from the parent task, but can not find // itself in the task list anymore. Therefore it needs to assume the parent // has extracted the list and will process it. It may not modify the list. // This is OK from a memory safety perspective, since the parent still // retains a reference to the child task until it finished iterating over // it. let mut first_child = { let mut guard = self.synchronized.lock().unwrap(); // Save the cancellation also inside the Mutex // This allows child tokens which want to detach themselves to detect // that this is no longer required since the parent cleared the list. guard.is_cancelled = true; // Wakeup all waiters // This happens inside the lock to make cancellation reliable // If we would access waiters outside of the lock, the pointers // may no longer be valid. // Typically this shouldn't be an issue, since waking a task should // only move it from the blocked into the ready state and not have // further side effects. // Use a reverse iterator, so that the oldest waiter gets // scheduled first guard.waiters.reverse_drain(|waiter| { // We are not allowed to move the `Waker` out of the list node. // The `Future` relies on the fact that the old `Waker` stays there // as long as the `Future` has not completed in order to perform // the `will_wake()` check. // Therefore `wake_by_ref` is used instead of `wake()` if let Some(handle) = &mut waiter.task { handle.wake_by_ref(); } // Mark the waiter to have been removed from the list. waiter.state = PollState::Done; }); guard.first_child.take() }; while let Some(mut child) = first_child { // Safety: We know this is a valid pointer since it is in our child pointer // list. It can't have been freed in between, since we retain a a reference // to each child. let mut_child = unsafe { child.as_mut() }; // Get the next child and clean up list pointers first_child = mut_child.from_parent.next_peer; mut_child.from_parent.prev_peer = None; mut_child.from_parent.next_peer = None; // Cancel the child task mut_child.cancel(); // Drop the parent reference. This `CancellationToken` is not interested // in interacting with the child anymore. // This is ONLY allowed once we promised not to touch the state anymore // after this interaction. mut_child.remove_parent_ref(mut_child.snapshot()); } // The cancellation has completed // At this point in time tasks which registered a wait node can be sure // that this wait node already had been dequeued from the list without // needing to inspect the list. self.atomic_update_state(state_after_cancellation, |mut state| { state.cancel_state = CancellationState::Cancelled; state }); } /// Returns `true` if the `CancellationToken` had been cancelled fn is_cancelled(&self) -> bool { let current_state = self.snapshot(); current_state.cancel_state != CancellationState::NotCancelled } /// Registers a waiting task at the `CancellationToken`. /// Safety: This method is only safe as long as the waiting waiting task /// will properly unregister the wait node before it gets moved. unsafe fn register( &self, wait_node: &mut ListNode, cx: &mut Context<'_>, ) -> Poll<()> { debug_assert_eq!(PollState::New, wait_node.state); let current_state = self.snapshot(); // Perform an optimistic cancellation check before. This is not strictly // necessary since we also check for cancellation in the Mutex, but // reduces the necessary work to be performed for tasks which already // had been cancelled. if current_state.cancel_state != CancellationState::NotCancelled { return Poll::Ready(()); } // So far the token is not cancelled. However it could be cancelled before // we get the chance to store the `Waker`. Therefore we need to check // for cancellation again inside the mutex. let mut guard = self.synchronized.lock().unwrap(); if guard.is_cancelled { // Cancellation was signalled wait_node.state = PollState::Done; Poll::Ready(()) } else { // Added the task to the wait queue wait_node.task = Some(cx.waker().clone()); wait_node.state = PollState::Waiting; guard.waiters.add_front(wait_node); Poll::Pending } } fn check_for_cancellation( &self, wait_node: &mut ListNode, cx: &mut Context<'_>, ) -> Poll<()> { debug_assert!( wait_node.task.is_some(), "Method can only be called after task had been registered" ); let current_state = self.snapshot(); if current_state.cancel_state != CancellationState::NotCancelled { // If the cancellation had been fully completed we know that our `Waker` // is no longer registered at the `CancellationToken`. // Otherwise the cancel call may or may not yet have iterated // through the waiters list and removed the wait nodes. // If it hasn't yet, we need to remove it. Otherwise an attempt to // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture` // getting dropped before the cancellation had interacted with it. if current_state.cancel_state != CancellationState::Cancelled { self.unregister(wait_node); } Poll::Ready(()) } else { // Check if we need to swap the `Waker`. This will make the check more // expensive, since the `Waker` is synchronized through the Mutex. // If we don't need to perform a `Waker` update, an atomic check for // cancellation is sufficient. let need_waker_update = wait_node .task .as_ref() .map(|waker| waker.will_wake(cx.waker())) .unwrap_or(true); if need_waker_update { let guard = self.synchronized.lock().unwrap(); if guard.is_cancelled { // Cancellation was signalled. Since this cancellation signal // is set inside the Mutex, the old waiter must already have // been removed from the waiting list debug_assert_eq!(PollState::Done, wait_node.state); wait_node.task = None; Poll::Ready(()) } else { // The WaitForCancellationFuture is already in the queue. // The CancellationToken can't have been cancelled, // since this would change the is_cancelled flag inside the mutex. // Therefore we just have to update the Waker. A follow-up // cancellation will always use the new waker. wait_node.task = Some(cx.waker().clone()); Poll::Pending } } else { // Do nothing. If the token gets cancelled, this task will get // woken again and can fetch the cancellation. Poll::Pending } } } fn unregister(&self, wait_node: &mut ListNode) { debug_assert!( wait_node.task.is_some(), "waiter can not be active without task" ); let mut guard = self.synchronized.lock().unwrap(); // WaitForCancellationFuture only needs to get removed if it has been added to // the wait queue of the CancellationToken. // This has happened in the PollState::Waiting case. if let PollState::Waiting = wait_node.state { // Safety: Due to the state, we know that the node must be part // of the waiter list if !unsafe { guard.waiters.remove(wait_node) } { // Panic if the address isn't found. This can only happen if the contract was // violated, e.g. the WaitQueueEntry got moved after the initial poll. panic!("Future could not be removed from wait queue"); } wait_node.state = PollState::Done; } wait_node.task = None; } } tokio-util-0.6.9/src/sync/intrusive_double_linked_list.rs000064400000000000000000000714440072674642500220560ustar 00000000000000//! An intrusive double linked list of data #![allow(dead_code, unreachable_pub)] use core::{ marker::PhantomPinned, ops::{Deref, DerefMut}, ptr::NonNull, }; /// A node which carries data of type `T` and is stored in an intrusive list #[derive(Debug)] pub struct ListNode { /// The previous node in the list. `None` if there is no previous node. prev: Option>>, /// The next node in the list. `None` if there is no previous node. next: Option>>, /// The data which is associated to this list item data: T, /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since /// the list semantics require addresses to be stable. _pin: PhantomPinned, } impl ListNode { /// Creates a new node with the associated data pub fn new(data: T) -> ListNode { Self { prev: None, next: None, data, _pin: PhantomPinned, } } } impl Deref for ListNode { type Target = T; fn deref(&self) -> &T { &self.data } } impl DerefMut for ListNode { fn deref_mut(&mut self) -> &mut T { &mut self.data } } /// An intrusive linked list of nodes, where each node carries associated data /// of type `T`. #[derive(Debug)] pub struct LinkedList { head: Option>>, tail: Option>>, } impl LinkedList { /// Creates an empty linked list pub fn new() -> Self { LinkedList:: { head: None, tail: None, } } /// Adds a node at the front of the linked list. /// Safety: This function is only safe as long as `node` is guaranteed to /// get removed from the list before it gets moved or dropped. /// In addition to this `node` may not be added to another other list before /// it is removed from the current one. pub unsafe fn add_front(&mut self, node: &mut ListNode) { node.next = self.head; node.prev = None; if let Some(mut head) = self.head { head.as_mut().prev = Some(node.into()) }; self.head = Some(node.into()); if self.tail.is_none() { self.tail = Some(node.into()); } } /// Inserts a node into the list in a way that the list keeps being sorted. /// Safety: This function is only safe as long as `node` is guaranteed to /// get removed from the list before it gets moved or dropped. /// In addition to this `node` may not be added to another other list before /// it is removed from the current one. pub unsafe fn add_sorted(&mut self, node: &mut ListNode) where T: PartialOrd, { if self.head.is_none() { // First node in the list self.head = Some(node.into()); self.tail = Some(node.into()); return; } let mut prev: Option>> = None; let mut current = self.head; while let Some(mut current_node) = current { if node.data < current_node.as_ref().data { // Need to insert before the current node current_node.as_mut().prev = Some(node.into()); match prev { Some(mut prev) => { prev.as_mut().next = Some(node.into()); } None => { // We are inserting at the beginning of the list self.head = Some(node.into()); } } node.next = current; node.prev = prev; return; } prev = current; current = current_node.as_ref().next; } // We looped through the whole list and the nodes data is bigger or equal // than everything we found up to now. // Insert at the end. Since we checked before that the list isn't empty, // tail always has a value. node.prev = self.tail; node.next = None; self.tail.as_mut().unwrap().as_mut().next = Some(node.into()); self.tail = Some(node.into()); } /// Returns the first node in the linked list without removing it from the list /// The function is only safe as long as valid pointers are stored inside /// the linked list. /// The returned pointer is only guaranteed to be valid as long as the list /// is not mutated pub fn peek_first(&self) -> Option<&mut ListNode> { // Safety: When the node was inserted it was promised that it is alive // until it gets removed from the list. // The returned node has a pointer which constrains it to the lifetime // of the list. This is ok, since the Node is supposed to outlive // its insertion in the list. unsafe { self.head .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) } } /// Returns the last node in the linked list without removing it from the list /// The function is only safe as long as valid pointers are stored inside /// the linked list. /// The returned pointer is only guaranteed to be valid as long as the list /// is not mutated pub fn peek_last(&self) -> Option<&mut ListNode> { // Safety: When the node was inserted it was promised that it is alive // until it gets removed from the list. // The returned node has a pointer which constrains it to the lifetime // of the list. This is ok, since the Node is supposed to outlive // its insertion in the list. unsafe { self.tail .map(|mut node| &mut *(node.as_mut() as *mut ListNode)) } } /// Removes the first node from the linked list pub fn remove_first(&mut self) -> Option<&mut ListNode> { #![allow(clippy::debug_assert_with_mut_call)] // Safety: When the node was inserted it was promised that it is alive // until it gets removed from the list unsafe { let mut head = self.head?; self.head = head.as_mut().next; let first_ref = head.as_mut(); match first_ref.next { None => { // This was the only node in the list debug_assert_eq!(Some(first_ref.into()), self.tail); self.tail = None; } Some(mut next) => { next.as_mut().prev = None; } } first_ref.prev = None; first_ref.next = None; Some(&mut *(first_ref as *mut ListNode)) } } /// Removes the last node from the linked list and returns it pub fn remove_last(&mut self) -> Option<&mut ListNode> { #![allow(clippy::debug_assert_with_mut_call)] // Safety: When the node was inserted it was promised that it is alive // until it gets removed from the list unsafe { let mut tail = self.tail?; self.tail = tail.as_mut().prev; let last_ref = tail.as_mut(); match last_ref.prev { None => { // This was the last node in the list debug_assert_eq!(Some(last_ref.into()), self.head); self.head = None; } Some(mut prev) => { prev.as_mut().next = None; } } last_ref.prev = None; last_ref.next = None; Some(&mut *(last_ref as *mut ListNode)) } } /// Returns whether the linked list does not contain any node pub fn is_empty(&self) -> bool { if self.head.is_some() { return false; } debug_assert!(self.tail.is_none()); true } /// Removes the given `node` from the linked list. /// Returns whether the `node` was removed. /// It is also only safe if it is known that the `node` is either part of this /// list, or of no list at all. If `node` is part of another list, the /// behavior is undefined. pub unsafe fn remove(&mut self, node: &mut ListNode) -> bool { #![allow(clippy::debug_assert_with_mut_call)] match node.prev { None => { // This might be the first node in the list. If it is not, the // node is not in the list at all. Since our precondition is that // the node must either be in this list or in no list, we check that // the node is really in no list. if self.head != Some(node.into()) { debug_assert!(node.next.is_none()); return false; } self.head = node.next; } Some(mut prev) => { debug_assert_eq!(prev.as_ref().next, Some(node.into())); prev.as_mut().next = node.next; } } match node.next { None => { // This must be the last node in our list. Otherwise the list // is inconsistent. debug_assert_eq!(self.tail, Some(node.into())); self.tail = node.prev; } Some(mut next) => { debug_assert_eq!(next.as_mut().prev, Some(node.into())); next.as_mut().prev = node.prev; } } node.next = None; node.prev = None; true } /// Drains the list iby calling a callback on each list node /// /// The method does not return an iterator since stopping or deferring /// draining the list is not permitted. If the method would push nodes to /// an iterator we could not guarantee that the nodes do not get utilized /// after having been removed from the list anymore. pub fn drain(&mut self, mut func: F) where F: FnMut(&mut ListNode), { let mut current = self.head; self.head = None; self.tail = None; while let Some(mut node) = current { // Safety: The nodes have not been removed from the list yet and must // therefore contain valid data. The nodes can also not be added to // the list again during iteration, since the list is mutably borrowed. unsafe { let node_ref = node.as_mut(); current = node_ref.next; node_ref.next = None; node_ref.prev = None; // Note: We do not reset the pointers from the next element in the // list to the current one since we will iterate over the whole // list anyway, and therefore clean up all pointers. func(node_ref); } } } /// Drains the list in reverse order by calling a callback on each list node /// /// The method does not return an iterator since stopping or deferring /// draining the list is not permitted. If the method would push nodes to /// an iterator we could not guarantee that the nodes do not get utilized /// after having been removed from the list anymore. pub fn reverse_drain(&mut self, mut func: F) where F: FnMut(&mut ListNode), { let mut current = self.tail; self.head = None; self.tail = None; while let Some(mut node) = current { // Safety: The nodes have not been removed from the list yet and must // therefore contain valid data. The nodes can also not be added to // the list again during iteration, since the list is mutably borrowed. unsafe { let node_ref = node.as_mut(); current = node_ref.prev; node_ref.next = None; node_ref.prev = None; // Note: We do not reset the pointers from the next element in the // list to the current one since we will iterate over the whole // list anyway, and therefore clean up all pointers. func(node_ref); } } } } #[cfg(all(test, feature = "std"))] // Tests make use of Vec at the moment mod tests { use super::*; fn collect_list(mut list: LinkedList) -> Vec { let mut result = Vec::new(); list.drain(|node| { result.push(**node); }); result } fn collect_reverse_list(mut list: LinkedList) -> Vec { let mut result = Vec::new(); list.reverse_drain(|node| { result.push(**node); }); result } unsafe fn add_nodes(list: &mut LinkedList, nodes: &mut [&mut ListNode]) { for node in nodes.iter_mut() { list.add_front(node); } } unsafe fn assert_clean(node: &mut ListNode) { assert!(node.next.is_none()); assert!(node.prev.is_none()); } #[test] fn insert_and_iterate() { unsafe { let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut setup = |list: &mut LinkedList| { assert_eq!(true, list.is_empty()); list.add_front(&mut c); assert_eq!(31, **list.peek_first().unwrap()); assert_eq!(false, list.is_empty()); list.add_front(&mut b); assert_eq!(7, **list.peek_first().unwrap()); list.add_front(&mut a); assert_eq!(5, **list.peek_first().unwrap()); }; let mut list = LinkedList::new(); setup(&mut list); let items: Vec = collect_list(list); assert_eq!([5, 7, 31].to_vec(), items); let mut list = LinkedList::new(); setup(&mut list); let items: Vec = collect_reverse_list(list); assert_eq!([31, 7, 5].to_vec(), items); } } #[test] fn add_sorted() { unsafe { let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut d = ListNode::new(99); let mut list = LinkedList::new(); list.add_sorted(&mut a); let items: Vec = collect_list(list); assert_eq!([5].to_vec(), items); let mut list = LinkedList::new(); list.add_sorted(&mut a); let items: Vec = collect_reverse_list(list); assert_eq!([5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); list.add_sorted(&mut a); let items: Vec = collect_list(list); assert_eq!([5, 7, 31, 99].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); list.add_sorted(&mut a); let items: Vec = collect_reverse_list(list); assert_eq!([99, 31, 7, 5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); list.add_sorted(&mut b); let items: Vec = collect_list(list); assert_eq!([5, 7, 31, 99].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); list.add_sorted(&mut b); let items: Vec = collect_reverse_list(list); assert_eq!([99, 31, 7, 5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); list.add_sorted(&mut c); let items: Vec = collect_list(list); assert_eq!([5, 7, 31, 99].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); list.add_sorted(&mut c); let items: Vec = collect_reverse_list(list); assert_eq!([99, 31, 7, 5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); list.add_sorted(&mut d); let items: Vec = collect_list(list); assert_eq!([5, 7, 31, 99].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); list.add_sorted(&mut d); let items: Vec = collect_reverse_list(list); assert_eq!([99, 31, 7, 5].to_vec(), items); } } #[test] fn drain_and_collect() { unsafe { let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let taken_items: Vec = collect_list(list); assert_eq!([5, 7, 31].to_vec(), taken_items); } } #[test] fn peek_last() { unsafe { let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let last = list.peek_last(); assert_eq!(31, **last.unwrap()); list.remove_last(); let last = list.peek_last(); assert_eq!(7, **last.unwrap()); list.remove_last(); let last = list.peek_last(); assert_eq!(5, **last.unwrap()); list.remove_last(); let last = list.peek_last(); assert!(last.is_none()); } } #[test] fn remove_first() { unsafe { // We iterate forward and backwards through the manipulated lists // to make sure pointers in both directions are still ok. let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_list(list); assert_eq!([7, 31].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_reverse_list(list); assert_eq!([31, 7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_list(list); assert_eq!([7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_reverse_list(list); assert_eq!([7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(list.is_empty()); let items: Vec = collect_list(list); assert!(items.is_empty()); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut a]); let removed = list.remove_first().unwrap(); assert_clean(removed); assert!(list.is_empty()); let items: Vec = collect_reverse_list(list); assert!(items.is_empty()); } } #[test] fn remove_last() { unsafe { // We iterate forward and backwards through the manipulated lists // to make sure pointers in both directions are still ok. let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_list(list); assert_eq!([5, 7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_reverse_list(list); assert_eq!([7, 5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_list(list); assert_eq!([5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(!list.is_empty()); let items: Vec = collect_reverse_list(list); assert_eq!([5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(list.is_empty()); let items: Vec = collect_list(list); assert!(items.is_empty()); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut a]); let removed = list.remove_last().unwrap(); assert_clean(removed); assert!(list.is_empty()); let items: Vec = collect_reverse_list(list); assert!(items.is_empty()); } } #[test] fn remove_by_address() { unsafe { let mut a = ListNode::new(5); let mut b = ListNode::new(7); let mut c = ListNode::new(31); { // Remove first let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut a)); assert_clean((&mut a).into()); // a should be no longer there and can't be removed twice assert_eq!(false, list.remove(&mut a)); assert_eq!(Some((&mut b).into()), list.head); assert_eq!(Some((&mut c).into()), b.next); assert_eq!(Some((&mut b).into()), c.prev); let items: Vec = collect_list(list); assert_eq!([7, 31].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut a)); assert_clean((&mut a).into()); // a should be no longer there and can't be removed twice assert_eq!(false, list.remove(&mut a)); assert_eq!(Some((&mut c).into()), b.next); assert_eq!(Some((&mut b).into()), c.prev); let items: Vec = collect_reverse_list(list); assert_eq!([31, 7].to_vec(), items); } { // Remove middle let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut b)); assert_clean((&mut b).into()); assert_eq!(Some((&mut c).into()), a.next); assert_eq!(Some((&mut a).into()), c.prev); let items: Vec = collect_list(list); assert_eq!([5, 31].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut b)); assert_clean((&mut b).into()); assert_eq!(Some((&mut c).into()), a.next); assert_eq!(Some((&mut a).into()), c.prev); let items: Vec = collect_reverse_list(list); assert_eq!([31, 5].to_vec(), items); } { // Remove last let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut c)); assert_clean((&mut c).into()); assert!(b.next.is_none()); assert_eq!(Some((&mut b).into()), list.tail); let items: Vec = collect_list(list); assert_eq!([5, 7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); assert_eq!(true, list.remove(&mut c)); assert_clean((&mut c).into()); assert!(b.next.is_none()); assert_eq!(Some((&mut b).into()), list.tail); let items: Vec = collect_reverse_list(list); assert_eq!([7, 5].to_vec(), items); } { // Remove first of two let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); assert_eq!(true, list.remove(&mut a)); assert_clean((&mut a).into()); // a should be no longer there and can't be removed twice assert_eq!(false, list.remove(&mut a)); assert_eq!(Some((&mut b).into()), list.head); assert_eq!(Some((&mut b).into()), list.tail); assert!(b.next.is_none()); assert!(b.prev.is_none()); let items: Vec = collect_list(list); assert_eq!([7].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); assert_eq!(true, list.remove(&mut a)); assert_clean((&mut a).into()); // a should be no longer there and can't be removed twice assert_eq!(false, list.remove(&mut a)); assert_eq!(Some((&mut b).into()), list.head); assert_eq!(Some((&mut b).into()), list.tail); assert!(b.next.is_none()); assert!(b.prev.is_none()); let items: Vec = collect_reverse_list(list); assert_eq!([7].to_vec(), items); } { // Remove last of two let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); assert_eq!(true, list.remove(&mut b)); assert_clean((&mut b).into()); assert_eq!(Some((&mut a).into()), list.head); assert_eq!(Some((&mut a).into()), list.tail); assert!(a.next.is_none()); assert!(a.prev.is_none()); let items: Vec = collect_list(list); assert_eq!([5].to_vec(), items); let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut b, &mut a]); assert_eq!(true, list.remove(&mut b)); assert_clean((&mut b).into()); assert_eq!(Some((&mut a).into()), list.head); assert_eq!(Some((&mut a).into()), list.tail); assert!(a.next.is_none()); assert!(a.prev.is_none()); let items: Vec = collect_reverse_list(list); assert_eq!([5].to_vec(), items); } { // Remove last item let mut list = LinkedList::new(); add_nodes(&mut list, &mut [&mut a]); assert_eq!(true, list.remove(&mut a)); assert_clean((&mut a).into()); assert!(list.head.is_none()); assert!(list.tail.is_none()); let items: Vec = collect_list(list); assert!(items.is_empty()); } { // Remove missing let mut list = LinkedList::new(); list.add_front(&mut b); list.add_front(&mut a); assert_eq!(false, list.remove(&mut c)); } } } } tokio-util-0.6.9/src/sync/mod.rs000064400000000000000000000005260072674642500146630ustar 00000000000000//! Synchronization primitives mod cancellation_token; pub use cancellation_token::{guard::DropGuard, CancellationToken, WaitForCancellationFuture}; mod intrusive_double_linked_list; mod mpsc; pub use mpsc::PollSender; mod poll_semaphore; pub use poll_semaphore::PollSemaphore; mod reusable_box; pub use reusable_box::ReusableBoxFuture; tokio-util-0.6.9/src/sync/mpsc.rs000064400000000000000000000175110072674642500150500ustar 00000000000000use futures_core::ready; use futures_sink::Sink; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::mpsc::{error::SendError, Sender}; use super::ReusableBoxFuture; // This implementation was chosen over something based on permits because to get a // `tokio::sync::mpsc::Permit` out of the `inner` future, you must transmute the // lifetime on the permit to `'static`. /// A wrapper around [`mpsc::Sender`] that can be polled. /// /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender #[derive(Debug)] pub struct PollSender { /// is none if closed sender: Option>>, is_sending: bool, inner: ReusableBoxFuture>>, } // By reusing the same async fn for both Some and None, we make sure every // future passed to ReusableBoxFuture has the same underlying type, and hence // the same size and alignment. async fn make_future(data: Option<(Arc>, T)>) -> Result<(), SendError> { match data { Some((sender, value)) => sender.send(value).await, None => unreachable!( "This future should not be pollable, as is_sending should be set to false." ), } } impl PollSender { /// Create a new `PollSender`. pub fn new(sender: Sender) -> Self { Self { sender: Some(Arc::new(sender)), is_sending: false, inner: ReusableBoxFuture::new(make_future(None)), } } /// Start sending a new item. /// /// This method panics if a send is currently in progress. To ensure that no /// send is in progress, call `poll_send_done` first until it returns /// `Poll::Ready`. /// /// If this method returns an error, that indicates that the channel is /// closed. Note that this method is not guaranteed to return an error if /// the channel is closed, but in that case the error would be reported by /// the first call to `poll_send_done`. pub fn start_send(&mut self, value: T) -> Result<(), SendError> { if self.is_sending { panic!("start_send called while not ready."); } match self.sender.clone() { Some(sender) => { self.inner.set(make_future(Some((sender, value)))); self.is_sending = true; Ok(()) } None => Err(SendError(value)), } } /// If a send is in progress, poll for its completion. If no send is in progress, /// this method returns `Poll::Ready(Ok(()))`. /// /// This method can return the following values: /// /// - `Poll::Ready(Ok(()))` if the in-progress send has been completed, or there is /// no send in progress (even if the channel is closed). /// - `Poll::Ready(Err(err))` if the in-progress send failed because the channel has /// been closed. /// - `Poll::Pending` if a send is in progress, but it could not complete now. /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when the message is sent, or when the entire channel /// is closed (but not if just this sender is closed by /// `close_this_sender`). Note that on multiple calls to `poll_send_done`, /// only the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. /// /// If this method returns `Poll::Ready`, then `start_send` is guaranteed to /// not panic. pub fn poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll>> { if !self.is_sending { return Poll::Ready(Ok(())); } let result = self.inner.poll(cx); if result.is_ready() { self.is_sending = false; } if let Poll::Ready(Err(_)) = &result { self.sender = None; } result } /// Check whether the channel is ready to send more messages now. /// /// If this method returns `true`, then `start_send` is guaranteed to not /// panic. /// /// If the channel is closed, this method returns `true`. pub fn is_ready(&self) -> bool { !self.is_sending } /// Check whether the channel has been closed. pub fn is_closed(&self) -> bool { match &self.sender { Some(sender) => sender.is_closed(), None => true, } } /// Clone the underlying `Sender`. /// /// If this method returns `None`, then the channel is closed. (But it is /// not guaranteed to return `None` if the channel is closed.) pub fn clone_inner(&self) -> Option> { self.sender.as_ref().map(|sender| (&**sender).clone()) } /// Access the underlying `Sender`. /// /// If this method returns `None`, then the channel is closed. (But it is /// not guaranteed to return `None` if the channel is closed.) pub fn inner_ref(&self) -> Option<&Sender> { self.sender.as_deref() } // This operation is supported because it is required by the Sink trait. /// Close this sender. No more messages can be sent from this sender. /// /// Note that this only closes the channel from the view-point of this /// sender. The channel remains open until all senders have gone away, or /// until the [`Receiver`] closes the channel. /// /// If there is a send in progress when this method is called, that send is /// unaffected by this operation, and `poll_send_done` can still be called /// to complete that send. /// /// [`Receiver`]: tokio::sync::mpsc::Receiver pub fn close_this_sender(&mut self) { self.sender = None; } /// Abort the current in-progress send, if any. /// /// Returns `true` if a send was aborted. pub fn abort_send(&mut self) -> bool { if self.is_sending { self.inner.set(make_future(None)); self.is_sending = false; true } else { false } } } impl Clone for PollSender { /// Clones this `PollSender`. The resulting clone will not have any /// in-progress send operations, even if the current `PollSender` does. fn clone(&self) -> PollSender { Self { sender: self.sender.clone(), is_sending: false, inner: ReusableBoxFuture::new(async { unreachable!() }), } } } impl Sink for PollSender { type Error = SendError; /// This is equivalent to calling [`poll_send_done`]. /// /// [`poll_send_done`]: PollSender::poll_send_done fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_send_done(cx) } /// This is equivalent to calling [`poll_send_done`]. /// /// [`poll_send_done`]: PollSender::poll_send_done fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_send_done(cx) } /// This is equivalent to calling [`start_send`]. /// /// [`start_send`]: PollSender::start_send fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { Pin::into_inner(self).start_send(item) } /// This method will first flush the `PollSender`, and then close it by /// calling [`close_this_sender`]. /// /// If a send fails while flushing because the [`Receiver`] has gone away, /// then this function returns an error. The channel is still successfully /// closed in this situation. /// /// [`close_this_sender`]: PollSender::close_this_sender /// [`Receiver`]: tokio::sync::mpsc::Receiver fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; Pin::into_inner(self).close_this_sender(); Poll::Ready(Ok(())) } } tokio-util-0.6.9/src/sync/poll_semaphore.rs000064400000000000000000000104620072674642500171150ustar 00000000000000use futures_core::{ready, Stream}; use std::fmt; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use super::ReusableBoxFuture; /// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. /// /// [`Semaphore`]: tokio::sync::Semaphore pub struct PollSemaphore { semaphore: Arc, permit_fut: Option>>, } impl PollSemaphore { /// Create a new `PollSemaphore`. pub fn new(semaphore: Arc) -> Self { Self { semaphore, permit_fut: None, } } /// Closes the semaphore. pub fn close(&self) { self.semaphore.close() } /// Obtain a clone of the inner semaphore. pub fn clone_inner(&self) -> Arc { self.semaphore.clone() } /// Get back the inner semaphore. pub fn into_inner(self) -> Arc { self.semaphore } /// Poll to acquire a permit from the semaphore. /// /// This can return the following values: /// /// - `Poll::Pending` if a permit is not currently available. /// - `Poll::Ready(Some(permit))` if a permit was acquired. /// - `Poll::Ready(None)` if the semaphore has been closed. /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when a permit becomes available, or when the /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only /// the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { let permit_future = match self.permit_fut.as_mut() { Some(fut) => fut, None => { // avoid allocations completely if we can grab a permit immediately match Arc::clone(&self.semaphore).try_acquire_owned() { Ok(permit) => return Poll::Ready(Some(permit)), Err(TryAcquireError::Closed) => return Poll::Ready(None), Err(TryAcquireError::NoPermits) => {} } let next_fut = Arc::clone(&self.semaphore).acquire_owned(); self.permit_fut .get_or_insert(ReusableBoxFuture::new(next_fut)) } }; let result = ready!(permit_future.poll(cx)); let next_fut = Arc::clone(&self.semaphore).acquire_owned(); permit_future.set(next_fut); match result { Ok(permit) => Poll::Ready(Some(permit)), Err(_closed) => { self.permit_fut = None; Poll::Ready(None) } } } /// Returns the current number of available permits. /// /// This is equivalent to the [`Semaphore::available_permits`] method on the /// `tokio::sync::Semaphore` type. /// /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits pub fn available_permits(&self) -> usize { self.semaphore.available_permits() } /// Adds `n` new permits to the semaphore. /// /// The maximum number of permits is `usize::MAX >> 3`, and this function /// will panic if the limit is exceeded. /// /// This is equivalent to the [`Semaphore::add_permits`] method on the /// `tokio::sync::Semaphore` type. /// /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits pub fn add_permits(&self, n: usize) { self.semaphore.add_permits(n); } } impl Stream for PollSemaphore { type Item = OwnedSemaphorePermit; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_acquire(cx) } } impl Clone for PollSemaphore { fn clone(&self) -> PollSemaphore { PollSemaphore::new(self.clone_inner()) } } impl fmt::Debug for PollSemaphore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PollSemaphore") .field("semaphore", &self.semaphore) .finish() } } impl AsRef for PollSemaphore { fn as_ref(&self) -> &Semaphore { &*self.semaphore } } tokio-util-0.6.9/src/sync/reusable_box.rs000064400000000000000000000114060072674642500165550ustar 00000000000000use std::alloc::Layout; use std::future::Future; use std::panic::AssertUnwindSafe; use std::pin::Pin; use std::ptr::{self, NonNull}; use std::task::{Context, Poll}; use std::{fmt, panic}; /// A reusable `Pin + Send>>`. /// /// This type lets you replace the future stored in the box without /// reallocating when the size and alignment permits this. pub struct ReusableBoxFuture { boxed: NonNull + Send>, } impl ReusableBoxFuture { /// Create a new `ReusableBoxFuture` containing the provided future. pub fn new(future: F) -> Self where F: Future + Send + 'static, { let boxed: Box + Send> = Box::new(future); let boxed = Box::into_raw(boxed); // SAFETY: Box::into_raw does not return null pointers. let boxed = unsafe { NonNull::new_unchecked(boxed) }; Self { boxed } } /// Replace the future currently stored in this box. /// /// This reallocates if and only if the layout of the provided future is /// different from the layout of the currently stored future. pub fn set(&mut self, future: F) where F: Future + Send + 'static, { if let Err(future) = self.try_set(future) { *self = Self::new(future); } } /// Replace the future currently stored in this box. /// /// This function never reallocates, but returns an error if the provided /// future has a different size or alignment from the currently stored /// future. pub fn try_set(&mut self, future: F) -> Result<(), F> where F: Future + Send + 'static, { // SAFETY: The pointer is not dangling. let self_layout = { let dyn_future: &(dyn Future + Send) = unsafe { self.boxed.as_ref() }; Layout::for_value(dyn_future) }; if Layout::new::() == self_layout { // SAFETY: We just checked that the layout of F is correct. unsafe { self.set_same_layout(future); } Ok(()) } else { Err(future) } } /// Set the current future. /// /// # Safety /// /// This function requires that the layout of the provided future is the /// same as `self.layout`. unsafe fn set_same_layout(&mut self, future: F) where F: Future + Send + 'static, { // Drop the existing future, catching any panics. let result = panic::catch_unwind(AssertUnwindSafe(|| { ptr::drop_in_place(self.boxed.as_ptr()); })); // Overwrite the future behind the pointer. This is safe because the // allocation was allocated with the same size and alignment as the type F. let self_ptr: *mut F = self.boxed.as_ptr() as *mut F; ptr::write(self_ptr, future); // Update the vtable of self.boxed. The pointer is not null because we // just got it from self.boxed, which is not null. self.boxed = NonNull::new_unchecked(self_ptr); // If the old future's destructor panicked, resume unwinding. match result { Ok(()) => {} Err(payload) => { panic::resume_unwind(payload); } } } /// Get a pinned reference to the underlying future. pub fn get_pin(&mut self) -> Pin<&mut (dyn Future + Send)> { // SAFETY: The user of this box cannot move the box, and we do not move it // either. unsafe { Pin::new_unchecked(self.boxed.as_mut()) } } /// Poll the future stored inside this box. pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll { self.get_pin().poll(cx) } } impl Future for ReusableBoxFuture { type Output = T; /// Poll the future stored inside this box. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::into_inner(self).get_pin().poll(cx) } } // The future stored inside ReusableBoxFuture must be Send. unsafe impl Send for ReusableBoxFuture {} // The only method called on self.boxed is poll, which takes &mut self, so this // struct being Sync does not permit any invalid access to the Future, even if // the future is not Sync. unsafe impl Sync for ReusableBoxFuture {} // Just like a Pin> is always Unpin, so is this type. impl Unpin for ReusableBoxFuture {} impl Drop for ReusableBoxFuture { fn drop(&mut self) { unsafe { drop(Box::from_raw(self.boxed.as_ptr())); } } } impl fmt::Debug for ReusableBoxFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ReusableBoxFuture").finish() } } tokio-util-0.6.9/src/sync/tests/loom_cancellation_token.rs000064400000000000000000000065440072674642500221360ustar 00000000000000use crate::sync::CancellationToken; use loom::{future::block_on, thread}; use tokio_test::assert_ok; #[test] fn cancel_token() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let th1 = thread::spawn(move || { block_on(async { token1.cancelled().await; }); }); let th2 = thread::spawn(move || { token.cancel(); }); assert_ok!(th1.join()); assert_ok!(th2.join()); }); } #[test] fn cancel_with_child() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let token2 = token.clone(); let child_token = token.child_token(); let th1 = thread::spawn(move || { block_on(async { token1.cancelled().await; }); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { block_on(async { child_token.cancelled().await; }); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_token_no_child() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let token2 = token.clone(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { drop(token2); }); let th3 = thread::spawn(move || { drop(token); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_token_with_childs() { loom::model(|| { let token1 = CancellationToken::new(); let child_token1 = token1.child_token(); let child_token2 = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { drop(child_token1); }); let th3 = thread::spawn(move || { drop(child_token2); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_and_cancel_token() { loom::model(|| { let token1 = CancellationToken::new(); let token2 = token1.clone(); let child_token = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { drop(child_token); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn cancel_parent_and_child() { loom::model(|| { let token1 = CancellationToken::new(); let token2 = token1.clone(); let child_token = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { child_token.cancel(); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } tokio-util-0.6.9/src/sync/tests/mod.rs000064400000000000000000000000010072674642500160110ustar 00000000000000 tokio-util-0.6.9/src/time/delay_queue.rs000064400000000000000000000707350072674642500164010ustar 00000000000000//! A queue of delayed elements. //! //! See [`DelayQueue`] for more details. //! //! [`DelayQueue`]: struct@DelayQueue use crate::time::wheel::{self, Wheel}; use futures_core::ready; use tokio::time::{error::Error, sleep_until, Duration, Instant, Sleep}; use slab::Slab; use std::cmp; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{self, Poll, Waker}; /// A queue of delayed elements. /// /// Once an element is inserted into the `DelayQueue`, it is yielded once the /// specified deadline has been reached. /// /// # Usage /// /// Elements are inserted into `DelayQueue` using the [`insert`] or /// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is /// returned. The key is used to remove the entry or to change the deadline at /// which it should be yielded back. /// /// Once delays have been configured, the `DelayQueue` is used via its /// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its /// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the /// current task will be notified once the deadline has been reached. /// /// # `Stream` implementation /// /// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have /// expired, no items are returned. In this case, `Poll::Pending` is returned and the /// current task is registered to be notified once the next item's delay has /// expired. /// /// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` /// returns `Poll::Ready(None)`. This indicates that the stream has reached an end. /// However, if a new item is inserted *after*, `poll` will once again start /// returning items or `Poll::Pending`. /// /// Items are returned ordered by their expirations. Items that are configured /// to expire first will be returned first. There are no ordering guarantees /// for items configured to expire at the same instant. Also note that delays are /// rounded to the closest millisecond. /// /// # Implementation /// /// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally /// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same /// performance and scalability benefits. /// /// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, /// and allows reuse of the memory allocated for expired entires. /// /// Capacity can be checked using [`capacity`] and allocated preemptively by using /// the [`reserve`] method. /// /// # Usage /// /// Using `DelayQueue` to manage cache entries. /// /// ```rust,no_run /// use tokio::time::error::Error; /// use tokio_util::time::{DelayQueue, delay_queue}; /// /// use futures::ready; /// use std::collections::HashMap; /// use std::task::{Context, Poll}; /// use std::time::Duration; /// # type CacheKey = String; /// # type Value = String; /// /// struct Cache { /// entries: HashMap, /// expirations: DelayQueue, /// } /// /// const TTL_SECS: u64 = 30; /// /// impl Cache { /// fn insert(&mut self, key: CacheKey, value: Value) { /// let delay = self.expirations /// .insert(key.clone(), Duration::from_secs(TTL_SECS)); /// /// self.entries.insert(key, (value, delay)); /// } /// /// fn get(&self, key: &CacheKey) -> Option<&Value> { /// self.entries.get(key) /// .map(|&(ref v, _)| v) /// } /// /// fn remove(&mut self, key: &CacheKey) { /// if let Some((_, cache_key)) = self.entries.remove(key) { /// self.expirations.remove(&cache_key); /// } /// } /// /// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll> { /// while let Some(res) = ready!(self.expirations.poll_expired(cx)) { /// let entry = res?; /// self.entries.remove(entry.get_ref()); /// } /// /// Poll::Ready(Ok(())) /// } /// } /// ``` /// /// [`insert`]: method@Self::insert /// [`insert_at`]: method@Self::insert_at /// [`Key`]: struct@Key /// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html /// [`poll_expired`]: method@Self::poll_expired /// [`Stream::poll_expired`]: method@Self::poll_expired /// [`DelayQueue`]: struct@DelayQueue /// [`sleep`]: fn@tokio::time::sleep /// [`slab`]: slab /// [`capacity`]: method@Self::capacity /// [`reserve`]: method@Self::reserve #[derive(Debug)] pub struct DelayQueue { /// Stores data associated with entries slab: Slab>, /// Lookup structure tracking all delays in the queue wheel: Wheel>, /// Delays that were inserted when already expired. These cannot be stored /// in the wheel expired: Stack, /// Delay expiring when the *first* item in the queue expires delay: Option>>, /// Wheel polling state wheel_now: u64, /// Instant at which the timer starts start: Instant, /// Waker that is invoked when we potentially need to reset the timer. /// Because we lazily create the timer when the first entry is created, we /// need to awaken any poller that polled us before that point. waker: Option, } /// An entry in `DelayQueue` that has expired and been removed. /// /// Values are returned by [`DelayQueue::poll_expired`]. /// /// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired #[derive(Debug)] pub struct Expired { /// The data stored in the queue data: T, /// The expiration time deadline: Instant, /// The key associated with the entry key: Key, } /// Token to a value stored in a `DelayQueue`. /// /// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] /// documentation for more details. /// /// [`DelayQueue`]: struct@DelayQueue /// [`DelayQueue::insert`]: method@DelayQueue::insert #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct Key { index: usize, } #[derive(Debug)] struct Stack { /// Head of the stack head: Option, _p: PhantomData T>, } #[derive(Debug)] struct Data { /// The data being stored in the queue and will be returned at the requested /// instant. inner: T, /// The instant at which the item is returned. when: u64, /// Set to true when stored in the `expired` queue expired: bool, /// Next entry in the stack next: Option, /// Previous entry in the stack prev: Option, } /// Maximum number of entries the queue can handle const MAX_ENTRIES: usize = (1 << 30) - 1; impl DelayQueue { /// Creates a new, empty, `DelayQueue`. /// /// The queue will not allocate storage until items are inserted into it. /// /// # Examples /// /// ```rust /// # use tokio_util::time::DelayQueue; /// let delay_queue: DelayQueue = DelayQueue::new(); /// ``` pub fn new() -> DelayQueue { DelayQueue::with_capacity(0) } /// Creates a new, empty, `DelayQueue` with the specified capacity. /// /// The queue will be able to hold at least `capacity` elements without /// reallocating. If `capacity` is 0, the queue will not allocate for /// storage. /// /// # Examples /// /// ```rust /// # use tokio_util::time::DelayQueue; /// # use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::with_capacity(10); /// /// // These insertions are done without further allocation /// for i in 0..10 { /// delay_queue.insert(i, Duration::from_secs(i)); /// } /// /// // This will make the queue allocate additional storage /// delay_queue.insert(11, Duration::from_secs(11)); /// # } /// ``` pub fn with_capacity(capacity: usize) -> DelayQueue { DelayQueue { wheel: Wheel::new(), slab: Slab::with_capacity(capacity), expired: Stack::default(), delay: None, wheel_now: 0, start: Instant::now(), waker: None, } } /// Inserts `value` into the queue set to expire at a specific instant in /// time. /// /// This function is identical to `insert`, but takes an `Instant` instead /// of a `Duration`. /// /// `value` is stored in the queue until `when` is reached. At which point, /// `value` will be returned from [`poll_expired`]. If `when` has already been /// reached, then `value` is immediately made available to poll. /// /// The return value represents the insertion and is used as an argument to /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once /// `value` is removed from the queue either by calling [`poll_expired`] after /// `when` is reached or by calling [`remove`]. At this point, the caller /// must take care to not use the returned [`Key`] again as it may reference /// a different item in the queue. /// /// See [type] level documentation for more details. /// /// # Panics /// /// This function panics if `when` is too far in the future. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio::time::{Duration, Instant}; /// use tokio_util::time::DelayQueue; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert_at( /// "foo", Instant::now() + Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` /// /// [`poll_expired`]: method@Self::poll_expired /// [`remove`]: method@Self::remove /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # pub fn insert_at(&mut self, value: T, when: Instant) -> Key { assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); // Normalize the deadline. Values cannot be set to expire in the past. let when = self.normalize_deadline(when); // Insert the value in the store let key = self.slab.insert(Data { inner: value, when, expired: false, next: None, prev: None, }); self.insert_idx(when, key); // Set a new delay if the current's deadline is later than the one of the new item let should_set_delay = if let Some(ref delay) = self.delay { let current_exp = self.normalize_deadline(delay.deadline()); current_exp > when } else { true }; if should_set_delay { if let Some(waker) = self.waker.take() { waker.wake(); } let delay_time = self.start + Duration::from_millis(when); if let Some(ref mut delay) = &mut self.delay { delay.as_mut().reset(delay_time); } else { self.delay = Some(Box::pin(sleep_until(delay_time))); } } Key::new(key) } /// Attempts to pull out the next value of the delay queue, registering the /// current task for wakeup if the value is not yet available, and returning /// `None` if the queue is exhausted. pub fn poll_expired( &mut self, cx: &mut task::Context<'_>, ) -> Poll, Error>>> { if !self .waker .as_ref() .map(|w| w.will_wake(cx.waker())) .unwrap_or(false) { self.waker = Some(cx.waker().clone()); } let item = ready!(self.poll_idx(cx)); Poll::Ready(item.map(|result| { result.map(|idx| { let data = self.slab.remove(idx); debug_assert!(data.next.is_none()); debug_assert!(data.prev.is_none()); Expired { key: Key::new(idx), data: data.inner, deadline: self.start + Duration::from_millis(data.when), } }) })) } /// Inserts `value` into the queue set to expire after the requested duration /// elapses. /// /// This function is identical to `insert_at`, but takes a `Duration` /// instead of an `Instant`. /// /// `value` is stored in the queue until `timeout` duration has /// elapsed after `insert` was called. At that point, `value` will /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of /// zero, then `value` is immediately made available to poll. /// /// The return value represents the insertion and is used as an /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a /// token and is reused once `value` is removed from the queue /// either by calling [`poll_expired`] after `timeout` has elapsed /// or by calling [`remove`]. At this point, the caller must not /// use the returned [`Key`] again as it may reference a different /// item in the queue. /// /// See [type] level documentation for more details. /// /// # Panics /// /// This function panics if `timeout` is greater than the maximum /// duration supported by the timer in the current `Runtime`. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` /// /// [`poll_expired`]: method@Self::poll_expired /// [`remove`]: method@Self::remove /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # pub fn insert(&mut self, value: T, timeout: Duration) -> Key { self.insert_at(value, Instant::now() + timeout) } fn insert_idx(&mut self, when: u64, key: usize) { use self::wheel::{InsertError, Stack}; // Register the deadline with the timer wheel match self.wheel.insert(when, key, &mut self.slab) { Ok(_) => {} Err((_, InsertError::Elapsed)) => { self.slab[key].expired = true; // The delay is already expired, store it in the expired queue self.expired.push(key, &mut self.slab); } Err((_, err)) => panic!("invalid deadline; err={:?}", err), } } /// Removes the key from the expired queue or the timer wheel /// depending on its expiration status. /// /// # Panics /// /// Panics if the key is not contained in the expired queue or the wheel. fn remove_key(&mut self, key: &Key) { use crate::time::wheel::Stack; // Special case the `expired` queue if self.slab[key.index].expired { self.expired.remove(&key.index, &mut self.slab); } else { self.wheel.remove(&key.index, &mut self.slab); } } /// Removes the item associated with `key` from the queue. /// /// There must be an item associated with `key`. The function returns the /// removed item as well as the `Instant` at which it will the delay will /// have expired. /// /// # Panics /// /// The function panics if `key` is not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` pub fn remove(&mut self, key: &Key) -> Expired { let prev_deadline = self.next_deadline(); self.remove_key(key); let data = self.slab.remove(key.index); let next_deadline = self.next_deadline(); if prev_deadline != next_deadline { match (next_deadline, &mut self.delay) { (None, _) => self.delay = None, (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline), (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))), } } Expired { key: Key::new(key.index), data: data.inner, deadline: self.start + Duration::from_millis(data.when), } } /// Sets the delay of the item associated with `key` to expire at `when`. /// /// This function is identical to `reset` but takes an `Instant` instead of /// a `Duration`. /// /// The item remains in the queue but the delay is set to expire at `when`. /// If `when` is in the past, then the item is immediately made available to /// the caller. /// /// # Panics /// /// This function panics if `when` is too far in the future or if `key` is /// not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio::time::{Duration, Instant}; /// use tokio_util::time::DelayQueue; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // "foo" is scheduled to be returned in 5 seconds /// /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); /// /// // "foo" is now scheduled to be returned in 10 seconds /// # } /// ``` pub fn reset_at(&mut self, key: &Key, when: Instant) { self.remove_key(key); // Normalize the deadline. Values cannot be set to expire in the past. let when = self.normalize_deadline(when); self.slab[key.index].when = when; self.slab[key.index].expired = false; self.insert_idx(when, key.index); let next_deadline = self.next_deadline(); if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { // This should awaken us if necessary (ie, if already expired) delay.as_mut().reset(deadline); } } /// Returns the next time to poll as determined by the wheel fn next_deadline(&mut self) -> Option { self.wheel .poll_at() .map(|poll_at| self.start + Duration::from_millis(poll_at)) } /// Sets the delay of the item associated with `key` to expire after /// `timeout`. /// /// This function is identical to `reset_at` but takes a `Duration` instead /// of an `Instant`. /// /// The item remains in the queue but the delay is set to expire after /// `timeout`. If `timeout` is zero, then the item is immediately made /// available to the caller. /// /// # Panics /// /// This function panics if `timeout` is greater than the maximum supported /// duration or if `key` is not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // "foo" is scheduled to be returned in 5 seconds /// /// delay_queue.reset(&key, Duration::from_secs(10)); /// /// // "foo"is now scheduled to be returned in 10 seconds /// # } /// ``` pub fn reset(&mut self, key: &Key, timeout: Duration) { self.reset_at(key, Instant::now() + timeout); } /// Clears the queue, removing all items. /// /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. /// /// Note that this method has no effect on the allocated capacity. /// /// [`poll_expired`]: method@Self::poll_expired /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// /// delay_queue.insert("foo", Duration::from_secs(5)); /// /// assert!(!delay_queue.is_empty()); /// /// delay_queue.clear(); /// /// assert!(delay_queue.is_empty()); /// # } /// ``` pub fn clear(&mut self) { self.slab.clear(); self.expired = Stack::default(); self.wheel = Wheel::new(); self.delay = None; } /// Returns the number of elements the queue can hold without reallocating. /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// /// let delay_queue: DelayQueue = DelayQueue::with_capacity(10); /// assert_eq!(delay_queue.capacity(), 10); /// ``` pub fn capacity(&self) -> usize { self.slab.capacity() } /// Returns the number of elements currently in the queue. /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue: DelayQueue = DelayQueue::with_capacity(10); /// assert_eq!(delay_queue.len(), 0); /// delay_queue.insert(3, Duration::from_secs(5)); /// assert_eq!(delay_queue.len(), 1); /// # } /// ``` pub fn len(&self) -> usize { self.slab.len() } /// Reserves capacity for at least `additional` more items to be queued /// without allocating. /// /// `reserve` does nothing if the queue already has sufficient capacity for /// `additional` more values. If more capacity is required, a new segment of /// memory will be allocated and all existing values will be copied into it. /// As such, if the queue is already very large, a call to `reserve` can end /// up being expensive. /// /// The queue may reserve more than `additional` extra space in order to /// avoid frequent reallocations. /// /// # Panics /// /// Panics if the new capacity exceeds the maximum number of entries the /// queue can contain. /// /// # Examples /// /// ``` /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// /// delay_queue.insert("hello", Duration::from_secs(10)); /// delay_queue.reserve(10); /// /// assert!(delay_queue.capacity() >= 11); /// # } /// ``` pub fn reserve(&mut self, additional: usize) { self.slab.reserve(additional); } /// Returns `true` if there are no items in the queue. /// /// Note that this function returns `false` even if all items have not yet /// expired and a call to `poll` will return `Poll::Pending`. /// /// # Examples /// /// ``` /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// assert!(delay_queue.is_empty()); /// /// delay_queue.insert("hello", Duration::from_secs(5)); /// assert!(!delay_queue.is_empty()); /// # } /// ``` pub fn is_empty(&self) -> bool { self.slab.is_empty() } /// Polls the queue, returning the index of the next slot in the slab that /// should be returned. /// /// A slot should be returned when the associated deadline has been reached. fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll>> { use self::wheel::Stack; let expired = self.expired.pop(&mut self.slab); if expired.is_some() { return Poll::Ready(expired.map(Ok)); } loop { if let Some(ref mut delay) = self.delay { if !delay.is_elapsed() { ready!(Pin::new(&mut *delay).poll(cx)); } let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); self.wheel_now = now; } // We poll the wheel to get the next value out before finding the next deadline. let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab); self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when))); if let Some(idx) = wheel_idx { return Poll::Ready(Some(Ok(idx))); } if self.delay.is_none() { return Poll::Ready(None); } } } fn normalize_deadline(&self, when: Instant) -> u64 { let when = if when < self.start { 0 } else { crate::time::ms(when - self.start, crate::time::Round::Up) }; cmp::max(when, self.wheel.elapsed()) } } // We never put `T` in a `Pin`... impl Unpin for DelayQueue {} impl Default for DelayQueue { fn default() -> DelayQueue { DelayQueue::new() } } impl futures_core::Stream for DelayQueue { // DelayQueue seems much more specific, where a user may care that it // has reached capacity, so return those errors instead of panicking. type Item = Result, Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { DelayQueue::poll_expired(self.get_mut(), cx) } } impl wheel::Stack for Stack { type Owned = usize; type Borrowed = usize; type Store = Slab>; fn is_empty(&self) -> bool { self.head.is_none() } fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { // Ensure the entry is not already in a stack. debug_assert!(store[item].next.is_none()); debug_assert!(store[item].prev.is_none()); // Remove the old head entry let old = self.head.take(); if let Some(idx) = old { store[idx].prev = Some(item); } store[item].next = old; self.head = Some(item) } fn pop(&mut self, store: &mut Self::Store) -> Option { if let Some(idx) = self.head { self.head = store[idx].next; if let Some(idx) = self.head { store[idx].prev = None; } store[idx].next = None; debug_assert!(store[idx].prev.is_none()); Some(idx) } else { None } } fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { assert!(store.contains(*item)); // Ensure that the entry is in fact contained by the stack debug_assert!({ // This walks the full linked list even if an entry is found. let mut next = self.head; let mut contains = false; while let Some(idx) = next { if idx == *item { debug_assert!(!contains); contains = true; } next = store[idx].next; } contains }); if let Some(next) = store[*item].next { store[next].prev = store[*item].prev; } if let Some(prev) = store[*item].prev { store[prev].next = store[*item].next; } else { self.head = store[*item].next; } store[*item].next = None; store[*item].prev = None; } fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { store[*item].when } } impl Default for Stack { fn default() -> Stack { Stack { head: None, _p: PhantomData, } } } impl Key { pub(crate) fn new(index: usize) -> Key { Key { index } } } impl Expired { /// Returns a reference to the inner value. pub fn get_ref(&self) -> &T { &self.data } /// Returns a mutable reference to the inner value. pub fn get_mut(&mut self) -> &mut T { &mut self.data } /// Consumes `self` and returns the inner value. pub fn into_inner(self) -> T { self.data } /// Returns the deadline that the expiration was set to. pub fn deadline(&self) -> Instant { self.deadline } /// Returns the key that the expiration is indexed by. pub fn key(&self) -> Key { self.key } } tokio-util-0.6.9/src/time/mod.rs000064400000000000000000000022110072674642500146360ustar 00000000000000//! Additional utilities for tracking time. //! //! This module provides additional utilities for executing code after a set period //! of time. Currently there is only one: //! //! * `DelayQueue`: A queue where items are returned once the requested delay //! has expired. //! //! This type must be used from within the context of the `Runtime`. use std::time::Duration; mod wheel; pub mod delay_queue; #[doc(inline)] pub use delay_queue::DelayQueue; // ===== Internal utils ===== enum Round { Up, Down, } /// Convert a `Duration` to milliseconds, rounding up and saturating at /// `u64::MAX`. /// /// The saturating is fine because `u64::MAX` milliseconds are still many /// million years. #[inline] fn ms(duration: Duration, round: Round) -> u64 { const NANOS_PER_MILLI: u32 = 1_000_000; const MILLIS_PER_SEC: u64 = 1_000; // Round up. let millis = match round { Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, Round::Down => duration.subsec_millis(), }; duration .as_secs() .saturating_mul(MILLIS_PER_SEC) .saturating_add(u64::from(millis)) } tokio-util-0.6.9/src/time/wheel/level.rs000064400000000000000000000151300072674642500162760ustar 00000000000000use crate::time::wheel::Stack; use std::fmt; /// Wheel for a single level in the timer. This wheel contains 64 slots. pub(crate) struct Level { level: usize, /// Bit field tracking which slots currently contain entries. /// /// Using a bit field to track slots that contain entries allows avoiding a /// scan to find entries. This field is updated when entries are added or /// removed from a slot. /// /// The least-significant bit represents slot zero. occupied: u64, /// Slots slot: [T; LEVEL_MULT], } /// Indicates when a slot must be processed next. #[derive(Debug)] pub(crate) struct Expiration { /// The level containing the slot. pub(crate) level: usize, /// The slot index. pub(crate) slot: usize, /// The instant at which the slot needs to be processed. pub(crate) deadline: u64, } /// Level multiplier. /// /// Being a power of 2 is very important. const LEVEL_MULT: usize = 64; impl Level { pub(crate) fn new(level: usize) -> Level { // Rust's derived implementations for arrays require that the value // contained by the array be `Copy`. So, here we have to manually // initialize every single slot. macro_rules! s { () => { T::default() }; } Level { level, occupied: 0, slot: [ // It does not look like the necessary traits are // derived for [T; 64]. s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), ], } } /// Finds the slot that needs to be processed next and returns the slot and /// `Instant` at which this slot must be processed. pub(crate) fn next_expiration(&self, now: u64) -> Option { // Use the `occupied` bit field to get the index of the next slot that // needs to be processed. let slot = match self.next_occupied_slot(now) { Some(slot) => slot, None => return None, }; // From the slot index, calculate the `Instant` at which it needs to be // processed. This value *must* be in the future with respect to `now`. let level_range = level_range(self.level); let slot_range = slot_range(self.level); // TODO: This can probably be simplified w/ power of 2 math let level_start = now - (now % level_range); let deadline = level_start + slot as u64 * slot_range; debug_assert!( deadline >= now, "deadline={}; now={}; level={}; slot={}; occupied={:b}", deadline, now, self.level, slot, self.occupied ); Some(Expiration { level: self.level, slot, deadline, }) } fn next_occupied_slot(&self, now: u64) -> Option { if self.occupied == 0 { return None; } // Get the slot for now using Maths let now_slot = (now / slot_range(self.level)) as usize; let occupied = self.occupied.rotate_right(now_slot as u32); let zeros = occupied.trailing_zeros() as usize; let slot = (zeros + now_slot) % 64; Some(slot) } pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { let slot = slot_for(when, self.level); self.slot[slot].push(item, store); self.occupied |= occupied_bit(slot); } pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { let slot = slot_for(when, self.level); self.slot[slot].remove(item, store); if self.slot[slot].is_empty() { // The bit is currently set debug_assert!(self.occupied & occupied_bit(slot) != 0); // Unset the bit self.occupied ^= occupied_bit(slot); } } pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option { let ret = self.slot[slot].pop(store); if ret.is_some() && self.slot[slot].is_empty() { // The bit is currently set debug_assert!(self.occupied & occupied_bit(slot) != 0); self.occupied ^= occupied_bit(slot); } ret } } impl fmt::Debug for Level { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Level") .field("occupied", &self.occupied) .finish() } } fn occupied_bit(slot: usize) -> u64 { 1 << slot } fn slot_range(level: usize) -> u64 { LEVEL_MULT.pow(level as u32) as u64 } fn level_range(level: usize) -> u64 { LEVEL_MULT as u64 * slot_range(level) } /// Convert a duration (milliseconds) and a level to a slot position fn slot_for(duration: u64, level: usize) -> usize { ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize } #[cfg(all(test, not(loom)))] mod test { use super::*; #[test] fn test_slot_for() { for pos in 0..64 { assert_eq!(pos as usize, slot_for(pos, 0)); } for level in 1..5 { for pos in level..64 { let a = pos * 64_usize.pow(level as u32); assert_eq!(pos as usize, slot_for(a as u64, level)); } } } } tokio-util-0.6.9/src/time/wheel/mod.rs000064400000000000000000000220610072674642500157470ustar 00000000000000mod level; pub(crate) use self::level::Expiration; use self::level::Level; mod stack; pub(crate) use self::stack::Stack; use std::borrow::Borrow; use std::usize; /// Timing wheel implementation. /// /// This type provides the hashed timing wheel implementation that backs `Timer` /// and `DelayQueue`. /// /// The structure is generic over `T: Stack`. This allows handling timeout data /// being stored on the heap or in a slab. In order to support the latter case, /// the slab must be passed into each function allowing the implementation to /// lookup timer entries. /// /// See `Timer` documentation for some implementation notes. #[derive(Debug)] pub(crate) struct Wheel { /// The number of milliseconds elapsed since the wheel started. elapsed: u64, /// Timer wheel. /// /// Levels: /// /// * 1 ms slots / 64 ms range /// * 64 ms slots / ~ 4 sec range /// * ~ 4 sec slots / ~ 4 min range /// * ~ 4 min slots / ~ 4 hr range /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Vec>, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots /// each, the timer is able to track time up to 2 years into the future with a /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; /// The maximum duration of a delay const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; #[derive(Debug)] pub(crate) enum InsertError { Elapsed, Invalid, } impl Wheel where T: Stack, { /// Create a new timing wheel pub(crate) fn new() -> Wheel { let levels = (0..NUM_LEVELS).map(Level::new).collect(); Wheel { elapsed: 0, levels } } /// Return the number of milliseconds that have elapsed since the timing /// wheel's creation. pub(crate) fn elapsed(&self) -> u64 { self.elapsed } /// Insert an entry into the timing wheel. /// /// # Arguments /// /// * `when`: is the instant at which the entry should be fired. It is /// represented as the number of milliseconds since the creation /// of the timing wheel. /// /// * `item`: The item to insert into the wheel. /// /// * `store`: The slab or `()` when using heap storage. /// /// # Return /// /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. /// /// `Err(Elapsed)` indicates that `when` represents an instant that has /// already passed. In this case, the caller should fire the timeout /// immediately. /// /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. pub(crate) fn insert( &mut self, when: u64, item: T::Owned, store: &mut T::Store, ) -> Result<(), (T::Owned, InsertError)> { if when <= self.elapsed { return Err((item, InsertError::Elapsed)); } else if when - self.elapsed > MAX_DURATION { return Err((item, InsertError::Invalid)); } // Get the level at which the entry should be stored let level = self.level_for(when); self.levels[level].add_entry(when, item, store); debug_assert!({ self.levels[level] .next_expiration(self.elapsed) .map(|e| e.deadline >= self.elapsed) .unwrap_or(true) }); Ok(()) } /// Remove `item` from the timing wheel. pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { let when = T::when(item, store); assert!( self.elapsed <= when, "elapsed={}; when={}", self.elapsed, when ); let level = self.level_for(when); self.levels[level].remove_entry(when, item, store); } /// Instant at which to poll pub(crate) fn poll_at(&self) -> Option { self.next_expiration().map(|expiration| expiration.deadline) } /// Advances the timer up to the instant represented by `now`. pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option { loop { let expiration = self.next_expiration().and_then(|expiration| { if expiration.deadline > now { None } else { Some(expiration) } }); match expiration { Some(ref expiration) => { if let Some(item) = self.poll_expiration(expiration, store) { return Some(item); } self.set_elapsed(expiration.deadline); } None => { // in this case the poll did not indicate an expiration // _and_ we were not able to find a next expiration in // the current list of timers. advance to the poll's // current time and do nothing else. self.set_elapsed(now); return None; } } } } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option { // Check all levels for level in 0..NUM_LEVELS { if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { // There cannot be any expirations at a higher level that happen // before this one. debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); return Some(expiration); } } None } /// Used for debug assertions fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { let mut res = true; for l2 in start_level..NUM_LEVELS { if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { if e2.deadline < before { res = false; } } } res } /// iteratively find entries that are between the wheel's current /// time and the expiration time. for each in that population either /// return it for notification (in the case of the last level) or tier /// it down to the next level (in all other cases). pub(crate) fn poll_expiration( &mut self, expiration: &Expiration, store: &mut T::Store, ) -> Option { while let Some(item) = self.pop_entry(expiration, store) { if expiration.level == 0 { debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); return Some(item); } else { let when = T::when(item.borrow(), store); let next_level = expiration.level - 1; self.levels[next_level].add_entry(when, item, store); } } None } fn set_elapsed(&mut self, when: u64) { assert!( self.elapsed <= when, "elapsed={:?}; when={:?}", self.elapsed, when ); if when > self.elapsed { self.elapsed = when; } } fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option { self.levels[expiration.level].pop_entry_slot(expiration.slot, store) } fn level_for(&self, when: u64) -> usize { level_for(self.elapsed, when) } } fn level_for(elapsed: u64, when: u64) -> usize { const SLOT_MASK: u64 = (1 << 6) - 1; // Mask in the trailing bits ignored by the level calculation in order to cap // the possible leading zeros let masked = elapsed ^ when | SLOT_MASK; let leading_zeros = masked.leading_zeros() as usize; let significant = 63 - leading_zeros; significant / 6 } #[cfg(all(test, not(loom)))] mod test { use super::*; #[test] fn test_level_for() { for pos in 0..64 { assert_eq!( 0, level_for(0, pos), "level_for({}) -- binary = {:b}", pos, pos ); } for level in 1..5 { for pos in level..64 { let a = pos * 64_usize.pow(level as u32); assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); if pos > level { let a = a - 1; assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); } if pos < 64 { let a = a + 1; assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); } } } } } tokio-util-0.6.9/src/time/wheel/stack.rs000064400000000000000000000014060072674642500162750ustar 00000000000000use std::borrow::Borrow; /// Abstracts the stack operations needed to track timeouts. pub(crate) trait Stack: Default { /// Type of the item stored in the stack type Owned: Borrow; /// Borrowed item type Borrowed; /// Item storage, this allows a slab to be used instead of just the heap type Store; /// Returns `true` if the stack is empty fn is_empty(&self) -> bool; /// Push an item onto the stack fn push(&mut self, item: Self::Owned, store: &mut Self::Store); /// Pop an item from the stack fn pop(&mut self, store: &mut Self::Store) -> Option; fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; } tokio-util-0.6.9/src/udp/frame.rs000064400000000000000000000170620072674642500150150ustar 00000000000000use crate::codec::{Decoder, Encoder}; use futures_core::Stream; use tokio::{io::ReadBuf, net::UdpSocket}; use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; use std::pin::Pin; use std::task::{Context, Poll}; use std::{ borrow::Borrow, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, }; use std::{io, mem::MaybeUninit}; /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using /// the `Encoder` and `Decoder` traits to encode and decode frames. /// /// Raw UDP sockets work with datagrams, but higher-level code usually wants to /// batch these into meaningful chunks, called "frames". This method layers /// framing on top of this socket by using the `Encoder` and `Decoder` traits to /// handle encoding and decoding of messages frames. Note that the incoming and /// outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; /// grouping this into a single object is often useful for layering things which /// require both read and write access to the underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `UdpFramed` returned by this method, which will break /// them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split #[must_use = "sinks do nothing unless polled"] #[derive(Debug)] pub struct UdpFramed { socket: T, codec: C, rd: BytesMut, wr: BytesMut, out_addr: SocketAddr, flushed: bool, is_readable: bool, current_addr: Option, } const INITIAL_RD_CAPACITY: usize = 64 * 1024; const INITIAL_WR_CAPACITY: usize = 8 * 1024; impl Unpin for UdpFramed {} impl Stream for UdpFramed where T: Borrow, C: Decoder, { type Item = Result<(C::Item, SocketAddr), C::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pin = self.get_mut(); pin.rd.reserve(INITIAL_RD_CAPACITY); loop { // Are there still bytes left in the read buffer to decode? if pin.is_readable { if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { let current_addr = pin .current_addr .expect("will always be set before this line is called"); return Poll::Ready(Some(Ok((frame, current_addr)))); } // if this line has been reached then decode has returned `None`. pin.is_readable = false; pin.rd.clear(); } // We're out of data. Try and fetch more data to decode let addr = unsafe { // Convert `&mut [MaybeUnit]` to `&mut [u8]` because we will be // writing to it via `poll_recv_from` and therefore initializing the memory. let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit]); let mut read = ReadBuf::uninit(buf); let ptr = read.filled().as_ptr(); let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); assert_eq!(ptr, read.filled().as_ptr()); let addr = res?; pin.rd.advance_mut(read.filled().len()); addr }; pin.current_addr = Some(addr); pin.is_readable = true; } } } impl Sink<(I, SocketAddr)> for UdpFramed where T: Borrow, C: Encoder, { type Error = C::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.flushed { match self.poll_flush(cx)? { Poll::Ready(()) => {} Poll::Pending => return Poll::Pending, } } Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { let (frame, out_addr) = item; let pin = self.get_mut(); pin.codec.encode(frame, &mut pin.wr)?; pin.out_addr = out_addr; pin.flushed = false; Ok(()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.flushed { return Poll::Ready(Ok(())); } let Self { ref socket, ref mut out_addr, ref mut wr, .. } = *self; let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; let wrote_all = n == self.wr.len(); self.wr.clear(); self.flushed = true; let res = if wrote_all { Ok(()) } else { Err(io::Error::new( io::ErrorKind::Other, "failed to write entire datagram to socket", ) .into()) }; Poll::Ready(res) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.poll_flush(cx))?; Poll::Ready(Ok(())) } } impl UdpFramed where T: Borrow, { /// Create a new `UdpFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. pub fn new(socket: T, codec: C) -> UdpFramed { Self { socket, codec, out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), flushed: true, is_readable: false, current_addr: None, } } /// Returns a reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_ref(&self) -> &T { &self.socket } /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_mut(&mut self) -> &mut T { &mut self.socket } /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &C { &self.codec } /// Returns a mutable reference to the underlying codec wrapped by /// `UdpFramed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut C { &mut self.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.rd } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.rd } /// Consumes the `Framed`, returning its underlying I/O stream. pub fn into_inner(self) -> T { self.socket } } tokio-util-0.6.9/src/udp/mod.rs000064400000000000000000000000660072674642500144760ustar 00000000000000//! UDP framing mod frame; pub use frame::UdpFramed; tokio-util-0.6.9/tests/_require_full.rs000064400000000000000000000001360072674642500163350ustar 00000000000000#![cfg(not(feature = "full"))] compile_error!("run tokio-util tests with `--features full`"); tokio-util-0.6.9/tests/codecs.rs000064400000000000000000000322250072674642500147440ustar 00000000000000#![warn(rust_2018_idioms)] use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec}; use bytes::{BufMut, Bytes, BytesMut}; #[test] fn bytes_decoder() { let mut codec = BytesCodec::new(); let buf = &mut BytesMut::new(); buf.put_slice(b"abc"); assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"a"); assert_eq!("a", codec.decode(buf).unwrap().unwrap()); } #[test] fn bytes_encoder() { let mut codec = BytesCodec::new(); // Default capacity of BytesMut #[cfg(target_pointer_width = "64")] const INLINE_CAP: usize = 4 * 8 - 1; #[cfg(target_pointer_width = "32")] const INLINE_CAP: usize = 4 * 4 - 1; let mut buf = BytesMut::new(); codec .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) .unwrap(); // Default capacity of Framed Read const INITIAL_CAPACITY: usize = 8 * 1024; let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); codec .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) .unwrap(); } #[test] fn lines_decoder() { let mut codec = LinesCodec::new(); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); assert_eq!("", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); } #[test] fn lines_decoder_max_length() { const MAX_LENGTH: usize = 6; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); assert!(codec.decode(buf).is_err()); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("line 2", line); assert!(codec.decode(buf).is_err()); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("line 4", line); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("", line); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); let line = codec.decode_eof(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("\rk", line); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); // Line that's one character too long. This could cause an out of bounds // error if we peek at the next characters using slice indexing. buf.put_slice(b"aaabbbc"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_underrun() { const MAX_LENGTH: usize = 6; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"ong\n"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"line 2"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"\n"); assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); } #[test] fn lines_decoder_max_length_bursts() { const MAX_LENGTH: usize = 10; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\n"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_big_burst() { const MAX_LENGTH: usize = 10; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too long!\n"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_newline_between_decodes() { const MAX_LENGTH: usize = 5; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"hello"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"\nworld"); assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); } // Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) #[test] fn lines_decoder_discard_repeat() { const MAX_LENGTH: usize = 1; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"aa"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"a"); assert_eq!(None, codec.decode(buf).unwrap()); } // Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555) #[test] fn lines_decoder_max_length_underrun_twice() { const MAX_LENGTH: usize = 11; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too very l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\nshort\n"); assert_eq!("short", codec.decode(buf).unwrap().unwrap()); } #[test] fn lines_encoder() { let mut codec = LinesCodec::new(); let mut buf = BytesMut::new(); codec.encode("line 1", &mut buf).unwrap(); assert_eq!("line 1\n", buf); codec.encode("line 2", &mut buf).unwrap(); assert_eq!("line 1\nline 2\n", buf); } #[test] fn any_delimiters_decoder_any_character() { let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec()); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); assert_eq!("", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); } #[test] fn any_delimiters_decoder_max_length() { const MAX_LENGTH: usize = 7; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n"); assert!(codec.decode(buf).is_err()); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 2", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 3", chunk); // \r\n cause empty chunk let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 4", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); let chunk = codec.decode_eof(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("k", chunk); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); // Delimiter that's one character too long. This could cause an out of bounds // error if we peek at the next characters using slice indexing. buf.put_slice(b"aaabbbcc"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_underrun() { const MAX_LENGTH: usize = 7; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"ong\n"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"chunk 2"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b","); assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_max_length_underrun_twice() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too very l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\nshort\n"); assert_eq!("short", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_max_length_bursts() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\n"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_big_burst() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too long!\n"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_delimiter_between_decodes() { const MAX_LENGTH: usize = 5; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"hello"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b",world"); assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_discard_repeat() { const MAX_LENGTH: usize = 1; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"aa"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"a"); assert_eq!(None, codec.decode(buf).unwrap()); } #[test] fn any_delimiter_encoder() { let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec()); let mut buf = BytesMut::new(); codec.encode("chunk 1", &mut buf).unwrap(); assert_eq!("chunk 1;--;", buf); codec.encode("chunk 2", &mut buf).unwrap(); assert_eq!("chunk 1;--;chunk 2;--;", buf); } tokio-util-0.6.9/tests/context.rs000064400000000000000000000012500072674642500151620ustar 00000000000000#![cfg(feature = "rt")] #![warn(rust_2018_idioms)] use tokio::runtime::Builder; use tokio::time::*; use tokio_util::context::RuntimeExt; #[test] fn tokio_context_with_another_runtime() { let rt1 = Builder::new_multi_thread() .worker_threads(1) // no timer! .build() .unwrap(); let rt2 = Builder::new_multi_thread() .worker_threads(1) .enable_all() .build() .unwrap(); // Without the `HandleExt.wrap()` there would be a panic because there is // no timer running, since it would be referencing runtime r1. let _ = rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await })); } tokio-util-0.6.9/tests/framed.rs000064400000000000000000000045530072674642500147450ustar 00000000000000#![warn(rust_2018_idioms)] use tokio_stream::StreamExt; use tokio_test::assert_ok; use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; use bytes::{Buf, BufMut, BytesMut}; use std::io::{self, Read}; use std::pin::Pin; use std::task::{Context, Poll}; const INITIAL_CAPACITY: usize = 8 * 1024; /// Encode and decode u32 values. struct U32Codec; impl Decoder for U32Codec { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 4 { return Ok(None); } let n = buf.split_to(4).get_u32(); Ok(Some(n)) } } impl Encoder for U32Codec { type Error = io::Error; fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(4); dst.put_u32(item); Ok(()) } } /// This value should never be used struct DontReadIntoThis; impl Read for DontReadIntoThis { fn read(&mut self, _: &mut [u8]) -> io::Result { Err(io::Error::new( io::ErrorKind::Other, "Read into something you weren't supposed to.", )) } } impl tokio::io::AsyncRead for DontReadIntoThis { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { unreachable!() } } #[tokio::test] async fn can_read_from_existing_buf() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let mut framed = Framed::from_parts(parts); let num = assert_ok!(framed.next().await.unwrap()); assert_eq!(num, 42); } #[test] fn external_buf_grows_to_init() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let framed = Framed::from_parts(parts); let FramedParts { read_buf, .. } = framed.into_parts(); assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); } #[test] fn external_buf_does_not_shrink() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec); parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); let framed = Framed::from_parts(parts); let FramedParts { read_buf, .. } = framed.into_parts(); assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); } tokio-util-0.6.9/tests/framed_read.rs000064400000000000000000000201130072674642500157260ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::{AsyncRead, ReadBuf}; use tokio_test::assert_ready; use tokio_test::task; use tokio_util::codec::{Decoder, FramedRead}; use bytes::{Buf, BytesMut}; use futures::Stream; use std::collections::VecDeque; use std::io; use std::pin::Pin; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! assert_read { ($e:expr, $n:expr) => {{ let val = assert_ready!($e); assert_eq!(val.unwrap().unwrap(), $n); }}; } macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } struct U32Decoder; impl Decoder for U32Decoder { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 4 { return Ok(None); } let n = buf.split_to(4).get_u32(); Ok(Some(n)) } } #[test] fn read_multi_frame_in_packet() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_multi_frame_across_packets() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), Ok(b"\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_not_ready() { let mut task = task::spawn(()); let mock = mock! { Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_partial_then_not_ready() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_err() { let mut task = task::spawn(()); let mock = mock! { Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn read_partial_then_err() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn read_partial_would_block_then_err() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn huge_size() { let mut task = task::spawn(()); let data = &[0; 32 * 1024][..]; let mut framed = FramedRead::new(data, BigDecoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); struct BigDecoder; impl Decoder for BigDecoder { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 32 * 1024 { return Ok(None); } buf.advance(32 * 1024); Ok(Some(0)) } } } #[test] fn data_remaining_is_error() { let mut task = task::spawn(()); let slice = &[0; 5][..]; let mut framed = FramedRead::new(slice, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); }); } #[test] fn multi_frames_on_eof() { let mut task = task::spawn(()); struct MyDecoder(Vec); impl Decoder for MyDecoder { type Item = u32; type Error = io::Error; fn decode(&mut self, _buf: &mut BytesMut) -> io::Result> { unreachable!(); } fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result> { if self.0.is_empty() { return Ok(None); } Ok(Some(self.0.remove(0))) } } let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert_read!(pin!(framed).poll_next(cx), 3); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_eof_then_resume() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x01".to_vec()), Ok(b"".to_vec()), Ok(b"\x00\x00\x00\x02".to_vec()), Ok(b"".to_vec()), Ok(b"\x00\x00\x00\x03".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 1); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert_read!(pin!(framed).poll_next(cx), 3); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } // ===== Mock ====== struct Mock { calls: VecDeque>>, } impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { use io::ErrorKind::WouldBlock; match self.calls.pop_front() { Some(Ok(data)) => { debug_assert!(buf.remaining() >= data.len()); buf.put_slice(&data); Ready(Ok(())) } Some(Err(ref e)) if e.kind() == WouldBlock => Pending, Some(Err(e)) => Ready(Err(e)), None => Ready(Ok(())), } } } tokio-util-0.6.9/tests/framed_stream.rs000064400000000000000000000020270072674642500163120ustar 00000000000000use futures_core::stream::Stream; use std::{io, pin::Pin}; use tokio_test::{assert_ready, io::Builder, task}; use tokio_util::codec::{BytesCodec, FramedRead}; macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } macro_rules! assert_read { ($e:expr, $n:expr) => {{ let val = assert_ready!($e); assert_eq!(val.unwrap().unwrap(), $n); }}; } #[tokio::test] async fn return_none_after_error() { let mut io = FramedRead::new( Builder::new() .read(b"abcdef") .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) .read(b"more data") .build(), BytesCodec::new(), ); let mut task = task::spawn(()); task.enter(|cx, _| { assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); }) } tokio-util-0.6.9/tests/framed_write.rs000064400000000000000000000116140072674642500161530ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::AsyncWrite; use tokio_test::{assert_ready, task}; use tokio_util::codec::{Encoder, FramedWrite}; use bytes::{BufMut, BytesMut}; use futures_sink::Sink; use std::collections::VecDeque; use std::io::{self, Write}; use std::pin::Pin; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } struct U32Encoder; impl Encoder for U32Encoder { type Error = io::Error; fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(4); dst.put_u32(item); Ok(()) } } #[test] fn write_multi_frame_in_packet() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedWrite::new(mock, U32Encoder); task.enter(|cx, _| { assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(0).is_ok()); assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(1).is_ok()); assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(2).is_ok()); // Nothing written yet assert_eq!(1, framed.get_ref().calls.len()); // Flush the writes assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); assert_eq!(0, framed.get_ref().calls.len()); }); } #[test] fn write_hits_backpressure() { const ITER: usize = 2 * 1024; let mut mock = mock! { // Block the `ITER`th write Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), Ok(b"".to_vec()), }; for i in 0..=ITER { let mut b = BytesMut::with_capacity(4); b.put_u32(i as u32); // Append to the end match mock.calls.back_mut().unwrap() { Ok(ref mut data) => { // Write in 2kb chunks if data.len() < ITER { data.extend_from_slice(&b[..]); continue; } // else fall through and create a new buffer } _ => unreachable!(), } // Push a new new chunk mock.calls.push_back(Ok(b[..].to_vec())); } // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer assert_eq!(mock.calls.len(), 6); let mut task = task::spawn(()); let mut framed = FramedWrite::new(mock, U32Encoder); task.enter(|cx, _| { // Send 8KB. This fills up FramedWrite2 buffer for i in 0..ITER { assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(i as u32).is_ok()); } // Now we poll_ready which forces a flush. The mock pops the front message // and decides to block. assert!(pin!(framed).poll_ready(cx).is_pending()); // We poll again, forcing another flush, which this time succeeds // The whole 8KB buffer is flushed assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); // Send more data. This matches the final message expected by the mock assert!(pin!(framed).start_send(ITER as u32).is_ok()); // Flush the rest of the buffer assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); // Ensure the mock is empty assert_eq!(0, framed.get_ref().calls.len()); }) } // // ===== Mock ====== struct Mock { calls: VecDeque>>, } impl Write for Mock { fn write(&mut self, src: &[u8]) -> io::Result { match self.calls.pop_front() { Some(Ok(data)) => { assert!(src.len() >= data.len()); assert_eq!(&data[..], &src[..data.len()]); Ok(data.len()) } Some(Err(e)) => Err(e), None => panic!("unexpected write; {:?}", src), } } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl AsyncWrite for Mock { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match Pin::get_mut(self).write(buf) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, other => Ready(other), } } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { match Pin::get_mut(self).flush() { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, other => Ready(other), } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { unimplemented!() } } tokio-util-0.6.9/tests/io_reader_stream.rs000064400000000000000000000034040072674642500170050ustar 00000000000000#![warn(rust_2018_idioms)] use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, ReadBuf}; use tokio_stream::StreamExt; /// produces at most `remaining` zeros, that returns error. /// each time it reads at most 31 byte. struct Reader { remaining: usize, } impl AsyncRead for Reader { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let this = Pin::into_inner(self); assert_ne!(buf.remaining(), 0); if this.remaining > 0 { let n = std::cmp::min(this.remaining, buf.remaining()); let n = std::cmp::min(n, 31); for x in &mut buf.initialize_unfilled_to(n)[..n] { *x = 0; } buf.advance(n); this.remaining -= n; Poll::Ready(Ok(())) } else { Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) } } } #[tokio::test] async fn correct_behavior_on_errors() { let reader = Reader { remaining: 8000 }; let mut stream = tokio_util::io::ReaderStream::new(reader); let mut zeros_received = 0; let mut had_error = false; loop { let item = stream.next().await.unwrap(); println!("{:?}", item); match item { Ok(bytes) => { let bytes = &*bytes; for byte in bytes { assert_eq!(*byte, 0); zeros_received += 1; } } Err(_) => { assert!(!had_error); had_error = true; break; } } } assert!(had_error); assert_eq!(zeros_received, 8000); assert!(stream.next().await.is_none()); } tokio-util-0.6.9/tests/io_stream_reader.rs000064400000000000000000000016510072674642500170070ustar 00000000000000#![warn(rust_2018_idioms)] use bytes::Bytes; use tokio::io::AsyncReadExt; use tokio_stream::iter; use tokio_util::io::StreamReader; #[tokio::test] async fn test_stream_reader() -> std::io::Result<()> { let stream = iter(vec![ std::io::Result::Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[0, 1, 2, 3])), Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[4, 5, 6, 7])), Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[8, 9, 10, 11])), Ok(Bytes::from_static(&[])), ]); let mut read = StreamReader::new(stream); let mut buf = [0; 5]; read.read_exact(&mut buf).await?; assert_eq!(buf, [0, 1, 2, 3, 4]); assert_eq!(read.read(&mut buf).await?, 3); assert_eq!(&buf[..3], [5, 6, 7]); assert_eq!(read.read(&mut buf).await?, 4); assert_eq!(&buf[..4], [8, 9, 10, 11]); assert_eq!(read.read(&mut buf).await?, 0); Ok(()) } tokio-util-0.6.9/tests/io_sync_bridge.rs000064400000000000000000000023270072674642500164630ustar 00000000000000#![cfg(feature = "io-util")] use std::error::Error; use std::io::{Cursor, Read, Result as IoResult}; use tokio::io::AsyncRead; use tokio_util::io::SyncIoBridge; async fn test_reader_len( r: impl AsyncRead + Unpin + Send + 'static, expected_len: usize, ) -> IoResult<()> { let mut r = SyncIoBridge::new(r); let res = tokio::task::spawn_blocking(move || { let mut buf = Vec::new(); r.read_to_end(&mut buf)?; Ok::<_, std::io::Error>(buf) }) .await?; assert_eq!(res?.len(), expected_len); Ok(()) } #[tokio::test] async fn test_async_read_to_sync() -> Result<(), Box> { test_reader_len(tokio::io::empty(), 0).await?; let buf = b"hello world"; test_reader_len(Cursor::new(buf), buf.len()).await?; Ok(()) } #[tokio::test] async fn test_async_write_to_sync() -> Result<(), Box> { let mut dest = Vec::new(); let src = b"hello world"; let dest = tokio::task::spawn_blocking(move || -> Result<_, String> { let mut w = SyncIoBridge::new(Cursor::new(&mut dest)); std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?; Ok(dest) }) .await??; assert_eq!(dest.as_slice(), src); Ok(()) } tokio-util-0.6.9/tests/length_delimited.rs000064400000000000000000000465460072674642500170200ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, }; use tokio_util::codec::*; use bytes::{BufMut, Bytes, BytesMut}; use futures::{pin_mut, Sink, Stream}; use std::collections::VecDeque; use std::io; use std::pin::Pin; use std::task::Poll::*; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! assert_next_eq { ($io:ident, $expect:expr) => {{ task::spawn(()).enter(|cx, _| { let res = assert_ready!($io.as_mut().poll_next(cx)); match res { Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), Some(Err(e)) => panic!("error = {:?}", e), None => panic!("none"), } }); }}; } macro_rules! assert_next_pending { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { Ready(Some(Ok(v))) => panic!("value = {:?}", v), Ready(Some(Err(e))) => panic!("error = {:?}", e), Ready(None) => panic!("done"), Pending => {} }); }}; } macro_rules! assert_next_err { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { Ready(Some(Ok(v))) => panic!("value = {:?}", v), Ready(Some(Err(_))) => {} Ready(None) => panic!("done"), Pending => panic!("pending"), }); }}; } macro_rules! assert_done { ($io:ident) => {{ task::spawn(()).enter(|cx, _| { let res = assert_ready!($io.as_mut().poll_next(cx)); match res { Some(Ok(v)) => panic!("value = {:?}", v), Some(Err(e)) => panic!("error = {:?}", e), None => {} } }); }}; } #[test] fn read_empty_io_yields_nothing() { let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); pin_mut!(io); assert_done!(io); } #[test] fn read_single_frame_one_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00\x00\x09abcdefghi"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_frame_one_packet_little_endian() { let io = length_delimited::Builder::new() .little_endian() .new_read(mock! { data(b"\x09\x00\x00\x00abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_frame_one_packet_native_endian() { let d = if cfg!(target_endian = "big") { b"\x00\x00\x00\x09abcdefghi" } else { b"\x09\x00\x00\x00abcdefghi" }; let io = length_delimited::Builder::new() .native_endian() .new_read(mock! { data(d), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); d.extend_from_slice(b"\x00\x00\x00\x03123"); d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); let io = FramedRead::new( mock! { data(&d), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn read_single_frame_multi_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00"), data(b"\x00\x09abc"), data(b"defghi"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_multi_frame_multi_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00"), data(b"\x00\x09abc"), data(b"defghi"), data(b"\x00\x00\x00\x0312"), data(b"3\x00\x00\x00\x0bhello world"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn read_single_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), Pending, data(b"\x00\x09abc"), Pending, data(b"defghi"), Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_pending!(io); assert_done!(io); } #[test] fn read_multi_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), Pending, data(b"\x00\x09abc"), Pending, data(b"defghi"), Pending, data(b"\x00\x00\x00\x0312"), Pending, data(b"3\x00\x00\x00\x0bhello world"), Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_next_pending!(io); assert_done!(io); } #[test] fn read_incomplete_head() { let io = FramedRead::new( mock! { data(b"\x00\x00"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_err!(io); } #[test] fn read_incomplete_head_multi() { let io = FramedRead::new( mock! { Pending, data(b"\x00"), Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_err!(io); } #[test] fn read_incomplete_payload() { let io = FramedRead::new( mock! { data(b"\x00\x00\x00\x09ab"), Pending, data(b"cd"), Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_err!(io); } #[test] fn read_max_frame_len() { let io = length_delimited::Builder::new() .max_frame_length(5) .new_read(mock! { data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_err!(io); } #[test] fn read_update_max_frame_len_at_rest() { let io = length_delimited::Builder::new().new_read(mock! { data(b"\x00\x00\x00\x09abcdefghi"), data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); io.decoder_mut().set_max_frame_length(5); assert_next_err!(io); } #[test] fn read_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_read(mock! { data(b"\x00\x00\x00\x09abcd"), Pending, data(b"efghi"), data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_pending!(io); io.decoder_mut().set_max_frame_length(5); assert_next_eq!(io, b"abcdefghi"); assert_next_err!(io); } #[test] fn read_one_byte_length_field() { let io = length_delimited::Builder::new() .length_field_length(1) .new_read(mock! { data(b"\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_header_offset() { let io = length_delimited::Builder::new() .length_field_length(2) .length_field_offset(4) .new_read(mock! { data(b"zzzz\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet_skip_none_adjusted() { let mut d: Vec = vec![]; d.extend_from_slice(b"xx\x00\x09abcdefghi"); d.extend_from_slice(b"yy\x00\x03123"); d.extend_from_slice(b"zz\x00\x0bhello world"); let io = length_delimited::Builder::new() .length_field_length(2) .length_field_offset(2) .num_skip(0) .length_adjustment(4) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"xx\x00\x09abcdefghi"); assert_next_eq!(io, b"yy\x00\x03123"); assert_next_eq!(io, b"zz\x00\x0bhello world"); assert_done!(io); } #[test] fn read_single_frame_length_adjusted() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x00\x0b\x0cHello world"); let io = length_delimited::Builder::new() .length_field_offset(0) .length_field_length(3) .length_adjustment(0) .num_skip(4) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"Hello world"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet_length_includes_head() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x0babcdefghi"); d.extend_from_slice(b"\x00\x05123"); d.extend_from_slice(b"\x00\x0dhello world"); let io = length_delimited::Builder::new() .length_field_length(2) .length_adjustment(-2) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn write_single_frame_length_adjusted() { let io = length_delimited::Builder::new() .length_adjustment(-2) .new_write(mock! { data(b"\x00\x00\x00\x0b"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_nothing_yields_nothing() { let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.poll_flush(cx)); }); } #[test] fn write_single_frame_one_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_multi_frame_one_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), data(b"\x00\x00\x00\x03"), data(b"123"), data(b"\x00\x00\x00\x0b"), data(b"hello world"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("123"))); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_multi_frame_multi_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), flush(), data(b"\x00\x00\x00\x03"), data(b"123"), flush(), data(b"\x00\x00\x00\x0b"), data(b"hello world"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("123"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_would_block() { let io = FramedWrite::new( mock! { Pending, data(b"\x00\x00"), Pending, data(b"\x00\x09"), data(b"abcdefghi"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_pending!(io.as_mut().poll_flush(cx)); assert_pending!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_little_endian() { let io = length_delimited::Builder::new() .little_endian() .new_write(mock! { data(b"\x09\x00\x00\x00"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_with_short_length_field() { let io = length_delimited::Builder::new() .length_field_length(1) .new_write(mock! { data(b"\x09"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_max_frame_len() { let io = length_delimited::Builder::new() .max_frame_length(5) .new_write(mock! {}); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_update_max_frame_len_at_rest() { let io = length_delimited::Builder::new().new_write(mock! { data(b"\x00\x00\x00\x06"), data(b"abcdef"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); io.encoder_mut().set_max_frame_length(5); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_write(mock! { data(b"\x00\x00\x00\x06"), data(b"ab"), Pending, data(b"cdef"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_pending!(io.as_mut().poll_flush(cx)); io.encoder_mut().set_max_frame_length(5); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_zero() { let io = length_delimited::Builder::new().new_write(mock! {}); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_ready_err!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn encode_overflow() { // Test reproducing tokio-rs/tokio#681. let mut codec = length_delimited::Builder::new().new_codec(); let mut buf = BytesMut::with_capacity(1024); // Put some data into the buffer without resizing it to hold more. let some_as = std::iter::repeat(b'a').take(1024).collect::>(); buf.put_slice(&some_as[..]); // Trying to encode the length header should resize the buffer if it won't fit. codec.encode(Bytes::from("hello"), &mut buf).unwrap(); } // ===== Test utils ===== struct Mock { calls: VecDeque>>, } enum Op { Data(Vec), Flush, } use self::Op::*; impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, dst: &mut ReadBuf<'_>, ) -> Poll> { match self.calls.pop_front() { Some(Ready(Ok(Op::Data(data)))) => { debug_assert!(dst.remaining() >= data.len()); dst.put_slice(&data); Ready(Ok(())) } Some(Ready(Ok(_))) => panic!(), Some(Ready(Err(e))) => Ready(Err(e)), Some(Pending) => Pending, None => Ready(Ok(())), } } } impl AsyncWrite for Mock { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, src: &[u8], ) -> Poll> { match self.calls.pop_front() { Some(Ready(Ok(Op::Data(data)))) => { let len = data.len(); assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); assert_eq!(&data[..], &src[..len]); Ready(Ok(len)) } Some(Ready(Ok(_))) => panic!(), Some(Ready(Err(e))) => Ready(Err(e)), Some(Pending) => Pending, None => Ready(Ok(0)), } } fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { match self.calls.pop_front() { Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), Some(Ready(Ok(_))) => panic!(), Some(Ready(Err(e))) => Ready(Err(e)), Some(Pending) => Pending, None => Ready(Ok(())), } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Ready(Ok(())) } } impl<'a> From<&'a [u8]> for Op { fn from(src: &'a [u8]) -> Op { Op::Data(src.into()) } } impl From> for Op { fn from(src: Vec) -> Op { Op::Data(src) } } fn data(bytes: &[u8]) -> Poll> { Ready(Ok(bytes.into())) } fn flush() -> Poll> { Ready(Ok(Flush)) } tokio-util-0.6.9/tests/mpsc.rs000064400000000000000000000052510072674642500144450ustar 00000000000000use futures::future::poll_fn; use tokio::sync::mpsc::channel; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; #[tokio::test] async fn test_simple() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); for i in 1..=3i32 { send.start_send(i).unwrap(); assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); } send.start_send(4).unwrap(); let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx))); assert_pending!(fourth_send.poll()); assert_eq!(recv.recv().await.unwrap(), 1); assert!(fourth_send.is_woken()); assert_ready_ok!(fourth_send.poll()); drop(recv); // Here, start_send is not guaranteed to fail, but if it doesn't the first // call to poll_send_done should. if send.start_send(5).is_ok() { assert_ready_err!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); } } #[tokio::test] async fn test_abort() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); let send2 = send.clone_inner().unwrap(); for i in 1..=3i32 { send.start_send(i).unwrap(); assert_ready_ok!(spawn(poll_fn(|cx| send.poll_send_done(cx))).poll()); } send.start_send(4).unwrap(); { let mut fourth_send = spawn(poll_fn(|cx| send.poll_send_done(cx))); assert_pending!(fourth_send.poll()); assert_eq!(recv.recv().await.unwrap(), 1); assert!(fourth_send.is_woken()); } let mut send2_send = spawn(send2.send(5)); assert_pending!(send2_send.poll()); send.abort_send(); assert!(send2_send.is_woken()); assert_ready_ok!(send2_send.poll()); assert_eq!(recv.recv().await.unwrap(), 2); assert_eq!(recv.recv().await.unwrap(), 3); assert_eq!(recv.recv().await.unwrap(), 5); } #[tokio::test] async fn close_sender_last() { let (send, mut recv) = channel::(3); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); send.close_this_sender(); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); } #[tokio::test] async fn close_sender_not_last() { let (send, mut recv) = channel::(3); let send2 = send.clone(); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); send.close_this_sender(); assert!(!recv_task.is_woken()); assert_pending!(recv_task.poll()); drop(send2); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); } tokio-util-0.6.9/tests/poll_semaphore.rs000064400000000000000000000017760072674642500165240ustar 00000000000000use std::future::Future; use std::sync::Arc; use std::task::Poll; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::sync::PollSemaphore; type SemRet = Option; fn semaphore_poll( sem: &mut PollSemaphore, ) -> tokio_test::task::Spawn + '_> { let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); tokio_test::task::spawn(fut) } #[tokio::test] async fn it_works() { let sem = Arc::new(Semaphore::new(1)); let mut poll_sem = PollSemaphore::new(sem.clone()); let permit = sem.acquire().await.unwrap(); let mut poll = semaphore_poll(&mut poll_sem); assert!(poll.poll().is_pending()); drop(permit); assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); drop(poll); sem.close(); assert!(semaphore_poll(&mut poll_sem).await.is_none()); // Check that it is fused. assert!(semaphore_poll(&mut poll_sem).await.is_none()); assert!(semaphore_poll(&mut poll_sem).await.is_none()); } tokio-util-0.6.9/tests/reusable_box.rs000064400000000000000000000036500072674642500161560ustar 00000000000000use futures::future::FutureExt; use std::alloc::Layout; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use tokio_util::sync::ReusableBoxFuture; #[test] fn test_different_futures() { let fut = async move { 10 }; // Not zero sized! assert_eq!(Layout::for_value(&fut).size(), 1); let mut b = ReusableBoxFuture::new(fut); assert_eq!(b.get_pin().now_or_never(), Some(10)); b.try_set(async move { 20 }) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(20)); b.try_set(async move { 30 }) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(30)); } #[test] fn test_different_sizes() { let fut1 = async move { 10 }; let val = [0u32; 1000]; let fut2 = async move { val[0] }; let fut3 = ZeroSizedFuture {}; assert_eq!(Layout::for_value(&fut1).size(), 1); assert_eq!(Layout::for_value(&fut2).size(), 4004); assert_eq!(Layout::for_value(&fut3).size(), 0); let mut b = ReusableBoxFuture::new(fut1); assert_eq!(b.get_pin().now_or_never(), Some(10)); b.set(fut2); assert_eq!(b.get_pin().now_or_never(), Some(0)); b.set(fut3); assert_eq!(b.get_pin().now_or_never(), Some(5)); } struct ZeroSizedFuture {} impl Future for ZeroSizedFuture { type Output = u32; fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { Poll::Ready(5) } } #[test] fn test_zero_sized() { let fut = ZeroSizedFuture {}; // Zero sized! assert_eq!(Layout::for_value(&fut).size(), 0); let mut b = ReusableBoxFuture::new(fut); assert_eq!(b.get_pin().now_or_never(), Some(5)); assert_eq!(b.get_pin().now_or_never(), Some(5)); b.try_set(ZeroSizedFuture {}) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(5)); assert_eq!(b.get_pin().now_or_never(), Some(5)); } tokio-util-0.6.9/tests/sync_cancellation_token.rs000064400000000000000000000127760072674642500204050ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::pin; use tokio_util::sync::CancellationToken; use core::future::Future; use core::task::{Context, Poll}; use futures_test::task::new_count_waker; #[test] fn cancel_token() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); assert!(!token.is_cancelled()); let wait_fut = token.cancelled(); pin!(wait_fut); assert_eq!( Poll::Pending, wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); let wait_fut_2 = token.cancelled(); pin!(wait_fut_2); token.cancel(); assert_eq!(wake_counter, 1); assert!(token.is_cancelled()); assert_eq!( Poll::Ready(()), wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_child_token_through_parent() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let child_token = token.child_token(); assert!(!child_token.is_cancelled()); let child_fut = child_token.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); token.cancel(); assert_eq!(wake_counter, 2); assert!(token.is_cancelled()); assert!(child_token.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_child_token_without_parent() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let child_token_1 = token.child_token(); let child_fut = child_token_1.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); child_token_1.cancel(); assert_eq!(wake_counter, 1); assert!(!token.is_cancelled()); assert!(child_token_1.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); let child_token_2 = token.child_token(); let child_fut_2 = child_token_2.cancelled(); pin!(child_fut_2); assert_eq!( Poll::Pending, child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); token.cancel(); assert_eq!(wake_counter, 3); assert!(token.is_cancelled()); assert!(child_token_2.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn create_child_token_after_parent_was_cancelled() { for drop_child_first in [true, false].iter().cloned() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); token.cancel(); let child_token = token.child_token(); assert!(child_token.is_cancelled()); { let child_fut = child_token.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); drop(child_fut); drop(parent_fut); } if drop_child_first { drop(child_token); drop(token); } else { drop(token); drop(child_token); } } } #[test] fn drop_multiple_child_tokens() { for drop_first_child_first in &[true, false] { let token = CancellationToken::new(); let mut child_tokens = [None, None, None]; for child in &mut child_tokens { *child = Some(token.child_token()); } assert!(!token.is_cancelled()); assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); for i in 0..child_tokens.len() { if *drop_first_child_first { child_tokens[i] = None; } else { child_tokens[child_tokens.len() - 1 - i] = None; } assert!(!token.is_cancelled()); } drop(token); } } #[test] fn drop_parent_before_child_tokens() { let token = CancellationToken::new(); let child1 = token.child_token(); let child2 = token.child_token(); drop(token); assert!(!child1.is_cancelled()); drop(child1); drop(child2); } tokio-util-0.6.9/tests/time_delay_queue.rs000064400000000000000000000355240072674642500170310ustar 00000000000000#![allow(clippy::blacklisted_name)] #![warn(rust_2018_idioms)] #![cfg(feature = "full")] use tokio::time::{self, sleep, sleep_until, Duration, Instant}; use tokio_test::{assert_ok, assert_pending, assert_ready, task}; use tokio_util::time::DelayQueue; macro_rules! poll { ($queue:ident) => { $queue.enter(|cx, mut queue| queue.poll_expired(cx)) }; } macro_rules! assert_ready_ok { ($e:expr) => {{ assert_ok!(match assert_ready!($e) { Some(v) => v, None => panic!("None"), }) }}; } #[tokio::test] async fn single_immediate_delay() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _key = queue.insert_at("foo", Instant::now()); // Advance time by 1ms to handle thee rounding sleep(ms(1)).await; assert_ready_ok!(poll!(queue)); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } #[tokio::test] async fn multi_immediate_delays() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _k = queue.insert_at("1", Instant::now()); let _k = queue.insert_at("2", Instant::now()); let _k = queue.insert_at("3", Instant::now()); sleep(ms(1)).await; let mut res = vec![]; while res.len() < 3 { let entry = assert_ready_ok!(poll!(queue)); res.push(entry.into_inner()); } let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); res.sort_unstable(); assert_eq!("1", res[0]); assert_eq!("2", res[1]); assert_eq!("3", res[2]); } #[tokio::test] async fn single_short_delay() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _key = queue.insert_at("foo", Instant::now() + ms(5)); assert_pending!(poll!(queue)); sleep(ms(1)).await; assert!(!queue.is_woken()); sleep(ms(5)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)); assert_eq!(*entry.get_ref(), "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn multi_delay_at_start() { time::pause(); let long = 262_144 + 9 * 4096; let delays = &[1000, 2, 234, long, 60, 10]; let mut queue = task::spawn(DelayQueue::new()); // Setup the delays for &i in delays { let _key = queue.insert_at(i, Instant::now() + ms(i)); } assert_pending!(poll!(queue)); assert!(!queue.is_woken()); let start = Instant::now(); for elapsed in 0..1200 { let elapsed = elapsed + 1; tokio::time::sleep_until(start + ms(elapsed)).await; if delays.contains(&elapsed) { assert!(queue.is_woken()); assert_ready!(poll!(queue)); assert_pending!(poll!(queue)); } else if queue.is_woken() { let cascade = &[192, 960]; assert!( cascade.contains(&elapsed), "elapsed={} dt={:?}", elapsed, Instant::now() - start ); assert_pending!(poll!(queue)); } } } #[tokio::test] async fn insert_in_past_fires_immediately() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(10)).await; queue.insert_at("foo", now); assert_ready!(poll!(queue)); } #[tokio::test] async fn remove_entry() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let key = queue.insert_at("foo", Instant::now() + ms(5)); assert_pending!(poll!(queue)); let entry = queue.remove(&key); assert_eq!(entry.into_inner(), "foo"); sleep(ms(10)).await; let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn reset_entry() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now + ms(5)); assert_pending!(poll!(queue)); sleep(ms(1)).await; queue.reset_at(&key, now + ms(10)); assert_pending!(poll!(queue)); sleep(ms(7)).await; assert!(!queue.is_woken()); assert_pending!(poll!(queue)); sleep(ms(3)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)); assert_eq!(*entry.get_ref(), "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } // Reproduces tokio-rs/tokio#849. #[tokio::test] async fn reset_much_later() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(1)).await; let key = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(3)).await; queue.reset_at(&key, now + ms(10)); sleep(ms(20)).await; assert!(queue.is_woken()); } // Reproduces tokio-rs/tokio#849. #[tokio::test] async fn reset_twice() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(1)).await; let key = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(3)).await; queue.reset_at(&key, now + ms(50)); sleep(ms(20)).await; queue.reset_at(&key, now + ms(40)); sleep(ms(20)).await; assert!(queue.is_woken()); } /// Regression test: Given an entry inserted with a deadline in the past, so /// that it is placed directly on the expired queue, reset the entry to a /// deadline in the future. Validate that this leaves the entry and queue in an /// internally consistent state by running an additional reset on the entry /// before polling it to completion. #[tokio::test] async fn repeatedly_reset_entry_inserted_as_expired() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now - ms(100)); queue.reset_at(&key, now + ms(100)); queue.reset_at(&key, now + ms(50)); assert_pending!(poll!(queue)); time::sleep_until(now + ms(60)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn remove_expired_item() { time::pause(); let mut queue = DelayQueue::new(); let now = Instant::now(); sleep(ms(10)).await; let key = queue.insert_at("foo", now); let entry = queue.remove(&key); assert_eq!(entry.into_inner(), "foo"); } /// Regression test: it should be possible to remove entries which fall in the /// 0th slot of the internal timer wheel — that is, entries whose expiration /// (a) falls at the beginning of one of the wheel's hierarchical levels and (b) /// is equal to the wheel's current elapsed time. #[tokio::test] async fn remove_at_timer_wheel_threshold() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key1 = queue.insert_at("foo", now + ms(64)); let key2 = queue.insert_at("bar", now + ms(64)); sleep(ms(80)).await; let entry = assert_ready_ok!(poll!(queue)).into_inner(); match entry { "foo" => { let entry = queue.remove(&key2).into_inner(); assert_eq!(entry, "bar"); } "bar" => { let entry = queue.remove(&key1).into_inner(); assert_eq!(entry, "foo"); } other => panic!("other: {:?}", other), } } #[tokio::test] async fn expires_before_last_insert() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo", now + ms(10_000)); // Delay should be set to 8.192s here. assert_pending!(poll!(queue)); // Delay should be set to the delay of the new item here queue.insert_at("bar", now + ms(600)); assert_pending!(poll!(queue)); sleep(ms(600)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "bar"); } #[tokio::test] async fn multi_reset() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); let two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(300)); queue.reset_at(&two, now + ms(350)); queue.reset_at(&one, now + ms(400)); sleep(ms(310)).await; assert_pending!(poll!(queue)); sleep(ms(50)).await; let entry = assert_ready_ok!(poll!(queue)); assert_eq!(*entry.get_ref(), "two"); assert_pending!(poll!(queue)); sleep(ms(50)).await; let entry = assert_ready_ok!(poll!(queue)); assert_eq!(*entry.get_ref(), "one"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } #[tokio::test] async fn expire_first_key_when_reset_to_expire_earlier() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(100)); sleep(ms(100)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "one"); } #[tokio::test] async fn expire_second_key_when_reset_to_expire_earlier() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("one", now + ms(200)); let two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&two, now + ms(100)); sleep(ms(100)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn reset_first_expiring_item_to_expire_later() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); let _two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(300)); sleep(ms(250)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn insert_before_first_after_poll() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let _one = queue.insert_at("one", now + ms(200)); assert_pending!(poll!(queue)); let _two = queue.insert_at("two", now + ms(100)); sleep(ms(99)).await; assert_pending!(poll!(queue)); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn insert_after_ready_poll() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("1", now + ms(100)); queue.insert_at("2", now + ms(100)); queue.insert_at("3", now + ms(100)); assert_pending!(poll!(queue)); sleep(ms(100)).await; assert!(queue.is_woken()); let mut res = vec![]; while res.len() < 3 { let entry = assert_ready_ok!(poll!(queue)); res.push(entry.into_inner()); queue.insert_at("foo", now + ms(500)); } res.sort_unstable(); assert_eq!("1", res[0]); assert_eq!("2", res[1]); assert_eq!("3", res[2]); } #[tokio::test] async fn reset_later_after_slot_starts() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let foo = queue.insert_at("foo", now + ms(100)); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(80)).await; assert!(!queue.is_woken()); // At this point the queue hasn't been polled, so `elapsed` on the wheel // for the queue is still at 0 and hence the 1ms resolution slots cover // [0-64). Resetting the time on the entry to 120 causes it to get put in // the [64-128) slot. As the queue knows that the first entry is within // that slot, but doesn't know when, it must wake immediately to advance // the wheel. queue.reset_at(&foo, now + ms(120)); assert!(queue.is_woken()); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(119)).await; assert!(!queue.is_woken()); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); } #[tokio::test] async fn reset_inserted_expired() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now - ms(100)); // this causes the panic described in #2473 queue.reset_at(&key, now + ms(100)); assert_eq!(1, queue.len()); sleep(ms(200)).await; let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); assert_eq!(queue.len(), 0); } #[tokio::test] async fn reset_earlier_after_slot_starts() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let foo = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(80)).await; assert!(!queue.is_woken()); // At this point the queue hasn't been polled, so `elapsed` on the wheel // for the queue is still at 0 and hence the 1ms resolution slots cover // [0-64). Resetting the time on the entry to 120 causes it to get put in // the [64-128) slot. As the queue knows that the first entry is within // that slot, but doesn't know when, it must wake immediately to advance // the wheel. queue.reset_at(&foo, now + ms(120)); assert!(queue.is_woken()); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(119)).await; assert!(!queue.is_woken()); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); } #[tokio::test] async fn insert_in_past_after_poll_fires_immediately() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(80)).await; assert!(!queue.is_woken()); queue.insert_at("bar", now + ms(40)); assert!(queue.is_woken()); let entry = assert_ready_ok!(poll!(queue)).into_inner(); assert_eq!(entry, "bar"); } #[tokio::test] async fn delay_queue_poll_expired_when_empty() { let mut delay_queue = task::spawn(DelayQueue::new()); let key = delay_queue.insert(0, std::time::Duration::from_secs(10)); assert_pending!(poll!(delay_queue)); delay_queue.remove(&key); assert!(assert_ready!(poll!(delay_queue)).is_none()); } fn ms(n: u64) -> Duration { Duration::from_millis(n) } tokio-util-0.6.9/tests/udp.rs000064400000000000000000000075130072674642500142760ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::net::UdpSocket; use tokio_stream::StreamExt; use tokio_util::codec::{Decoder, Encoder, LinesCodec}; use tokio_util::udp::UdpFramed; use bytes::{BufMut, BytesMut}; use futures::future::try_join; use futures::future::FutureExt; use futures::sink::SinkExt; use std::io; use std::sync::Arc; #[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] #[tokio::test] async fn send_framed_byte_codec() -> std::io::Result<()> { let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; // test sending & receiving bytes { let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, ByteCodec); let msg = b"4567"; let send = a.send((msg, b_addr)); let recv = b.next().map(|e| e.unwrap()); let (_, received) = try_join(send, recv).await.unwrap(); let (data, addr) = received; assert_eq!(msg, &*data); assert_eq!(a_addr, addr); a_soc = a.into_inner(); b_soc = b.into_inner(); } #[cfg(not(any(target_os = "macos", target_os = "ios")))] // test sending & receiving an empty message { let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, ByteCodec); let msg = b""; let send = a.send((msg, b_addr)); let recv = b.next().map(|e| e.unwrap()); let (_, received) = try_join(send, recv).await.unwrap(); let (data, addr) = received; assert_eq!(msg, &*data); assert_eq!(a_addr, addr); } Ok(()) } pub struct ByteCodec; impl Decoder for ByteCodec { type Item = Vec; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result>, io::Error> { let len = buf.len(); Ok(Some(buf.split_to(len).to_vec())) } } impl Encoder<&[u8]> for ByteCodec { type Error = io::Error; fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put_slice(data); Ok(()) } } #[tokio::test] async fn send_framed_lines_codec() -> std::io::Result<()> { let a_soc = UdpSocket::bind("127.0.0.1:0").await?; let b_soc = UdpSocket::bind("127.0.0.1:0").await?; let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, LinesCodec::new()); let msg = b"1\r\n2\r\n3\r\n".to_vec(); a.send((&msg, b_addr)).await?; assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); Ok(()) } #[tokio::test] async fn framed_half() -> std::io::Result<()> { let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); let b_soc = a_soc.clone(); let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, LinesCodec::new()); let msg = b"1\r\n2\r\n3\r\n".to_vec(); a.send((&msg, b_addr)).await?; let msg = b"4\r\n5\r\n6\r\n".to_vec(); a.send((&msg, b_addr)).await?; assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); Ok(()) }