tokio-util-0.7.10/.cargo_vcs_info.json0000644000000001500000000000100132150ustar { "git": { "sha1": "503fad79087ed5791c7a018e07621689ea5e4676" }, "path_in_vcs": "tokio-util" }tokio-util-0.7.10/CHANGELOG.md000064400000000000000000000301631046102023000136250ustar 00000000000000# 0.7.10 (October 24th, 2023) ### Added - task: add `TaskTracker` ([#6033]) - task: add `JoinMap::keys` ([#6046]) - io: implement `Seek` for `SyncIoBridge` ([#6058]) ### Changed - deps: update hashbrown to 0.14 ([#6102]) [#6033]: https://github.com/tokio-rs/tokio/pull/6033 [#6046]: https://github.com/tokio-rs/tokio/pull/6046 [#6058]: https://github.com/tokio-rs/tokio/pull/6058 [#6102]: https://github.com/tokio-rs/tokio/pull/6102 # 0.7.9 (September 20th, 2023) ### Added - io: add passthrough `AsyncRead`/`AsyncWrite` to `InspectWriter`/`InspectReader` ([#5739]) - task: add spawn blocking methods to `JoinMap` ([#5797]) - io: pass through traits for `StreamReader` and `SinkWriter` ([#5941]) - io: add `SyncIoBridge::into_inner` ([#5971]) ### Fixed - sync: handle possibly dangling reference safely ([#5812]) - util: fix broken intra-doc link ([#5849]) - compat: fix clippy warnings ([#5891]) ### Documented - codec: Specify the line ending of `LinesCodec` ([#5982]) [#5739]: https://github.com/tokio-rs/tokio/pull/5739 [#5797]: https://github.com/tokio-rs/tokio/pull/5797 [#5941]: https://github.com/tokio-rs/tokio/pull/5941 [#5971]: https://github.com/tokio-rs/tokio/pull/5971 [#5812]: https://github.com/tokio-rs/tokio/pull/5812 [#5849]: https://github.com/tokio-rs/tokio/pull/5849 [#5891]: https://github.com/tokio-rs/tokio/pull/5891 [#5982]: https://github.com/tokio-rs/tokio/pull/5982 # 0.7.8 (April 25th, 2023) This release bumps the MSRV of tokio-util to 1.56. ### Added - time: add `DelayQueue::peek` ([#5569]) ### Changed This release contains one performance improvement: - sync: try to lock the parent first in `CancellationToken` ([#5561]) ### Fixed - time: fix panic in `DelayQueue` ([#5630]) ### Documented - sync: improve `CancellationToken` doc on child tokens ([#5632]) [#5561]: https://github.com/tokio-rs/tokio/pull/5561 [#5569]: https://github.com/tokio-rs/tokio/pull/5569 [#5630]: https://github.com/tokio-rs/tokio/pull/5630 [#5632]: https://github.com/tokio-rs/tokio/pull/5632 # 0.7.7 (February 12, 2023) This release reverts the removal of the `Encoder` bound on the `FramedParts` constructor from [#5280] since it turned out to be a breaking change. ([#5450]) [#5450]: https://github.com/tokio-rs/tokio/pull/5450 # 0.7.6 (February 10, 2023) This release fixes a compilation failure in 0.7.5 when it is used together with Tokio version 1.21 and unstable features are enabled. ([#5445]) [#5445]: https://github.com/tokio-rs/tokio/pull/5445 # 0.7.5 (February 9, 2023) This release fixes an accidental breaking change where `UnwindSafe` was accidentally removed from `CancellationToken`. ### Added - codec: add `Framed::backpressure_boundary` ([#5124]) - io: add `InspectReader` and `InspectWriter` ([#5033]) - io: add `tokio_util::io::{CopyToBytes, SinkWriter}` ([#5070], [#5436]) - io: impl `std::io::BufRead` on `SyncIoBridge` ([#5265]) - sync: add `PollSemaphore::poll_acquire_many` ([#5137]) - sync: add owned future for `CancellationToken` ([#5153]) - time: add `DelayQueue::try_remove` ([#5052]) ### Fixed - codec: fix `LengthDelimitedCodec` buffer over-reservation ([#4997]) - sync: impl `UnwindSafe` on `CancellationToken` ([#5438]) - util: remove `Encoder` bound on `FramedParts` constructor ([#5280]) ### Documented - io: add lines example for `StreamReader` ([#5145]) [#4997]: https://github.com/tokio-rs/tokio/pull/4997 [#5033]: https://github.com/tokio-rs/tokio/pull/5033 [#5052]: https://github.com/tokio-rs/tokio/pull/5052 [#5070]: https://github.com/tokio-rs/tokio/pull/5070 [#5124]: https://github.com/tokio-rs/tokio/pull/5124 [#5137]: https://github.com/tokio-rs/tokio/pull/5137 [#5145]: https://github.com/tokio-rs/tokio/pull/5145 [#5153]: https://github.com/tokio-rs/tokio/pull/5153 [#5265]: https://github.com/tokio-rs/tokio/pull/5265 [#5280]: https://github.com/tokio-rs/tokio/pull/5280 [#5436]: https://github.com/tokio-rs/tokio/pull/5436 [#5438]: https://github.com/tokio-rs/tokio/pull/5438 # 0.7.4 (September 8, 2022) ### Added - io: add `SyncIoBridge::shutdown()` ([#4938]) - task: improve `LocalPoolHandle` ([#4680]) ### Fixed - util: add `track_caller` to public APIs ([#4785]) ### Unstable - task: fix compilation errors in `JoinMap` with Tokio v1.21.0 ([#4755]) - task: remove the unstable, deprecated `JoinMap::join_one` ([#4920]) [#4680]: https://github.com/tokio-rs/tokio/pull/4680 [#4755]: https://github.com/tokio-rs/tokio/pull/4755 [#4785]: https://github.com/tokio-rs/tokio/pull/4785 [#4920]: https://github.com/tokio-rs/tokio/pull/4920 [#4938]: https://github.com/tokio-rs/tokio/pull/4938 # 0.7.3 (June 4, 2022) ### Changed - tracing: don't require default tracing features ([#4592]) - util: simplify implementation of `ReusableBoxFuture` ([#4675]) ### Added (unstable) - task: add `JoinMap` ([#4640], [#4697]) [#4592]: https://github.com/tokio-rs/tokio/pull/4592 [#4640]: https://github.com/tokio-rs/tokio/pull/4640 [#4675]: https://github.com/tokio-rs/tokio/pull/4675 [#4697]: https://github.com/tokio-rs/tokio/pull/4697 # 0.7.2 (May 14, 2022) This release contains a rewrite of `CancellationToken` that fixes a memory leak. ([#4652]) [#4652]: https://github.com/tokio-rs/tokio/pull/4652 # 0.7.1 (February 21, 2022) ### Added - codec: add `length_field_type` to `LengthDelimitedCodec` builder ([#4508]) - io: add `StreamReader::into_inner_with_chunk()` ([#4559]) ### Changed - switch from log to tracing ([#4539]) ### Fixed - sync: fix waker update condition in `CancellationToken` ([#4497]) - bumped tokio dependency to 1.6 to satisfy minimum requirements ([#4490]) [#4490]: https://github.com/tokio-rs/tokio/pull/4490 [#4497]: https://github.com/tokio-rs/tokio/pull/4497 [#4508]: https://github.com/tokio-rs/tokio/pull/4508 [#4539]: https://github.com/tokio-rs/tokio/pull/4539 [#4559]: https://github.com/tokio-rs/tokio/pull/4559 # 0.7.0 (February 9, 2022) ### Added - task: add `spawn_pinned` ([#3370]) - time: add `shrink_to_fit` and `compact` methods to `DelayQueue` ([#4170]) - codec: improve `Builder::max_frame_length` docs ([#4352]) - codec: add mutable reference getters for codecs to pinned `Framed` ([#4372]) - net: add generic trait to combine `UnixListener` and `TcpListener` ([#4385]) - codec: implement `Framed::map_codec` ([#4427]) - codec: implement `Encoder` for `BytesCodec` ([#4465]) ### Changed - sync: add lifetime parameter to `ReusableBoxFuture` ([#3762]) - sync: refactored `PollSender` to fix a subtly broken `Sink` implementation ([#4214]) - time: remove error case from the infallible `DelayQueue::poll_elapsed` ([#4241]) [#3370]: https://github.com/tokio-rs/tokio/pull/3370 [#4170]: https://github.com/tokio-rs/tokio/pull/4170 [#4352]: https://github.com/tokio-rs/tokio/pull/4352 [#4372]: https://github.com/tokio-rs/tokio/pull/4372 [#4385]: https://github.com/tokio-rs/tokio/pull/4385 [#4427]: https://github.com/tokio-rs/tokio/pull/4427 [#4465]: https://github.com/tokio-rs/tokio/pull/4465 [#3762]: https://github.com/tokio-rs/tokio/pull/3762 [#4214]: https://github.com/tokio-rs/tokio/pull/4214 [#4241]: https://github.com/tokio-rs/tokio/pull/4241 # 0.6.10 (May 14, 2021) This is a backport for the memory leak in `CancellationToken` that was originally fixed in 0.7.2. ([#4652]) [#4652]: https://github.com/tokio-rs/tokio/pull/4652 # 0.6.9 (October 29, 2021) ### Added - codec: implement `Clone` for `LengthDelimitedCodec` ([#4089]) - io: add `SyncIoBridge` ([#4146]) ### Fixed - time: update deadline on removal in `DelayQueue` ([#4178]) - codec: Update stream impl for Framed to return None after Err ([#4166]) [#4089]: https://github.com/tokio-rs/tokio/pull/4089 [#4146]: https://github.com/tokio-rs/tokio/pull/4146 [#4166]: https://github.com/tokio-rs/tokio/pull/4166 [#4178]: https://github.com/tokio-rs/tokio/pull/4178 # 0.6.8 (September 3, 2021) ### Added - sync: add drop guard for `CancellationToken` ([#3839]) - compact: added `AsyncSeek` compat ([#4078]) - time: expose `Key` used in `DelayQueue`'s `Expired` ([#4081]) - io: add `with_capacity` to `ReaderStream` ([#4086]) ### Fixed - codec: remove unnecessary `doc(cfg(...))` ([#3989]) [#3839]: https://github.com/tokio-rs/tokio/pull/3839 [#4078]: https://github.com/tokio-rs/tokio/pull/4078 [#4081]: https://github.com/tokio-rs/tokio/pull/4081 [#4086]: https://github.com/tokio-rs/tokio/pull/4086 [#3989]: https://github.com/tokio-rs/tokio/pull/3989 # 0.6.7 (May 14, 2021) ### Added - udp: make `UdpFramed` take `Borrow` ([#3451]) - compat: implement `AsRawFd`/`AsRawHandle` for `Compat` ([#3765]) [#3451]: https://github.com/tokio-rs/tokio/pull/3451 [#3765]: https://github.com/tokio-rs/tokio/pull/3765 # 0.6.6 (April 12, 2021) ### Added - util: makes `Framed` and `FramedStream` resumable after eof ([#3272]) - util: add `PollSemaphore::{add_permits, available_permits}` ([#3683]) ### Fixed - chore: avoid allocation if `PollSemaphore` is unused ([#3634]) [#3272]: https://github.com/tokio-rs/tokio/pull/3272 [#3634]: https://github.com/tokio-rs/tokio/pull/3634 [#3683]: https://github.com/tokio-rs/tokio/pull/3683 # 0.6.5 (March 20, 2021) ### Fixed - util: annotate time module as requiring `time` feature ([#3606]) [#3606]: https://github.com/tokio-rs/tokio/pull/3606 # 0.6.4 (March 9, 2021) ### Added - codec: `AnyDelimiter` codec ([#3406]) - sync: add pollable `mpsc::Sender` ([#3490]) ### Fixed - codec: `LinesCodec` should only return `MaxLineLengthExceeded` once per line ([#3556]) - sync: fuse PollSemaphore ([#3578]) [#3406]: https://github.com/tokio-rs/tokio/pull/3406 [#3490]: https://github.com/tokio-rs/tokio/pull/3490 [#3556]: https://github.com/tokio-rs/tokio/pull/3556 [#3578]: https://github.com/tokio-rs/tokio/pull/3578 # 0.6.3 (January 31, 2021) ### Added - sync: add `ReusableBoxFuture` utility ([#3464]) ### Changed - sync: use `ReusableBoxFuture` for `PollSemaphore` ([#3463]) - deps: remove `async-stream` dependency ([#3463]) - deps: remove `tokio-stream` dependency ([#3487]) # 0.6.2 (January 21, 2021) ### Added - sync: add pollable `Semaphore` ([#3444]) ### Fixed - time: fix panics on updating `DelayQueue` entries ([#3270]) # 0.6.1 (January 12, 2021) ### Added - codec: `get_ref()`, `get_mut()`, `get_pin_mut()` and `into_inner()` for `Framed`, `FramedRead`, `FramedWrite` and `StreamReader` ([#3364]). - codec: `write_buffer()` and `write_buffer_mut()` for `Framed` and `FramedWrite` ([#3387]). # 0.6.0 (December 23, 2020) ### Changed - depend on `tokio` 1.0. ### Added - rt: add constructors to `TokioContext` (#3221). # 0.5.1 (December 3, 2020) ### Added - io: `poll_read_buf` util fn (#2972). - io: `poll_write_buf` util fn with vectored write support (#3156). # 0.5.0 (October 30, 2020) ### Changed - io: update `bytes` to 0.6 (#3071). # 0.4.0 (October 15, 2020) ### Added - sync: `CancellationToken` for coordinating task cancellation (#2747). - rt: `TokioContext` sets the Tokio runtime for the duration of a future (#2791) - io: `StreamReader`/`ReaderStream` map between `AsyncRead` values and `Stream` of bytes (#2788). - time: `DelayQueue` to manage many delays (#2897). # 0.3.1 (March 18, 2020) ### Fixed - Adjust minimum-supported Tokio version to v0.2.5 to account for an internal dependency on features in that version of Tokio. ([#2326]) # 0.3.0 (March 4, 2020) ### Changed - **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746]) ### Added - Add futures-io/tokio::io compatibility layer. ([#2117]) - Add `Framed::with_capacity`. ([#2215]) ### Fixed - Use advance over split_to when data is not needed. ([#2198]) # 0.2.0 (November 26, 2019) - Initial release [#3487]: https://github.com/tokio-rs/tokio/pull/3487 [#3464]: https://github.com/tokio-rs/tokio/pull/3464 [#3463]: https://github.com/tokio-rs/tokio/pull/3463 [#3444]: https://github.com/tokio-rs/tokio/pull/3444 [#3387]: https://github.com/tokio-rs/tokio/pull/3387 [#3364]: https://github.com/tokio-rs/tokio/pull/3364 [#3270]: https://github.com/tokio-rs/tokio/pull/3270 [#2326]: https://github.com/tokio-rs/tokio/pull/2326 [#2215]: https://github.com/tokio-rs/tokio/pull/2215 [#2198]: https://github.com/tokio-rs/tokio/pull/2198 [#2117]: https://github.com/tokio-rs/tokio/pull/2117 [#1746]: https://github.com/tokio-rs/tokio/pull/1746 tokio-util-0.7.10/Cargo.toml0000644000000046160000000000100112260ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.56" name = "tokio-util" version = "0.7.10" authors = ["Tokio Contributors "] description = """ Additional utilities for working with Tokio. """ homepage = "https://tokio.rs" readme = "README.md" categories = ["asynchronous"] license = "MIT" repository = "https://github.com/tokio-rs/tokio" [package.metadata.docs.rs] all-features = true rustc-args = [ "--cfg", "docsrs", "--cfg", "tokio_unstable", ] rustdoc-args = [ "--cfg", "docsrs", "--cfg", "tokio_unstable", ] [dependencies.bytes] version = "1.0.0" [dependencies.futures-core] version = "0.3.0" [dependencies.futures-io] version = "0.3.0" optional = true [dependencies.futures-sink] version = "0.3.0" [dependencies.futures-util] version = "0.3.0" optional = true [dependencies.pin-project-lite] version = "0.2.11" [dependencies.slab] version = "0.4.4" optional = true [dependencies.tokio] version = "1.28.0" features = ["sync"] [dependencies.tracing] version = "0.1.25" features = ["std"] optional = true default-features = false [dev-dependencies.async-stream] version = "0.3.0" [dev-dependencies.futures] version = "0.3.0" [dev-dependencies.futures-test] version = "0.3.5" [dev-dependencies.parking_lot] version = "0.12.0" [dev-dependencies.tempfile] version = "3.1.0" [dev-dependencies.tokio] version = "1.0.0" features = ["full"] [dev-dependencies.tokio-stream] version = "0.1" [dev-dependencies.tokio-test] version = "0.4.0" [features] __docs_rs = ["futures-util"] codec = ["tracing"] compat = ["futures-io"] default = [] full = [ "codec", "compat", "io-util", "time", "net", "rt", ] io = [] io-util = [ "io", "tokio/rt", "tokio/io-util", ] net = ["tokio/net"] rt = [ "tokio/rt", "tokio/sync", "futures-util", "hashbrown", ] time = [ "tokio/time", "slab", ] [target."cfg(tokio_unstable)".dependencies.hashbrown] version = "0.14.0" optional = true tokio-util-0.7.10/Cargo.toml.orig000064400000000000000000000040131046102023000146760ustar 00000000000000[package] name = "tokio-util" # When releasing to crates.io: # - Remove path dependencies # - Update CHANGELOG.md. # - Create "tokio-util-0.7.x" git tag. version = "0.7.10" edition = "2021" rust-version = "1.56" authors = ["Tokio Contributors "] license = "MIT" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ Additional utilities for working with Tokio. """ categories = ["asynchronous"] [features] # No features on by default default = [] # Shorthand for enabling everything full = ["codec", "compat", "io-util", "time", "net", "rt"] net = ["tokio/net"] compat = ["futures-io",] codec = ["tracing"] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"] __docs_rs = ["futures-util"] [dependencies] tokio = { version = "1.28.0", path = "../tokio", features = ["sync"] } bytes = "1.0.0" futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } futures-util = { version = "0.3.0", optional = true } pin-project-lite = "0.2.11" slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } [target.'cfg(tokio_unstable)'.dependencies] hashbrown = { version = "0.14.0", optional = true } [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } async-stream = "0.3.0" futures = "0.3.0" futures-test = "0.3.5" parking_lot = "0.12.0" tempfile = "3.1.0" [package.metadata.docs.rs] all-features = true # enable unstable features in the documentation rustdoc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"] # it's necessary to _also_ pass `--cfg tokio_unstable` to rustc, or else # dependencies will not be enabled, and the docs build will fail. rustc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"] tokio-util-0.7.10/LICENSE000064400000000000000000000020461046102023000130200ustar 00000000000000Copyright (c) 2023 Tokio Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. tokio-util-0.7.10/README.md000064400000000000000000000005001046102023000132630ustar 00000000000000# tokio-util Utilities for working with Tokio. ## License This project is licensed under the [MIT license](LICENSE). ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in Tokio by you, shall be licensed as MIT, without any additional terms or conditions. tokio-util-0.7.10/src/cfg.rs000064400000000000000000000026761046102023000137200ustar 00000000000000macro_rules! cfg_codec { ($($item:item)*) => { $( #[cfg(feature = "codec")] #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] $item )* } } macro_rules! cfg_compat { ($($item:item)*) => { $( #[cfg(feature = "compat")] #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] $item )* } } macro_rules! cfg_net { ($($item:item)*) => { $( #[cfg(all(feature = "net", feature = "codec"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] $item )* } } macro_rules! cfg_io { ($($item:item)*) => { $( #[cfg(feature = "io")] #[cfg_attr(docsrs, doc(cfg(feature = "io")))] $item )* } } cfg_io! { macro_rules! cfg_io_util { ($($item:item)*) => { $( #[cfg(feature = "io-util")] #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] $item )* } } } macro_rules! cfg_rt { ($($item:item)*) => { $( #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] $item )* } } macro_rules! cfg_time { ($($item:item)*) => { $( #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] $item )* } } tokio-util-0.7.10/src/codec/any_delimiter_codec.rs000064400000000000000000000226521046102023000202140ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::{cmp, fmt, io, str, usize}; const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r"; const DEFAULT_SEQUENCE_WRITER: &[u8] = b","; /// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// /// # Example /// Decode string of bytes containing various different delimiters. /// /// [`BytesMut`]: bytes::BytesMut /// [`Error`]: std::io::Error /// /// ``` /// use tokio_util::codec::{AnyDelimiterCodec, Decoder}; /// use bytes::{BufMut, BytesMut}; /// /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> Result<(), std::io::Error> { /// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec()); /// let buf = &mut BytesMut::new(); /// buf.reserve(200); /// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); /// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); /// assert_eq!("", codec.decode(buf).unwrap().unwrap()); /// assert_eq!(None, codec.decode(buf).unwrap()); /// # Ok(()) /// # } /// ``` /// #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct AnyDelimiterCodec { // Stored index of the next index to examine for the delimiter character. // This is used to optimize searching. // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`, // because that is the next index to examine. // The next time `decode` is called with `abcde}`, the method will // only look at `de}` before returning. next_index: usize, /// The maximum length for a given chunk. If `usize::MAX`, chunks will be /// read until a delimiter character is reached. max_length: usize, /// Are we currently discarding the remainder of a chunk which was over /// the length limit? is_discarding: bool, /// The bytes that are using for search during decode seek_delimiters: Vec, /// The bytes that are using for encoding sequence_writer: Vec, } impl AnyDelimiterCodec { /// Returns a `AnyDelimiterCodec` for splitting up data into chunks. /// /// # Note /// /// The returned `AnyDelimiterCodec` will not have an upper bound on the length /// of a buffered chunk. See the documentation for [`new_with_max_length`] /// for information on why this could be a potential security risk. /// /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length() pub fn new(seek_delimiters: Vec, sequence_writer: Vec) -> AnyDelimiterCodec { AnyDelimiterCodec { next_index: 0, max_length: usize::MAX, is_discarding: false, seek_delimiters, sequence_writer, } } /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit. /// /// If this is set, calls to `AnyDelimiterCodec::decode` will return a /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls /// will discard up to `limit` bytes from that chunk until a delimiter /// character is reached, returning `None` until the delimiter over the limit /// has been fully discarded. After that point, calls to `decode` will /// function as normal. /// /// # Note /// /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which /// will be exposed to untrusted input. Otherwise, the size of the buffer /// that holds the chunk currently being read is unbounded. An attacker could /// exploit this unbounded buffer by sending an unbounded amount of input /// without any delimiter characters, causing unbounded memory consumption. /// /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError pub fn new_with_max_length( seek_delimiters: Vec, sequence_writer: Vec, max_length: usize, ) -> Self { AnyDelimiterCodec { max_length, ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer) } } /// Returns the maximum chunk length when decoding. /// /// ``` /// use std::usize; /// use tokio_util::codec::AnyDelimiterCodec; /// /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec()); /// assert_eq!(codec.max_length(), usize::MAX); /// ``` /// ``` /// use tokio_util::codec::AnyDelimiterCodec; /// /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256); /// assert_eq!(codec.max_length(), 256); /// ``` pub fn max_length(&self) -> usize { self.max_length } } impl Decoder for AnyDelimiterCodec { type Item = Bytes; type Error = AnyDelimiterCodecError; fn decode(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { loop { // Determine how far into the buffer we'll search for a delimiter. If // there's no max_length set, we'll read to the end of the buffer. let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| { self.seek_delimiters .iter() .any(|delimiter| *b == *delimiter) }); match (self.is_discarding, new_chunk_offset) { (true, Some(offset)) => { // If we found a new chunk, discard up to that offset and // then stop discarding. On the next iteration, we'll try // to read a chunk normally. buf.advance(offset + self.next_index + 1); self.is_discarding = false; self.next_index = 0; } (true, None) => { // Otherwise, we didn't find a new chunk, so we'll discard // everything we read. On the next iteration, we'll continue // discarding up to max_len bytes unless we find a new chunk. buf.advance(read_to); self.next_index = 0; if buf.is_empty() { return Ok(None); } } (false, Some(offset)) => { // Found a chunk! let new_chunk_index = offset + self.next_index; self.next_index = 0; let mut chunk = buf.split_to(new_chunk_index + 1); chunk.truncate(chunk.len() - 1); let chunk = chunk.freeze(); return Ok(Some(chunk)); } (false, None) if buf.len() > self.max_length => { // Reached the maximum length without finding a // new chunk, return an error and start discarding on the // next call. self.is_discarding = true; return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded); } (false, None) => { // We didn't find a chunk or reach the length limit, so the next // call will resume searching at the current offset. self.next_index = read_to; return Ok(None); } } } } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { Ok(match self.decode(buf)? { Some(frame) => Some(frame), None => { // return remaining data, if any if buf.is_empty() { None } else { let chunk = buf.split_to(buf.len()); self.next_index = 0; Some(chunk.freeze()) } } }) } } impl Encoder for AnyDelimiterCodec where T: AsRef, { type Error = AnyDelimiterCodecError; fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> { let chunk = chunk.as_ref(); buf.reserve(chunk.len() + 1); buf.put(chunk.as_bytes()); buf.put(self.sequence_writer.as_ref()); Ok(()) } } impl Default for AnyDelimiterCodec { fn default() -> Self { Self::new( DEFAULT_SEEK_DELIMITERS.to_vec(), DEFAULT_SEQUENCE_WRITER.to_vec(), ) } } /// An error occurred while encoding or decoding a chunk. #[derive(Debug)] pub enum AnyDelimiterCodecError { /// The maximum chunk length was exceeded. MaxChunkLengthExceeded, /// An IO error occurred. Io(io::Error), } impl fmt::Display for AnyDelimiterCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AnyDelimiterCodecError::MaxChunkLengthExceeded => { write!(f, "max chunk length exceeded") } AnyDelimiterCodecError::Io(e) => write!(f, "{}", e), } } } impl From for AnyDelimiterCodecError { fn from(e: io::Error) -> AnyDelimiterCodecError { AnyDelimiterCodecError::Io(e) } } impl std::error::Error for AnyDelimiterCodecError {} tokio-util-0.7.10/src/codec/bytes_codec.rs000064400000000000000000000043051046102023000165100ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{BufMut, Bytes, BytesMut}; use std::io; /// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// /// # Example /// /// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`BytesMut`]: bytes::BytesMut /// [`Error`]: std::io::Error /// /// ``` /// # mod hidden { /// # #[allow(unused_imports)] /// use tokio::fs::File; /// # } /// use tokio::io::AsyncRead; /// use tokio_util::codec::{FramedRead, BytesCodec}; /// /// # enum File {} /// # impl File { /// # async fn open(_name: &str) -> Result { /// # use std::io::Cursor; /// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5])) /// # } /// # } /// # /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> Result<(), std::io::Error> { /// let my_async_read = File::open("filename.txt").await?; /// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new()); /// # Ok(()) /// # } /// ``` /// #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] pub struct BytesCodec(()); impl BytesCodec { /// Creates a new `BytesCodec` for shipping around raw bytes. pub fn new() -> BytesCodec { BytesCodec(()) } } impl Decoder for BytesCodec { type Item = BytesMut; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { if !buf.is_empty() { let len = buf.len(); Ok(Some(buf.split_to(len))) } else { Ok(None) } } } impl Encoder for BytesCodec { type Error = io::Error; fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put(data); Ok(()) } } impl Encoder for BytesCodec { type Error = io::Error; fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put(data); Ok(()) } } tokio-util-0.7.10/src/codec/decoder.rs000064400000000000000000000176351046102023000156440ustar 00000000000000use crate::codec::Framed; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use std::io; /// Decoding of frames via buffers. /// /// This trait is used when constructing an instance of [`Framed`] or /// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has /// already been buffered in `src` and decodes the data into a stream of /// `Self::Item` frames. /// /// Implementations are able to track state on `self`, which enables /// implementing stateful streaming parsers. In many cases, though, this type /// will simply be a unit struct (e.g. `struct HttpDecoder`). /// /// For some underlying data-sources, namely files and FIFOs, /// it's possible to temporarily read 0 bytes by reaching EOF. /// /// In these cases `decode_eof` will be called until it signals /// fulfillment of all closing frames by returning `Ok(None)`. /// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`] /// will not invoke `decode` or `decode_eof` again, until data can be read /// during a retry. /// /// It is up to the Decoder to keep track of a restart after an EOF, /// and to decide how to handle such an event by, for example, /// allowing frames to cross EOF boundaries, re-emitting opening frames, or /// resetting the entire internal state. /// /// [`Framed`]: crate::codec::Framed /// [`FramedRead`]: crate::codec::FramedRead pub trait Decoder { /// The type of decoded frames. type Item; /// The type of unrecoverable frame decoding errors. /// /// If an individual message is ill-formed but can be ignored without /// interfering with the processing of future messages, it may be more /// useful to report the failure as an `Item`. /// /// `From` is required in the interest of making `Error` suitable /// for returning directly from a [`FramedRead`], and to enable the default /// implementation of `decode_eof` to yield an `io::Error` when the decoder /// fails to consume all available data. /// /// Note that implementors of this trait can simply indicate `type Error = /// io::Error` to use I/O errors as this type. /// /// [`FramedRead`]: crate::codec::FramedRead type Error: From; /// Attempts to decode a frame from the provided buffer of bytes. /// /// This method is called by [`FramedRead`] whenever bytes are ready to be /// parsed. The provided buffer of bytes is what's been read so far, and /// this instance of `Decode` can determine whether an entire frame is in /// the buffer and is ready to be returned. /// /// If an entire frame is available, then this instance will remove those /// bytes from the buffer provided and return them as a decoded /// frame. Note that removing bytes from the provided buffer doesn't always /// necessarily copy the bytes, so this should be an efficient operation in /// most circumstances. /// /// If the bytes look valid, but a frame isn't fully available yet, then /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that /// it needs to read some more bytes before calling this method again. /// /// Note that the bytes provided may be empty. If a previous call to /// `decode` consumed all the bytes in the buffer then `decode` will be /// called again until it returns `Ok(None)`, indicating that more bytes need to /// be read. /// /// Finally, if the bytes in the buffer are malformed then an error is /// returned indicating why. This informs [`Framed`] that the stream is now /// corrupt and should be terminated. /// /// [`Framed`]: crate::codec::Framed /// [`FramedRead`]: crate::codec::FramedRead /// /// # Buffer management /// /// Before returning from the function, implementations should ensure that /// the buffer has appropriate capacity in anticipation of future calls to /// `decode`. Failing to do so leads to inefficiency. /// /// For example, if frames have a fixed length, or if the length of the /// current frame is known from a header, a possible buffer management /// strategy is: /// /// ```no_run /// # use std::io; /// # /// # use bytes::BytesMut; /// # use tokio_util::codec::Decoder; /// # /// # struct MyCodec; /// # /// impl Decoder for MyCodec { /// // ... /// # type Item = BytesMut; /// # type Error = io::Error; /// /// fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { /// // ... /// /// // Reserve enough to complete decoding of the current frame. /// let current_frame_len: usize = 1000; // Example. /// // And to start decoding the next frame. /// let next_frame_header_len: usize = 10; // Example. /// src.reserve(current_frame_len + next_frame_header_len); /// /// return Ok(None); /// } /// } /// ``` /// /// An optimal buffer management strategy minimizes reallocations and /// over-allocations. fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; /// A default method available to be called when there are no more bytes /// available to be read from the underlying I/O. /// /// This method defaults to calling `decode` and returns an error if /// `Ok(None)` is returned while there is unconsumed data in `buf`. /// Typically this doesn't need to be implemented unless the framing /// protocol differs near the end of the stream, or if you need to construct /// frames _across_ eof boundaries on sources that can be resumed. /// /// Note that the `buf` argument may be empty. If a previous call to /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be /// called again until it returns `None`, indicating that there are no more /// frames to yield. This behavior enables returning finalization frames /// that may not be based on inbound data. /// /// Once `None` has been returned, `decode_eof` won't be called again until /// an attempt to resume the stream has been made, where the underlying stream /// actually returned more data. fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { match self.decode(buf)? { Some(frame) => Ok(Some(frame)), None => { if buf.is_empty() { Ok(None) } else { Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) } } } } /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// `Io` object, using `Decode` and `Encode` to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the `Codec` /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both `Stream` and /// `Sink`; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling `split` on the [`Framed`] returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Framed`]: crate::codec::Framed fn framed(self, io: T) -> Framed where Self: Sized, { Framed::new(io, self) } } tokio-util-0.7.10/src/codec/encoder.rs000064400000000000000000000016021046102023000156410ustar 00000000000000use bytes::BytesMut; use std::io; /// Trait of helper objects to write out messages as bytes, for use with /// [`FramedWrite`]. /// /// [`FramedWrite`]: crate::codec::FramedWrite pub trait Encoder { /// The type of encoding errors. /// /// [`FramedWrite`] requires `Encoder`s errors to implement `From` /// in the interest letting it return `Error`s directly. /// /// [`FramedWrite`]: crate::codec::FramedWrite type Error: From; /// Encodes a frame into the buffer provided. /// /// This method will encode `item` into the byte buffer provided by `dst`. /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and /// will be written out when possible. /// /// [`FramedWrite`]: crate::codec::FramedWrite fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; } tokio-util-0.7.10/src/codec/framed.rs000064400000000000000000000332641046102023000154710ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; use futures_core::Stream; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using /// the `Encoder` and `Decoder` traits to encode and decode frames. /// /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or /// by using the `new` function seen below. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Decoder::framed`]: crate::codec::Decoder::framed() pub struct Framed { #[pin] inner: FramedImpl } } impl Framed where T: AsyncRead + AsyncWrite, { /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the codec /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// Note that, for some byte sources, the stream can be resumed after an EOF /// by reading from it, even after it has returned `None`. Repeated attempts /// to do so, without new data available, continue to return `None` without /// creating more (closing) frames. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decode`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn new(inner: T, codec: U) -> Framed { Framed { inner: FramedImpl { inner, codec, state: Default::default(), }, } } /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data, /// with a specific read buffer initial capacity. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the codec /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decode`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { Framed { inner: FramedImpl { inner, codec, state: RWFrames { read: ReadFrame { eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), has_errored: false, }, write: WriteFrame::default(), }, }, } } } impl Framed { /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. /// /// Raw I/O objects work with byte sequences, but higher-level code usually /// wants to batch these into meaningful chunks, called "frames". This /// method layers framing on top of an I/O object, by using the `Codec` /// traits to handle encoding and decoding of messages frames. Note that /// the incoming and outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and /// [`Sink`]; grouping this into a single object is often useful for layering /// things like gzip or TLS, which require both read and write access to the /// underlying object. /// /// This objects takes a stream and a readbuffer and a writebuffer. These field /// can be obtained from an existing `Framed` with the [`into_parts`] method. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `Framed` returned by this method, which will /// break them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder /// [`into_parts`]: crate::codec::Framed::into_parts() /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split pub fn from_parts(parts: FramedParts) -> Framed { Framed { inner: FramedImpl { inner: parts.io, codec: parts.codec, state: RWFrames { read: parts.read_buf.into(), write: parts.write_buf.into(), }, }, } } /// Returns a reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &U { &self.inner.codec } /// Returns a mutable reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut U { &mut self.inner.codec } /// Maps the codec `U` to `C`, preserving the read and write buffers /// wrapped by `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn map_codec(self, map: F) -> Framed where F: FnOnce(U) -> C, { // This could be potentially simplified once rust-lang/rust#86555 hits stable let parts = self.into_parts(); Framed::from_parts(FramedParts { io: parts.io, codec: map(parts.codec), read_buf: parts.read_buf, write_buf: parts.write_buf, _priv: (), }) } /// Returns a mutable reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U { self.project().inner.project().codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.inner.state.read.buffer } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.read.buffer } /// Returns a reference to the write buffer. pub fn write_buffer(&self) -> &BytesMut { &self.inner.state.write.buffer } /// Returns a mutable reference to the write buffer. pub fn write_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.write.buffer } /// Returns backpressure boundary pub fn backpressure_boundary(&self) -> usize { self.inner.state.write.backpressure_boundary } /// Updates backpressure boundary pub fn set_backpressure_boundary(&mut self, boundary: usize) { self.inner.state.write.backpressure_boundary = boundary; } /// Consumes the `Framed`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Consumes the `Framed`, returning its underlying I/O stream, the buffer /// with unprocessed data, and the codec. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_parts(self) -> FramedParts { FramedParts { io: self.inner.inner, codec: self.inner.codec, read_buf: self.inner.state.read.buffer, write_buf: self.inner.state.write.buffer, _priv: (), } } } // This impl just defers to the underlying FramedImpl impl Stream for Framed where T: AsyncRead, U: Decoder, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } // This impl just defers to the underlying FramedImpl impl Sink for Framed where T: AsyncWrite, U: Encoder, U::Error: From, { type Error = U::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } impl fmt::Debug for Framed where T: fmt::Debug, U: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Framed") .field("io", self.get_ref()) .field("codec", self.codec()) .finish() } } /// `FramedParts` contains an export of the data of a Framed transport. /// It can be used to construct a new [`Framed`] with a different codec. /// It contains all current buffers and the inner transport. /// /// [`Framed`]: crate::codec::Framed #[derive(Debug)] #[allow(clippy::manual_non_exhaustive)] pub struct FramedParts { /// The inner transport used to read bytes to and write bytes to pub io: T, /// The codec pub codec: U, /// The buffer with read but unprocessed data. pub read_buf: BytesMut, /// A buffer with unprocessed data which are not written yet. pub write_buf: BytesMut, /// This private field allows us to add additional fields in the future in a /// backwards compatible way. _priv: (), } impl FramedParts { /// Create a new, default, `FramedParts` pub fn new(io: T, codec: U) -> FramedParts where U: Encoder, { FramedParts { io, codec, read_buf: BytesMut::new(), write_buf: BytesMut::new(), _priv: (), } } } tokio-util-0.7.10/src/codec/framed_impl.rs000064400000000000000000000310041046102023000165000ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use futures_core::Stream; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_core::ready; use futures_sink::Sink; use pin_project_lite::pin_project; use std::borrow::{Borrow, BorrowMut}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tracing::trace; pin_project! { #[derive(Debug)] pub(crate) struct FramedImpl { #[pin] pub(crate) inner: T, pub(crate) state: State, pub(crate) codec: U, } } const INITIAL_CAPACITY: usize = 8 * 1024; #[derive(Debug)] pub(crate) struct ReadFrame { pub(crate) eof: bool, pub(crate) is_readable: bool, pub(crate) buffer: BytesMut, pub(crate) has_errored: bool, } pub(crate) struct WriteFrame { pub(crate) buffer: BytesMut, pub(crate) backpressure_boundary: usize, } #[derive(Default)] pub(crate) struct RWFrames { pub(crate) read: ReadFrame, pub(crate) write: WriteFrame, } impl Default for ReadFrame { fn default() -> Self { Self { eof: false, is_readable: false, buffer: BytesMut::with_capacity(INITIAL_CAPACITY), has_errored: false, } } } impl Default for WriteFrame { fn default() -> Self { Self { buffer: BytesMut::with_capacity(INITIAL_CAPACITY), backpressure_boundary: INITIAL_CAPACITY, } } } impl From for ReadFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer, is_readable: size > 0, eof: false, has_errored: false, } } } impl From for WriteFrame { fn from(mut buffer: BytesMut) -> Self { let size = buffer.capacity(); if size < INITIAL_CAPACITY { buffer.reserve(INITIAL_CAPACITY - size); } Self { buffer, backpressure_boundary: INITIAL_CAPACITY, } } } impl Borrow for RWFrames { fn borrow(&self) -> &ReadFrame { &self.read } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut ReadFrame { &mut self.read } } impl Borrow for RWFrames { fn borrow(&self) -> &WriteFrame { &self.write } } impl BorrowMut for RWFrames { fn borrow_mut(&mut self) -> &mut WriteFrame { &mut self.write } } impl Stream for FramedImpl where T: AsyncRead, U: Decoder, R: BorrowMut, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_read_buf; let mut pinned = self.project(); let state: &mut ReadFrame = pinned.state.borrow_mut(); // The following loops implements a state machine with each state corresponding // to a combination of the `is_readable` and `eof` flags. States persist across // loop entries and most state transitions occur with a return. // // The initial state is `reading`. // // | state | eof | is_readable | has_errored | // |---------|-------|-------------|-------------| // | reading | false | false | false | // | framing | false | true | false | // | pausing | true | true | false | // | paused | true | false | false | // | errored | | | true | // `decode_eof` returns Err // ┌────────────────────────────────────────────────────────┐ // `decode_eof` returns │ │ // `Ok(Some)` │ │ // ┌─────┐ │ `decode_eof` returns After returning │ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ // Pending read │ │ │ │ │ │ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ // │ │ │ ┌──────┐ │ Pending │ │ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ // └──┬─▲────┘ └─────┬──┬┘ │ │ // │ │ │ │ `decode` returns Err │ │ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ // │ read returns Err │ // └────────────────────────────────────────────────────────────────────────────────────────────┘ loop { // Return `None` if we have encountered an error from the underlying decoder // See: https://github.com/tokio-rs/tokio/issues/3976 if state.has_errored { // preparing has_errored -> paused trace!("Returning None and setting paused"); state.is_readable = false; state.has_errored = false; return Poll::Ready(None); } // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", // i.e. it _might_ contain data consumable as a frame or closing frame. // Both signal that there is no such data by returning `None`. // // If `decode` couldn't read a frame and the upstream source has returned eof, // `decode_eof` will attempt to decode the remaining bytes as closing frames. // // If the underlying AsyncRead is resumable, we may continue after an EOF, // but must finish emitting all of it's associated `decode_eof` frames. // Furthermore, we don't want to emit any `decode_eof` frames on retried // reads after an EOF unless we've actually read more data. if state.is_readable { // pausing or framing if state.eof { // pausing let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { trace!("Got an error, going to errored state"); state.has_errored = true; err })?; if frame.is_none() { state.is_readable = false; // prepare pausing -> paused } // implicit pausing -> pausing or pausing -> paused return Poll::Ready(frame.map(Ok)); } // framing trace!("attempting to decode a frame"); if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { trace!("Got an error, going to errored state"); state.has_errored = true; op })? { trace!("frame decoded from buffer"); // implicit framing -> framing return Poll::Ready(Some(Ok(frame))); } // framing -> reading state.is_readable = false; } // reading or paused // If we can't build a frame yet, try to read more data and try again. // Make sure we've got room for at least one byte to read to ensure // that we don't get a spurious 0 that looks like EOF. state.buffer.reserve(1); let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |err| { trace!("Got an error, going to errored state"); state.has_errored = true; err }, )? { Poll::Ready(ct) => ct, // implicit reading -> reading or implicit paused -> paused Poll::Pending => return Poll::Pending, }; if bytect == 0 { if state.eof { // We're already at an EOF, and since we've reached this path // we're also not readable. This implies that we've already finished // our `decode_eof` handling, so we can simply return `None`. // implicit paused -> paused return Poll::Ready(None); } // prepare reading -> paused state.eof = true; } else { // prepare paused -> framing or noop reading -> framing state.eof = false; } // paused -> framing or reading -> framing or reading -> pausing state.is_readable = true; } } } impl Sink for FramedImpl where T: AsyncWrite, U: Encoder, U::Error: From, W: BorrowMut, { type Error = U::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { self.as_mut().poll_flush(cx) } else { Poll::Ready(Ok(())) } } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { let pinned = self.project(); pinned .codec .encode(item, &mut pinned.state.borrow_mut().buffer)?; Ok(()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_write_buf; trace!("flushing framed transport"); let mut pinned = self.project(); while !pinned.state.borrow_mut().buffer.is_empty() { let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); trace!(remaining = buffer.len(), "writing;"); let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; if n == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, "failed to \ write frame to transport", ) .into())); } } // Try flushing the underlying IO ready!(pinned.inner.poll_flush(cx))?; trace!("framed transport flushed"); Poll::Ready(Ok(())) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; ready!(self.project().inner.poll_shutdown(cx))?; Poll::Ready(Ok(())) } } tokio-util-0.7.10/src/codec/framed_read.rs000064400000000000000000000133771046102023000164670ustar 00000000000000use crate::codec::framed_impl::{FramedImpl, ReadFrame}; use crate::codec::Decoder; use futures_core::Stream; use tokio::io::AsyncRead; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A [`Stream`] of messages decoded from an [`AsyncRead`]. /// /// [`Stream`]: futures_core::Stream /// [`AsyncRead`]: tokio::io::AsyncRead pub struct FramedRead { #[pin] inner: FramedImpl, } } // ===== impl FramedRead ===== impl FramedRead where T: AsyncRead, D: Decoder, { /// Creates a new `FramedRead` with the given `decoder`. pub fn new(inner: T, decoder: D) -> FramedRead { FramedRead { inner: FramedImpl { inner, codec: decoder, state: Default::default(), }, } } /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity` /// initial size. pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead { FramedRead { inner: FramedImpl { inner, codec: decoder, state: ReadFrame { eof: false, is_readable: false, buffer: BytesMut::with_capacity(capacity), has_errored: false, }, }, } } } impl FramedRead { /// Returns a reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `FramedRead`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Consumes the `FramedRead`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Returns a reference to the underlying decoder. pub fn decoder(&self) -> &D { &self.inner.codec } /// Returns a mutable reference to the underlying decoder. pub fn decoder_mut(&mut self) -> &mut D { &mut self.inner.codec } /// Maps the decoder `D` to `C`, preserving the read buffer /// wrapped by `Framed`. pub fn map_decoder(self, map: F) -> FramedRead where F: FnOnce(D) -> C, { // This could be potentially simplified once rust-lang/rust#86555 hits stable let FramedImpl { inner, state, codec, } = self.inner; FramedRead { inner: FramedImpl { inner, state, codec: map(codec), }, } } /// Returns a mutable reference to the underlying decoder. pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { self.project().inner.project().codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.inner.state.buffer } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.buffer } } // This impl just defers to the underlying FramedImpl impl Stream for FramedRead where T: AsyncRead, D: Decoder, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } // This impl just defers to the underlying T: Sink impl Sink for FramedRead where T: Sink, { type Error = T::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_close(cx) } } impl fmt::Debug for FramedRead where T: fmt::Debug, D: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedRead") .field("inner", &self.get_ref()) .field("decoder", &self.decoder()) .field("eof", &self.inner.state.eof) .field("is_readable", &self.inner.state.is_readable) .field("buffer", &self.read_buffer()) .finish() } } tokio-util-0.7.10/src/codec/framed_write.rs000064400000000000000000000125051046102023000166760ustar 00000000000000use crate::codec::encoder::Encoder; use crate::codec::framed_impl::{FramedImpl, WriteFrame}; use futures_core::Stream; use tokio::io::AsyncWrite; use bytes::BytesMut; use futures_sink::Sink; use pin_project_lite::pin_project; use std::fmt; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A [`Sink`] of frames encoded to an `AsyncWrite`. /// /// [`Sink`]: futures_sink::Sink pub struct FramedWrite { #[pin] inner: FramedImpl, } } impl FramedWrite where T: AsyncWrite, { /// Creates a new `FramedWrite` with the given `encoder`. pub fn new(inner: T, encoder: E) -> FramedWrite { FramedWrite { inner: FramedImpl { inner, codec: encoder, state: WriteFrame::default(), }, } } } impl FramedWrite { /// Returns a reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_ref(&self) -> &T { &self.inner.inner } /// Returns a mutable reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_mut(&mut self) -> &mut T { &mut self.inner.inner } /// Returns a pinned mutable reference to the underlying I/O stream wrapped by /// `FramedWrite`. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { self.project().inner.project().inner } /// Consumes the `FramedWrite`, returning its underlying I/O stream. /// /// Note that care should be taken to not tamper with the underlying stream /// of data coming in as it may corrupt the stream of frames otherwise /// being worked with. pub fn into_inner(self) -> T { self.inner.inner } /// Returns a reference to the underlying encoder. pub fn encoder(&self) -> &E { &self.inner.codec } /// Returns a mutable reference to the underlying encoder. pub fn encoder_mut(&mut self) -> &mut E { &mut self.inner.codec } /// Maps the encoder `E` to `C`, preserving the write buffer /// wrapped by `Framed`. pub fn map_encoder(self, map: F) -> FramedWrite where F: FnOnce(E) -> C, { // This could be potentially simplified once rust-lang/rust#86555 hits stable let FramedImpl { inner, state, codec, } = self.inner; FramedWrite { inner: FramedImpl { inner, state, codec: map(codec), }, } } /// Returns a mutable reference to the underlying encoder. pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { self.project().inner.project().codec } /// Returns a reference to the write buffer. pub fn write_buffer(&self) -> &BytesMut { &self.inner.state.buffer } /// Returns a mutable reference to the write buffer. pub fn write_buffer_mut(&mut self) -> &mut BytesMut { &mut self.inner.state.buffer } /// Returns backpressure boundary pub fn backpressure_boundary(&self) -> usize { self.inner.state.backpressure_boundary } /// Updates backpressure boundary pub fn set_backpressure_boundary(&mut self, boundary: usize) { self.inner.state.backpressure_boundary = boundary; } } // This impl just defers to the underlying FramedImpl impl Sink for FramedWrite where T: AsyncWrite, E: Encoder, E::Error: From, { type Error = E::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } // This impl just defers to the underlying T: Stream impl Stream for FramedWrite where T: Stream, { type Item = T::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.project().inner.poll_next(cx) } } impl fmt::Debug for FramedWrite where T: fmt::Debug, U: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FramedWrite") .field("inner", &self.get_ref()) .field("encoder", &self.encoder()) .field("buffer", &self.inner.state.buffer) .finish() } } tokio-util-0.7.10/src/codec/length_delimited.rs000064400000000000000000000766641046102023000175470ustar 00000000000000//! Frame a stream of bytes based on a length prefix //! //! Many protocols delimit their frames by prefacing frame data with a //! frame head that specifies the length of the frame. The //! `length_delimited` module provides utilities for handling the length //! based framing. This allows the consumer to work with entire frames //! without having to worry about buffering or other framing logic. //! //! # Getting started //! //! If implementing a protocol from scratch, using length delimited framing //! is an easy way to get started. [`LengthDelimitedCodec::new()`] will //! return a length delimited codec using default configuration values. //! This can then be used to construct a framer to adapt a full-duplex //! byte stream into a stream of frames. //! //! ``` //! use tokio::io::{AsyncRead, AsyncWrite}; //! use tokio_util::codec::{Framed, LengthDelimitedCodec}; //! //! fn bind_transport(io: T) //! -> Framed //! { //! Framed::new(io, LengthDelimitedCodec::new()) //! } //! # pub fn main() {} //! ``` //! //! The returned transport implements `Sink + Stream` for `BytesMut`. It //! encodes the frame with a big-endian `u32` header denoting the frame //! payload length: //! //! ```text //! +----------+--------------------------------+ //! | len: u32 | frame payload | //! +----------+--------------------------------+ //! ``` //! //! Specifically, given the following: //! //! ``` //! use tokio::io::{AsyncRead, AsyncWrite}; //! use tokio_util::codec::{Framed, LengthDelimitedCodec}; //! //! use futures::SinkExt; //! use bytes::Bytes; //! //! async fn write_frame(io: T) -> Result<(), Box> //! where //! T: AsyncRead + AsyncWrite + Unpin, //! { //! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); //! let frame = Bytes::from("hello world"); //! //! transport.send(frame).await?; //! Ok(()) //! } //! ``` //! //! The encoded frame will look like this: //! //! ```text //! +---- len: u32 ----+---- data ----+ //! | \x00\x00\x00\x0b | hello world | //! +------------------+--------------+ //! ``` //! //! # Decoding //! //! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], //! such that each yielded [`BytesMut`] value contains the contents of an //! entire frame. There are many configuration parameters enabling //! [`FramedRead`] to handle a wide range of protocols. Here are some //! examples that will cover the various options at a high level. //! //! ## Example 1 //! //! The following will parse a `u16` length field at offset 0, including the //! frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_type::() //! .length_adjustment(0) // default value //! .num_skip(0) // Do not strip frame header //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ //! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | //! +----------+---------------+ +----------+---------------+ //! ``` //! //! The value of the length field is 11 (`\x0B`) which represents the length //! of the payload, `hello world`. By default, [`FramedRead`] assumes that //! the length field represents the number of bytes that **follows** the //! length field. Thus, the entire frame has a length of 13: 2 bytes for the //! frame head + 11 bytes for the payload. //! //! ## Example 2 //! //! The following will parse a `u16` length field at offset 0, omitting the //! frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_type::() //! .length_adjustment(0) // default value //! // `num_skip` is not needed, the default is to skip //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +--- Payload ---+ //! | \x00\x0B | Hello world | --> | Hello world | //! +----------+---------------+ +---------------+ //! ``` //! //! This is similar to the first example, the only difference is that the //! frame head is **not** included in the yielded `BytesMut` value. //! //! ## Example 3 //! //! The following will parse a `u16` length field at offset 0, including the //! frame head in the yielded `BytesMut`. In this case, the length field //! **includes** the frame head length. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_type::() //! .length_adjustment(-2) // size of head //! .num_skip(0) //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ //! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | //! +----------+---------------+ +----------+---------------+ //! ``` //! //! In most cases, the length field represents the length of the payload //! only, as shown in the previous examples. However, in some protocols the //! length field represents the length of the whole frame, including the //! head. In such cases, we specify a negative `length_adjustment` to adjust //! the value provided in the frame head to represent the payload length. //! //! ## Example 4 //! //! The following will parse a 3 byte length field at offset 0 in a 5 byte //! frame head, including the frame head in the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(3) //! .length_adjustment(2) // remaining head //! .num_skip(0) //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +---- len -----+- head -+--- Payload ---+ //! | \x00\x00\x0B | \xCAFE | Hello world | //! +--------------+--------+---------------+ //! //! DECODED //! +---- len -----+- head -+--- Payload ---+ //! | \x00\x00\x0B | \xCAFE | Hello world | //! +--------------+--------+---------------+ //! ``` //! //! A more advanced example that shows a case where there is extra frame //! head data between the length field and the payload. In such cases, it is //! usually desirable to include the frame head as part of the yielded //! `BytesMut`. This lets consumers of the length delimited framer to //! process the frame head as needed. //! //! The positive `length_adjustment` value lets `FramedRead` factor in the //! additional head into the frame length calculation. //! //! ## Example 5 //! //! The following will parse a `u16` length field at offset 1 of a 4 byte //! frame head. The first byte and the length field will be omitted from the //! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be //! included. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 //! .length_field_type::() //! .length_adjustment(1) // length of hdr2 //! .num_skip(3) // length of hdr1 + LEN //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ //! | \xCA | \x00\x0B | \xFE | Hello world | //! +--------+----------+--------+---------------+ //! //! DECODED //! +- hdr2 -+--- Payload ---+ //! | \xFE | Hello world | //! +--------+---------------+ //! ``` //! //! The length field is situated in the middle of the frame head. In this //! case, the first byte in the frame head could be a version or some other //! identifier that is not needed for processing. On the other hand, the //! second half of the head is needed. //! //! `length_field_offset` indicates how many bytes to skip before starting //! to read the length field. `length_adjustment` is the number of bytes to //! skip starting at the end of the length field. In this case, it is the //! second half of the head. //! //! ## Example 6 //! //! The following will parse a `u16` length field at offset 1 of a 4 byte //! frame head. The first byte and the length field will be omitted from the //! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be //! included. In this case, the length field **includes** the frame head //! length. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 //! .length_field_type::() //! .length_adjustment(-3) // length of hdr1 + LEN, negative //! .num_skip(3) //! .new_read(io); //! # } //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT //! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ //! | \xCA | \x00\x0F | \xFE | Hello world | //! +--------+----------+--------+---------------+ //! //! DECODED //! +- hdr2 -+--- Payload ---+ //! | \xFE | Hello world | //! +--------+---------------+ //! ``` //! //! Similar to the example above, the difference is that the length field //! represents the length of the entire frame instead of just the payload. //! The length of `hdr1` and `len` must be counted in `length_adjustment`. //! Note that the length of `hdr2` does **not** need to be explicitly set //! anywhere because it already is factored into the total frame length that //! is read from the byte stream. //! //! ## Example 7 //! //! The following will parse a 3 byte length field at offset 0 in a 4 byte //! frame head, excluding the 4th byte from the yielded `BytesMut`. //! //! ``` //! # use tokio::io::AsyncRead; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value //! .length_field_length(3) //! .length_adjustment(0) // default value //! .num_skip(4) // skip the first 4 bytes //! .new_read(io); //! # } //! # pub fn main() {} //! ``` //! //! The following frame will be decoded as such: //! //! ```text //! INPUT DECODED //! +------- len ------+--- Payload ---+ +--- Payload ---+ //! | \x00\x00\x0B\xFF | Hello world | => | Hello world | //! +------------------+---------------+ +---------------+ //! ``` //! //! A simple example where there are unused bytes between the length field //! and the payload. //! //! # Encoding //! //! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], //! such that each submitted [`BytesMut`] is prefaced by a length field. //! There are fewer configuration options than [`FramedRead`]. Given //! protocols that have more complex frame heads, an encoder should probably //! be written by hand using [`Encoder`]. //! //! Here is a simple example, given a `FramedWrite` with the following //! configuration: //! //! ``` //! # use tokio::io::AsyncWrite; //! # use tokio_util::codec::LengthDelimitedCodec; //! # fn write_frame(io: T) { //! # let _ = //! LengthDelimitedCodec::builder() //! .length_field_type::() //! .new_write(io); //! # } //! # pub fn main() {} //! ``` //! //! A payload of `hello world` will be encoded as: //! //! ```text //! +- len: u16 -+---- data ----+ //! | \x00\x0b | hello world | //! +------------+--------------+ //! ``` //! //! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new //! [`FramedRead`]: struct@FramedRead //! [`FramedWrite`]: struct@FramedWrite //! [`AsyncRead`]: trait@tokio::io::AsyncRead //! [`AsyncWrite`]: trait@tokio::io::AsyncWrite //! [`Encoder`]: trait@Encoder //! [`BytesMut`]: bytes::BytesMut use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; use tokio::io::{AsyncRead, AsyncWrite}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::error::Error as StdError; use std::io::{self, Cursor}; use std::{cmp, fmt, mem}; /// Configure length delimited `LengthDelimitedCodec`s. /// /// `Builder` enables constructing configured length delimited codecs. Note /// that not all configuration settings apply to both encoding and decoding. See /// the documentation for specific methods for more detail. #[derive(Debug, Clone, Copy)] pub struct Builder { // Maximum frame length max_frame_len: usize, // Number of bytes representing the field length length_field_len: usize, // Number of bytes in the header before the length field length_field_offset: usize, // Adjust the length specified in the header field by this amount length_adjustment: isize, // Total number of bytes to skip before reading the payload, if not set, // `length_field_len + length_field_offset` num_skip: Option, // Length field byte order (little or big endian) length_field_is_big_endian: bool, } /// An error when the number of bytes read is more than max frame length. pub struct LengthDelimitedCodecError { _priv: (), } /// A codec for frames delimited by a frame head specifying their lengths. /// /// This allows the consumer to work with entire frames without having to worry /// about buffering or other framing logic. /// /// See [module level] documentation for more detail. /// /// [module level]: index.html #[derive(Debug, Clone)] pub struct LengthDelimitedCodec { // Configuration values builder: Builder, // Read state state: DecodeState, } #[derive(Debug, Clone, Copy)] enum DecodeState { Head, Data(usize), } // ===== impl LengthDelimitedCodec ====== impl LengthDelimitedCodec { /// Creates a new `LengthDelimitedCodec` with the default configuration values. pub fn new() -> Self { Self { builder: Builder::new(), state: DecodeState::Head, } } /// Creates a new length delimited codec builder with default configuration /// values. pub fn builder() -> Builder { Builder::new() } /// Returns the current max frame setting /// /// This is the largest size this codec will accept from the wire. Larger /// frames will be rejected. pub fn max_frame_length(&self) -> usize { self.builder.max_frame_len } /// Updates the max frame setting. /// /// The change takes effect the next time a frame is decoded. In other /// words, if a frame is currently in process of being decoded with a frame /// size greater than `val` but less than the max frame length in effect /// before calling this function, then the frame will be allowed. pub fn set_max_frame_length(&mut self, val: usize) { self.builder.max_frame_length(val); } fn decode_head(&mut self, src: &mut BytesMut) -> io::Result> { let head_len = self.builder.num_head_bytes(); let field_len = self.builder.length_field_len; if src.len() < head_len { // Not enough data return Ok(None); } let n = { let mut src = Cursor::new(&mut *src); // Skip the required bytes src.advance(self.builder.length_field_offset); // match endianness let n = if self.builder.length_field_is_big_endian { src.get_uint(field_len) } else { src.get_uint_le(field_len) }; if n > self.builder.max_frame_len as u64 { return Err(io::Error::new( io::ErrorKind::InvalidData, LengthDelimitedCodecError { _priv: () }, )); } // The check above ensures there is no overflow let n = n as usize; // Adjust `n` with bounds checking let n = if self.builder.length_adjustment < 0 { n.checked_sub(-self.builder.length_adjustment as usize) } else { n.checked_add(self.builder.length_adjustment as usize) }; // Error handling match n { Some(n) => n, None => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "provided length would overflow after adjustment", )); } } }; src.advance(self.builder.get_num_skip()); // Ensure that the buffer has enough space to read the incoming // payload src.reserve(n.saturating_sub(src.len())); Ok(Some(n)) } fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option { // At this point, the buffer has already had the required capacity // reserved. All there is to do is read. if src.len() < n { return None; } Some(src.split_to(n)) } } impl Decoder for LengthDelimitedCodec { type Item = BytesMut; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> io::Result> { let n = match self.state { DecodeState::Head => match self.decode_head(src)? { Some(n) => { self.state = DecodeState::Data(n); n } None => return Ok(None), }, DecodeState::Data(n) => n, }; match self.decode_data(n, src) { Some(data) => { // Update the decode state self.state = DecodeState::Head; // Make sure the buffer has enough space to read the next head src.reserve(self.builder.num_head_bytes().saturating_sub(src.len())); Ok(Some(data)) } None => Ok(None), } } } impl Encoder for LengthDelimitedCodec { type Error = io::Error; fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { let n = data.len(); if n > self.builder.max_frame_len { return Err(io::Error::new( io::ErrorKind::InvalidInput, LengthDelimitedCodecError { _priv: () }, )); } // Adjust `n` with bounds checking let n = if self.builder.length_adjustment < 0 { n.checked_add(-self.builder.length_adjustment as usize) } else { n.checked_sub(self.builder.length_adjustment as usize) }; let n = n.ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "provided length would overflow after adjustment", ) })?; // Reserve capacity in the destination buffer to fit the frame and // length field (plus adjustment). dst.reserve(self.builder.length_field_len + n); if self.builder.length_field_is_big_endian { dst.put_uint(n as u64, self.builder.length_field_len); } else { dst.put_uint_le(n as u64, self.builder.length_field_len); } // Write the frame to the buffer dst.extend_from_slice(&data[..]); Ok(()) } } impl Default for LengthDelimitedCodec { fn default() -> Self { Self::new() } } // ===== impl Builder ===== mod builder { /// Types that can be used with `Builder::length_field_type`. pub trait LengthFieldType {} impl LengthFieldType for u8 {} impl LengthFieldType for u16 {} impl LengthFieldType for u32 {} impl LengthFieldType for u64 {} #[cfg(any( target_pointer_width = "8", target_pointer_width = "16", target_pointer_width = "32", target_pointer_width = "64", ))] impl LengthFieldType for usize {} } impl Builder { /// Creates a new length delimited codec builder with default configuration /// values. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn new() -> Builder { Builder { // Default max frame length of 8MB max_frame_len: 8 * 1_024 * 1_024, // Default byte length of 4 length_field_len: 4, // Default to the header field being at the start of the header. length_field_offset: 0, length_adjustment: 0, // Total number of bytes to skip before reading the payload, if not set, // `length_field_len + length_field_offset` num_skip: None, // Default to reading the length field in network (big) endian. length_field_is_big_endian: true, } } /// Read the length field as a big endian integer /// /// This is the default setting. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .big_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn big_endian(&mut self) -> &mut Self { self.length_field_is_big_endian = true; self } /// Read the length field as a little endian integer /// /// The default setting is big endian. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .little_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn little_endian(&mut self) -> &mut Self { self.length_field_is_big_endian = false; self } /// Read the length field as a native endian integer /// /// The default setting is big endian. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .native_endian() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn native_endian(&mut self) -> &mut Self { if cfg!(target_endian = "big") { self.big_endian() } else { self.little_endian() } } /// Sets the max frame length in bytes /// /// This configuration option applies to both encoding and decoding. The /// default value is 8MB. /// /// When decoding, the length field read from the byte stream is checked /// against this setting **before** any adjustments are applied. When /// encoding, the length of the submitted payload is checked against this /// setting. /// /// When frames exceed the max length, an `io::Error` with the custom value /// of the `LengthDelimitedCodecError` type will be returned. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .max_frame_length(8 * 1024 * 1024) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn max_frame_length(&mut self, val: usize) -> &mut Self { self.max_frame_len = val; self } /// Sets the unsigned integer type used to represent the length field. /// /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on /// 64-bit targets). /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_type::() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` /// /// Unlike [`Builder::length_field_length`], this does not fail at runtime /// and instead produces a compile error: /// /// ```compile_fail /// # use tokio::io::AsyncRead; /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_type::() /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_field_type(&mut self) -> &mut Self { self.length_field_length(mem::size_of::()) } /// Sets the number of bytes used to represent the length field /// /// The default value is `4`. The max value is `8`. /// /// This configuration option applies to both encoding and decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_length(4) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_field_length(&mut self, val: usize) -> &mut Self { assert!(val > 0 && val <= 8, "invalid length field length"); self.length_field_len = val; self } /// Sets the number of bytes in the header before the length field /// /// This configuration option only applies to decoding. /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(1) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_field_offset(&mut self, val: usize) -> &mut Self { self.length_field_offset = val; self } /// Delta between the payload length specified in the header and the real /// payload length /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_adjustment(-2) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn length_adjustment(&mut self, val: isize) -> &mut Self { self.length_adjustment = val; self } /// Sets the number of bytes to skip before reading the payload /// /// Default value is `length_field_len + length_field_offset` /// /// This configuration option only applies to decoding /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .num_skip(4) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn num_skip(&mut self, val: usize) -> &mut Self { self.num_skip = Some(val); self } /// Create a configured length delimited `LengthDelimitedCodec` /// /// # Examples /// /// ``` /// use tokio_util::codec::LengthDelimitedCodec; /// # pub fn main() { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_codec(); /// # } /// ``` pub fn new_codec(&self) -> LengthDelimitedCodec { LengthDelimitedCodec { builder: *self, state: DecodeState::Head, } } /// Create a configured length delimited `FramedRead` /// /// # Examples /// /// ``` /// # use tokio::io::AsyncRead; /// use tokio_util::codec::LengthDelimitedCodec; /// /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); /// # } /// # pub fn main() {} /// ``` pub fn new_read(&self, upstream: T) -> FramedRead where T: AsyncRead, { FramedRead::new(upstream, self.new_codec()) } /// Create a configured length delimited `FramedWrite` /// /// # Examples /// /// ``` /// # use tokio::io::AsyncWrite; /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn write_frame(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_type::() /// .new_write(io); /// # } /// # pub fn main() {} /// ``` pub fn new_write(&self, inner: T) -> FramedWrite where T: AsyncWrite, { FramedWrite::new(inner, self.new_codec()) } /// Create a configured length delimited `Framed` /// /// # Examples /// /// ``` /// # use tokio::io::{AsyncRead, AsyncWrite}; /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn write_frame(io: T) { /// # let _ = /// LengthDelimitedCodec::builder() /// .length_field_type::() /// .new_framed(io); /// # } /// # pub fn main() {} /// ``` pub fn new_framed(&self, inner: T) -> Framed where T: AsyncRead + AsyncWrite, { Framed::new(inner, self.new_codec()) } fn num_head_bytes(&self) -> usize { let num = self.length_field_offset + self.length_field_len; cmp::max(num, self.num_skip.unwrap_or(0)) } fn get_num_skip(&self) -> usize { self.num_skip .unwrap_or(self.length_field_offset + self.length_field_len) } } impl Default for Builder { fn default() -> Self { Self::new() } } // ===== impl LengthDelimitedCodecError ===== impl fmt::Debug for LengthDelimitedCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LengthDelimitedCodecError").finish() } } impl fmt::Display for LengthDelimitedCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("frame size too big") } } impl StdError for LengthDelimitedCodecError {} tokio-util-0.7.10/src/codec/lines_codec.rs000064400000000000000000000177561046102023000165120ustar 00000000000000use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use bytes::{Buf, BufMut, BytesMut}; use std::{cmp, fmt, io, str, usize}; /// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. /// /// This uses the `\n` character as the line ending on all platforms. /// /// [`Decoder`]: crate::codec::Decoder /// [`Encoder`]: crate::codec::Encoder #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct LinesCodec { // Stored index of the next index to examine for a `\n` character. // This is used to optimize searching. // For example, if `decode` was called with `abc`, it would hold `3`, // because that is the next index to examine. // The next time `decode` is called with `abcde\n`, the method will // only look at `de\n` before returning. next_index: usize, /// The maximum length for a given line. If `usize::MAX`, lines will be /// read until a `\n` character is reached. max_length: usize, /// Are we currently discarding the remainder of a line which was over /// the length limit? is_discarding: bool, } impl LinesCodec { /// Returns a `LinesCodec` for splitting up data into lines. /// /// # Note /// /// The returned `LinesCodec` will not have an upper bound on the length /// of a buffered line. See the documentation for [`new_with_max_length`] /// for information on why this could be a potential security risk. /// /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() pub fn new() -> LinesCodec { LinesCodec { next_index: 0, max_length: usize::MAX, is_discarding: false, } } /// Returns a `LinesCodec` with a maximum line length limit. /// /// If this is set, calls to `LinesCodec::decode` will return a /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls /// will discard up to `limit` bytes from that line until a newline /// character is reached, returning `None` until the line over the limit /// has been fully discarded. After that point, calls to `decode` will /// function as normal. /// /// # Note /// /// Setting a length limit is highly recommended for any `LinesCodec` which /// will be exposed to untrusted input. Otherwise, the size of the buffer /// that holds the line currently being read is unbounded. An attacker could /// exploit this unbounded buffer by sending an unbounded amount of input /// without any `\n` characters, causing unbounded memory consumption. /// /// [`LinesCodecError`]: crate::codec::LinesCodecError pub fn new_with_max_length(max_length: usize) -> Self { LinesCodec { max_length, ..LinesCodec::new() } } /// Returns the maximum line length when decoding. /// /// ``` /// use std::usize; /// use tokio_util::codec::LinesCodec; /// /// let codec = LinesCodec::new(); /// assert_eq!(codec.max_length(), usize::MAX); /// ``` /// ``` /// use tokio_util::codec::LinesCodec; /// /// let codec = LinesCodec::new_with_max_length(256); /// assert_eq!(codec.max_length(), 256); /// ``` pub fn max_length(&self) -> usize { self.max_length } } fn utf8(buf: &[u8]) -> Result<&str, io::Error> { str::from_utf8(buf) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) } fn without_carriage_return(s: &[u8]) -> &[u8] { if let Some(&b'\r') = s.last() { &s[..s.len() - 1] } else { s } } impl Decoder for LinesCodec { type Item = String; type Error = LinesCodecError; fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { loop { // Determine how far into the buffer we'll search for a newline. If // there's no max_length set, we'll read to the end of the buffer. let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); let newline_offset = buf[self.next_index..read_to] .iter() .position(|b| *b == b'\n'); match (self.is_discarding, newline_offset) { (true, Some(offset)) => { // If we found a newline, discard up to that offset and // then stop discarding. On the next iteration, we'll try // to read a line normally. buf.advance(offset + self.next_index + 1); self.is_discarding = false; self.next_index = 0; } (true, None) => { // Otherwise, we didn't find a newline, so we'll discard // everything we read. On the next iteration, we'll continue // discarding up to max_len bytes unless we find a newline. buf.advance(read_to); self.next_index = 0; if buf.is_empty() { return Ok(None); } } (false, Some(offset)) => { // Found a line! let newline_index = offset + self.next_index; self.next_index = 0; let line = buf.split_to(newline_index + 1); let line = &line[..line.len() - 1]; let line = without_carriage_return(line); let line = utf8(line)?; return Ok(Some(line.to_string())); } (false, None) if buf.len() > self.max_length => { // Reached the maximum length without finding a // newline, return an error and start discarding on the // next call. self.is_discarding = true; return Err(LinesCodecError::MaxLineLengthExceeded); } (false, None) => { // We didn't find a line or reach the length limit, so the next // call will resume searching at the current offset. self.next_index = read_to; return Ok(None); } } } } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { Ok(match self.decode(buf)? { Some(frame) => Some(frame), None => { // No terminating newline - return remaining data, if any if buf.is_empty() || buf == &b"\r"[..] { None } else { let line = buf.split_to(buf.len()); let line = without_carriage_return(&line); let line = utf8(line)?; self.next_index = 0; Some(line.to_string()) } } }) } } impl Encoder for LinesCodec where T: AsRef, { type Error = LinesCodecError; fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { let line = line.as_ref(); buf.reserve(line.len() + 1); buf.put(line.as_bytes()); buf.put_u8(b'\n'); Ok(()) } } impl Default for LinesCodec { fn default() -> Self { Self::new() } } /// An error occurred while encoding or decoding a line. #[derive(Debug)] pub enum LinesCodecError { /// The maximum line length was exceeded. MaxLineLengthExceeded, /// An IO error occurred. Io(io::Error), } impl fmt::Display for LinesCodecError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), LinesCodecError::Io(e) => write!(f, "{}", e), } } } impl From for LinesCodecError { fn from(e: io::Error) -> LinesCodecError { LinesCodecError::Io(e) } } impl std::error::Error for LinesCodecError {} tokio-util-0.7.10/src/codec/mod.rs000064400000000000000000000253271046102023000150130ustar 00000000000000//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink //! //! Raw I/O objects work with byte sequences, but higher-level code usually //! wants to batch these into meaningful chunks, called "frames". //! //! This module contains adapters to go from streams of bytes, [`AsyncRead`] and //! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. //! Framed streams are also known as transports. //! //! # The Decoder trait //! //! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an //! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify //! how sequences of bytes are turned into a sequence of frames, and to //! determine where the boundaries between frames are. The job of the //! `FramedRead` is to repeatedly switch between reading more data from the IO //! resource, and asking the decoder whether we have received enough data to //! decode another frame of data. //! //! The main method on the `Decoder` trait is the [`decode`] method. This method //! takes as argument the data that has been read so far, and when it is called, //! it will be in one of the following situations: //! //! 1. The buffer contains less than a full frame. //! 2. The buffer contains exactly a full frame. //! 3. The buffer contains more than a full frame. //! //! In the first situation, the decoder should return `Ok(None)`. //! //! In the second situation, the decoder should clear the provided buffer and //! return `Ok(Some(the_decoded_frame))`. //! //! In the third situation, the decoder should use a method such as [`split_to`] //! or [`advance`] to modify the buffer such that the frame is removed from the //! buffer, but any data in the buffer after that frame should still remain in //! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in //! this case. //! //! Finally the decoder may return an error if the data is invalid in some way. //! The decoder should _not_ return an error just because it has yet to receive //! a full frame. //! //! It is guaranteed that, from one call to `decode` to another, the provided //! buffer will contain the exact same data as before, except that if more data //! has arrived through the IO resource, that data will have been appended to //! the buffer. This means that reading frames from a `FramedRead` is //! essentially equivalent to the following loop: //! //! ```no_run //! use tokio::io::AsyncReadExt; //! # // This uses async_stream to create an example that compiles. //! # fn foo() -> impl futures_core::Stream> { async_stream::try_stream! { //! # use tokio_util::codec::Decoder; //! # let mut decoder = tokio_util::codec::BytesCodec::new(); //! # let io_resource = &mut &[0u8, 1, 2, 3][..]; //! //! let mut buf = bytes::BytesMut::new(); //! loop { //! // The read_buf call will append to buf rather than overwrite existing data. //! let len = io_resource.read_buf(&mut buf).await?; //! //! if len == 0 { //! while let Some(frame) = decoder.decode_eof(&mut buf)? { //! yield frame; //! } //! break; //! } //! //! while let Some(frame) = decoder.decode(&mut buf)? { //! yield frame; //! } //! } //! # }} //! ``` //! The example above uses `yield` whenever the `Stream` produces an item. //! //! ## Example decoder //! //! As an example, consider a protocol that can be used to send strings where //! each frame is a four byte integer that contains the length of the frame, //! followed by that many bytes of string data. The decoder fails with an error //! if the string data is not valid utf-8 or too long. //! //! Such a decoder can be written like this: //! ``` //! use tokio_util::codec::Decoder; //! use bytes::{BytesMut, Buf}; //! //! struct MyStringDecoder {} //! //! const MAX: usize = 8 * 1024 * 1024; //! //! impl Decoder for MyStringDecoder { //! type Item = String; //! type Error = std::io::Error; //! //! fn decode( //! &mut self, //! src: &mut BytesMut //! ) -> Result, Self::Error> { //! if src.len() < 4 { //! // Not enough data to read length marker. //! return Ok(None); //! } //! //! // Read length marker. //! let mut length_bytes = [0u8; 4]; //! length_bytes.copy_from_slice(&src[..4]); //! let length = u32::from_le_bytes(length_bytes) as usize; //! //! // Check that the length is not too large to avoid a denial of //! // service attack where the server runs out of memory. //! if length > MAX { //! return Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! format!("Frame of length {} is too large.", length) //! )); //! } //! //! if src.len() < 4 + length { //! // The full string has not yet arrived. //! // //! // We reserve more space in the buffer. This is not strictly //! // necessary, but is a good idea performance-wise. //! src.reserve(4 + length - src.len()); //! //! // We inform the Framed that we need more bytes to form the next //! // frame. //! return Ok(None); //! } //! //! // Use advance to modify src such that it no longer contains //! // this frame. //! let data = src[4..4 + length].to_vec(); //! src.advance(4 + length); //! //! // Convert the data to a string, or fail if it is not valid utf-8. //! match String::from_utf8(data) { //! Ok(string) => Ok(Some(string)), //! Err(utf8_error) => { //! Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! utf8_error.utf8_error(), //! )) //! }, //! } //! } //! } //! ``` //! //! # The Encoder trait //! //! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn //! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to //! specify how frames are turned into a sequences of bytes. The job of the //! `FramedWrite` is to take the resulting sequence of bytes and write it to the //! IO resource. //! //! The main method on the `Encoder` trait is the [`encode`] method. This method //! takes an item that is being written, and a buffer to write the item to. The //! buffer may already contain data, and in this case, the encoder should append //! the new frame the to buffer rather than overwrite the existing data. //! //! It is guaranteed that, from one call to `encode` to another, the provided //! buffer will contain the exact same data as before, except that some of the //! data may have been removed from the front of the buffer. Writing to a //! `FramedWrite` is essentially equivalent to the following loop: //! //! ```no_run //! use tokio::io::AsyncWriteExt; //! use bytes::Buf; // for advance //! # use tokio_util::codec::Encoder; //! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() } //! # async fn no_more_frames() { } //! # #[tokio::main] async fn main() -> std::io::Result<()> { //! # let mut io_resource = tokio::io::sink(); //! # let mut encoder = tokio_util::codec::BytesCodec::new(); //! //! const MAX: usize = 8192; //! //! let mut buf = bytes::BytesMut::new(); //! loop { //! tokio::select! { //! num_written = io_resource.write(&buf), if !buf.is_empty() => { //! buf.advance(num_written?); //! }, //! frame = next_frame(), if buf.len() < MAX => { //! encoder.encode(frame, &mut buf)?; //! }, //! _ = no_more_frames() => { //! io_resource.write_all(&buf).await?; //! io_resource.shutdown().await?; //! return Ok(()); //! }, //! } //! } //! # } //! ``` //! Here the `next_frame` method corresponds to any frames you write to the //! `FramedWrite`. The `no_more_frames` method corresponds to closing the //! `FramedWrite` with [`SinkExt::close`]. //! //! ## Example encoder //! //! As an example, consider a protocol that can be used to send strings where //! each frame is a four byte integer that contains the length of the frame, //! followed by that many bytes of string data. The encoder will fail if the //! string is too long. //! //! Such an encoder can be written like this: //! ``` //! use tokio_util::codec::Encoder; //! use bytes::BytesMut; //! //! struct MyStringEncoder {} //! //! const MAX: usize = 8 * 1024 * 1024; //! //! impl Encoder for MyStringEncoder { //! type Error = std::io::Error; //! //! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> { //! // Don't send a string if it is longer than the other end will //! // accept. //! if item.len() > MAX { //! return Err(std::io::Error::new( //! std::io::ErrorKind::InvalidData, //! format!("Frame of length {} is too large.", item.len()) //! )); //! } //! //! // Convert the length into a byte array. //! // The cast to u32 cannot overflow due to the length check above. //! let len_slice = u32::to_le_bytes(item.len() as u32); //! //! // Reserve space in the buffer. //! dst.reserve(4 + item.len()); //! //! // Write the length and string to the buffer. //! dst.extend_from_slice(&len_slice); //! dst.extend_from_slice(item.as_bytes()); //! Ok(()) //! } //! } //! ``` //! //! [`AsyncRead`]: tokio::io::AsyncRead //! [`AsyncWrite`]: tokio::io::AsyncWrite //! [`Stream`]: futures_core::Stream //! [`Sink`]: futures_sink::Sink //! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close //! [`FramedRead`]: struct@crate::codec::FramedRead //! [`FramedWrite`]: struct@crate::codec::FramedWrite //! [`Framed`]: struct@crate::codec::Framed //! [`Decoder`]: trait@crate::codec::Decoder //! [`decode`]: fn@crate::codec::Decoder::decode //! [`encode`]: fn@crate::codec::Encoder::encode //! [`split_to`]: fn@bytes::BytesMut::split_to //! [`advance`]: fn@bytes::Buf::advance mod bytes_codec; pub use self::bytes_codec::BytesCodec; mod decoder; pub use self::decoder::Decoder; mod encoder; pub use self::encoder::Encoder; mod framed_impl; #[allow(unused_imports)] pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; mod framed; pub use self::framed::{Framed, FramedParts}; mod framed_read; pub use self::framed_read::FramedRead; mod framed_write; pub use self::framed_write::FramedWrite; pub mod length_delimited; pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; mod lines_codec; pub use self::lines_codec::{LinesCodec, LinesCodecError}; mod any_delimiter_codec; pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError}; tokio-util-0.7.10/src/compat.rs000064400000000000000000000177701046102023000144450ustar 00000000000000//! Compatibility between the `tokio::io` and `futures-io` versions of the //! `AsyncRead` and `AsyncWrite` traits. use futures_core::ready; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A compatibility layer that allows conversion between the /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. #[derive(Copy, Clone, Debug)] pub struct Compat { #[pin] inner: T, seek_pos: Option, } } /// Extension trait that allows converting a type implementing /// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { /// Wraps `self` with a compatibility layer that implements /// `tokio_io::AsyncRead`. fn compat(self) -> Compat where Self: Sized, { Compat::new(self) } } impl FuturesAsyncReadCompatExt for T {} /// Extension trait that allows converting a type implementing /// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { /// Wraps `self` with a compatibility layer that implements /// `tokio::io::AsyncWrite`. fn compat_write(self) -> Compat where Self: Sized, { Compat::new(self) } } impl FuturesAsyncWriteCompatExt for T {} /// Extension trait that allows converting a type implementing /// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { /// Wraps `self` with a compatibility layer that implements /// `futures_io::AsyncRead`. fn compat(self) -> Compat where Self: Sized, { Compat::new(self) } } impl TokioAsyncReadCompatExt for T {} /// Extension trait that allows converting a type implementing /// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { /// Wraps `self` with a compatibility layer that implements /// `futures_io::AsyncWrite`. fn compat_write(self) -> Compat where Self: Sized, { Compat::new(self) } } impl TokioAsyncWriteCompatExt for T {} // === impl Compat === impl Compat { fn new(inner: T) -> Self { Self { inner, seek_pos: None, } } /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object /// contained within. pub fn get_ref(&self) -> &T { &self.inner } /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object /// contained within. pub fn get_mut(&mut self) -> &mut T { &mut self.inner } /// Returns the wrapped item. pub fn into_inner(self) -> T { self.inner } } impl tokio::io::AsyncRead for Compat where T: futures_io::AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { // We can't trust the inner type to not peak at the bytes, // so we must defensively initialize the buffer. let slice = buf.initialize_unfilled(); let n = ready!(futures_io::AsyncRead::poll_read( self.project().inner, cx, slice ))?; buf.advance(n); Poll::Ready(Ok(())) } } impl futures_io::AsyncRead for Compat where T: tokio::io::AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, slice: &mut [u8], ) -> Poll> { let mut buf = tokio::io::ReadBuf::new(slice); ready!(tokio::io::AsyncRead::poll_read( self.project().inner, cx, &mut buf ))?; Poll::Ready(Ok(buf.filled().len())) } } impl tokio::io::AsyncBufRead for Compat where T: futures_io::AsyncBufRead, { fn poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll> { futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) } fn consume(self: Pin<&mut Self>, amt: usize) { futures_io::AsyncBufRead::consume(self.project().inner, amt) } } impl futures_io::AsyncBufRead for Compat where T: tokio::io::AsyncBufRead, { fn poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll> { tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) } fn consume(self: Pin<&mut Self>, amt: usize) { tokio::io::AsyncBufRead::consume(self.project().inner, amt) } } impl tokio::io::AsyncWrite for Compat where T: futures_io::AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { futures_io::AsyncWrite::poll_flush(self.project().inner, cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { futures_io::AsyncWrite::poll_close(self.project().inner, cx) } } impl futures_io::AsyncWrite for Compat where T: tokio::io::AsyncWrite, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) } } impl futures_io::AsyncSeek for Compat { fn poll_seek( mut self: Pin<&mut Self>, cx: &mut Context<'_>, pos: io::SeekFrom, ) -> Poll> { if self.seek_pos != Some(pos) { // Ensure previous seeks have finished before starting a new one ready!(self.as_mut().project().inner.poll_complete(cx))?; self.as_mut().project().inner.start_seek(pos)?; *self.as_mut().project().seek_pos = Some(pos); } let res = ready!(self.as_mut().project().inner.poll_complete(cx)); *self.as_mut().project().seek_pos = None; Poll::Ready(res) } } impl tokio::io::AsyncSeek for Compat { fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { *self.as_mut().project().seek_pos = Some(pos); Ok(()) } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pos = match self.seek_pos { None => { // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. // We don't have to guarantee that the value returned by // poll_complete called without start_seek is correct, // so we'll return 0. return Poll::Ready(Ok(0)); } Some(pos) => pos, }; let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); *self.as_mut().project().seek_pos = None; Poll::Ready(res) } } #[cfg(unix)] impl std::os::unix::io::AsRawFd for Compat { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { self.inner.as_raw_fd() } } #[cfg(windows)] impl std::os::windows::io::AsRawHandle for Compat { fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { self.inner.as_raw_handle() } } tokio-util-0.7.10/src/context.rs000064400000000000000000000143411046102023000146350ustar 00000000000000//! Tokio context aware futures utilities. //! //! This module includes utilities around integrating tokio with other runtimes //! by allowing the context to be attached to futures. This allows spawning //! futures on other executors while still using tokio to drive them. This //! can be useful if you need to use a tokio based library in an executor/runtime //! that does not provide a tokio context. use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::runtime::{Handle, Runtime}; pin_project! { /// `TokioContext` allows running futures that must be inside Tokio's /// context on a non-Tokio runtime. /// /// It contains a [`Handle`] to the runtime. A handle to the runtime can be /// obtain by calling the [`Runtime::handle()`] method. /// /// Note that the `TokioContext` wrapper only works if the `Runtime` it is /// connected to has not yet been destroyed. You must keep the `Runtime` /// alive until the future has finished executing. /// /// **Warning:** If `TokioContext` is used together with a [current thread] /// runtime, that runtime must be inside a call to `block_on` for the /// wrapped future to work. For this reason, it is recommended to use a /// [multi thread] runtime, even if you configure it to only spawn one /// worker thread. /// /// # Examples /// /// This example creates two runtimes, but only [enables time] on one of /// them. It then uses the context of the runtime with the timer enabled to /// execute a [`sleep`] future on the runtime with timing disabled. /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::RuntimeExt; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// // Wrap the sleep future in the context of rt. /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` /// /// [`Handle`]: struct@tokio::runtime::Handle /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle /// [`RuntimeExt`]: trait@crate::context::RuntimeExt /// [`new_static`]: fn@Self::new_static /// [`sleep`]: fn@tokio::time::sleep /// [current thread]: fn@tokio::runtime::Builder::new_current_thread /// [enables time]: fn@tokio::runtime::Builder::enable_time /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread pub struct TokioContext { #[pin] inner: F, handle: Handle, } } impl TokioContext { /// Associate the provided future with the context of the runtime behind /// the provided `Handle`. /// /// This constructor uses a `'static` lifetime to opt-out of checking that /// the runtime still exists. /// /// # Examples /// /// This is the same as the example above, but uses the `new` constructor /// rather than [`RuntimeExt::wrap`]. /// /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap /// /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::TokioContext; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// let fut = TokioContext::new( /// async { sleep(Duration::from_millis(2)).await }, /// rt.handle().clone(), /// ); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` pub fn new(future: F, handle: Handle) -> TokioContext { TokioContext { inner: future, handle, } } /// Obtain a reference to the handle inside this `TokioContext`. pub fn handle(&self) -> &Handle { &self.handle } /// Remove the association between the Tokio runtime and the wrapped future. pub fn into_inner(self) -> F { self.inner } } impl Future for TokioContext { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let me = self.project(); let handle = me.handle; let fut = me.inner; let _enter = handle.enter(); fut.poll(cx) } } /// Extension trait that simplifies bundling a `Handle` with a `Future`. pub trait RuntimeExt { /// Create a [`TokioContext`] that wraps the provided future and runs it in /// this runtime's context. /// /// # Examples /// /// This example creates two runtimes, but only [enables time] on one of /// them. It then uses the context of the runtime with the timer enabled to /// execute a [`sleep`] future on the runtime with timing disabled. /// /// ``` /// use tokio::time::{sleep, Duration}; /// use tokio_util::context::RuntimeExt; /// /// // This runtime has timers enabled. /// let rt = tokio::runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// /// // This runtime has timers disabled. /// let rt2 = tokio::runtime::Builder::new_multi_thread() /// .build() /// .unwrap(); /// /// // Wrap the sleep future in the context of rt. /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); /// /// // Execute the future on rt2. /// rt2.block_on(fut); /// ``` /// /// [`TokioContext`]: struct@crate::context::TokioContext /// [`sleep`]: fn@tokio::time::sleep /// [enables time]: fn@tokio::runtime::Builder::enable_time fn wrap(&self, fut: F) -> TokioContext; } impl RuntimeExt for Runtime { fn wrap(&self, fut: F) -> TokioContext { TokioContext { inner: fut, handle: self.handle().clone(), } } } tokio-util-0.7.10/src/either.rs000064400000000000000000000122451046102023000144320ustar 00000000000000//! Module defining an Either type. use std::{ future::Future, io::SeekFrom, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; /// Combines two different futures, streams, or sinks having the same associated types into a single type. /// /// This type implements common asynchronous traits such as [`Future`] and those in Tokio. /// /// [`Future`]: std::future::Future /// /// # Example /// /// The following code will not work: /// /// ```compile_fail /// # fn some_condition() -> bool { true } /// # async fn some_async_function() -> u32 { 10 } /// # async fn other_async_function() -> u32 { 20 } /// #[tokio::main] /// async fn main() { /// let result = if some_condition() { /// some_async_function() /// } else { /// other_async_function() // <- Will print: "`if` and `else` have incompatible types" /// }; /// /// println!("Result is {}", result.await); /// } /// ``` /// // This is because although the output types for both futures is the same, the exact future // types are different, but the compiler must be able to choose a single type for the // `result` variable. /// /// When the output type is the same, we can wrap each future in `Either` to avoid the /// issue: /// /// ``` /// use tokio_util::either::Either; /// # fn some_condition() -> bool { true } /// # async fn some_async_function() -> u32 { 10 } /// # async fn other_async_function() -> u32 { 20 } /// /// #[tokio::main] /// async fn main() { /// let result = if some_condition() { /// Either::Left(some_async_function()) /// } else { /// Either::Right(other_async_function()) /// }; /// /// let value = result.await; /// println!("Result is {}", value); /// # assert_eq!(value, 10); /// } /// ``` #[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. #[derive(Debug, Clone)] pub enum Either { Left(L), Right(R), } /// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. /// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either /// enum variant held in `self`. macro_rules! delegate_call { ($self:ident.$method:ident($($args:ident),+)) => { unsafe { match $self.get_unchecked_mut() { Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), } } } } impl Future for Either where L: Future, R: Future, { type Output = O; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { delegate_call!(self.poll(cx)) } } impl AsyncRead for Either where L: AsyncRead, R: AsyncRead, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { delegate_call!(self.poll_read(cx, buf)) } } impl AsyncBufRead for Either where L: AsyncBufRead, R: AsyncBufRead, { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_fill_buf(cx)) } fn consume(self: Pin<&mut Self>, amt: usize) { delegate_call!(self.consume(amt)); } } impl AsyncSeek for Either where L: AsyncSeek, R: AsyncSeek, { fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { delegate_call!(self.start_seek(position)) } fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_complete(cx)) } } impl AsyncWrite for Either where L: AsyncWrite, R: AsyncWrite, { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { delegate_call!(self.poll_write(cx, buf)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_flush(cx)) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_shutdown(cx)) } } impl futures_core::stream::Stream for Either where L: futures_core::stream::Stream, R: futures_core::stream::Stream, { type Item = L::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { delegate_call!(self.poll_next(cx)) } } #[cfg(test)] mod tests { use super::*; use tokio::io::{repeat, AsyncReadExt, Repeat}; use tokio_stream::{once, Once, StreamExt}; #[tokio::test] async fn either_is_stream() { let mut either: Either, Once> = Either::Left(once(1)); assert_eq!(Some(1u32), either.next().await); } #[tokio::test] async fn either_is_async_read() { let mut buffer = [0; 3]; let mut either: Either = Either::Right(repeat(0b101)); either.read_exact(&mut buffer).await.unwrap(); assert_eq!(buffer, [0b101, 0b101, 0b101]); } } tokio-util-0.7.10/src/io/copy_to_bytes.rs000064400000000000000000000040341046102023000164400ustar 00000000000000use bytes::Bytes; use futures_core::stream::Stream; use futures_sink::Sink; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// A helper that wraps a [`Sink`]`<`[`Bytes`]`>` and converts it into a /// [`Sink`]`<&'a [u8]>` by copying each byte slice into an owned [`Bytes`]. /// /// See the documentation for [`SinkWriter`] for an example. /// /// [`Bytes`]: bytes::Bytes /// [`SinkWriter`]: crate::io::SinkWriter /// [`Sink`]: futures_sink::Sink #[derive(Debug)] pub struct CopyToBytes { #[pin] inner: S, } } impl CopyToBytes { /// Creates a new [`CopyToBytes`]. pub fn new(inner: S) -> Self { Self { inner } } /// Gets a reference to the underlying sink. pub fn get_ref(&self) -> &S { &self.inner } /// Gets a mutable reference to the underlying sink. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Consumes this [`CopyToBytes`], returning the underlying sink. pub fn into_inner(self) -> S { self.inner } } impl<'a, S> Sink<&'a [u8]> for CopyToBytes where S: Sink, { type Error = S::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: &'a [u8]) -> Result<(), Self::Error> { self.project() .inner .start_send(Bytes::copy_from_slice(item)) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } impl Stream for CopyToBytes { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } tokio-util-0.7.10/src/io/inspect.rs000064400000000000000000000120251046102023000152220ustar 00000000000000use futures_core::ready; use pin_project_lite::pin_project; use std::io::{IoSlice, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// An adapter that lets you inspect the data that's being read. /// /// This is useful for things like hashing data as it's read in. pub struct InspectReader { #[pin] reader: R, f: F, } } impl InspectReader { /// Create a new InspectReader, wrapping `reader` and calling `f` for the /// new data supplied by each read call. /// /// The closure will only be called with an empty slice if the inner reader /// returns without reading data into the buffer. This happens at EOF, or if /// `poll_read` is called with a zero-size buffer. pub fn new(reader: R, f: F) -> InspectReader where R: AsyncRead, F: FnMut(&[u8]), { InspectReader { reader, f } } /// Consumes the `InspectReader`, returning the wrapped reader pub fn into_inner(self) -> R { self.reader } } impl AsyncRead for InspectReader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let me = self.project(); let filled_length = buf.filled().len(); ready!(me.reader.poll_read(cx, buf))?; (me.f)(&buf.filled()[filled_length..]); Poll::Ready(Ok(())) } } impl AsyncWrite for InspectReader { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.project().reader.poll_write(cx, buf) } fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { self.project().reader.poll_flush(cx) } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { self.project().reader.poll_shutdown(cx) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { self.project().reader.poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { self.reader.is_write_vectored() } } pin_project! { /// An adapter that lets you inspect the data that's being written. /// /// This is useful for things like hashing data as it's written out. pub struct InspectWriter { #[pin] writer: W, f: F, } } impl InspectWriter { /// Create a new InspectWriter, wrapping `write` and calling `f` for the /// data successfully written by each write call. /// /// The closure `f` will never be called with an empty slice. A vectored /// write can result in multiple calls to `f` - at most one call to `f` per /// buffer supplied to `poll_write_vectored`. pub fn new(writer: W, f: F) -> InspectWriter where W: AsyncWrite, F: FnMut(&[u8]), { InspectWriter { writer, f } } /// Consumes the `InspectWriter`, returning the wrapped writer pub fn into_inner(self) -> W { self.writer } } impl AsyncWrite for InspectWriter { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let me = self.project(); let res = me.writer.poll_write(cx, buf); if let Poll::Ready(Ok(count)) = res { if count != 0 { (me.f)(&buf[..count]); } } res } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let me = self.project(); me.writer.poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let me = self.project(); me.writer.poll_shutdown(cx) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { let me = self.project(); let res = me.writer.poll_write_vectored(cx, bufs); if let Poll::Ready(Ok(mut count)) = res { for buf in bufs { if count == 0 { break; } let size = count.min(buf.len()); if size != 0 { (me.f)(&buf[..size]); count -= size; } } } res } fn is_write_vectored(&self) -> bool { self.writer.is_write_vectored() } } impl AsyncRead for InspectWriter { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.project().writer.poll_read(cx, buf) } } tokio-util-0.7.10/src/io/mod.rs000064400000000000000000000017241046102023000143400ustar 00000000000000//! Helpers for IO related tasks. //! //! The stream types are often used in combination with hyper or reqwest, as they //! allow converting between a hyper [`Body`] and [`AsyncRead`]. //! //! The [`SyncIoBridge`] type converts from the world of async I/O //! to synchronous I/O; this may often come up when using synchronous APIs //! inside [`tokio::task::spawn_blocking`]. //! //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead mod copy_to_bytes; mod inspect; mod read_buf; mod reader_stream; mod sink_writer; mod stream_reader; cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } pub use self::copy_to_bytes::CopyToBytes; pub use self::inspect::{InspectReader, InspectWriter}; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::sink_writer::SinkWriter; pub use self::stream_reader::StreamReader; pub use crate::util::{poll_read_buf, poll_write_buf}; tokio-util-0.7.10/src/io/read_buf.rs000064400000000000000000000031401046102023000153220ustar 00000000000000use bytes::BufMut; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; /// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// /// [`BufMut`]: bytes::BufMut /// /// # Example /// /// ``` /// use bytes::{Bytes, BytesMut}; /// use tokio_stream as stream; /// use tokio::io::Result; /// use tokio_util::io::{StreamReader, read_buf}; /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// /// // Create a reader from an iterator. This particular reader will always be /// // ready. /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); /// /// let mut buf = BytesMut::new(); /// let mut reads = 0; /// /// loop { /// reads += 1; /// let n = read_buf(&mut read, &mut buf).await?; /// /// if n == 0 { /// break; /// } /// } /// /// // one or more reads might be necessary. /// assert!(reads >= 1); /// assert_eq!(&buf[..], &[0, 1, 2, 3]); /// # Ok(()) /// # } /// ``` pub async fn read_buf(read: &mut R, buf: &mut B) -> io::Result where R: AsyncRead + Unpin, B: BufMut, { return ReadBufFn(read, buf).await; struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); impl<'a, R, B> Future for ReadBufFn<'a, R, B> where R: AsyncRead + Unpin, B: BufMut, { type Output = io::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = &mut *self; crate::util::poll_read_buf(Pin::new(this.0), cx, this.1) } } } tokio-util-0.7.10/src/io/reader_stream.rs000064400000000000000000000067161046102023000164040ustar 00000000000000use bytes::{Bytes, BytesMut}; use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::AsyncRead; const DEFAULT_CAPACITY: usize = 4096; pin_project! { /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. /// /// This stream is fused. It performs the inverse operation of /// [`StreamReader`]. /// /// # Example /// /// ``` /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// use tokio_stream::StreamExt; /// use tokio_util::io::ReaderStream; /// /// // Create a stream of data. /// let data = b"hello, world!"; /// let mut stream = ReaderStream::new(&data[..]); /// /// // Read all of the chunks into a vector. /// let mut stream_contents = Vec::new(); /// while let Some(chunk) = stream.next().await { /// stream_contents.extend_from_slice(&chunk?); /// } /// /// // Once the chunks are concatenated, we should have the /// // original data. /// assert_eq!(stream_contents, data); /// # Ok(()) /// # } /// ``` /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`StreamReader`]: crate::io::StreamReader /// [`Stream`]: futures_core::Stream #[derive(Debug)] pub struct ReaderStream { // Reader itself. // // This value is `None` if the stream has terminated. #[pin] reader: Option, // Working buffer, used to optimize allocations. buf: BytesMut, capacity: usize, } } impl ReaderStream { /// Convert an [`AsyncRead`] into a [`Stream`] with item type /// `Result`. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Stream`]: futures_core::Stream pub fn new(reader: R) -> Self { ReaderStream { reader: Some(reader), buf: BytesMut::new(), capacity: DEFAULT_CAPACITY, } } /// Convert an [`AsyncRead`] into a [`Stream`] with item type /// `Result`, /// with a specific read buffer initial capacity. /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`Stream`]: futures_core::Stream pub fn with_capacity(reader: R, capacity: usize) -> Self { ReaderStream { reader: Some(reader), buf: BytesMut::with_capacity(capacity), capacity, } } } impl Stream for ReaderStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use crate::util::poll_read_buf; let mut this = self.as_mut().project(); let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; if this.buf.capacity() == 0 { this.buf.reserve(*this.capacity); } match poll_read_buf(reader, cx, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { self.project().reader.set(None); Poll::Ready(Some(Err(err))) } Poll::Ready(Ok(0)) => { self.project().reader.set(None); Poll::Ready(None) } Poll::Ready(Ok(_)) => { let chunk = this.buf.split(); Poll::Ready(Some(Ok(chunk.freeze()))) } } } } tokio-util-0.7.10/src/io/sink_writer.rs000064400000000000000000000103161046102023000161160ustar 00000000000000use futures_core::ready; use futures_sink::Sink; use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; pin_project! { /// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`]. /// /// Whenever you write to this [`SinkWriter`], the supplied bytes are /// forwarded to the inner [`Sink`]. When `shutdown` is called on this /// [`SinkWriter`], the inner sink is closed. /// /// This adapter takes a `Sink<&[u8]>` and provides an [`AsyncWrite`] impl /// for it. Because of the lifetime, this trait is relatively rarely /// implemented. The main ways to get a `Sink<&[u8]>` that you can use with /// this type are: /// /// * With the codec module by implementing the [`Encoder`]`<&[u8]>` trait. /// * By wrapping a `Sink` in a [`CopyToBytes`]. /// * Manually implementing `Sink<&[u8]>` directly. /// /// The opposite conversion of implementing `Sink<_>` for an [`AsyncWrite`] /// is done using the [`codec`] module. /// /// # Example /// /// ``` /// use bytes::Bytes; /// use futures_util::SinkExt; /// use std::io::{Error, ErrorKind}; /// use tokio::io::AsyncWriteExt; /// use tokio_util::io::{SinkWriter, CopyToBytes}; /// use tokio_util::sync::PollSender; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> Result<(), Error> { /// // We use an mpsc channel as an example of a `Sink`. /// let (tx, mut rx) = tokio::sync::mpsc::channel::(1); /// let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe)); /// /// // Wrap it in `CopyToBytes` to get a `Sink<&[u8]>`. /// let mut writer = SinkWriter::new(CopyToBytes::new(sink)); /// /// // Write data to our interface... /// let data: [u8; 4] = [1, 2, 3, 4]; /// let _ = writer.write(&data).await?; /// /// // ... and receive it. /// assert_eq!(data.as_slice(), &*rx.recv().await.unwrap()); /// # Ok(()) /// # } /// ``` /// /// [`AsyncWrite`]: tokio::io::AsyncWrite /// [`CopyToBytes`]: crate::io::CopyToBytes /// [`Encoder`]: crate::codec::Encoder /// [`Sink`]: futures_sink::Sink /// [`codec`]: crate::codec #[derive(Debug)] pub struct SinkWriter { #[pin] inner: S, } } impl SinkWriter { /// Creates a new [`SinkWriter`]. pub fn new(sink: S) -> Self { Self { inner: sink } } /// Gets a reference to the underlying sink. pub fn get_ref(&self) -> &S { &self.inner } /// Gets a mutable reference to the underlying sink. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Consumes this [`SinkWriter`], returning the underlying sink. pub fn into_inner(self) -> S { self.inner } } impl AsyncWrite for SinkWriter where for<'a> S: Sink<&'a [u8], Error = E>, E: Into, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let mut this = self.project(); ready!(this.inner.as_mut().poll_ready(cx).map_err(Into::into))?; match this.inner.as_mut().start_send(buf) { Ok(()) => Poll::Ready(Ok(buf.len())), Err(e) => Poll::Ready(Err(e.into())), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx).map_err(Into::into) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx).map_err(Into::into) } } impl Stream for SinkWriter { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_next(cx) } } impl AsyncRead for SinkWriter { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { self.project().inner.poll_read(cx, buf) } } tokio-util-0.7.10/src/io/stream_reader.rs000064400000000000000000000252541046102023000164020ustar 00000000000000use bytes::Buf; use futures_core::stream::Stream; use futures_sink::Sink; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. /// /// This type performs the inverse operation of [`ReaderStream`]. /// /// This type also implements the [`AsyncBufRead`] trait, so you can use it /// to read a `Stream` of byte chunks line-by-line. See the examples below. /// /// # Example /// /// ``` /// use bytes::Bytes; /// use tokio::io::{AsyncReadExt, Result}; /// use tokio_util::io::StreamReader; /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> std::io::Result<()> { /// /// // Create a stream from an iterator. /// let stream = tokio_stream::iter(vec![ /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), /// ]); /// /// // Convert it to an AsyncRead. /// let mut read = StreamReader::new(stream); /// /// // Read five bytes from the stream. /// let mut buf = [0; 5]; /// read.read_exact(&mut buf).await?; /// assert_eq!(buf, [0, 1, 2, 3, 4]); /// /// // Read the rest of the current chunk. /// assert_eq!(read.read(&mut buf).await?, 3); /// assert_eq!(&buf[..3], [5, 6, 7]); /// /// // Read the next chunk. /// assert_eq!(read.read(&mut buf).await?, 4); /// assert_eq!(&buf[..4], [8, 9, 10, 11]); /// /// // We have now reached the end. /// assert_eq!(read.read(&mut buf).await?, 0); /// /// # Ok(()) /// # } /// ``` /// /// If the stream produces errors which are not [`std::io::Error`], /// the errors can be converted using [`StreamExt`] to map each /// element. /// /// ``` /// use bytes::Bytes; /// use tokio::io::AsyncReadExt; /// use tokio_util::io::StreamReader; /// use tokio_stream::StreamExt; /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> std::io::Result<()> { /// /// // Create a stream from an iterator, including an error. /// let stream = tokio_stream::iter(vec![ /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), /// Result::Err("Something bad happened!") /// ]); /// /// // Use StreamExt to map the stream and error to a std::io::Error /// let stream = stream.map(|result| result.map_err(|err| { /// std::io::Error::new(std::io::ErrorKind::Other, err) /// })); /// /// // Convert it to an AsyncRead. /// let mut read = StreamReader::new(stream); /// /// // Read five bytes from the stream. /// let mut buf = [0; 5]; /// read.read_exact(&mut buf).await?; /// assert_eq!(buf, [0, 1, 2, 3, 4]); /// /// // Read the rest of the current chunk. /// assert_eq!(read.read(&mut buf).await?, 3); /// assert_eq!(&buf[..3], [5, 6, 7]); /// /// // Reading the next chunk will produce an error /// let error = read.read(&mut buf).await.unwrap_err(); /// assert_eq!(error.kind(), std::io::ErrorKind::Other); /// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!"); /// /// // We have now reached the end. /// assert_eq!(read.read(&mut buf).await?, 0); /// /// # Ok(()) /// # } /// ``` /// /// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks /// line-by-line. Note that you will usually also need to convert the error /// type when doing this. See the second example for an explanation of how /// to do this. /// /// ``` /// use tokio::io::{Result, AsyncBufReadExt}; /// use tokio_util::io::StreamReader; /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() -> std::io::Result<()> { /// /// // Create a stream of byte chunks. /// let stream = tokio_stream::iter(vec![ /// Result::Ok(b"The first line.\n".as_slice()), /// Result::Ok(b"The second line.".as_slice()), /// Result::Ok(b"\nThe third".as_slice()), /// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()), /// ]); /// /// // Convert it to an AsyncRead. /// let mut read = StreamReader::new(stream); /// /// // Loop through the lines from the `StreamReader`. /// let mut line = String::new(); /// let mut lines = Vec::new(); /// loop { /// line.clear(); /// let len = read.read_line(&mut line).await?; /// if len == 0 { break; } /// lines.push(line.clone()); /// } /// /// // Verify that we got the lines we expected. /// assert_eq!( /// lines, /// vec![ /// "The first line.\n", /// "The second line.\n", /// "The third line.\n", /// "The fourth line.\n", /// "The fifth line.\n", /// ] /// ); /// # Ok(()) /// # } /// ``` /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`AsyncBufRead`]: tokio::io::AsyncBufRead /// [`Stream`]: futures_core::Stream /// [`ReaderStream`]: crate::io::ReaderStream /// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html #[derive(Debug)] pub struct StreamReader { // This field is pinned. inner: S, // This field is not pinned. chunk: Option, } impl StreamReader where S: Stream>, B: Buf, E: Into, { /// Convert a stream of byte chunks into an [`AsyncRead`]. /// /// The item should be a [`Result`] with the ok variant being something that /// implements the [`Buf`] trait (e.g. `Vec` or `Bytes`). The error /// should be convertible into an [io error]. /// /// [`Result`]: std::result::Result /// [`Buf`]: bytes::Buf /// [io error]: std::io::Error pub fn new(stream: S) -> Self { Self { inner: stream, chunk: None, } } /// Do we have a chunk and is it non-empty? fn has_chunk(&self) -> bool { if let Some(ref chunk) = self.chunk { chunk.remaining() > 0 } else { false } } /// Consumes this `StreamReader`, returning a Tuple consisting /// of the underlying stream and an Option of the internal buffer, /// which is Some in case the buffer contains elements. pub fn into_inner_with_chunk(self) -> (S, Option) { if self.has_chunk() { (self.inner, self.chunk) } else { (self.inner, None) } } } impl StreamReader { /// Gets a reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_ref(&self) -> &S { &self.inner } /// Gets a mutable reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Gets a pinned mutable reference to the underlying stream. /// /// It is inadvisable to directly read from the underlying stream. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { self.project().inner } /// Consumes this `BufWriter`, returning the underlying stream. /// /// Note that any leftover data in the internal buffer is lost. /// If you additionally want access to the internal buffer use /// [`into_inner_with_chunk`]. /// /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk pub fn into_inner(self) -> S { self.inner } } impl AsyncRead for StreamReader where S: Stream>, B: Buf, E: Into, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if buf.remaining() == 0 { return Poll::Ready(Ok(())); } let inner_buf = match self.as_mut().poll_fill_buf(cx) { Poll::Ready(Ok(buf)) => buf, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, }; let len = std::cmp::min(inner_buf.len(), buf.remaining()); buf.put_slice(&inner_buf[..len]); self.consume(len); Poll::Ready(Ok(())) } } impl AsyncBufRead for StreamReader where S: Stream>, B: Buf, E: Into, { fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if self.as_mut().has_chunk() { // This unwrap is very sad, but it can't be avoided. let buf = self.project().chunk.as_ref().unwrap().chunk(); return Poll::Ready(Ok(buf)); } else { match self.as_mut().project().inner.poll_next(cx) { Poll::Ready(Some(Ok(chunk))) => { // Go around the loop in case the chunk is empty. *self.as_mut().project().chunk = Some(chunk); } Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), Poll::Ready(None) => return Poll::Ready(Ok(&[])), Poll::Pending => return Poll::Pending, } } } } fn consume(self: Pin<&mut Self>, amt: usize) { if amt > 0 { self.project() .chunk .as_mut() .expect("No chunk present") .advance(amt); } } } // The code below is a manual expansion of the code that pin-project-lite would // generate. This is done because pin-project-lite fails by hitting the recusion // limit on this struct. (Every line of documentation is handled recursively by // the macro.) impl Unpin for StreamReader {} struct StreamReaderProject<'a, S, B> { inner: Pin<&'a mut S>, chunk: &'a mut Option, } impl StreamReader { #[inline] fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> { // SAFETY: We define that only `inner` should be pinned when `Self` is // and have an appropriate `impl Unpin` for this. let me = unsafe { Pin::into_inner_unchecked(self) }; StreamReaderProject { inner: unsafe { Pin::new_unchecked(&mut me.inner) }, chunk: &mut me.chunk, } } } impl, E, T> Sink for StreamReader { type Error = E; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_ready(cx) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { self.project().inner.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_close(cx) } } tokio-util-0.7.10/src/io/sync_bridge.rs000064400000000000000000000127671046102023000160620ustar 00000000000000use std::io::{BufRead, Read, Seek, Write}; use tokio::io::{ AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, }; /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. #[derive(Debug)] pub struct SyncIoBridge { src: T, rt: tokio::runtime::Handle, } impl BufRead for SyncIoBridge { fn fill_buf(&mut self) -> std::io::Result<&[u8]> { let src = &mut self.src; self.rt.block_on(AsyncBufReadExt::fill_buf(src)) } fn consume(&mut self, amt: usize) { let src = &mut self.src; AsyncBufReadExt::consume(src, amt) } fn read_until(&mut self, byte: u8, buf: &mut Vec) -> std::io::Result { let src = &mut self.src; self.rt .block_on(AsyncBufReadExt::read_until(src, byte, buf)) } fn read_line(&mut self, buf: &mut String) -> std::io::Result { let src = &mut self.src; self.rt.block_on(AsyncBufReadExt::read_line(src, buf)) } } impl Read for SyncIoBridge { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(AsyncReadExt::read(src, buf)) } fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.read_to_end(buf)) } fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.read_to_string(buf)) } fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { let src = &mut self.src; // The AsyncRead trait returns the count, synchronous doesn't. let _n = self.rt.block_on(src.read_exact(buf))?; Ok(()) } } impl Write for SyncIoBridge { fn write(&mut self, buf: &[u8]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.write(buf)) } fn flush(&mut self) -> std::io::Result<()> { let src = &mut self.src; self.rt.block_on(src.flush()) } fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { let src = &mut self.src; self.rt.block_on(src.write_all(buf)) } fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { let src = &mut self.src; self.rt.block_on(src.write_vectored(bufs)) } } impl Seek for SyncIoBridge { fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { let src = &mut self.src; self.rt.block_on(AsyncSeekExt::seek(src, pos)) } } // Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time // of this writing still unstable, we expose this as part of a standalone method. impl SyncIoBridge { /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes. /// /// See [`tokio::io::AsyncWrite::is_write_vectored`]. pub fn is_write_vectored(&self) -> bool { self.src.is_write_vectored() } } impl SyncIoBridge { /// Shutdown this writer. This method provides a way to call the [`AsyncWriteExt::shutdown`] /// function of the inner [`tokio::io::AsyncWrite`] instance. /// /// # Errors /// /// This method returns the same errors as [`AsyncWriteExt::shutdown`]. /// /// [`AsyncWriteExt::shutdown`]: tokio::io::AsyncWriteExt::shutdown pub fn shutdown(&mut self) -> std::io::Result<()> { let src = &mut self.src; self.rt.block_on(src.shutdown()) } } impl SyncIoBridge { /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. /// /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`]. /// It is hence OK to move this struct into a separate thread outside the runtime, as created /// by e.g. [`tokio::task::spawn_blocking`]. /// /// Stated even more strongly: to make use of this bridge, you *must* move /// it into a separate thread outside the runtime. The synchronous I/O will use the /// underlying handle to block on the backing asynchronous source, via /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that /// function, an attempt to `block_on` from an asynchronous execution context /// will panic. /// /// # Wrapping `!Unpin` types /// /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. /// /// # Panics /// /// This will panic if called outside the context of a Tokio runtime. #[track_caller] pub fn new(src: T) -> Self { Self::new_with_handle(src, tokio::runtime::Handle::current()) } /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. /// /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may /// be initially invoked outside of an asynchronous context. pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { Self { src, rt } } /// Consume this bridge, returning the underlying stream. pub fn into_inner(self) -> T { self.src } } tokio-util-0.7.10/src/lib.rs000064400000000000000000000016721046102023000137220ustar 00000000000000#![allow(clippy::needless_doctest_main)] #![warn( missing_debug_implementations, missing_docs, rust_2018_idioms, unreachable_pub )] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] //! Utilities for working with Tokio. //! //! This crate is not versioned in lockstep with the core //! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's //! semantic versioning policy, especially with regard to breaking changes. //! //! [`tokio`]: https://docs.rs/tokio #[macro_use] mod cfg; mod loom; cfg_codec! { pub mod codec; } cfg_net! { #[cfg(not(target_arch = "wasm32"))] pub mod udp; pub mod net; } cfg_compat! { pub mod compat; } cfg_io! { pub mod io; } cfg_rt! { pub mod context; pub mod task; } cfg_time! { pub mod time; } pub mod sync; pub mod either; pub use bytes; mod util; tokio-util-0.7.10/src/loom.rs000064400000000000000000000000321046102023000141070ustar 00000000000000pub(crate) use std::sync; tokio-util-0.7.10/src/net/mod.rs000064400000000000000000000053551046102023000145230ustar 00000000000000//! TCP/UDP/Unix helpers for tokio. use crate::either::Either; use std::future::Future; use std::io::Result; use std::pin::Pin; use std::task::{Context, Poll}; #[cfg(unix)] pub mod unix; /// A trait for a listener: `TcpListener` and `UnixListener`. pub trait Listener { /// The stream's type of this listener. type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite; /// The socket address type of this listener. type Addr; /// Polls to accept a new incoming connection to this listener. fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll>; /// Accepts a new incoming connection from this listener. fn accept(&mut self) -> ListenerAcceptFut<'_, Self> where Self: Sized, { ListenerAcceptFut { listener: self } } /// Returns the local address that this listener is bound to. fn local_addr(&self) -> Result; } impl Listener for tokio::net::TcpListener { type Io = tokio::net::TcpStream; type Addr = std::net::SocketAddr; fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { Self::poll_accept(self, cx) } fn local_addr(&self) -> Result { self.local_addr().map(Into::into) } } /// Future for accepting a new connection from a listener. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ListenerAcceptFut<'a, L> { listener: &'a mut L, } impl<'a, L> Future for ListenerAcceptFut<'a, L> where L: Listener, { type Output = Result<(L::Io, L::Addr)>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.listener.poll_accept(cx) } } impl Either where L: Listener, R: Listener, { /// Accepts a new incoming connection from this listener. pub async fn accept(&mut self) -> Result> { match self { Either::Left(listener) => { let (stream, addr) = listener.accept().await?; Ok(Either::Left((stream, addr))) } Either::Right(listener) => { let (stream, addr) = listener.accept().await?; Ok(Either::Right((stream, addr))) } } } /// Returns the local address that this listener is bound to. pub fn local_addr(&self) -> Result> { match self { Either::Left(listener) => { let addr = listener.local_addr()?; Ok(Either::Left(addr)) } Either::Right(listener) => { let addr = listener.local_addr()?; Ok(Either::Right(addr)) } } } } tokio-util-0.7.10/src/net/unix/mod.rs000064400000000000000000000007321046102023000155000ustar 00000000000000//! Unix domain socket helpers. use super::Listener; use std::io::Result; use std::task::{Context, Poll}; impl Listener for tokio::net::UnixListener { type Io = tokio::net::UnixStream; type Addr = tokio::net::unix::SocketAddr; fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { Self::poll_accept(self, cx) } fn local_addr(&self) -> Result { self.local_addr().map(Into::into) } } tokio-util-0.7.10/src/sync/cancellation_token/guard.rs000064400000000000000000000014121046102023000210560ustar 00000000000000use crate::sync::CancellationToken; /// A wrapper for cancellation token which automatically cancels /// it on drop. It is created using `drop_guard` method on the `CancellationToken`. #[derive(Debug)] pub struct DropGuard { pub(super) inner: Option, } impl DropGuard { /// Returns stored cancellation token and removes this drop guard instance /// (i.e. it will no longer cancel token). Other guards for this token /// are not affected. pub fn disarm(mut self) -> CancellationToken { self.inner .take() .expect("`inner` can be only None in a destructor") } } impl Drop for DropGuard { fn drop(&mut self) { if let Some(inner) = &self.inner { inner.cancel(); } } } tokio-util-0.7.10/src/sync/cancellation_token/tree_node.rs000064400000000000000000000332771046102023000217360ustar 00000000000000//! This mod provides the logic for the inner tree structure of the CancellationToken. //! //! CancellationTokens are only light handles with references to [`TreeNode`]. //! All the logic is actually implemented in the [`TreeNode`]. //! //! A [`TreeNode`] is part of the cancellation tree and may have one parent and an arbitrary number of //! children. //! //! A [`TreeNode`] can receive the request to perform a cancellation through a CancellationToken. //! This cancellation request will cancel the node and all of its descendants. //! //! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no //! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the //! tree as small as possible. //! //! # Invariants //! //! Those invariants shall be true at any time. //! //! 1. A node that has no parents and no handles can no longer be cancelled. //! This is important during both cancellation and refcounting. //! //! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. //! This is important for deadlock safety, as it is used for lock order. //! Node B can only become the child of node A in two ways: //! - being created with `child_node()`, in which case it is trivially true that //! node A already existed when node B was created //! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` //! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C //! was younger than A, therefore B is also younger than A. //! //! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of //! node A. It is important to always restore that invariant before dropping the lock of a node. //! //! # Deadlock safety //! //! We always lock in the order of creation time. We can prove this through invariant #2. //! Specifically, through invariant #2, we know that we always have to lock a parent //! before its child. //! use crate::loom::sync::{Arc, Mutex, MutexGuard}; /// A node of the cancellation tree structure /// /// The actual data it holds is wrapped inside a mutex for synchronization. pub(crate) struct TreeNode { inner: Mutex, waker: tokio::sync::Notify, } impl TreeNode { pub(crate) fn new() -> Self { Self { inner: Mutex::new(Inner { parent: None, parent_idx: 0, children: vec![], is_cancelled: false, num_handles: 1, }), waker: tokio::sync::Notify::new(), } } pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { self.waker.notified() } } /// The data contained inside a TreeNode. /// /// This struct exists so that the data of the node can be wrapped /// in a Mutex. struct Inner { parent: Option>, parent_idx: usize, children: Vec>, is_cancelled: bool, num_handles: usize, } /// Returns whether or not the node is cancelled pub(crate) fn is_cancelled(node: &Arc) -> bool { node.inner.lock().unwrap().is_cancelled } /// Creates a child node pub(crate) fn child_node(parent: &Arc) -> Arc { let mut locked_parent = parent.inner.lock().unwrap(); // Do not register as child if we are already cancelled. // Cancelled trees can never be uncancelled and therefore // need no connection to parents or children any more. if locked_parent.is_cancelled { return Arc::new(TreeNode { inner: Mutex::new(Inner { parent: None, parent_idx: 0, children: vec![], is_cancelled: true, num_handles: 1, }), waker: tokio::sync::Notify::new(), }); } let child = Arc::new(TreeNode { inner: Mutex::new(Inner { parent: Some(parent.clone()), parent_idx: locked_parent.children.len(), children: vec![], is_cancelled: false, num_handles: 1, }), waker: tokio::sync::Notify::new(), }); locked_parent.children.push(child.clone()); child } /// Disconnects the given parent from all of its children. /// /// Takes a reference to [Inner] to make sure the parent is already locked. fn disconnect_children(node: &mut Inner) { for child in std::mem::take(&mut node.children) { let mut locked_child = child.inner.lock().unwrap(); locked_child.parent_idx = 0; locked_child.parent = None; } } /// Figures out the parent of the node and locks the node and its parent atomically. /// /// The basic principle of preventing deadlocks in the tree is /// that we always lock the parent first, and then the child. /// For more info look at *deadlock safety* and *invariant #2*. /// /// Sadly, it's impossible to figure out the parent of a node without /// locking it. To then achieve locking order consistency, the node /// has to be unlocked before the parent gets locked. /// This leaves a small window where we already assume that we know the parent, /// but neither the parent nor the node is locked. Therefore, the parent could change. /// /// To prevent that this problem leaks into the rest of the code, it is abstracted /// in this function. /// /// The locked child and optionally its locked parent, if a parent exists, get passed /// to the `func` argument via (node, None) or (node, Some(parent)). fn with_locked_node_and_parent(node: &Arc, func: F) -> Ret where F: FnOnce(MutexGuard<'_, Inner>, Option>) -> Ret, { use std::sync::TryLockError; let mut locked_node = node.inner.lock().unwrap(); // Every time this fails, the number of ancestors of the node decreases, // so the loop must succeed after a finite number of iterations. loop { // Look up the parent of the currently locked node. let potential_parent = match locked_node.parent.as_ref() { Some(potential_parent) => potential_parent.clone(), None => return func(locked_node, None), }; // Lock the parent. This may require unlocking the child first. let locked_parent = match potential_parent.inner.try_lock() { Ok(locked_parent) => locked_parent, Err(TryLockError::WouldBlock) => { drop(locked_node); // Deadlock safety: // // Due to invariant #2, the potential parent must come before // the child in the creation order. Therefore, we can safely // lock the child while holding the parent lock. let locked_parent = potential_parent.inner.lock().unwrap(); locked_node = node.inner.lock().unwrap(); locked_parent } Err(TryLockError::Poisoned(err)) => Err(err).unwrap(), }; // If we unlocked the child, then the parent may have changed. Check // that we still have the right parent. if let Some(actual_parent) = locked_node.parent.as_ref() { if Arc::ptr_eq(actual_parent, &potential_parent) { return func(locked_node, Some(locked_parent)); } } } } /// Moves all children from `node` to `parent`. /// /// `parent` MUST have been a parent of the node when they both got locked, /// otherwise there is a potential for a deadlock as invariant #2 would be violated. /// /// To acquire the locks for node and parent, use [with_locked_node_and_parent]. fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { // Pre-allocate in the parent, for performance parent.children.reserve(node.children.len()); for child in std::mem::take(&mut node.children) { { let mut child_locked = child.inner.lock().unwrap(); child_locked.parent = node.parent.clone(); child_locked.parent_idx = parent.children.len(); } parent.children.push(child); } } /// Removes a child from the parent. /// /// `parent` MUST be the parent of `node`. /// To acquire the locks for node and parent, use [with_locked_node_and_parent]. fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { // Query the position from where to remove a node let pos = node.parent_idx; node.parent = None; node.parent_idx = 0; // Unlock node, so that only one child at a time is locked. // Otherwise we would violate the lock order (see 'deadlock safety') as we // don't know the creation order of the child nodes drop(node); // If `node` is the last element in the list, we don't need any swapping if parent.children.len() == pos + 1 { parent.children.pop().unwrap(); } else { // If `node` is not the last element in the list, we need to // replace it with the last element let replacement_child = parent.children.pop().unwrap(); replacement_child.inner.lock().unwrap().parent_idx = pos; parent.children[pos] = replacement_child; } let len = parent.children.len(); if 4 * len <= parent.children.capacity() { parent.children.shrink_to(2 * len); } } /// Increases the reference count of handles. pub(crate) fn increase_handle_refcount(node: &Arc) { let mut locked_node = node.inner.lock().unwrap(); // Once no handles are left over, the node gets detached from the tree. // There should never be a new handle once all handles are dropped. assert!(locked_node.num_handles > 0); locked_node.num_handles += 1; } /// Decreases the reference count of handles. /// /// Once no handle is left, we can remove the node from the /// tree and connect its parent directly to its children. pub(crate) fn decrease_handle_refcount(node: &Arc) { let num_handles = { let mut locked_node = node.inner.lock().unwrap(); locked_node.num_handles -= 1; locked_node.num_handles }; if num_handles == 0 { with_locked_node_and_parent(node, |mut node, parent| { // Remove the node from the tree match parent { Some(mut parent) => { // As we want to remove ourselves from the tree, // we have to move the children to the parent, so that // they still receive the cancellation event without us. // Moving them does not violate invariant #1. move_children_to_parent(&mut node, &mut parent); // Remove the node from the parent remove_child(&mut parent, node); } None => { // Due to invariant #1, we can assume that our // children can no longer be cancelled through us. // (as we now have neither a parent nor handles) // Therefore we can disconnect them. disconnect_children(&mut node); } } }); } } /// Cancels a node and its children. pub(crate) fn cancel(node: &Arc) { let mut locked_node = node.inner.lock().unwrap(); if locked_node.is_cancelled { return; } // One by one, adopt grandchildren and then cancel and detach the child while let Some(child) = locked_node.children.pop() { // This can't deadlock because the mutex we are already // holding is the parent of child. let mut locked_child = child.inner.lock().unwrap(); // Detach the child from node // No need to modify node.children, as the child already got removed with `.pop` locked_child.parent = None; locked_child.parent_idx = 0; // If child is already cancelled, detaching is enough if locked_child.is_cancelled { continue; } // Cancel or adopt grandchildren while let Some(grandchild) = locked_child.children.pop() { // This can't deadlock because the two mutexes we are already // holding is the parent and grandparent of grandchild. let mut locked_grandchild = grandchild.inner.lock().unwrap(); // Detach the grandchild locked_grandchild.parent = None; locked_grandchild.parent_idx = 0; // If grandchild is already cancelled, detaching is enough if locked_grandchild.is_cancelled { continue; } // For performance reasons, only adopt grandchildren that have children. // Otherwise, just cancel them right away, no need for another iteration. if locked_grandchild.children.is_empty() { // Cancel the grandchild locked_grandchild.is_cancelled = true; locked_grandchild.children = Vec::new(); drop(locked_grandchild); grandchild.waker.notify_waiters(); } else { // Otherwise, adopt grandchild locked_grandchild.parent = Some(node.clone()); locked_grandchild.parent_idx = locked_node.children.len(); drop(locked_grandchild); locked_node.children.push(grandchild); } } // Cancel the child locked_child.is_cancelled = true; locked_child.children = Vec::new(); drop(locked_child); child.waker.notify_waiters(); // Now the child is cancelled and detached and all its children are adopted. // Just continue until all (including adopted) children are cancelled and detached. } // Cancel the node itself. locked_node.is_cancelled = true; locked_node.children = Vec::new(); drop(locked_node); node.waker.notify_waiters(); } tokio-util-0.7.10/src/sync/cancellation_token.rs000064400000000000000000000271311046102023000177620ustar 00000000000000//! An asynchronously awaitable `CancellationToken`. //! The token allows to signal a cancellation request to one or more tasks. pub(crate) mod guard; mod tree_node; use crate::loom::sync::Arc; use crate::util::MaybeDangling; use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; use guard::DropGuard; use pin_project_lite::pin_project; /// A token which can be used to signal a cancellation request to one or more /// tasks. /// /// Tasks can call [`CancellationToken::cancelled()`] in order to /// obtain a Future which will be resolved when cancellation is requested. /// /// Cancellation can be requested through the [`CancellationToken::cancel`] method. /// /// # Examples /// /// ```no_run /// use tokio::select; /// use tokio_util::sync::CancellationToken; /// /// #[tokio::main] /// async fn main() { /// let token = CancellationToken::new(); /// let cloned_token = token.clone(); /// /// let join_handle = tokio::spawn(async move { /// // Wait for either cancellation or a very long time /// select! { /// _ = cloned_token.cancelled() => { /// // The token was cancelled /// 5 /// } /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { /// 99 /// } /// } /// }); /// /// tokio::spawn(async move { /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; /// token.cancel(); /// }); /// /// assert_eq!(5, join_handle.await.unwrap()); /// } /// ``` pub struct CancellationToken { inner: Arc, } impl std::panic::UnwindSafe for CancellationToken {} impl std::panic::RefUnwindSafe for CancellationToken {} pin_project! { /// A Future that is resolved once the corresponding [`CancellationToken`] /// is cancelled. #[must_use = "futures do nothing unless polled"] pub struct WaitForCancellationFuture<'a> { cancellation_token: &'a CancellationToken, #[pin] future: tokio::sync::futures::Notified<'a>, } } pin_project! { /// A Future that is resolved once the corresponding [`CancellationToken`] /// is cancelled. /// /// This is the counterpart to [`WaitForCancellationFuture`] that takes /// [`CancellationToken`] by value instead of using a reference. #[must_use = "futures do nothing unless polled"] pub struct WaitForCancellationFutureOwned { // This field internally has a reference to the cancellation token, but camouflages // the relationship with `'static`. To avoid Undefined Behavior, we must ensure // that the reference is only used while the cancellation token is still alive. To // do that, we ensure that the future is the first field, so that it is dropped // before the cancellation token. // // We use `MaybeDanglingFuture` here because without it, the compiler could assert // the reference inside `future` to be valid even after the destructor of that // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed // as an argument to a function, the reference can be asserted to be valid for the // rest of that function.) To avoid that, we use `MaybeDangling` which tells the // compiler that the reference stored inside it might not be valid. // // See // for more info. #[pin] future: MaybeDangling>, cancellation_token: CancellationToken, } } // ===== impl CancellationToken ===== impl core::fmt::Debug for CancellationToken { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("CancellationToken") .field("is_cancelled", &self.is_cancelled()) .finish() } } impl Clone for CancellationToken { /// Creates a clone of the `CancellationToken` which will get cancelled /// whenever the current token gets cancelled, and vice versa. fn clone(&self) -> Self { tree_node::increase_handle_refcount(&self.inner); CancellationToken { inner: self.inner.clone(), } } } impl Drop for CancellationToken { fn drop(&mut self) { tree_node::decrease_handle_refcount(&self.inner); } } impl Default for CancellationToken { fn default() -> CancellationToken { CancellationToken::new() } } impl CancellationToken { /// Creates a new `CancellationToken` in the non-cancelled state. pub fn new() -> CancellationToken { CancellationToken { inner: Arc::new(tree_node::TreeNode::new()), } } /// Creates a `CancellationToken` which will get cancelled whenever the /// current token gets cancelled. Unlike a cloned `CancellationToken`, /// cancelling a child token does not cancel the parent token. /// /// If the current token is already cancelled, the child token will get /// returned in cancelled state. /// /// # Examples /// /// ```no_run /// use tokio::select; /// use tokio_util::sync::CancellationToken; /// /// #[tokio::main] /// async fn main() { /// let token = CancellationToken::new(); /// let child_token = token.child_token(); /// /// let join_handle = tokio::spawn(async move { /// // Wait for either cancellation or a very long time /// select! { /// _ = child_token.cancelled() => { /// // The token was cancelled /// 5 /// } /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { /// 99 /// } /// } /// }); /// /// tokio::spawn(async move { /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; /// token.cancel(); /// }); /// /// assert_eq!(5, join_handle.await.unwrap()); /// } /// ``` pub fn child_token(&self) -> CancellationToken { CancellationToken { inner: tree_node::child_node(&self.inner), } } /// Cancel the [`CancellationToken`] and all child tokens which had been /// derived from it. /// /// This will wake up all tasks which are waiting for cancellation. /// /// Be aware that cancellation is not an atomic operation. It is possible /// for another thread running in parallel with a call to `cancel` to first /// receive `true` from `is_cancelled` on one child node, and then receive /// `false` from `is_cancelled` on another child node. However, once the /// call to `cancel` returns, all child nodes have been fully cancelled. pub fn cancel(&self) { tree_node::cancel(&self.inner); } /// Returns `true` if the `CancellationToken` is cancelled. pub fn is_cancelled(&self) -> bool { tree_node::is_cancelled(&self.inner) } /// Returns a `Future` that gets fulfilled when cancellation is requested. /// /// The future will complete immediately if the token is already cancelled /// when this method is called. /// /// # Cancel safety /// /// This method is cancel safe. pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { WaitForCancellationFuture { cancellation_token: self, future: self.inner.notified(), } } /// Returns a `Future` that gets fulfilled when cancellation is requested. /// /// The future will complete immediately if the token is already cancelled /// when this method is called. /// /// The function takes self by value and returns a future that owns the /// token. /// /// # Cancel safety /// /// This method is cancel safe. pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned { WaitForCancellationFutureOwned::new(self) } /// Creates a `DropGuard` for this token. /// /// Returned guard will cancel this token (and all its children) on drop /// unless disarmed. pub fn drop_guard(self) -> DropGuard { DropGuard { inner: Some(self) } } } // ===== impl WaitForCancellationFuture ===== impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("WaitForCancellationFuture").finish() } } impl<'a> Future for WaitForCancellationFuture<'a> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { let mut this = self.project(); loop { if this.cancellation_token.is_cancelled() { return Poll::Ready(()); } // No wakeups can be lost here because there is always a call to // `is_cancelled` between the creation of the future and the call to // `poll`, and the code that sets the cancelled flag does so before // waking the `Notified`. if this.future.as_mut().poll(cx).is_pending() { return Poll::Pending; } this.future.set(this.cancellation_token.inner.notified()); } } } // ===== impl WaitForCancellationFutureOwned ===== impl core::fmt::Debug for WaitForCancellationFutureOwned { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("WaitForCancellationFutureOwned").finish() } } impl WaitForCancellationFutureOwned { fn new(cancellation_token: CancellationToken) -> Self { WaitForCancellationFutureOwned { // cancellation_token holds a heap allocation and is guaranteed to have a // stable deref, thus it would be ok to move the cancellation_token while // the future holds a reference to it. // // # Safety // // cancellation_token is dropped after future due to the field ordering. future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }), cancellation_token, } } /// # Safety /// The returned future must be destroyed before the cancellation token is /// destroyed. unsafe fn new_future( cancellation_token: &CancellationToken, ) -> tokio::sync::futures::Notified<'static> { let inner_ptr = Arc::as_ptr(&cancellation_token.inner); // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains // valid until the strong count of the Arc drops to zero, and the caller // guarantees that they will drop the future before that happens. (*inner_ptr).notified() } } impl Future for WaitForCancellationFutureOwned { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { let mut this = self.project(); loop { if this.cancellation_token.is_cancelled() { return Poll::Ready(()); } // No wakeups can be lost here because there is always a call to // `is_cancelled` between the creation of the future and the call to // `poll`, and the code that sets the cancelled flag does so before // waking the `Notified`. if this.future.as_mut().poll(cx).is_pending() { return Poll::Pending; } // # Safety // // cancellation_token is dropped after future due to the field ordering. this.future.set(MaybeDangling::new(unsafe { Self::new_future(this.cancellation_token) })); } } } tokio-util-0.7.10/src/sync/mod.rs000064400000000000000000000005531046102023000147040ustar 00000000000000//! Synchronization primitives mod cancellation_token; pub use cancellation_token::{ guard::DropGuard, CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned, }; mod mpsc; pub use mpsc::{PollSendError, PollSender}; mod poll_semaphore; pub use poll_semaphore::PollSemaphore; mod reusable_box; pub use reusable_box::ReusableBoxFuture; tokio-util-0.7.10/src/sync/mpsc.rs000064400000000000000000000305121046102023000150650ustar 00000000000000use futures_sink::Sink; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, mem}; use tokio::sync::mpsc::OwnedPermit; use tokio::sync::mpsc::Sender; use super::ReusableBoxFuture; /// Error returned by the `PollSender` when the channel is closed. #[derive(Debug)] pub struct PollSendError(Option); impl PollSendError { /// Consumes the stored value, if any. /// /// If this error was encountered when calling `start_send`/`send_item`, this will be the item /// that the caller attempted to send. Otherwise, it will be `None`. pub fn into_inner(self) -> Option { self.0 } } impl fmt::Display for PollSendError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } impl std::error::Error for PollSendError {} #[derive(Debug)] enum State { Idle(Sender), Acquiring, ReadyToSend(OwnedPermit), Closed, } /// A wrapper around [`mpsc::Sender`] that can be polled. /// /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender #[derive(Debug)] pub struct PollSender { sender: Option>, state: State, acquire: PollSenderFuture, } // Creates a future for acquiring a permit from the underlying channel. This is used to ensure // there's capacity for a send to complete. // // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. async fn make_acquire_future( data: Option>, ) -> Result, PollSendError> { match data { Some(sender) => sender .reserve_owned() .await .map_err(|_| PollSendError(None)), None => unreachable!("this future should not be pollable in this state"), } } type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result, PollSendError>>; #[derive(Debug)] // TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes struct PollSenderFuture(InnerFuture<'static, T>); impl PollSenderFuture { /// Create with an empty inner future with no `Send` bound. fn empty() -> Self { // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not // compatible with the transitive bounds required by `Sender`. Self(ReusableBoxFuture::new(async { unreachable!() })) } } impl PollSenderFuture { /// Create with an empty inner future. fn new() -> Self { let v = InnerFuture::new(make_acquire_future(None)); // This is safe because `make_acquire_future(None)` is actually `'static` Self(unsafe { mem::transmute::, InnerFuture<'static, T>>(v) }) } /// Poll the inner future. fn poll(&mut self, cx: &mut Context<'_>) -> Poll, PollSendError>> { self.0.poll(cx) } /// Replace the inner future. fn set(&mut self, sender: Option>) { let inner: *mut InnerFuture<'static, T> = &mut self.0; let inner: *mut InnerFuture<'_, T> = inner.cast(); // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T` // becomes invalid, and this casts away the type-level lifetime check for that. However, the // inner future is never moved out of this `PollSenderFuture`, so the future will not // live longer than the `PollSenderFuture` lives. A `PollSenderFuture` is guaranteed // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so // this is ok. let inner = unsafe { &mut *inner }; inner.set(make_acquire_future(sender)); } } impl PollSender { /// Creates a new `PollSender`. pub fn new(sender: Sender) -> Self { Self { sender: Some(sender.clone()), state: State::Idle(sender), acquire: PollSenderFuture::new(), } } fn take_state(&mut self) -> State { mem::replace(&mut self.state, State::Closed) } /// Attempts to prepare the sender to receive a value. /// /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to /// `send_item`. /// /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, /// by reserving a slot in the channel for the item to be sent. If this method returns /// `Poll::Pending`, the current task is registered to be notified (via /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. /// /// # Errors /// /// If the channel is closed, an error will be returned. This is a permanent state. pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll>> { loop { let (result, next_state) = match self.take_state() { State::Idle(sender) => { // Start trying to acquire a permit to reserve a slot for our send, and // immediately loop back around to poll it the first time. self.acquire.set(Some(sender)); (None, State::Acquiring) } State::Acquiring => match self.acquire.poll(cx) { // Channel has capacity. Poll::Ready(Ok(permit)) => { (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) } // Channel is closed. Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), // Channel doesn't have capacity yet, so we need to wait. Poll::Pending => (Some(Poll::Pending), State::Acquiring), }, // We're closed, either by choice or because the underlying sender was closed. s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), // We're already ready to send an item. s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), }; self.state = next_state; if let Some(result) = result { return result; } } } /// Sends an item to the channel. /// /// Before calling `send_item`, `poll_reserve` must be called with a successful return /// value of `Poll::Ready(Ok(()))`. /// /// # Errors /// /// If the channel is closed, an error will be returned. This is a permanent state. /// /// # Panics /// /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method /// will panic. #[track_caller] pub fn send_item(&mut self, value: T) -> Result<(), PollSendError> { let (result, next_state) = match self.take_state() { State::Idle(_) | State::Acquiring => { panic!("`send_item` called without first calling `poll_reserve`") } // We have a permit to send our item, so go ahead, which gets us our sender back. State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), // We're closed, either by choice or because the underlying sender was closed. State::Closed => (Err(PollSendError(Some(value))), State::Closed), }; // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. self.state = if self.sender.is_some() { next_state } else { State::Closed }; result } /// Checks whether this sender is been closed. /// /// The underlying channel that this sender was wrapping may still be open. pub fn is_closed(&self) -> bool { matches!(self.state, State::Closed) || self.sender.is_none() } /// Gets a reference to the `Sender` of the underlying channel. /// /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender /// was wrapping may still be open. pub fn get_ref(&self) -> Option<&Sender> { self.sender.as_ref() } /// Closes this sender. /// /// No more messages will be able to be sent from this sender, but the underlying channel will /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. /// /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made /// to `send_item` in order to consume the reserved slot. After that, no further sends will be /// possible. If you do not intend to send another item, you can release the reserved slot back /// to the underlying sender by calling [`abort_send`]. /// /// [`abort_send`]: crate::sync::PollSender::abort_send /// [`Receiver`]: tokio::sync::mpsc::Receiver pub fn close(&mut self) { // Mark ourselves officially closed by dropping our main sender. self.sender = None; // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly // transition to the closed state. Otherwise, leave the existing permit in place for the // caller if they want to complete the send. match self.state { State::Idle(_) => self.state = State::Closed, State::Acquiring => { self.acquire.set(None); self.state = State::Closed; } _ => {} } } /// Aborts the current in-progress send, if any. /// /// Returns `true` if a send was aborted. If the sender was closed prior to calling /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be /// ready to attempt another send. pub fn abort_send(&mut self) -> bool { // We may have been closed in the meantime, after a call to `poll_reserve` already // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the // closed state when we actually abort a send, rather than resetting ourselves back to idle. let (result, next_state) = match self.take_state() { // We're currently trying to reserve a slot to send into. State::Acquiring => { // Replacing the future drops the in-flight one. self.acquire.set(None); // If we haven't closed yet, we have to clone our stored sender since we have no way // to get it back from the acquire future we just dropped. let state = match self.sender.clone() { Some(sender) => State::Idle(sender), None => State::Closed, }; (true, state) } // We got the permit. If we haven't closed yet, get the sender back. State::ReadyToSend(permit) => { let state = if self.sender.is_some() { State::Idle(permit.release()) } else { State::Closed }; (true, state) } s => (false, s), }; self.state = next_state; result } } impl Clone for PollSender { /// Clones this `PollSender`. /// /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. fn clone(&self) -> PollSender { let (sender, state) = match self.sender.clone() { Some(sender) => (Some(sender.clone()), State::Idle(sender)), None => (None, State::Closed), }; Self { sender, state, acquire: PollSenderFuture::empty(), } } } impl Sink for PollSender { type Error = PollSendError; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_reserve(cx) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { Pin::into_inner(self).send_item(item) } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).close(); Poll::Ready(Ok(())) } } tokio-util-0.7.10/src/sync/poll_semaphore.rs000064400000000000000000000134731046102023000171430ustar 00000000000000use futures_core::{ready, Stream}; use std::fmt; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use super::ReusableBoxFuture; /// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. /// /// [`Semaphore`]: tokio::sync::Semaphore pub struct PollSemaphore { semaphore: Arc, permit_fut: Option<( u32, // The number of permits requested. ReusableBoxFuture<'static, Result>, )>, } impl PollSemaphore { /// Create a new `PollSemaphore`. pub fn new(semaphore: Arc) -> Self { Self { semaphore, permit_fut: None, } } /// Closes the semaphore. pub fn close(&self) { self.semaphore.close(); } /// Obtain a clone of the inner semaphore. pub fn clone_inner(&self) -> Arc { self.semaphore.clone() } /// Get back the inner semaphore. pub fn into_inner(self) -> Arc { self.semaphore } /// Poll to acquire a permit from the semaphore. /// /// This can return the following values: /// /// - `Poll::Pending` if a permit is not currently available. /// - `Poll::Ready(Some(permit))` if a permit was acquired. /// - `Poll::Ready(None)` if the semaphore has been closed. /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when a permit becomes available, or when the /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only /// the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { self.poll_acquire_many(cx, 1) } /// Poll to acquire many permits from the semaphore. /// /// This can return the following values: /// /// - `Poll::Pending` if a permit is not currently available. /// - `Poll::Ready(Some(permit))` if a permit was acquired. /// - `Poll::Ready(None)` if the semaphore has been closed. /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when the permits become available, or when the /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only /// the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. pub fn poll_acquire_many( &mut self, cx: &mut Context<'_>, permits: u32, ) -> Poll> { let permit_future = match self.permit_fut.as_mut() { Some((prev_permits, fut)) if *prev_permits == permits => fut, Some((old_permits, fut_box)) => { // We're requesting a different number of permits, so replace the future // and record the new amount. let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); fut_box.set(fut); *old_permits = permits; fut_box } None => { // avoid allocations completely if we can grab a permit immediately match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) { Ok(permit) => return Poll::Ready(Some(permit)), Err(TryAcquireError::Closed) => return Poll::Ready(None), Err(TryAcquireError::NoPermits) => {} } let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); &mut self .permit_fut .get_or_insert((permits, ReusableBoxFuture::new(next_fut))) .1 } }; let result = ready!(permit_future.poll(cx)); // Assume we'll request the same amount of permits in a subsequent call. let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); permit_future.set(next_fut); match result { Ok(permit) => Poll::Ready(Some(permit)), Err(_closed) => { self.permit_fut = None; Poll::Ready(None) } } } /// Returns the current number of available permits. /// /// This is equivalent to the [`Semaphore::available_permits`] method on the /// `tokio::sync::Semaphore` type. /// /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits pub fn available_permits(&self) -> usize { self.semaphore.available_permits() } /// Adds `n` new permits to the semaphore. /// /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function /// will panic if the limit is exceeded. /// /// This is equivalent to the [`Semaphore::add_permits`] method on the /// `tokio::sync::Semaphore` type. /// /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits pub fn add_permits(&self, n: usize) { self.semaphore.add_permits(n); } } impl Stream for PollSemaphore { type Item = OwnedSemaphorePermit; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).poll_acquire(cx) } } impl Clone for PollSemaphore { fn clone(&self) -> PollSemaphore { PollSemaphore::new(self.clone_inner()) } } impl fmt::Debug for PollSemaphore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PollSemaphore") .field("semaphore", &self.semaphore) .finish() } } impl AsRef for PollSemaphore { fn as_ref(&self) -> &Semaphore { &self.semaphore } } tokio-util-0.7.10/src/sync/reusable_box.rs000064400000000000000000000117051046102023000166000ustar 00000000000000use std::alloc::Layout; use std::fmt; use std::future::{self, Future}; use std::mem::{self, ManuallyDrop}; use std::pin::Pin; use std::ptr; use std::task::{Context, Poll}; /// A reusable `Pin + Send + 'a>>`. /// /// This type lets you replace the future stored in the box without /// reallocating when the size and alignment permits this. pub struct ReusableBoxFuture<'a, T> { boxed: Pin + Send + 'a>>, } impl<'a, T> ReusableBoxFuture<'a, T> { /// Create a new `ReusableBoxFuture` containing the provided future. pub fn new(future: F) -> Self where F: Future + Send + 'a, { Self { boxed: Box::pin(future), } } /// Replace the future currently stored in this box. /// /// This reallocates if and only if the layout of the provided future is /// different from the layout of the currently stored future. pub fn set(&mut self, future: F) where F: Future + Send + 'a, { if let Err(future) = self.try_set(future) { *self = Self::new(future); } } /// Replace the future currently stored in this box. /// /// This function never reallocates, but returns an error if the provided /// future has a different size or alignment from the currently stored /// future. pub fn try_set(&mut self, future: F) -> Result<(), F> where F: Future + Send + 'a, { // If we try to inline the contents of this function, the type checker complains because // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`. #[inline(always)] fn real_try_set<'a, F>( this: &mut ReusableBoxFuture<'a, F::Output>, future: F, ) -> Result<(), F> where F: Future + Send + 'a, { // future::Pending is a ZST so this never allocates. let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending())); reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed)) } real_try_set(self, future) } /// Get a pinned reference to the underlying future. pub fn get_pin(&mut self) -> Pin<&mut (dyn Future + Send)> { self.boxed.as_mut() } /// Poll the future stored inside this box. pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll { self.get_pin().poll(cx) } } impl Future for ReusableBoxFuture<'_, T> { type Output = T; /// Poll the future stored inside this box. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { Pin::into_inner(self).get_pin().poll(cx) } } // The only method called on self.boxed is poll, which takes &mut self, so this // struct being Sync does not permit any invalid access to the Future, even if // the future is not Sync. unsafe impl Sync for ReusableBoxFuture<'_, T> {} impl fmt::Debug for ReusableBoxFuture<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ReusableBoxFuture").finish() } } fn reuse_pin_box(boxed: Pin>, new_value: U, callback: F) -> Result where F: FnOnce(Box) -> O, { let layout = Layout::for_value::(&*boxed); if layout != Layout::new::() { return Err(new_value); } // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we // always drop the `T`. let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) }); // When dropping the old value panics, we still want to call `callback` — so move the rest of // the code into a guard type. let guard = CallOnDrop::new(|| { let raw: *mut U = raw.cast::(); unsafe { raw.write(new_value) }; // SAFETY: // - `T` and `U` have the same layout. // - `raw` comes from a `Box` that uses the same allocator as this one. // - `raw` points to a valid instance of `U` (we just wrote it in). let boxed = unsafe { Box::from_raw(raw) }; callback(boxed) }); // Drop the old value. unsafe { ptr::drop_in_place(raw) }; // Run the rest of the code. Ok(guard.call()) } struct CallOnDrop O> { f: ManuallyDrop, } impl O> CallOnDrop { fn new(f: F) -> Self { let f = ManuallyDrop::new(f); Self { f } } fn call(self) -> O { let mut this = ManuallyDrop::new(self); let f = unsafe { ManuallyDrop::take(&mut this.f) }; f() } } impl O> Drop for CallOnDrop { fn drop(&mut self) { let f = unsafe { ManuallyDrop::take(&mut self.f) }; f(); } } tokio-util-0.7.10/src/sync/tests/loom_cancellation_token.rs000064400000000000000000000074421046102023000221550ustar 00000000000000use crate::sync::CancellationToken; use loom::{future::block_on, thread}; use tokio_test::assert_ok; #[test] fn cancel_token() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let th1 = thread::spawn(move || { block_on(async { token1.cancelled().await; }); }); let th2 = thread::spawn(move || { token.cancel(); }); assert_ok!(th1.join()); assert_ok!(th2.join()); }); } #[test] fn cancel_token_owned() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let th1 = thread::spawn(move || { block_on(async { token1.cancelled_owned().await; }); }); let th2 = thread::spawn(move || { token.cancel(); }); assert_ok!(th1.join()); assert_ok!(th2.join()); }); } #[test] fn cancel_with_child() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let token2 = token.clone(); let child_token = token.child_token(); let th1 = thread::spawn(move || { block_on(async { token1.cancelled().await; }); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { block_on(async { child_token.cancelled().await; }); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_token_no_child() { loom::model(|| { let token = CancellationToken::new(); let token1 = token.clone(); let token2 = token.clone(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { drop(token2); }); let th3 = thread::spawn(move || { drop(token); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_token_with_children() { loom::model(|| { let token1 = CancellationToken::new(); let child_token1 = token1.child_token(); let child_token2 = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { drop(child_token1); }); let th3 = thread::spawn(move || { drop(child_token2); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn drop_and_cancel_token() { loom::model(|| { let token1 = CancellationToken::new(); let token2 = token1.clone(); let child_token = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { drop(child_token); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } #[test] fn cancel_parent_and_child() { loom::model(|| { let token1 = CancellationToken::new(); let token2 = token1.clone(); let child_token = token1.child_token(); let th1 = thread::spawn(move || { drop(token1); }); let th2 = thread::spawn(move || { token2.cancel(); }); let th3 = thread::spawn(move || { child_token.cancel(); }); assert_ok!(th1.join()); assert_ok!(th2.join()); assert_ok!(th3.join()); }); } tokio-util-0.7.10/src/sync/tests/mod.rs000064400000000000000000000000011046102023000160320ustar 00000000000000 tokio-util-0.7.10/src/task/join_map.rs000064400000000000000000000756111046102023000157160ustar 00000000000000use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; use std::borrow::Borrow; use std::collections::hash_map::RandomState; use std::fmt; use std::future::Future; use std::hash::{BuildHasher, Hash, Hasher}; use std::marker::PhantomData; use tokio::runtime::Handle; use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; /// A collection of tasks spawned on a Tokio runtime, associated with hash map /// keys. /// /// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the /// addition of a set of keys associated with each task. These keys allow /// [cancelling a task][abort] or [multiple tasks][abort_matching] in the /// `JoinMap` based on their keys, or [test whether a task corresponding to a /// given key exists][contains] in the `JoinMap`. /// /// In addition, when tasks in the `JoinMap` complete, they will return the /// associated key along with the value returned by the task, if any. /// /// A `JoinMap` can be used to await the completion of some or all of the tasks /// in the map. The map is not ordered, and the tasks will be returned in the /// order they complete. /// /// All of the tasks must have the same return type `V`. /// /// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted. /// /// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the /// documentation on unstable features][unstable] for details on how to enable /// Tokio's unstable features. /// /// # Examples /// /// Spawn multiple tasks and wait for them: /// /// ``` /// use tokio_util::task::JoinMap; /// /// #[tokio::main] /// async fn main() { /// let mut map = JoinMap::new(); /// /// for i in 0..10 { /// // Spawn a task on the `JoinMap` with `i` as its key. /// map.spawn(i, async move { /* ... */ }); /// } /// /// let mut seen = [false; 10]; /// /// // When a task completes, `join_next` returns the task's key along /// // with its output. /// while let Some((key, res)) = map.join_next().await { /// seen[key] = true; /// assert!(res.is_ok(), "task {} completed successfully!", key); /// } /// /// for i in 0..10 { /// assert!(seen[i]); /// } /// } /// ``` /// /// Cancel tasks based on their keys: /// /// ``` /// use tokio_util::task::JoinMap; /// /// #[tokio::main] /// async fn main() { /// let mut map = JoinMap::new(); /// /// map.spawn("hello world", async move { /* ... */ }); /// map.spawn("goodbye world", async move { /* ... */}); /// /// // Look up the "goodbye world" task in the map and abort it. /// let aborted = map.abort("goodbye world"); /// /// // `JoinMap::abort` returns `true` if a task existed for the /// // provided key. /// assert!(aborted); /// /// while let Some((key, res)) = map.join_next().await { /// if key == "goodbye world" { /// // The aborted task should complete with a cancelled `JoinError`. /// assert!(res.unwrap_err().is_cancelled()); /// } else { /// // Other tasks should complete normally. /// assert!(res.is_ok()); /// } /// } /// } /// ``` /// /// [`JoinSet`]: tokio::task::JoinSet /// [unstable]: tokio#unstable-features /// [abort]: fn@Self::abort /// [abort_matching]: fn@Self::abort_matching /// [contains]: fn@Self::contains_key #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] pub struct JoinMap { /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, /// indexed by their keys and task IDs. /// /// The [`Key`] type contains both the task's `K`-typed key provided when /// spawning tasks, and the task's IDs. The IDs are stored here to resolve /// hash collisions when looking up tasks based on their pre-computed hash /// (as stored in the `hashes_by_task` map). tasks_by_key: HashMap, AbortHandle, S>, /// A map from task IDs to the hash of the key associated with that task. /// /// This map is used to perform reverse lookups of tasks in the /// `tasks_by_key` map based on their task IDs. When a task terminates, the /// ID is provided to us by the `JoinSet`, so we can look up the hash value /// of that task's key, and then remove it from the `tasks_by_key` map using /// the raw hash code, resolving collisions by comparing task IDs. hashes_by_task: HashMap, /// The [`JoinSet`] that awaits the completion of tasks spawned on this /// `JoinMap`. tasks: JoinSet, } /// A [`JoinMap`] key. /// /// This holds both a `K`-typed key (the actual key as seen by the user), _and_ /// a task ID, so that hash collisions between `K`-typed keys can be resolved /// using either `K`'s `Eq` impl *or* by checking the task IDs. /// /// This allows looking up a task using either an actual key (such as when the /// user queries the map with a key), *or* using a task ID and a hash (such as /// when removing completed tasks from the map). #[derive(Debug)] struct Key { key: K, id: Id, } impl JoinMap { /// Creates a new empty `JoinMap`. /// /// The `JoinMap` is initially created with a capacity of 0, so it will not /// allocate until a task is first spawned on it. /// /// # Examples /// /// ``` /// use tokio_util::task::JoinMap; /// let map: JoinMap<&str, i32> = JoinMap::new(); /// ``` #[inline] #[must_use] pub fn new() -> Self { Self::with_hasher(RandomState::new()) } /// Creates an empty `JoinMap` with the specified capacity. /// /// The `JoinMap` will be able to hold at least `capacity` tasks without /// reallocating. /// /// # Examples /// /// ``` /// use tokio_util::task::JoinMap; /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10); /// ``` #[inline] #[must_use] pub fn with_capacity(capacity: usize) -> Self { JoinMap::with_capacity_and_hasher(capacity, Default::default()) } } impl JoinMap { /// Creates an empty `JoinMap` which will use the given hash builder to hash /// keys. /// /// The created map has the default initial capacity. /// /// Warning: `hash_builder` is normally randomly generated, and /// is designed to allow `JoinMap` to be resistant to attacks that /// cause many collisions and very poor performance. Setting it /// manually using this function can expose a DoS attack vector. /// /// The `hash_builder` passed should implement the [`BuildHasher`] trait for /// the `JoinMap` to be useful, see its documentation for details. #[inline] #[must_use] pub fn with_hasher(hash_builder: S) -> Self { Self::with_capacity_and_hasher(0, hash_builder) } /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder` /// to hash the keys. /// /// The `JoinMap` will be able to hold at least `capacity` elements without /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate. /// /// Warning: `hash_builder` is normally randomly generated, and /// is designed to allow HashMaps to be resistant to attacks that /// cause many collisions and very poor performance. Setting it /// manually using this function can expose a DoS attack vector. /// /// The `hash_builder` passed should implement the [`BuildHasher`] trait for /// the `JoinMap`to be useful, see its documentation for details. /// /// # Examples /// /// ``` /// # #[tokio::main] /// # async fn main() { /// use tokio_util::task::JoinMap; /// use std::collections::hash_map::RandomState; /// /// let s = RandomState::new(); /// let mut map = JoinMap::with_capacity_and_hasher(10, s); /// map.spawn(1, async move { "hello world!" }); /// # } /// ``` #[inline] #[must_use] pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { Self { tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), tasks: JoinSet::new(), } } /// Returns the number of tasks currently in the `JoinMap`. pub fn len(&self) -> usize { let len = self.tasks_by_key.len(); debug_assert_eq!(len, self.hashes_by_task.len()); len } /// Returns whether the `JoinMap` is empty. pub fn is_empty(&self) -> bool { let empty = self.tasks_by_key.is_empty(); debug_assert_eq!(empty, self.hashes_by_task.is_empty()); empty } /// Returns the number of tasks the map can hold without reallocating. /// /// This number is a lower bound; the `JoinMap` might be able to hold /// more, but is guaranteed to be able to hold at least this many. /// /// # Examples /// /// ``` /// use tokio_util::task::JoinMap; /// /// let map: JoinMap = JoinMap::with_capacity(100); /// assert!(map.capacity() >= 100); /// ``` #[inline] pub fn capacity(&self) -> usize { let capacity = self.tasks_by_key.capacity(); debug_assert_eq!(capacity, self.hashes_by_task.capacity()); capacity } } impl JoinMap where K: Hash + Eq, V: 'static, S: BuildHasher, { /// Spawn the provided task and store it in this `JoinMap` with the provided /// key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. /// /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn(&mut self, key: K, task: F) where F: Future, F: Send + 'static, V: Send, { let task = self.tasks.spawn(task); self.insert(key, task) } /// Spawn the provided task on the provided runtime and store it in this /// `JoinMap` with the provided key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn_on(&mut self, key: K, task: F, handle: &Handle) where F: Future, F: Send + 'static, V: Send, { let task = self.tasks.spawn_on(task, handle); self.insert(key, task); } /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided /// key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// Note that blocking tasks cannot be cancelled after execution starts. /// Replaced blocking tasks will still run to completion if the task has begun /// to execute when it is replaced. A blocking task which is replaced before /// it has been scheduled on a blocking worker thread will be cancelled. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. /// /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn_blocking(&mut self, key: K, f: F) where F: FnOnce() -> V, F: Send + 'static, V: Send, { let task = self.tasks.spawn_blocking(f); self.insert(key, task) } /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this /// `JoinMap` with the provided key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// Note that blocking tasks cannot be cancelled after execution starts. /// Replaced blocking tasks will still run to completion if the task has begun /// to execute when it is replaced. A blocking task which is replaced before /// it has been scheduled on a blocking worker thread will be cancelled. /// /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn_blocking_on(&mut self, key: K, f: F, handle: &Handle) where F: FnOnce() -> V, F: Send + 'static, V: Send, { let task = self.tasks.spawn_blocking_on(f, handle); self.insert(key, task); } /// Spawn the provided task on the current [`LocalSet`] and store it in this /// `JoinMap` with the provided key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// # Panics /// /// This method panics if it is called outside of a `LocalSet`. /// /// [`LocalSet`]: tokio::task::LocalSet /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn_local(&mut self, key: K, task: F) where F: Future, F: 'static, { let task = self.tasks.spawn_local(task); self.insert(key, task); } /// Spawn the provided task on the provided [`LocalSet`] and store it in /// this `JoinMap` with the provided key. /// /// If a task previously existed in the `JoinMap` for this key, that task /// will be cancelled and replaced with the new one. The previous task will /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will /// *not* return a cancelled [`JoinError`] for that task. /// /// [`LocalSet`]: tokio::task::LocalSet /// [`join_next`]: Self::join_next #[track_caller] pub fn spawn_local_on(&mut self, key: K, task: F, local_set: &LocalSet) where F: Future, F: 'static, { let task = self.tasks.spawn_local_on(task, local_set); self.insert(key, task) } fn insert(&mut self, key: K, abort: AbortHandle) { let hash = self.hash(&key); let id = abort.id(); let map_key = Key { id, key }; // Insert the new key into the map of tasks by keys. let entry = self .tasks_by_key .raw_entry_mut() .from_hash(hash, |k| k.key == map_key.key); match entry { RawEntryMut::Occupied(mut occ) => { // There was a previous task spawned with the same key! Cancel // that task, and remove its ID from the map of hashes by task IDs. let Key { id: prev_id, .. } = occ.insert_key(map_key); occ.insert(abort).abort(); let _prev_hash = self.hashes_by_task.remove(&prev_id); debug_assert_eq!(Some(hash), _prev_hash); } RawEntryMut::Vacant(vac) => { vac.insert(map_key, abort); } }; // Associate the key's hash with this task's ID, for looking up tasks by ID. let _prev = self.hashes_by_task.insert(id, hash); debug_assert!(_prev.is_none(), "no prior task should have had the same ID"); } /// Waits until one of the tasks in the map completes and returns its /// output, along with the key corresponding to that task. /// /// Returns `None` if the map is empty. /// /// # Cancel Safety /// /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`] /// statement and some other branch completes first, it is guaranteed that no tasks were /// removed from this `JoinMap`. /// /// # Returns /// /// This function returns: /// /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has /// completed. The `value` is the return value of that ask, and `key` is /// the key associated with the task. /// * `Some((key, Err(err))` if one of the tasks in this JoinMap` has /// panicked or been aborted. `key` is the key associated with the task /// that panicked or was aborted. /// * `None` if the `JoinMap` is empty. /// /// [`tokio::select!`]: tokio::select pub async fn join_next(&mut self) -> Option<(K, Result)> { let (res, id) = match self.tasks.join_next_with_id().await { Some(Ok((id, output))) => (Ok(output), id), Some(Err(e)) => { let id = e.id(); (Err(e), id) } None => return None, }; let key = self.remove_by_id(id)?; Some((key, res)) } /// Aborts all tasks and waits for them to finish shutting down. /// /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in /// a loop until it returns `None`. /// /// This method ignores any panics in the tasks shutting down. When this call returns, the /// `JoinMap` will be empty. /// /// [`abort_all`]: fn@Self::abort_all /// [`join_next`]: fn@Self::join_next pub async fn shutdown(&mut self) { self.abort_all(); while self.join_next().await.is_some() {} } /// Abort the task corresponding to the provided `key`. /// /// If this `JoinMap` contains a task corresponding to `key`, this method /// will abort that task and return `true`. Otherwise, if no task exists for /// `key`, this method returns `false`. /// /// # Examples /// /// Aborting a task by key: /// /// ``` /// use tokio_util::task::JoinMap; /// /// # #[tokio::main] /// # async fn main() { /// let mut map = JoinMap::new(); /// /// map.spawn("hello world", async move { /* ... */ }); /// map.spawn("goodbye world", async move { /* ... */}); /// /// // Look up the "goodbye world" task in the map and abort it. /// map.abort("goodbye world"); /// /// while let Some((key, res)) = map.join_next().await { /// if key == "goodbye world" { /// // The aborted task should complete with a cancelled `JoinError`. /// assert!(res.unwrap_err().is_cancelled()); /// } else { /// // Other tasks should complete normally. /// assert!(res.is_ok()); /// } /// } /// # } /// ``` /// /// `abort` returns `true` if a task was aborted: /// ``` /// use tokio_util::task::JoinMap; /// /// # #[tokio::main] /// # async fn main() { /// let mut map = JoinMap::new(); /// /// map.spawn("hello world", async move { /* ... */ }); /// map.spawn("goodbye world", async move { /* ... */}); /// /// // A task for the key "goodbye world" should exist in the map: /// assert!(map.abort("goodbye world")); /// /// // Aborting a key that does not exist will return `false`: /// assert!(!map.abort("goodbye universe")); /// # } /// ``` pub fn abort(&mut self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow, { match self.get_by_key(key) { Some((_, handle)) => { handle.abort(); true } None => false, } } /// Aborts all tasks with keys matching `predicate`. /// /// `predicate` is a function called with a reference to each key in the /// map. If it returns `true` for a given key, the corresponding task will /// be cancelled. /// /// # Examples /// ``` /// use tokio_util::task::JoinMap; /// /// # // use the current thread rt so that spawned tasks don't /// # // complete in the background before they can be aborted. /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// let mut map = JoinMap::new(); /// /// map.spawn("hello world", async move { /// // ... /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! /// }); /// map.spawn("goodbye world", async move { /// // ... /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! /// }); /// map.spawn("hello san francisco", async move { /// // ... /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! /// }); /// map.spawn("goodbye universe", async move { /// // ... /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! /// }); /// /// // Abort all tasks whose keys begin with "goodbye" /// map.abort_matching(|key| key.starts_with("goodbye")); /// /// let mut seen = 0; /// while let Some((key, res)) = map.join_next().await { /// seen += 1; /// if key.starts_with("goodbye") { /// // The aborted task should complete with a cancelled `JoinError`. /// assert!(res.unwrap_err().is_cancelled()); /// } else { /// // Other tasks should complete normally. /// assert!(key.starts_with("hello")); /// assert!(res.is_ok()); /// } /// } /// /// // All spawned tasks should have completed. /// assert_eq!(seen, 4); /// # } /// ``` pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) { // Note: this method iterates over the tasks and keys *without* removing // any entries, so that the keys from aborted tasks can still be // returned when calling `join_next` in the future. for (Key { ref key, .. }, task) in &self.tasks_by_key { if predicate(key) { task.abort(); } } } /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order. /// /// If a task has completed, but its output hasn't yet been consumed by a /// call to [`join_next`], this method will still return its key. /// /// [`join_next`]: fn@Self::join_next pub fn keys(&self) -> JoinMapKeys<'_, K, V> { JoinMapKeys { iter: self.tasks_by_key.keys(), _value: PhantomData, } } /// Returns `true` if this `JoinMap` contains a task for the provided key. /// /// If the task has completed, but its output hasn't yet been consumed by a /// call to [`join_next`], this method will still return `true`. /// /// [`join_next`]: fn@Self::join_next pub fn contains_key(&self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow, { self.get_by_key(key).is_some() } /// Returns `true` if this `JoinMap` contains a task with the provided /// [task ID]. /// /// If the task has completed, but its output hasn't yet been consumed by a /// call to [`join_next`], this method will still return `true`. /// /// [`join_next`]: fn@Self::join_next /// [task ID]: tokio::task::Id pub fn contains_task(&self, task: &Id) -> bool { self.get_by_id(task).is_some() } /// Reserves capacity for at least `additional` more tasks to be spawned /// on this `JoinMap` without reallocating for the map of task keys. The /// collection may reserve more space to avoid frequent reallocations. /// /// Note that spawning a task will still cause an allocation for the task /// itself. /// /// # Panics /// /// Panics if the new allocation size overflows [`usize`]. /// /// # Examples /// /// ``` /// use tokio_util::task::JoinMap; /// /// let mut map: JoinMap<&str, i32> = JoinMap::new(); /// map.reserve(10); /// ``` #[inline] pub fn reserve(&mut self, additional: usize) { self.tasks_by_key.reserve(additional); self.hashes_by_task.reserve(additional); } /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop /// down as much as possible while maintaining the internal rules /// and possibly leaving some space in accordance with the resize policy. /// /// # Examples /// /// ``` /// # #[tokio::main] /// # async fn main() { /// use tokio_util::task::JoinMap; /// /// let mut map: JoinMap = JoinMap::with_capacity(100); /// map.spawn(1, async move { 2 }); /// map.spawn(3, async move { 4 }); /// assert!(map.capacity() >= 100); /// map.shrink_to_fit(); /// assert!(map.capacity() >= 2); /// # } /// ``` #[inline] pub fn shrink_to_fit(&mut self) { self.hashes_by_task.shrink_to_fit(); self.tasks_by_key.shrink_to_fit(); } /// Shrinks the capacity of the map with a lower limit. It will drop /// down no lower than the supplied limit while maintaining the internal rules /// and possibly leaving some space in accordance with the resize policy. /// /// If the current capacity is less than the lower limit, this is a no-op. /// /// # Examples /// /// ``` /// # #[tokio::main] /// # async fn main() { /// use tokio_util::task::JoinMap; /// /// let mut map: JoinMap = JoinMap::with_capacity(100); /// map.spawn(1, async move { 2 }); /// map.spawn(3, async move { 4 }); /// assert!(map.capacity() >= 100); /// map.shrink_to(10); /// assert!(map.capacity() >= 10); /// map.shrink_to(0); /// assert!(map.capacity() >= 2); /// # } /// ``` #[inline] pub fn shrink_to(&mut self, min_capacity: usize) { self.hashes_by_task.shrink_to(min_capacity); self.tasks_by_key.shrink_to(min_capacity) } /// Look up a task in the map by its key, returning the key and abort handle. fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key, &'map AbortHandle)> where Q: Hash + Eq, K: Borrow, { let hash = self.hash(key); self.tasks_by_key .raw_entry() .from_hash(hash, |k| k.key.borrow() == key) } /// Look up a task in the map by its task ID, returning the key and abort handle. fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key, &'map AbortHandle)> { let hash = self.hashes_by_task.get(id)?; self.tasks_by_key .raw_entry() .from_hash(*hash, |k| &k.id == id) } /// Remove a task from the map by ID, returning the key for that task. fn remove_by_id(&mut self, id: Id) -> Option { // Get the hash for the given ID. let hash = self.hashes_by_task.remove(&id)?; // Remove the entry for that hash. let entry = self .tasks_by_key .raw_entry_mut() .from_hash(hash, |k| k.id == id); let (Key { id: _key_id, key }, handle) = match entry { RawEntryMut::Occupied(entry) => entry.remove_entry(), _ => return None, }; debug_assert_eq!(_key_id, id); debug_assert_eq!(id, handle.id()); self.hashes_by_task.remove(&id); Some(key) } /// Returns the hash for a given key. #[inline] fn hash(&self, key: &Q) -> u64 where Q: Hash, { let mut hasher = self.tasks_by_key.hasher().build_hasher(); key.hash(&mut hasher); hasher.finish() } } impl JoinMap where V: 'static, { /// Aborts all tasks on this `JoinMap`. /// /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty. pub fn abort_all(&mut self) { self.tasks.abort_all() } /// Removes all tasks from this `JoinMap` without aborting them. /// /// The tasks removed by this call will continue to run in the background even if the `JoinMap` /// is dropped. They may still be aborted by key. pub fn detach_all(&mut self) { self.tasks.detach_all(); self.tasks_by_key.clear(); self.hashes_by_task.clear(); } } // Hand-written `fmt::Debug` implementation in order to avoid requiring `V: // Debug`, since no value is ever actually stored in the map. impl fmt::Debug for JoinMap { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // format the task keys and abort handles a little nicer by just // printing the key and task ID pairs, without format the `Key` struct // itself or the `AbortHandle`, which would just format the task's ID // again. struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap, AbortHandle, S>); impl fmt::Debug for KeySet<'_, K, S> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_map() .entries(self.0.keys().map(|Key { key, id }| (key, id))) .finish() } } f.debug_struct("JoinMap") // The `tasks_by_key` map is the only one that contains information // that's really worth formatting for the user, since it contains // the tasks' keys and IDs. The other fields are basically // implementation details. .field("tasks", &KeySet(&self.tasks_by_key)) .finish() } } impl Default for JoinMap { fn default() -> Self { Self::new() } } // === impl Key === impl Hash for Key { // Don't include the task ID in the hash. #[inline] fn hash(&self, hasher: &mut H) { self.key.hash(hasher); } } // Because we override `Hash` for this type, we must also override the // `PartialEq` impl, so that all instances with the same hash are equal. impl PartialEq for Key { #[inline] fn eq(&self, other: &Self) -> bool { self.key == other.key } } impl Eq for Key {} /// An iterator over the keys of a [`JoinMap`]. #[derive(Debug, Clone)] pub struct JoinMapKeys<'a, K, V> { iter: hashbrown::hash_map::Keys<'a, Key, AbortHandle>, /// To make it easier to change JoinMap in the future, keep V as a generic /// parameter. _value: PhantomData<&'a V>, } impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { type Item = &'a K; fn next(&mut self) -> Option<&'a K> { self.iter.next().map(|key| &key.key) } fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } } impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> { fn len(&self) -> usize { self.iter.len() } } impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {} tokio-util-0.7.10/src/task/mod.rs000064400000000000000000000006051046102023000146700ustar 00000000000000//! Extra utilities for spawning tasks #[cfg(tokio_unstable)] mod join_map; #[cfg(not(target_os = "wasi"))] mod spawn_pinned; #[cfg(not(target_os = "wasi"))] pub use spawn_pinned::LocalPoolHandle; #[cfg(tokio_unstable)] #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] pub use join_map::{JoinMap, JoinMapKeys}; pub mod task_tracker; pub use task_tracker::TaskTracker; tokio-util-0.7.10/src/task/spawn_pinned.rs000064400000000000000000000353301046102023000166010ustar 00000000000000use futures_util::future::{AbortHandle, Abortable}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; use tokio::task::{spawn_local, JoinHandle, LocalSet}; /// A cloneable handle to a local pool, used for spawning `!Send` tasks. /// /// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will /// execute on the same thread) inside the Future you supply to the various spawn methods /// of `LocalPoolHandle`, /// /// [`tokio::task::LocalSet`]: tokio::task::LocalSet /// [`tokio::task::spawn_local`]: tokio::task::spawn_local /// /// # Examples /// /// ``` /// use std::rc::Rc; /// use tokio::{self, task }; /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main(flavor = "current_thread")] /// async fn main() { /// let pool = LocalPoolHandle::new(5); /// /// let output = pool.spawn_pinned(|| { /// // `data` is !Send + !Sync /// let data = Rc::new("local data"); /// let data_clone = data.clone(); /// /// async move { /// task::spawn_local(async move { /// println!("{}", data_clone); /// }); /// /// data.to_string() /// } /// }).await.unwrap(); /// println!("output: {}", output); /// } /// ``` /// #[derive(Clone)] pub struct LocalPoolHandle { pool: Arc, } impl LocalPoolHandle { /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this /// pool via [`LocalPoolHandle::spawn_pinned`]. /// /// # Panics /// /// Panics if the pool size is less than one. #[track_caller] pub fn new(pool_size: usize) -> LocalPoolHandle { assert!(pool_size > 0); let workers = (0..pool_size) .map(|_| LocalWorkerHandle::new_worker()) .collect(); let pool = Arc::new(LocalPool { workers }); LocalPoolHandle { pool } } /// Returns the number of threads of the Pool. #[inline] pub fn num_threads(&self) -> usize { self.pool.workers.len() } /// Returns the number of tasks scheduled on each worker. The indices of the /// worker threads correspond to the indices of the returned `Vec`. pub fn get_task_loads_for_each_worker(&self) -> Vec { self.pool .workers .iter() .map(|worker| worker.task_count.load(Ordering::SeqCst)) .collect::>() } /// Spawn a task onto a worker thread and pin it there so it can't be moved /// off of the thread. Note that the future is not [`Send`], but the /// [`FnOnce`] which creates it is. /// /// # Examples /// ``` /// use std::rc::Rc; /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main] /// async fn main() { /// // Create the local pool /// let pool = LocalPoolHandle::new(1); /// /// // Spawn a !Send future onto the pool and await it /// let output = pool /// .spawn_pinned(|| { /// // Rc is !Send + !Sync /// let local_data = Rc::new("test"); /// /// // This future holds an Rc, so it is !Send /// async move { local_data.to_string() } /// }) /// .await /// .unwrap(); /// /// assert_eq!(output, "test"); /// } /// ``` pub fn spawn_pinned(&self, create_task: F) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { self.pool .spawn_pinned(create_task, WorkerChoice::LeastBurdened) } /// Differs from `spawn_pinned` only in that you can choose a specific worker thread /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest /// number of tasks scheduled. /// /// A worker thread is chosen by index. Indices are 0 based and the largest index /// is given by `num_threads() - 1` /// /// # Panics /// /// This method panics if the index is out of bounds. /// /// # Examples /// /// This method can be used to spawn a task on all worker threads of the pool: /// /// ``` /// use tokio_util::task::LocalPoolHandle; /// /// #[tokio::main] /// async fn main() { /// const NUM_WORKERS: usize = 3; /// let pool = LocalPoolHandle::new(NUM_WORKERS); /// let handles = (0..pool.num_threads()) /// .map(|worker_idx| { /// pool.spawn_pinned_by_idx( /// || { /// async { /// "test" /// } /// }, /// worker_idx, /// ) /// }) /// .collect::>(); /// /// for handle in handles { /// handle.await.unwrap(); /// } /// } /// ``` /// #[track_caller] pub fn spawn_pinned_by_idx(&self, create_task: F, idx: usize) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { self.pool .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) } } impl Debug for LocalPoolHandle { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("LocalPoolHandle") } } enum WorkerChoice { LeastBurdened, ByIdx(usize), } struct LocalPool { workers: Vec, } impl LocalPool { /// Spawn a `?Send` future onto a worker #[track_caller] fn spawn_pinned( &self, create_task: F, worker_choice: WorkerChoice, ) -> JoinHandle where F: FnOnce() -> Fut, F: Send + 'static, Fut: Future + 'static, Fut::Output: Send + 'static, { let (sender, receiver) = oneshot::channel(); let (worker, job_guard) = match worker_choice { WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), }; let worker_spawner = worker.spawner.clone(); // Spawn a future onto the worker's runtime so we can immediately return // a join handle. worker.runtime_handle.spawn(async move { // Move the job guard into the task let _job_guard = job_guard; // Propagate aborts via Abortable/AbortHandle let (abort_handle, abort_registration) = AbortHandle::new_pair(); let _abort_guard = AbortGuard(abort_handle); // Inside the future we can't run spawn_local yet because we're not // in the context of a LocalSet. We need to send create_task to the // LocalSet task for spawning. let spawn_task = Box::new(move || { // Once we're in the LocalSet context we can call spawn_local let join_handle = spawn_local( async move { Abortable::new(create_task(), abort_registration).await }, ); // Send the join handle back to the spawner. If sending fails, // we assume the parent task was canceled, so cancel this task // as well. if let Err(join_handle) = sender.send(join_handle) { join_handle.abort() } }); // Send the callback to the LocalSet task if let Err(e) = worker_spawner.send(spawn_task) { // Propagate the error as a panic in the join handle. panic!("Failed to send job to worker: {}", e); } // Wait for the task's join handle let join_handle = match receiver.await { Ok(handle) => handle, Err(e) => { // We sent the task successfully, but failed to get its // join handle... We assume something happened to the worker // and the task was not spawned. Propagate the error as a // panic in the join handle. panic!("Worker failed to send join handle: {}", e); } }; // Wait for the task to complete let join_result = join_handle.await; match join_result { Ok(Ok(output)) => output, Ok(Err(_)) => { // Pinned task was aborted. But that only happens if this // task is aborted. So this is an impossible branch. unreachable!( "Reaching this branch means this task was previously \ aborted but it continued running anyways" ) } Err(e) => { if e.is_panic() { std::panic::resume_unwind(e.into_panic()); } else if e.is_cancelled() { // No one else should have the join handle, so this is // unexpected. Forward this error as a panic in the join // handle. panic!("spawn_pinned task was canceled: {}", e); } else { // Something unknown happened (not a panic or // cancellation). Forward this error as a panic in the // join handle. panic!("spawn_pinned task failed: {}", e); } } } }) } /// Find the worker with the least number of tasks, increment its task /// count, and return its handle. Make sure to actually spawn a task on /// the worker so the task count is kept consistent with load. /// /// A job count guard is also returned to ensure the task count gets /// decremented when the job is done. fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { loop { let (worker, task_count) = self .workers .iter() .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) .min_by_key(|&(_, count)| count) .expect("There must be more than one worker"); // Make sure the task count hasn't changed since when we choose this // worker. Otherwise, restart the search. if worker .task_count .compare_exchange( task_count, task_count + 1, Ordering::SeqCst, Ordering::Relaxed, ) .is_ok() { return (worker, JobCountGuard(Arc::clone(&worker.task_count))); } } } #[track_caller] fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { let worker = &self.workers[idx]; worker.task_count.fetch_add(1, Ordering::SeqCst); (worker, JobCountGuard(Arc::clone(&worker.task_count))) } } /// Automatically decrements a worker's job count when a job finishes (when /// this gets dropped). struct JobCountGuard(Arc); impl Drop for JobCountGuard { fn drop(&mut self) { // Decrement the job count let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); debug_assert!(previous_value >= 1); } } /// Calls abort on the handle when dropped. struct AbortGuard(AbortHandle); impl Drop for AbortGuard { fn drop(&mut self) { self.0.abort(); } } type PinnedFutureSpawner = Box; struct LocalWorkerHandle { runtime_handle: tokio::runtime::Handle, spawner: UnboundedSender, task_count: Arc, } impl LocalWorkerHandle { /// Create a new worker for executing pinned tasks fn new_worker() -> LocalWorkerHandle { let (sender, receiver) = unbounded_channel(); let runtime = Builder::new_current_thread() .enable_all() .build() .expect("Failed to start a pinned worker thread runtime"); let runtime_handle = runtime.handle().clone(); let task_count = Arc::new(AtomicUsize::new(0)); let task_count_clone = Arc::clone(&task_count); std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); LocalWorkerHandle { runtime_handle, spawner: sender, task_count, } } fn run( runtime: tokio::runtime::Runtime, mut task_receiver: UnboundedReceiver, task_count: Arc, ) { let local_set = LocalSet::new(); local_set.block_on(&runtime, async { while let Some(spawn_task) = task_receiver.recv().await { // Calls spawn_local(future) (spawn_task)(); } }); // If there are any tasks on the runtime associated with a LocalSet task // that has already completed, but whose output has not yet been // reported, let that task complete. // // Since the task_count is decremented when the runtime task exits, // reading that counter lets us know if any such tasks completed during // the call to `block_on`. // // Tasks on the LocalSet can't complete during this loop since they're // stored on the LocalSet and we aren't accessing it. let mut previous_task_count = task_count.load(Ordering::SeqCst); loop { // This call will also run tasks spawned on the runtime. runtime.block_on(tokio::task::yield_now()); let new_task_count = task_count.load(Ordering::SeqCst); if new_task_count == previous_task_count { break; } else { previous_task_count = new_task_count; } } // It's now no longer possible for a task on the runtime to be // associated with a LocalSet task that has completed. Drop both the // LocalSet and runtime to let tasks on the runtime be cancelled if and // only if they are still on the LocalSet. // // Drop the LocalSet task first so that anyone awaiting the runtime // JoinHandle will see the cancelled error after the LocalSet task // destructor has completed. drop(local_set); drop(runtime); } } tokio-util-0.7.10/src/task/task_tracker.rs000064400000000000000000000552511046102023000165750ustar 00000000000000//! Types related to the [`TaskTracker`] collection. //! //! See the documentation of [`TaskTracker`] for more information. use pin_project_lite::pin_project; use std::fmt; use std::future::Future; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::{futures::Notified, Notify}; #[cfg(feature = "rt")] use tokio::{ runtime::Handle, task::{JoinHandle, LocalSet}, }; /// A task tracker used for waiting until tasks exit. /// /// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The /// `CancellationToken` is used to signal to tasks that they should shut down, and the /// `TaskTracker` is used to wait for them to finish shutting down. /// /// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case /// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the /// [`wait`] method will wait until *both* of the following happen at the same time: /// /// * The `TaskTracker` must be closed using the [`close`] method. /// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited. /// /// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that /// the destructor of the future has finished running. However, there might be a short amount of /// time where [`JoinHandle::is_finished`] returns false. /// /// # Comparison to `JoinSet` /// /// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a /// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its /// unique features is required: /// /// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory. /// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from from returning even if /// the `TaskTracker` is empty. /// 3. A `TaskTracker` does not require mutable access to insert tasks. /// 4. A `TaskTracker` can be cloned to share it with many tasks. /// /// The first point is the most important one. A [`JoinSet`] keeps track of the return value of /// every inserted task. This means that if the caller keeps inserting tasks and never calls /// [`join_next`], then their return values will keep building up and consuming memory, _even if_ /// most of the tasks have already exited. This can cause the process to run out of memory. With a /// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the /// `TaskTracker`. /// /// # Examples /// /// For more examples, please see the topic page on [graceful shutdown]. /// /// ## Spawn tasks and wait for them to exit /// /// This is a simple example. For this case, [`JoinSet`] should probably be used instead. /// /// ``` /// use tokio_util::task::TaskTracker; /// /// #[tokio::main] /// async fn main() { /// let tracker = TaskTracker::new(); /// /// for i in 0..10 { /// tracker.spawn(async move { /// println!("Task {} is running!", i); /// }); /// } /// // Once we spawned everything, we close the tracker. /// tracker.close(); /// /// // Wait for everything to finish. /// tracker.wait().await; /// /// println!("This is printed after all of the tasks."); /// } /// ``` /// /// ## Wait for tasks to exit /// /// This example shows the intended use-case of `TaskTracker`. It is used together with /// [`CancellationToken`] to implement graceful shutdown. /// ``` /// use tokio_util::sync::CancellationToken; /// use tokio_util::task::TaskTracker; /// use tokio::time::{self, Duration}; /// /// async fn background_task(num: u64) { /// for i in 0..10 { /// time::sleep(Duration::from_millis(100*num)).await; /// println!("Background task {} in iteration {}.", num, i); /// } /// } /// /// #[tokio::main] /// # async fn _hidden() {} /// # #[tokio::main(flavor = "current_thread", start_paused = true)] /// async fn main() { /// let tracker = TaskTracker::new(); /// let token = CancellationToken::new(); /// /// for i in 0..10 { /// let token = token.clone(); /// tracker.spawn(async move { /// // Use a `tokio::select!` to kill the background task if the token is /// // cancelled. /// tokio::select! { /// () = background_task(i) => { /// println!("Task {} exiting normally.", i); /// }, /// () = token.cancelled() => { /// // Do some cleanup before we really exit. /// time::sleep(Duration::from_millis(50)).await; /// println!("Task {} finished cleanup.", i); /// }, /// } /// }); /// } /// /// // Spawn a background task that will send the shutdown signal. /// { /// let tracker = tracker.clone(); /// tokio::spawn(async move { /// // Normally you would use something like ctrl-c instead of /// // sleeping. /// time::sleep(Duration::from_secs(2)).await; /// tracker.close(); /// token.cancel(); /// }); /// } /// /// // Wait for all tasks to exit. /// tracker.wait().await; /// /// println!("All tasks have exited now."); /// } /// ``` /// /// [`CancellationToken`]: crate::sync::CancellationToken /// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished /// [`JoinSet`]: tokio::task::JoinSet /// [`close`]: Self::close /// [`join_next`]: tokio::task::JoinSet::join_next /// [`wait`]: Self::wait /// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown pub struct TaskTracker { inner: Arc, } /// Represents a task tracked by a [`TaskTracker`]. #[must_use] #[derive(Debug)] pub struct TaskTrackerToken { task_tracker: TaskTracker, } struct TaskTrackerInner { /// Keeps track of the state. /// /// The lowest bit is whether the task tracker is closed. /// /// The rest of the bits count the number of tracked tasks. state: AtomicUsize, /// Used to notify when the last task exits. on_last_exit: Notify, } pin_project! { /// A future that is tracked as a task by a [`TaskTracker`]. /// /// The associated [`TaskTracker`] cannot complete until this future is dropped. /// /// This future is returned by [`TaskTracker::track_future`]. #[must_use = "futures do nothing unless polled"] pub struct TrackedFuture { #[pin] future: F, token: TaskTrackerToken, } } pin_project! { /// A future that completes when the [`TaskTracker`] is empty and closed. /// /// This future is returned by [`TaskTracker::wait`]. #[must_use = "futures do nothing unless polled"] pub struct TaskTrackerWaitFuture<'a> { #[pin] future: Notified<'a>, inner: Option<&'a TaskTrackerInner>, } } impl TaskTrackerInner { #[inline] fn new() -> Self { Self { state: AtomicUsize::new(0), on_last_exit: Notify::new(), } } #[inline] fn is_closed_and_empty(&self) -> bool { // If empty and closed bit set, then we are done. // // The acquire load will synchronize with the release store of any previous call to // `set_closed` and `drop_task`. self.state.load(Ordering::Acquire) == 1 } #[inline] fn set_closed(&self) -> bool { // The AcqRel ordering makes the closed bit behave like a `Mutex` for synchronization // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}` // more meaningful for the user. Without these orderings, this assert could fail: // ``` // // thread 1 // some_other_atomic.store(true, Relaxed); // tracker.close(); // // // thread 2 // if tracker.reopen() { // assert!(some_other_atomic.load(Relaxed)); // } // ``` // However, with the AcqRel ordering, we establish a happens-before relationship from the // call to `close` and the later call to `reopen` that returned true. let state = self.state.fetch_or(1, Ordering::AcqRel); // If there are no tasks, and if it was not already closed: if state == 0 { self.notify_now(); } (state & 1) == 0 } #[inline] fn set_open(&self) -> bool { // See `set_closed` regarding the AcqRel ordering. let state = self.state.fetch_and(!1, Ordering::AcqRel); (state & 1) == 1 } #[inline] fn add_task(&self) { self.state.fetch_add(2, Ordering::Relaxed); } #[inline] fn drop_task(&self) { let state = self.state.fetch_sub(2, Ordering::Release); // If this was the last task and we are closed: if state == 3 { self.notify_now(); } } #[cold] fn notify_now(&self) { // Insert an acquire fence. This matters for `drop_task` but doesn't matter for // `set_closed` since it already uses AcqRel. // // This synchronizes with the release store of any other call to `drop_task`, and with the // release store in the call to `set_closed`. That ensures that everything that happened // before those other calls to `drop_task` or `set_closed` will be visible after this load, // and those things will also be visible to anything woken by the call to `notify_waiters`. self.state.load(Ordering::Acquire); self.on_last_exit.notify_waiters(); } } impl TaskTracker { /// Creates a new `TaskTracker`. /// /// The `TaskTracker` will start out as open. #[must_use] pub fn new() -> Self { Self { inner: Arc::new(TaskTrackerInner::new()), } } /// Waits until this `TaskTracker` is both closed and empty. /// /// If the `TaskTracker` is already closed and empty when this method is called, then it /// returns immediately. /// /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker` /// becomes both closed and empty for a short amount of time, then it is guarantee that all /// `wait` futures that were created before the short time interval will trigger, even if they /// are not polled during that short time interval. /// /// # Cancel safety /// /// This method is cancel safe. /// /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the /// condition in a `tokio::select!` loop. /// /// [aba]: https://en.wikipedia.org/wiki/ABA_problem #[inline] pub fn wait(&self) -> TaskTrackerWaitFuture<'_> { TaskTrackerWaitFuture { future: self.inner.on_last_exit.notified(), inner: if self.inner.is_closed_and_empty() { None } else { Some(&self.inner) }, } } /// Close this `TaskTracker`. /// /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks. /// /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed. /// /// [`wait`]: Self::wait #[inline] pub fn close(&self) -> bool { self.inner.set_closed() } /// Reopen this `TaskTracker`. /// /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty. /// /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open. /// /// [`wait`]: Self::wait #[inline] pub fn reopen(&self) -> bool { self.inner.set_open() } /// Returns `true` if this `TaskTracker` is [closed](Self::close). #[inline] #[must_use] pub fn is_closed(&self) -> bool { (self.inner.state.load(Ordering::Acquire) & 1) != 0 } /// Returns the number of tasks tracked by this `TaskTracker`. #[inline] #[must_use] pub fn len(&self) -> usize { self.inner.state.load(Ordering::Acquire) >> 1 } /// Returns `true` if there are no tasks in this `TaskTracker`. #[inline] #[must_use] pub fn is_empty(&self) -> bool { self.inner.state.load(Ordering::Acquire) <= 1 } /// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`. /// /// This is equivalent to `tokio::spawn(tracker.track_future(task))`. #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn(&self, task: F) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { tokio::task::spawn(self.track_future(task)) } /// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`. /// /// This is equivalent to `handle.spawn(tracker.track_future(task))`. #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn_on(&self, task: F, handle: &Handle) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { handle.spawn(self.track_future(task)) } /// Spawn the provided future on the current [`LocalSet`], and track it in this `TaskTracker`. /// /// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`. /// /// [`LocalSet`]: tokio::task::LocalSet #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn_local(&self, task: F) -> JoinHandle where F: Future + 'static, F::Output: 'static, { tokio::task::spawn_local(self.track_future(task)) } /// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`. /// /// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`. /// /// [`LocalSet`]: tokio::task::LocalSet #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn_local_on(&self, task: F, local_set: &LocalSet) -> JoinHandle where F: Future + 'static, F::Output: 'static, { local_set.spawn_local(self.track_future(task)) } /// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`. /// /// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`. #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg(not(target_family = "wasm"))] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn_blocking(&self, task: F) -> JoinHandle where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static, { let token = self.token(); tokio::task::spawn_blocking(move || { let res = task(); drop(token); res }) } /// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`. /// /// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`. #[inline] #[track_caller] #[cfg(feature = "rt")] #[cfg(not(target_family = "wasm"))] #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn spawn_blocking_on(&self, task: F, handle: &Handle) -> JoinHandle where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static, { let token = self.token(); handle.spawn_blocking(move || { let res = task(); drop(token); res }) } /// Track the provided future. /// /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will /// prevent calls to [`wait`] from returning until the task is dropped. /// /// The task is removed from the collection when it is dropped, not when [`poll`] returns /// [`Poll::Ready`]. /// /// # Examples /// /// Track a future spawned with [`tokio::spawn`]. /// /// ``` /// # async fn my_async_fn() {} /// use tokio_util::task::TaskTracker; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// let tracker = TaskTracker::new(); /// /// tokio::spawn(tracker.track_future(my_async_fn())); /// # } /// ``` /// /// Track a future spawned on a [`JoinSet`]. /// ``` /// # async fn my_async_fn() {} /// use tokio::task::JoinSet; /// use tokio_util::task::TaskTracker; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// let tracker = TaskTracker::new(); /// let mut join_set = JoinSet::new(); /// /// join_set.spawn(tracker.track_future(my_async_fn())); /// # } /// ``` /// /// [`JoinSet`]: tokio::task::JoinSet /// [`Poll::Pending`]: std::task::Poll::Pending /// [`poll`]: std::future::Future::poll /// [`wait`]: Self::wait #[inline] pub fn track_future(&self, future: F) -> TrackedFuture { TrackedFuture { future, token: self.token(), } } /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`. /// /// This token is a lower-level utility than the spawn methods. Each token is considered to /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete. /// Furthermore, the count returned by the [`len`] method will include the tokens in the count. /// /// Dropping the token indicates to the `TaskTracker` that the task has exited. /// /// [`len`]: TaskTracker::len #[inline] pub fn token(&self) -> TaskTrackerToken { self.inner.add_task(); TaskTrackerToken { task_tracker: self.clone(), } } /// Returns `true` if both task trackers correspond to the same set of tasks. /// /// # Examples /// /// ``` /// use tokio_util::task::TaskTracker; /// /// let tracker_1 = TaskTracker::new(); /// let tracker_2 = TaskTracker::new(); /// let tracker_1_clone = tracker_1.clone(); /// /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone)); /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2)); /// ``` #[inline] #[must_use] pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool { Arc::ptr_eq(&left.inner, &right.inner) } } impl Default for TaskTracker { /// Creates a new `TaskTracker`. /// /// The `TaskTracker` will start out as open. #[inline] fn default() -> TaskTracker { TaskTracker::new() } } impl Clone for TaskTracker { /// Returns a new `TaskTracker` that tracks the same set of tasks. /// /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in /// all other clones. /// /// # Examples /// /// ``` /// use tokio_util::task::TaskTracker; /// /// #[tokio::main] /// # async fn _hidden() {} /// # #[tokio::main(flavor = "current_thread")] /// async fn main() { /// let tracker = TaskTracker::new(); /// let cloned = tracker.clone(); /// /// // Spawns on `tracker` are visible in `cloned`. /// tracker.spawn(std::future::pending::<()>()); /// assert_eq!(cloned.len(), 1); /// /// // Spawns on `cloned` are visible in `tracker`. /// cloned.spawn(std::future::pending::<()>()); /// assert_eq!(tracker.len(), 2); /// /// // Calling `close` is visible to `cloned`. /// tracker.close(); /// assert!(cloned.is_closed()); /// /// // Calling `reopen` is visible to `tracker`. /// cloned.reopen(); /// assert!(!tracker.is_closed()); /// } /// ``` #[inline] fn clone(&self) -> TaskTracker { Self { inner: self.inner.clone(), } } } fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result { let state = inner.state.load(Ordering::Acquire); let is_closed = (state & 1) != 0; let len = state >> 1; f.debug_struct("TaskTracker") .field("len", &len) .field("is_closed", &is_closed) .field("inner", &(inner as *const TaskTrackerInner)) .finish() } impl fmt::Debug for TaskTracker { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { debug_inner(&self.inner, f) } } impl TaskTrackerToken { /// Returns the [`TaskTracker`] that this token is associated with. #[inline] #[must_use] pub fn task_tracker(&self) -> &TaskTracker { &self.task_tracker } } impl Clone for TaskTrackerToken { /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`]. /// /// This is equivalent to `token.task_tracker().token()`. #[inline] fn clone(&self) -> TaskTrackerToken { self.task_tracker.token() } } impl Drop for TaskTrackerToken { /// Dropping the token indicates to the [`TaskTracker`] that the task has exited. #[inline] fn drop(&mut self) { self.task_tracker.inner.drop_task(); } } impl Future for TrackedFuture { type Output = F::Output; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().future.poll(cx) } } impl fmt::Debug for TrackedFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TrackedFuture") .field("future", &self.future) .field("task_tracker", self.token.task_tracker()) .finish() } } impl<'a> Future for TaskTrackerWaitFuture<'a> { type Output = (); #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { let me = self.project(); let inner = match me.inner.as_ref() { None => return Poll::Ready(()), Some(inner) => inner, }; let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready(); if ready { *me.inner = None; Poll::Ready(()) } else { Poll::Pending } } } impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct Helper<'a>(&'a TaskTrackerInner); impl fmt::Debug for Helper<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { debug_inner(self.0, f) } } f.debug_struct("TaskTrackerWaitFuture") .field("future", &self.future) .field("task_tracker", &self.inner.map(Helper)) .finish() } } tokio-util-0.7.10/src/time/delay_queue.rs000064400000000000000000001202721046102023000164120ustar 00000000000000//! A queue of delayed elements. //! //! See [`DelayQueue`] for more details. //! //! [`DelayQueue`]: struct@DelayQueue use crate::time::wheel::{self, Wheel}; use futures_core::ready; use tokio::time::{sleep_until, Duration, Instant, Sleep}; use core::ops::{Index, IndexMut}; use slab::Slab; use std::cmp; use std::collections::HashMap; use std::convert::From; use std::fmt; use std::fmt::Debug; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{self, Poll, Waker}; /// A queue of delayed elements. /// /// Once an element is inserted into the `DelayQueue`, it is yielded once the /// specified deadline has been reached. /// /// # Usage /// /// Elements are inserted into `DelayQueue` using the [`insert`] or /// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is /// returned. The key is used to remove the entry or to change the deadline at /// which it should be yielded back. /// /// Once delays have been configured, the `DelayQueue` is used via its /// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its /// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the /// current task will be notified once the deadline has been reached. /// /// # `Stream` implementation /// /// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have /// expired, no items are returned. In this case, `Poll::Pending` is returned and the /// current task is registered to be notified once the next item's delay has /// expired. /// /// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` /// returns `Poll::Ready(None)`. This indicates that the stream has reached an end. /// However, if a new item is inserted *after*, `poll` will once again start /// returning items or `Poll::Pending`. /// /// Items are returned ordered by their expirations. Items that are configured /// to expire first will be returned first. There are no ordering guarantees /// for items configured to expire at the same instant. Also note that delays are /// rounded to the closest millisecond. /// /// # Implementation /// /// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally /// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same /// performance and scalability benefits. /// /// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, /// and allows reuse of the memory allocated for expired entries. /// /// Capacity can be checked using [`capacity`] and allocated preemptively by using /// the [`reserve`] method. /// /// # Usage /// /// Using `DelayQueue` to manage cache entries. /// /// ```rust,no_run /// use tokio_util::time::{DelayQueue, delay_queue}; /// /// use futures::ready; /// use std::collections::HashMap; /// use std::task::{Context, Poll}; /// use std::time::Duration; /// # type CacheKey = String; /// # type Value = String; /// /// struct Cache { /// entries: HashMap, /// expirations: DelayQueue, /// } /// /// const TTL_SECS: u64 = 30; /// /// impl Cache { /// fn insert(&mut self, key: CacheKey, value: Value) { /// let delay = self.expirations /// .insert(key.clone(), Duration::from_secs(TTL_SECS)); /// /// self.entries.insert(key, (value, delay)); /// } /// /// fn get(&self, key: &CacheKey) -> Option<&Value> { /// self.entries.get(key) /// .map(|&(ref v, _)| v) /// } /// /// fn remove(&mut self, key: &CacheKey) { /// if let Some((_, cache_key)) = self.entries.remove(key) { /// self.expirations.remove(&cache_key); /// } /// } /// /// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> { /// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) { /// self.entries.remove(entry.get_ref()); /// } /// /// Poll::Ready(()) /// } /// } /// ``` /// /// [`insert`]: method@Self::insert /// [`insert_at`]: method@Self::insert_at /// [`Key`]: struct@Key /// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html /// [`poll_expired`]: method@Self::poll_expired /// [`Stream::poll_expired`]: method@Self::poll_expired /// [`DelayQueue`]: struct@DelayQueue /// [`sleep`]: fn@tokio::time::sleep /// [`slab`]: slab /// [`capacity`]: method@Self::capacity /// [`reserve`]: method@Self::reserve #[derive(Debug)] pub struct DelayQueue { /// Stores data associated with entries slab: SlabStorage, /// Lookup structure tracking all delays in the queue wheel: Wheel>, /// Delays that were inserted when already expired. These cannot be stored /// in the wheel expired: Stack, /// Delay expiring when the *first* item in the queue expires delay: Option>>, /// Wheel polling state wheel_now: u64, /// Instant at which the timer starts start: Instant, /// Waker that is invoked when we potentially need to reset the timer. /// Because we lazily create the timer when the first entry is created, we /// need to awaken any poller that polled us before that point. waker: Option, } #[derive(Default)] struct SlabStorage { inner: Slab>, // A `compact` call requires a re-mapping of the `Key`s that were changed // during the `compact` call of the `slab`. Since the keys that were given out // cannot be changed retroactively we need to keep track of these re-mappings. // The keys of `key_map` correspond to the old keys that were given out and // the values to the `Key`s that were re-mapped by the `compact` call. key_map: HashMap, // Index used to create new keys to hand out. next_key_index: usize, // Whether `compact` has been called, necessary in order to decide whether // to include keys in `key_map`. compact_called: bool, } impl SlabStorage { pub(crate) fn with_capacity(capacity: usize) -> SlabStorage { SlabStorage { inner: Slab::with_capacity(capacity), key_map: HashMap::new(), next_key_index: 0, compact_called: false, } } // Inserts data into the inner slab and re-maps keys if necessary pub(crate) fn insert(&mut self, val: Data) -> Key { let mut key = KeyInternal::new(self.inner.insert(val)); let key_contained = self.key_map.contains_key(&key.into()); if key_contained { // It's possible that a `compact` call creates capacity in `self.inner` in // such a way that a `self.inner.insert` call creates a `key` which was // previously given out during an `insert` call prior to the `compact` call. // If `key` is contained in `self.key_map`, we have encountered this exact situation, // We need to create a new key `key_to_give_out` and include the relation // `key_to_give_out` -> `key` in `self.key_map`. let key_to_give_out = self.create_new_key(); assert!(!self.key_map.contains_key(&key_to_give_out.into())); self.key_map.insert(key_to_give_out.into(), key); key = key_to_give_out; } else if self.compact_called { // Include an identity mapping in `self.key_map` in order to allow us to // panic if a key that was handed out is removed more than once. self.key_map.insert(key.into(), key); } key.into() } // Re-map the key in case compact was previously called. // Note: Since we include identity mappings in key_map after compact was called, // we have information about all keys that were handed out. In the case in which // compact was called and we try to remove a Key that was previously removed // we can detect invalid keys if no key is found in `key_map`. This is necessary // in order to prevent situations in which a previously removed key // corresponds to a re-mapped key internally and which would then be incorrectly // removed from the slab. // // Example to illuminate this problem: // // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we // were to remove 1 again, we would not find it inside `key_map` anymore. // If we were to imply from this that no re-mapping was necessary, we would // incorrectly remove 1 from `self.slab.inner`, which corresponds to the // handed-out key 2. pub(crate) fn remove(&mut self, key: &Key) -> Data { let remapped_key = if self.compact_called { match self.key_map.remove(key) { Some(key_internal) => key_internal, None => panic!("invalid key"), } } else { (*key).into() }; self.inner.remove(remapped_key.index) } pub(crate) fn shrink_to_fit(&mut self) { self.inner.shrink_to_fit(); self.key_map.shrink_to_fit(); } pub(crate) fn compact(&mut self) { if !self.compact_called { for (key, _) in self.inner.iter() { self.key_map.insert(Key::new(key), KeyInternal::new(key)); } } let mut remapping = HashMap::new(); self.inner.compact(|_, from, to| { remapping.insert(from, to); true }); // At this point `key_map` contains a mapping for every element. for internal_key in self.key_map.values_mut() { if let Some(new_internal_key) = remapping.get(&internal_key.index) { *internal_key = KeyInternal::new(*new_internal_key); } } if self.key_map.capacity() > 2 * self.key_map.len() { self.key_map.shrink_to_fit(); } self.compact_called = true; } // Tries to re-map a `Key` that was given out to the user to its // corresponding internal key. fn remap_key(&self, key: &Key) -> Option { let key_map = &self.key_map; if self.compact_called { key_map.get(key).copied() } else { Some((*key).into()) } } fn create_new_key(&mut self) -> KeyInternal { while self.key_map.contains_key(&Key::new(self.next_key_index)) { self.next_key_index = self.next_key_index.wrapping_add(1); } KeyInternal::new(self.next_key_index) } pub(crate) fn len(&self) -> usize { self.inner.len() } pub(crate) fn capacity(&self) -> usize { self.inner.capacity() } pub(crate) fn clear(&mut self) { self.inner.clear(); self.key_map.clear(); self.compact_called = false; } pub(crate) fn reserve(&mut self, additional: usize) { self.inner.reserve(additional); if self.compact_called { self.key_map.reserve(additional); } } pub(crate) fn is_empty(&self) -> bool { self.inner.is_empty() } pub(crate) fn contains(&self, key: &Key) -> bool { let remapped_key = self.remap_key(key); match remapped_key { Some(internal_key) => self.inner.contains(internal_key.index), None => false, } } } impl fmt::Debug for SlabStorage where T: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { if fmt.alternate() { fmt.debug_map().entries(self.inner.iter()).finish() } else { fmt.debug_struct("Slab") .field("len", &self.len()) .field("cap", &self.capacity()) .finish() } } } impl Index for SlabStorage { type Output = Data; fn index(&self, key: Key) -> &Self::Output { let remapped_key = self.remap_key(&key); match remapped_key { Some(internal_key) => &self.inner[internal_key.index], None => panic!("Invalid index {}", key.index), } } } impl IndexMut for SlabStorage { fn index_mut(&mut self, key: Key) -> &mut Data { let remapped_key = self.remap_key(&key); match remapped_key { Some(internal_key) => &mut self.inner[internal_key.index], None => panic!("Invalid index {}", key.index), } } } /// An entry in `DelayQueue` that has expired and been removed. /// /// Values are returned by [`DelayQueue::poll_expired`]. /// /// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired #[derive(Debug)] pub struct Expired { /// The data stored in the queue data: T, /// The expiration time deadline: Instant, /// The key associated with the entry key: Key, } /// Token to a value stored in a `DelayQueue`. /// /// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] /// documentation for more details. /// /// [`DelayQueue`]: struct@DelayQueue /// [`DelayQueue::insert`]: method@DelayQueue::insert #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Key { index: usize, } // Whereas `Key` is given out to users that use `DelayQueue`, internally we use // `KeyInternal` as the key type in order to make the logic of mapping between keys // as a result of `compact` calls clearer. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct KeyInternal { index: usize, } #[derive(Debug)] struct Stack { /// Head of the stack head: Option, _p: PhantomData T>, } #[derive(Debug)] struct Data { /// The data being stored in the queue and will be returned at the requested /// instant. inner: T, /// The instant at which the item is returned. when: u64, /// Set to true when stored in the `expired` queue expired: bool, /// Next entry in the stack next: Option, /// Previous entry in the stack prev: Option, } /// Maximum number of entries the queue can handle const MAX_ENTRIES: usize = (1 << 30) - 1; impl DelayQueue { /// Creates a new, empty, `DelayQueue`. /// /// The queue will not allocate storage until items are inserted into it. /// /// # Examples /// /// ```rust /// # use tokio_util::time::DelayQueue; /// let delay_queue: DelayQueue = DelayQueue::new(); /// ``` pub fn new() -> DelayQueue { DelayQueue::with_capacity(0) } /// Creates a new, empty, `DelayQueue` with the specified capacity. /// /// The queue will be able to hold at least `capacity` elements without /// reallocating. If `capacity` is 0, the queue will not allocate for /// storage. /// /// # Examples /// /// ```rust /// # use tokio_util::time::DelayQueue; /// # use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::with_capacity(10); /// /// // These insertions are done without further allocation /// for i in 0..10 { /// delay_queue.insert(i, Duration::from_secs(i)); /// } /// /// // This will make the queue allocate additional storage /// delay_queue.insert(11, Duration::from_secs(11)); /// # } /// ``` pub fn with_capacity(capacity: usize) -> DelayQueue { DelayQueue { wheel: Wheel::new(), slab: SlabStorage::with_capacity(capacity), expired: Stack::default(), delay: None, wheel_now: 0, start: Instant::now(), waker: None, } } /// Inserts `value` into the queue set to expire at a specific instant in /// time. /// /// This function is identical to `insert`, but takes an `Instant` instead /// of a `Duration`. /// /// `value` is stored in the queue until `when` is reached. At which point, /// `value` will be returned from [`poll_expired`]. If `when` has already been /// reached, then `value` is immediately made available to poll. /// /// The return value represents the insertion and is used as an argument to /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once /// `value` is removed from the queue either by calling [`poll_expired`] after /// `when` is reached or by calling [`remove`]. At this point, the caller /// must take care to not use the returned [`Key`] again as it may reference /// a different item in the queue. /// /// See [type] level documentation for more details. /// /// # Panics /// /// This function panics if `when` is too far in the future. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio::time::{Duration, Instant}; /// use tokio_util::time::DelayQueue; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert_at( /// "foo", Instant::now() + Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` /// /// [`poll_expired`]: method@Self::poll_expired /// [`remove`]: method@Self::remove /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # #[track_caller] pub fn insert_at(&mut self, value: T, when: Instant) -> Key { assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); // Normalize the deadline. Values cannot be set to expire in the past. let when = self.normalize_deadline(when); // Insert the value in the store let key = self.slab.insert(Data { inner: value, when, expired: false, next: None, prev: None, }); self.insert_idx(when, key); // Set a new delay if the current's deadline is later than the one of the new item let should_set_delay = if let Some(ref delay) = self.delay { let current_exp = self.normalize_deadline(delay.deadline()); current_exp > when } else { true }; if should_set_delay { if let Some(waker) = self.waker.take() { waker.wake(); } let delay_time = self.start + Duration::from_millis(when); if let Some(ref mut delay) = &mut self.delay { delay.as_mut().reset(delay_time); } else { self.delay = Some(Box::pin(sleep_until(delay_time))); } } key } /// Attempts to pull out the next value of the delay queue, registering the /// current task for wakeup if the value is not yet available, and returning /// `None` if the queue is exhausted. pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll>> { if !self .waker .as_ref() .map(|w| w.will_wake(cx.waker())) .unwrap_or(false) { self.waker = Some(cx.waker().clone()); } let item = ready!(self.poll_idx(cx)); Poll::Ready(item.map(|key| { let data = self.slab.remove(&key); debug_assert!(data.next.is_none()); debug_assert!(data.prev.is_none()); Expired { key, data: data.inner, deadline: self.start + Duration::from_millis(data.when), } })) } /// Inserts `value` into the queue set to expire after the requested duration /// elapses. /// /// This function is identical to `insert_at`, but takes a `Duration` /// instead of an `Instant`. /// /// `value` is stored in the queue until `timeout` duration has /// elapsed after `insert` was called. At that point, `value` will /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of /// zero, then `value` is immediately made available to poll. /// /// The return value represents the insertion and is used as an /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a /// token and is reused once `value` is removed from the queue /// either by calling [`poll_expired`] after `timeout` has elapsed /// or by calling [`remove`]. At this point, the caller must not /// use the returned [`Key`] again as it may reference a different /// item in the queue. /// /// See [type] level documentation for more details. /// /// # Panics /// /// This function panics if `timeout` is greater than the maximum /// duration supported by the timer in the current `Runtime`. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` /// /// [`poll_expired`]: method@Self::poll_expired /// [`remove`]: method@Self::remove /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # #[track_caller] pub fn insert(&mut self, value: T, timeout: Duration) -> Key { self.insert_at(value, Instant::now() + timeout) } #[track_caller] fn insert_idx(&mut self, when: u64, key: Key) { use self::wheel::{InsertError, Stack}; // Register the deadline with the timer wheel match self.wheel.insert(when, key, &mut self.slab) { Ok(_) => {} Err((_, InsertError::Elapsed)) => { self.slab[key].expired = true; // The delay is already expired, store it in the expired queue self.expired.push(key, &mut self.slab); } Err((_, err)) => panic!("invalid deadline; err={:?}", err), } } /// Removes the key from the expired queue or the timer wheel /// depending on its expiration status. /// /// # Panics /// /// Panics if the key is not contained in the expired queue or the wheel. #[track_caller] fn remove_key(&mut self, key: &Key) { use crate::time::wheel::Stack; // Special case the `expired` queue if self.slab[*key].expired { self.expired.remove(key, &mut self.slab); } else { self.wheel.remove(key, &mut self.slab); } } /// Removes the item associated with `key` from the queue. /// /// There must be an item associated with `key`. The function returns the /// removed item as well as the `Instant` at which it will the delay will /// have expired. /// /// # Panics /// /// The function panics if `key` is not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // Remove the entry /// let item = delay_queue.remove(&key); /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` #[track_caller] pub fn remove(&mut self, key: &Key) -> Expired { let prev_deadline = self.next_deadline(); self.remove_key(key); let data = self.slab.remove(key); let next_deadline = self.next_deadline(); if prev_deadline != next_deadline { match (next_deadline, &mut self.delay) { (None, _) => self.delay = None, (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline), (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))), } } Expired { key: Key::new(key.index), data: data.inner, deadline: self.start + Duration::from_millis(data.when), } } /// Attempts to remove the item associated with `key` from the queue. /// /// Removes the item associated with `key`, and returns it along with the /// `Instant` at which it would have expired, if it exists. /// /// Returns `None` if `key` is not in the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main(flavor = "current_thread")] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // The item is in the queue, `try_remove` returns `Some(Expired("foo"))`. /// let item = delay_queue.try_remove(&key); /// assert_eq!(item.unwrap().into_inner(), "foo"); /// /// // The item is not in the queue anymore, `try_remove` returns `None`. /// let item = delay_queue.try_remove(&key); /// assert!(item.is_none()); /// # } /// ``` pub fn try_remove(&mut self, key: &Key) -> Option> { if self.slab.contains(key) { Some(self.remove(key)) } else { None } } /// Sets the delay of the item associated with `key` to expire at `when`. /// /// This function is identical to `reset` but takes an `Instant` instead of /// a `Duration`. /// /// The item remains in the queue but the delay is set to expire at `when`. /// If `when` is in the past, then the item is immediately made available to /// the caller. /// /// # Panics /// /// This function panics if `when` is too far in the future or if `key` is /// not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio::time::{Duration, Instant}; /// use tokio_util::time::DelayQueue; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // "foo" is scheduled to be returned in 5 seconds /// /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); /// /// // "foo" is now scheduled to be returned in 10 seconds /// # } /// ``` #[track_caller] pub fn reset_at(&mut self, key: &Key, when: Instant) { self.remove_key(key); // Normalize the deadline. Values cannot be set to expire in the past. let when = self.normalize_deadline(when); self.slab[*key].when = when; self.slab[*key].expired = false; self.insert_idx(when, *key); let next_deadline = self.next_deadline(); if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { // This should awaken us if necessary (ie, if already expired) delay.as_mut().reset(deadline); } } /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation. /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab /// to the number of elements still contained in it, because elements cannot be moved to a different /// index. To decrease the capacity to the size of the slab use [`compact`]. /// /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is /// shrunk in place. Repeated calls run in O(1) though. /// /// [`compact`]: method@Self::compact pub fn shrink_to_fit(&mut self) { self.slab.shrink_to_fit(); } /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation, /// to the number of elements that are contained in it. /// /// This methods runs in O(n). /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::with_capacity(10); /// /// let key1 = delay_queue.insert(5, Duration::from_secs(5)); /// let key2 = delay_queue.insert(10, Duration::from_secs(10)); /// let key3 = delay_queue.insert(15, Duration::from_secs(15)); /// /// delay_queue.remove(&key2); /// /// delay_queue.compact(); /// assert_eq!(delay_queue.capacity(), 2); /// # } /// ``` pub fn compact(&mut self) { self.slab.compact(); } /// Gets the [`Key`] that [`poll_expired`] will pull out of the queue next, without /// pulling it out or waiting for the deadline to expire. /// /// Entries that have already expired may be returned in any order, but it is /// guaranteed that this method returns them in the same order as when items /// are popped from the `DelayQueue`. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// /// let key1 = delay_queue.insert("foo", Duration::from_secs(10)); /// let key2 = delay_queue.insert("bar", Duration::from_secs(5)); /// let key3 = delay_queue.insert("baz", Duration::from_secs(15)); /// /// assert_eq!(delay_queue.peek().unwrap(), key2); /// # } /// ``` /// /// [`Key`]: struct@Key /// [`poll_expired`]: method@Self::poll_expired pub fn peek(&self) -> Option { use self::wheel::Stack; self.expired.peek().or_else(|| self.wheel.peek()) } /// Returns the next time to poll as determined by the wheel fn next_deadline(&mut self) -> Option { self.wheel .poll_at() .map(|poll_at| self.start + Duration::from_millis(poll_at)) } /// Sets the delay of the item associated with `key` to expire after /// `timeout`. /// /// This function is identical to `reset_at` but takes a `Duration` instead /// of an `Instant`. /// /// The item remains in the queue but the delay is set to expire after /// `timeout`. If `timeout` is zero, then the item is immediately made /// available to the caller. /// /// # Panics /// /// This function panics if `timeout` is greater than the maximum supported /// duration or if `key` is not contained by the queue. /// /// # Examples /// /// Basic usage /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// let key = delay_queue.insert("foo", Duration::from_secs(5)); /// /// // "foo" is scheduled to be returned in 5 seconds /// /// delay_queue.reset(&key, Duration::from_secs(10)); /// /// // "foo"is now scheduled to be returned in 10 seconds /// # } /// ``` #[track_caller] pub fn reset(&mut self, key: &Key, timeout: Duration) { self.reset_at(key, Instant::now() + timeout); } /// Clears the queue, removing all items. /// /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. /// /// Note that this method has no effect on the allocated capacity. /// /// [`poll_expired`]: method@Self::poll_expired /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// /// delay_queue.insert("foo", Duration::from_secs(5)); /// /// assert!(!delay_queue.is_empty()); /// /// delay_queue.clear(); /// /// assert!(delay_queue.is_empty()); /// # } /// ``` pub fn clear(&mut self) { self.slab.clear(); self.expired = Stack::default(); self.wheel = Wheel::new(); self.delay = None; } /// Returns the number of elements the queue can hold without reallocating. /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// /// let delay_queue: DelayQueue = DelayQueue::with_capacity(10); /// assert_eq!(delay_queue.capacity(), 10); /// ``` pub fn capacity(&self) -> usize { self.slab.capacity() } /// Returns the number of elements currently in the queue. /// /// # Examples /// /// ```rust /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue: DelayQueue = DelayQueue::with_capacity(10); /// assert_eq!(delay_queue.len(), 0); /// delay_queue.insert(3, Duration::from_secs(5)); /// assert_eq!(delay_queue.len(), 1); /// # } /// ``` pub fn len(&self) -> usize { self.slab.len() } /// Reserves capacity for at least `additional` more items to be queued /// without allocating. /// /// `reserve` does nothing if the queue already has sufficient capacity for /// `additional` more values. If more capacity is required, a new segment of /// memory will be allocated and all existing values will be copied into it. /// As such, if the queue is already very large, a call to `reserve` can end /// up being expensive. /// /// The queue may reserve more than `additional` extra space in order to /// avoid frequent reallocations. /// /// # Panics /// /// Panics if the new capacity exceeds the maximum number of entries the /// queue can contain. /// /// # Examples /// /// ``` /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// /// delay_queue.insert("hello", Duration::from_secs(10)); /// delay_queue.reserve(10); /// /// assert!(delay_queue.capacity() >= 11); /// # } /// ``` #[track_caller] pub fn reserve(&mut self, additional: usize) { assert!( self.slab.capacity() + additional <= MAX_ENTRIES, "max queue capacity exceeded" ); self.slab.reserve(additional); } /// Returns `true` if there are no items in the queue. /// /// Note that this function returns `false` even if all items have not yet /// expired and a call to `poll` will return `Poll::Pending`. /// /// # Examples /// /// ``` /// use tokio_util::time::DelayQueue; /// use std::time::Duration; /// /// # #[tokio::main] /// # async fn main() { /// let mut delay_queue = DelayQueue::new(); /// assert!(delay_queue.is_empty()); /// /// delay_queue.insert("hello", Duration::from_secs(5)); /// assert!(!delay_queue.is_empty()); /// # } /// ``` pub fn is_empty(&self) -> bool { self.slab.is_empty() } /// Polls the queue, returning the index of the next slot in the slab that /// should be returned. /// /// A slot should be returned when the associated deadline has been reached. fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll> { use self::wheel::Stack; let expired = self.expired.pop(&mut self.slab); if expired.is_some() { return Poll::Ready(expired); } loop { if let Some(ref mut delay) = self.delay { if !delay.is_elapsed() { ready!(Pin::new(&mut *delay).poll(cx)); } let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); self.wheel_now = now; } // We poll the wheel to get the next value out before finding the next deadline. let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab); self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when))); if let Some(idx) = wheel_idx { return Poll::Ready(Some(idx)); } if self.delay.is_none() { return Poll::Ready(None); } } } fn normalize_deadline(&self, when: Instant) -> u64 { let when = if when < self.start { 0 } else { crate::time::ms(when - self.start, crate::time::Round::Up) }; cmp::max(when, self.wheel.elapsed()) } } // We never put `T` in a `Pin`... impl Unpin for DelayQueue {} impl Default for DelayQueue { fn default() -> DelayQueue { DelayQueue::new() } } impl futures_core::Stream for DelayQueue { // DelayQueue seems much more specific, where a user may care that it // has reached capacity, so return those errors instead of panicking. type Item = Expired; fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { DelayQueue::poll_expired(self.get_mut(), cx) } } impl wheel::Stack for Stack { type Owned = Key; type Borrowed = Key; type Store = SlabStorage; fn is_empty(&self) -> bool { self.head.is_none() } fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { // Ensure the entry is not already in a stack. debug_assert!(store[item].next.is_none()); debug_assert!(store[item].prev.is_none()); // Remove the old head entry let old = self.head.take(); if let Some(idx) = old { store[idx].prev = Some(item); } store[item].next = old; self.head = Some(item); } fn pop(&mut self, store: &mut Self::Store) -> Option { if let Some(key) = self.head { self.head = store[key].next; if let Some(idx) = self.head { store[idx].prev = None; } store[key].next = None; debug_assert!(store[key].prev.is_none()); Some(key) } else { None } } fn peek(&self) -> Option { self.head } #[track_caller] fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { let key = *item; assert!(store.contains(item)); // Ensure that the entry is in fact contained by the stack debug_assert!({ // This walks the full linked list even if an entry is found. let mut next = self.head; let mut contains = false; while let Some(idx) = next { let data = &store[idx]; if idx == *item { debug_assert!(!contains); contains = true; } next = data.next; } contains }); if let Some(next) = store[key].next { store[next].prev = store[key].prev; } if let Some(prev) = store[key].prev { store[prev].next = store[key].next; } else { self.head = store[key].next; } store[key].next = None; store[key].prev = None; } fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { store[*item].when } } impl Default for Stack { fn default() -> Stack { Stack { head: None, _p: PhantomData, } } } impl Key { pub(crate) fn new(index: usize) -> Key { Key { index } } } impl KeyInternal { pub(crate) fn new(index: usize) -> KeyInternal { KeyInternal { index } } } impl From for KeyInternal { fn from(item: Key) -> Self { KeyInternal::new(item.index) } } impl From for Key { fn from(item: KeyInternal) -> Self { Key::new(item.index) } } impl Expired { /// Returns a reference to the inner value. pub fn get_ref(&self) -> &T { &self.data } /// Returns a mutable reference to the inner value. pub fn get_mut(&mut self) -> &mut T { &mut self.data } /// Consumes `self` and returns the inner value. pub fn into_inner(self) -> T { self.data } /// Returns the deadline that the expiration was set to. pub fn deadline(&self) -> Instant { self.deadline } /// Returns the key that the expiration is indexed by. pub fn key(&self) -> Key { self.key } } tokio-util-0.7.10/src/time/mod.rs000064400000000000000000000022111046102023000146570ustar 00000000000000//! Additional utilities for tracking time. //! //! This module provides additional utilities for executing code after a set period //! of time. Currently there is only one: //! //! * `DelayQueue`: A queue where items are returned once the requested delay //! has expired. //! //! This type must be used from within the context of the `Runtime`. use std::time::Duration; mod wheel; pub mod delay_queue; #[doc(inline)] pub use delay_queue::DelayQueue; // ===== Internal utils ===== enum Round { Up, Down, } /// Convert a `Duration` to milliseconds, rounding up and saturating at /// `u64::MAX`. /// /// The saturating is fine because `u64::MAX` milliseconds are still many /// million years. #[inline] fn ms(duration: Duration, round: Round) -> u64 { const NANOS_PER_MILLI: u32 = 1_000_000; const MILLIS_PER_SEC: u64 = 1_000; // Round up. let millis = match round { Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, Round::Down => duration.subsec_millis(), }; duration .as_secs() .saturating_mul(MILLIS_PER_SEC) .saturating_add(u64::from(millis)) } tokio-util-0.7.10/src/time/wheel/level.rs000064400000000000000000000175461046102023000163340ustar 00000000000000use crate::time::wheel::Stack; use std::fmt; /// Wheel for a single level in the timer. This wheel contains 64 slots. pub(crate) struct Level { level: usize, /// Bit field tracking which slots currently contain entries. /// /// Using a bit field to track slots that contain entries allows avoiding a /// scan to find entries. This field is updated when entries are added or /// removed from a slot. /// /// The least-significant bit represents slot zero. occupied: u64, /// Slots slot: [T; LEVEL_MULT], } /// Indicates when a slot must be processed next. #[derive(Debug)] pub(crate) struct Expiration { /// The level containing the slot. pub(crate) level: usize, /// The slot index. pub(crate) slot: usize, /// The instant at which the slot needs to be processed. pub(crate) deadline: u64, } /// Level multiplier. /// /// Being a power of 2 is very important. const LEVEL_MULT: usize = 64; impl Level { pub(crate) fn new(level: usize) -> Level { // Rust's derived implementations for arrays require that the value // contained by the array be `Copy`. So, here we have to manually // initialize every single slot. macro_rules! s { () => { T::default() }; } Level { level, occupied: 0, slot: [ // It does not look like the necessary traits are // derived for [T; 64]. s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), s!(), ], } } /// Finds the slot that needs to be processed next and returns the slot and /// `Instant` at which this slot must be processed. pub(crate) fn next_expiration(&self, now: u64) -> Option { // Use the `occupied` bit field to get the index of the next slot that // needs to be processed. let slot = match self.next_occupied_slot(now) { Some(slot) => slot, None => return None, }; // From the slot index, calculate the `Instant` at which it needs to be // processed. This value *must* be in the future with respect to `now`. let level_range = level_range(self.level); let slot_range = slot_range(self.level); // TODO: This can probably be simplified w/ power of 2 math let level_start = now - (now % level_range); let mut deadline = level_start + slot as u64 * slot_range; if deadline < now { // A timer is in a slot "prior" to the current time. This can occur // because we do not have an infinite hierarchy of timer levels, and // eventually a timer scheduled for a very distant time might end up // being placed in a slot that is beyond the end of all of the // arrays. // // To deal with this, we first limit timers to being scheduled no // more than MAX_DURATION ticks in the future; that is, they're at // most one rotation of the top level away. Then, we force timers // that logically would go into the top+1 level, to instead go into // the top level's slots. // // What this means is that the top level's slots act as a // pseudo-ring buffer, and we rotate around them indefinitely. If we // compute a deadline before now, and it's the top level, it // therefore means we're actually looking at a slot in the future. debug_assert_eq!(self.level, super::NUM_LEVELS - 1); deadline += level_range; } debug_assert!( deadline >= now, "deadline={:016X}; now={:016X}; level={}; slot={}; occupied={:b}", deadline, now, self.level, slot, self.occupied ); Some(Expiration { level: self.level, slot, deadline, }) } fn next_occupied_slot(&self, now: u64) -> Option { if self.occupied == 0 { return None; } // Get the slot for now using Maths let now_slot = (now / slot_range(self.level)) as usize; let occupied = self.occupied.rotate_right(now_slot as u32); let zeros = occupied.trailing_zeros() as usize; let slot = (zeros + now_slot) % 64; Some(slot) } pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { let slot = slot_for(when, self.level); self.slot[slot].push(item, store); self.occupied |= occupied_bit(slot); } pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { let slot = slot_for(when, self.level); self.slot[slot].remove(item, store); if self.slot[slot].is_empty() { // The bit is currently set debug_assert!(self.occupied & occupied_bit(slot) != 0); // Unset the bit self.occupied ^= occupied_bit(slot); } } pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option { let ret = self.slot[slot].pop(store); if ret.is_some() && self.slot[slot].is_empty() { // The bit is currently set debug_assert!(self.occupied & occupied_bit(slot) != 0); self.occupied ^= occupied_bit(slot); } ret } pub(crate) fn peek_entry_slot(&self, slot: usize) -> Option { self.slot[slot].peek() } } impl fmt::Debug for Level { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Level") .field("occupied", &self.occupied) .finish() } } fn occupied_bit(slot: usize) -> u64 { 1 << slot } fn slot_range(level: usize) -> u64 { LEVEL_MULT.pow(level as u32) as u64 } fn level_range(level: usize) -> u64 { LEVEL_MULT as u64 * slot_range(level) } /// Convert a duration (milliseconds) and a level to a slot position fn slot_for(duration: u64, level: usize) -> usize { ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize } #[cfg(all(test, not(loom)))] mod test { use super::*; #[test] fn test_slot_for() { for pos in 0..64 { assert_eq!(pos as usize, slot_for(pos, 0)); } for level in 1..5 { for pos in level..64 { let a = pos * 64_usize.pow(level as u32); assert_eq!(pos as usize, slot_for(a as u64, level)); } } } } tokio-util-0.7.10/src/time/wheel/mod.rs000064400000000000000000000230471046102023000157750ustar 00000000000000mod level; pub(crate) use self::level::Expiration; use self::level::Level; mod stack; pub(crate) use self::stack::Stack; use std::borrow::Borrow; use std::fmt::Debug; use std::usize; /// Timing wheel implementation. /// /// This type provides the hashed timing wheel implementation that backs `Timer` /// and `DelayQueue`. /// /// The structure is generic over `T: Stack`. This allows handling timeout data /// being stored on the heap or in a slab. In order to support the latter case, /// the slab must be passed into each function allowing the implementation to /// lookup timer entries. /// /// See `Timer` documentation for some implementation notes. #[derive(Debug)] pub(crate) struct Wheel { /// The number of milliseconds elapsed since the wheel started. elapsed: u64, /// Timer wheel. /// /// Levels: /// /// * 1 ms slots / 64 ms range /// * 64 ms slots / ~ 4 sec range /// * ~ 4 sec slots / ~ 4 min range /// * ~ 4 min slots / ~ 4 hr range /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Vec>, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots /// each, the timer is able to track time up to 2 years into the future with a /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; /// The maximum duration of a delay const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; #[derive(Debug)] pub(crate) enum InsertError { Elapsed, Invalid, } impl Wheel where T: Stack, { /// Create a new timing wheel pub(crate) fn new() -> Wheel { let levels = (0..NUM_LEVELS).map(Level::new).collect(); Wheel { elapsed: 0, levels } } /// Return the number of milliseconds that have elapsed since the timing /// wheel's creation. pub(crate) fn elapsed(&self) -> u64 { self.elapsed } /// Insert an entry into the timing wheel. /// /// # Arguments /// /// * `when`: is the instant at which the entry should be fired. It is /// represented as the number of milliseconds since the creation /// of the timing wheel. /// /// * `item`: The item to insert into the wheel. /// /// * `store`: The slab or `()` when using heap storage. /// /// # Return /// /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. /// /// `Err(Elapsed)` indicates that `when` represents an instant that has /// already passed. In this case, the caller should fire the timeout /// immediately. /// /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. pub(crate) fn insert( &mut self, when: u64, item: T::Owned, store: &mut T::Store, ) -> Result<(), (T::Owned, InsertError)> { if when <= self.elapsed { return Err((item, InsertError::Elapsed)); } else if when - self.elapsed > MAX_DURATION { return Err((item, InsertError::Invalid)); } // Get the level at which the entry should be stored let level = self.level_for(when); self.levels[level].add_entry(when, item, store); debug_assert!({ self.levels[level] .next_expiration(self.elapsed) .map(|e| e.deadline >= self.elapsed) .unwrap_or(true) }); Ok(()) } /// Remove `item` from the timing wheel. #[track_caller] pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { let when = T::when(item, store); assert!( self.elapsed <= when, "elapsed={}; when={}", self.elapsed, when ); let level = self.level_for(when); self.levels[level].remove_entry(when, item, store); } /// Instant at which to poll pub(crate) fn poll_at(&self) -> Option { self.next_expiration().map(|expiration| expiration.deadline) } /// Next key that will expire pub(crate) fn peek(&self) -> Option { self.next_expiration() .and_then(|expiration| self.peek_entry(&expiration)) } /// Advances the timer up to the instant represented by `now`. pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option { loop { let expiration = self.next_expiration().and_then(|expiration| { if expiration.deadline > now { None } else { Some(expiration) } }); match expiration { Some(ref expiration) => { if let Some(item) = self.poll_expiration(expiration, store) { return Some(item); } self.set_elapsed(expiration.deadline); } None => { // in this case the poll did not indicate an expiration // _and_ we were not able to find a next expiration in // the current list of timers. advance to the poll's // current time and do nothing else. self.set_elapsed(now); return None; } } } } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option { // Check all levels for level in 0..NUM_LEVELS { if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { // There cannot be any expirations at a higher level that happen // before this one. debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); return Some(expiration); } } None } /// Used for debug assertions fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { let mut res = true; for l2 in start_level..NUM_LEVELS { if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { if e2.deadline < before { res = false; } } } res } /// iteratively find entries that are between the wheel's current /// time and the expiration time. for each in that population either /// return it for notification (in the case of the last level) or tier /// it down to the next level (in all other cases). pub(crate) fn poll_expiration( &mut self, expiration: &Expiration, store: &mut T::Store, ) -> Option { while let Some(item) = self.pop_entry(expiration, store) { if expiration.level == 0 { debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); return Some(item); } else { let when = T::when(item.borrow(), store); let next_level = expiration.level - 1; self.levels[next_level].add_entry(when, item, store); } } None } fn set_elapsed(&mut self, when: u64) { assert!( self.elapsed <= when, "elapsed={:?}; when={:?}", self.elapsed, when ); if when > self.elapsed { self.elapsed = when; } } fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option { self.levels[expiration.level].pop_entry_slot(expiration.slot, store) } fn peek_entry(&self, expiration: &Expiration) -> Option { self.levels[expiration.level].peek_entry_slot(expiration.slot) } fn level_for(&self, when: u64) -> usize { level_for(self.elapsed, when) } } fn level_for(elapsed: u64, when: u64) -> usize { const SLOT_MASK: u64 = (1 << 6) - 1; // Mask in the trailing bits ignored by the level calculation in order to cap // the possible leading zeros let mut masked = elapsed ^ when | SLOT_MASK; if masked >= MAX_DURATION { // Fudge the timer into the top level masked = MAX_DURATION - 1; } let leading_zeros = masked.leading_zeros() as usize; let significant = 63 - leading_zeros; significant / 6 } #[cfg(all(test, not(loom)))] mod test { use super::*; #[test] fn test_level_for() { for pos in 0..64 { assert_eq!( 0, level_for(0, pos), "level_for({}) -- binary = {:b}", pos, pos ); } for level in 1..5 { for pos in level..64 { let a = pos * 64_usize.pow(level as u32); assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); if pos > level { let a = a - 1; assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); } if pos < 64 { let a = a + 1; assert_eq!( level, level_for(0, a as u64), "level_for({}) -- binary = {:b}", a, a ); } } } } } tokio-util-0.7.10/src/time/wheel/stack.rs000064400000000000000000000016011046102023000163130ustar 00000000000000use std::borrow::Borrow; use std::cmp::Eq; use std::hash::Hash; /// Abstracts the stack operations needed to track timeouts. pub(crate) trait Stack: Default { /// Type of the item stored in the stack type Owned: Borrow; /// Borrowed item type Borrowed: Eq + Hash; /// Item storage, this allows a slab to be used instead of just the heap type Store; /// Returns `true` if the stack is empty fn is_empty(&self) -> bool; /// Push an item onto the stack fn push(&mut self, item: Self::Owned, store: &mut Self::Store); /// Pop an item from the stack fn pop(&mut self, store: &mut Self::Store) -> Option; /// Peek into the stack. fn peek(&self) -> Option; fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; } tokio-util-0.7.10/src/udp/frame.rs000064400000000000000000000173321046102023000150360ustar 00000000000000use crate::codec::{Decoder, Encoder}; use futures_core::Stream; use tokio::{io::ReadBuf, net::UdpSocket}; use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; use std::pin::Pin; use std::task::{Context, Poll}; use std::{ borrow::Borrow, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, }; use std::{io, mem::MaybeUninit}; /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using /// the `Encoder` and `Decoder` traits to encode and decode frames. /// /// Raw UDP sockets work with datagrams, but higher-level code usually wants to /// batch these into meaningful chunks, called "frames". This method layers /// framing on top of this socket by using the `Encoder` and `Decoder` traits to /// handle encoding and decoding of messages frames. Note that the incoming and /// outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; /// grouping this into a single object is often useful for layering things which /// require both read and write access to the underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `UdpFramed` returned by this method, which will break /// them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split #[must_use = "sinks do nothing unless polled"] #[derive(Debug)] pub struct UdpFramed { socket: T, codec: C, rd: BytesMut, wr: BytesMut, out_addr: SocketAddr, flushed: bool, is_readable: bool, current_addr: Option, } const INITIAL_RD_CAPACITY: usize = 64 * 1024; const INITIAL_WR_CAPACITY: usize = 8 * 1024; impl Unpin for UdpFramed {} impl Stream for UdpFramed where T: Borrow, C: Decoder, { type Item = Result<(C::Item, SocketAddr), C::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pin = self.get_mut(); pin.rd.reserve(INITIAL_RD_CAPACITY); loop { // Are there still bytes left in the read buffer to decode? if pin.is_readable { if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { let current_addr = pin .current_addr .expect("will always be set before this line is called"); return Poll::Ready(Some(Ok((frame, current_addr)))); } // if this line has been reached then decode has returned `None`. pin.is_readable = false; pin.rd.clear(); } // We're out of data. Try and fetch more data to decode let addr = { // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a // transparent wrapper around `[MaybeUninit]`. let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit]) }; let mut read = ReadBuf::uninit(buf); let ptr = read.filled().as_ptr(); let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); assert_eq!(ptr, read.filled().as_ptr()); let addr = res?; // Safety: This is guaranteed to be the number of initialized (and read) bytes due // to the invariants provided by `ReadBuf::filled`. unsafe { pin.rd.advance_mut(read.filled().len()) }; addr }; pin.current_addr = Some(addr); pin.is_readable = true; } } } impl Sink<(I, SocketAddr)> for UdpFramed where T: Borrow, C: Encoder, { type Error = C::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.flushed { match self.poll_flush(cx)? { Poll::Ready(()) => {} Poll::Pending => return Poll::Pending, } } Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { let (frame, out_addr) = item; let pin = self.get_mut(); pin.codec.encode(frame, &mut pin.wr)?; pin.out_addr = out_addr; pin.flushed = false; Ok(()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.flushed { return Poll::Ready(Ok(())); } let Self { ref socket, ref mut out_addr, ref mut wr, .. } = *self; let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; let wrote_all = n == self.wr.len(); self.wr.clear(); self.flushed = true; let res = if wrote_all { Ok(()) } else { Err(io::Error::new( io::ErrorKind::Other, "failed to write entire datagram to socket", ) .into()) }; Poll::Ready(res) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.poll_flush(cx))?; Poll::Ready(Ok(())) } } impl UdpFramed where T: Borrow, { /// Create a new `UdpFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. pub fn new(socket: T, codec: C) -> UdpFramed { Self { socket, codec, out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), flushed: true, is_readable: false, current_addr: None, } } /// Returns a reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_ref(&self) -> &T { &self.socket } /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_mut(&mut self) -> &mut T { &mut self.socket } /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &C { &self.codec } /// Returns a mutable reference to the underlying codec wrapped by /// `UdpFramed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut C { &mut self.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.rd } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.rd } /// Consumes the `Framed`, returning its underlying I/O stream. pub fn into_inner(self) -> T { self.socket } } tokio-util-0.7.10/src/udp/mod.rs000064400000000000000000000000661046102023000145170ustar 00000000000000//! UDP framing mod frame; pub use frame::UdpFramed; tokio-util-0.7.10/src/util/maybe_dangling.rs000064400000000000000000000042001046102023000170570ustar 00000000000000use core::future::Future; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::{Context, Poll}; /// A wrapper type that tells the compiler that the contents might not be valid. /// /// This is necessary mainly when `T` contains a reference. In that case, the /// compiler will sometimes assume that the reference is always valid; in some /// cases it will assume this even after the destructor of `T` runs. For /// example, when a reference is used as a function argument, then the compiler /// will assume that the reference is valid until the function returns, even if /// the reference is destroyed during the function. When the reference is used /// as part of a self-referential struct, that assumption can be false. Wrapping /// the reference in this type prevents the compiler from making that /// assumption. /// /// # Invariants /// /// The `MaybeUninit` will always contain a valid value until the destructor runs. // // Reference // See // // TODO: replace this with an official solution once RFC #3336 or similar is available. // #[repr(transparent)] pub(crate) struct MaybeDangling(MaybeUninit); impl Drop for MaybeDangling { fn drop(&mut self) { // Safety: `0` is always initialized. unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) }; } } impl MaybeDangling { pub(crate) fn new(inner: T) -> Self { Self(MaybeUninit::new(inner)) } } impl Future for MaybeDangling { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Safety: `0` is always initialized. let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) }; fut.poll(cx) } } #[test] fn maybedangling_runs_drop() { struct SetOnDrop<'a>(&'a mut bool); impl Drop for SetOnDrop<'_> { fn drop(&mut self) { *self.0 = true; } } let mut success = false; drop(MaybeDangling::new(SetOnDrop(&mut success))); assert!(success); } tokio-util-0.7.10/src/util/mod.rs000064400000000000000000000004331046102023000147020ustar 00000000000000mod maybe_dangling; #[cfg(any(feature = "io", feature = "codec"))] mod poll_buf; pub(crate) use maybe_dangling::MaybeDangling; #[cfg(any(feature = "io", feature = "codec"))] #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub use poll_buf::{poll_read_buf, poll_write_buf}; tokio-util-0.7.10/src/util/poll_buf.rs000064400000000000000000000076741046102023000157430ustar 00000000000000use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Buf, BufMut}; use futures_core::ready; use std::io::{self, IoSlice}; use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. /// /// [`BufMut`]: bytes::Buf /// /// # Example /// /// ``` /// use bytes::{Bytes, BytesMut}; /// use tokio_stream as stream; /// use tokio::io::Result; /// use tokio_util::io::{StreamReader, poll_read_buf}; /// use futures::future::poll_fn; /// use std::pin::Pin; /// # #[tokio::main] /// # async fn main() -> std::io::Result<()> { /// /// // Create a reader from an iterator. This particular reader will always be /// // ready. /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); /// /// let mut buf = BytesMut::new(); /// let mut reads = 0; /// /// loop { /// reads += 1; /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; /// /// if n == 0 { /// break; /// } /// } /// /// // one or more reads might be necessary. /// assert!(reads >= 1); /// assert_eq!(&buf[..], &[0, 1, 2, 3]); /// # Ok(()) /// # } /// ``` #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub fn poll_read_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { if !buf.has_remaining_mut() { return Poll::Ready(Ok(0)); } let n = { let dst = buf.chunk_mut(); // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a // transparent wrapper around `[MaybeUninit]`. let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); let ptr = buf.filled().as_ptr(); ready!(io.poll_read(cx, &mut buf)?); // Ensure the pointer does not change from under us assert_eq!(ptr, buf.filled().as_ptr()); buf.filled().len() }; // Safety: This is guaranteed to be the number of initialized (and read) // bytes due to the invariants provided by `ReadBuf::filled`. unsafe { buf.advance_mut(n); } Poll::Ready(Ok(n)) } /// Try to write data from an implementer of the [`Buf`] trait to an /// [`AsyncWrite`], advancing the buffer's internal cursor. /// /// This function will use [vectored writes] when the [`AsyncWrite`] supports /// vectored writes. /// /// # Examples /// /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements /// [`Buf`]: /// /// ```no_run /// use tokio_util::io::poll_write_buf; /// use tokio::io; /// use tokio::fs::File; /// /// use bytes::Buf; /// use std::io::Cursor; /// use std::pin::Pin; /// use futures::future::poll_fn; /// /// #[tokio::main] /// async fn main() -> io::Result<()> { /// let mut file = File::create("foo.txt").await?; /// let mut buf = Cursor::new(b"data to write"); /// /// // Loop until the entire contents of the buffer are written to /// // the file. /// while buf.has_remaining() { /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; /// } /// /// Ok(()) /// } /// ``` /// /// [`Buf`]: bytes::Buf /// [`AsyncWrite`]: tokio::io::AsyncWrite /// [`File`]: tokio::fs::File /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub fn poll_write_buf( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll> { const MAX_BUFS: usize = 64; if !buf.has_remaining() { return Poll::Ready(Ok(0)); } let n = if io.is_write_vectored() { let mut slices = [IoSlice::new(&[]); MAX_BUFS]; let cnt = buf.chunks_vectored(&mut slices); ready!(io.poll_write_vectored(cx, &slices[..cnt]))? } else { ready!(io.poll_write(cx, buf.chunk()))? }; buf.advance(n); Poll::Ready(Ok(n)) } tokio-util-0.7.10/tests/_require_full.rs000064400000000000000000000001361046102023000163560ustar 00000000000000#![cfg(not(feature = "full"))] compile_error!("run tokio-util tests with `--features full`"); tokio-util-0.7.10/tests/codecs.rs000064400000000000000000000323531046102023000147670ustar 00000000000000#![warn(rust_2018_idioms)] use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec}; use bytes::{BufMut, Bytes, BytesMut}; #[test] fn bytes_decoder() { let mut codec = BytesCodec::new(); let buf = &mut BytesMut::new(); buf.put_slice(b"abc"); assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"a"); assert_eq!("a", codec.decode(buf).unwrap().unwrap()); } #[test] fn bytes_encoder() { let mut codec = BytesCodec::new(); // Default capacity of BytesMut #[cfg(target_pointer_width = "64")] const INLINE_CAP: usize = 4 * 8 - 1; #[cfg(target_pointer_width = "32")] const INLINE_CAP: usize = 4 * 4 - 1; let mut buf = BytesMut::new(); codec .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) .unwrap(); // Default capacity of Framed Read const INITIAL_CAPACITY: usize = 8 * 1024; let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); codec .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) .unwrap(); codec .encode(BytesMut::from(&b"hello"[..]), &mut buf) .unwrap(); } #[test] fn lines_decoder() { let mut codec = LinesCodec::new(); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); assert_eq!("", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); } #[test] fn lines_decoder_max_length() { const MAX_LENGTH: usize = 6; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); assert!(codec.decode(buf).is_err()); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("line 2", line); assert!(codec.decode(buf).is_err()); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("line 4", line); let line = codec.decode(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("", line); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); let line = codec.decode_eof(buf).unwrap().unwrap(); assert!( line.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", line, MAX_LENGTH ); assert_eq!("\rk", line); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); // Line that's one character too long. This could cause an out of bounds // error if we peek at the next characters using slice indexing. buf.put_slice(b"aaabbbc"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_underrun() { const MAX_LENGTH: usize = 6; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"ong\n"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"line 2"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"\n"); assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); } #[test] fn lines_decoder_max_length_bursts() { const MAX_LENGTH: usize = 10; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\n"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_big_burst() { const MAX_LENGTH: usize = 10; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too long!\n"); assert!(codec.decode(buf).is_err()); } #[test] fn lines_decoder_max_length_newline_between_decodes() { const MAX_LENGTH: usize = 5; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"hello"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"\nworld"); assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); } // Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) #[test] fn lines_decoder_discard_repeat() { const MAX_LENGTH: usize = 1; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"aa"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"a"); assert_eq!(None, codec.decode(buf).unwrap()); } // Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555) #[test] fn lines_decoder_max_length_underrun_twice() { const MAX_LENGTH: usize = 11; let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"line "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too very l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\nshort\n"); assert_eq!("short", codec.decode(buf).unwrap().unwrap()); } #[test] fn lines_encoder() { let mut codec = LinesCodec::new(); let mut buf = BytesMut::new(); codec.encode("line 1", &mut buf).unwrap(); assert_eq!("line 1\n", buf); codec.encode("line 2", &mut buf).unwrap(); assert_eq!("line 1\nline 2\n", buf); } #[test] fn any_delimiters_decoder_any_character() { let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec()); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); assert_eq!("", codec.decode(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap()); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); } #[test] fn any_delimiters_decoder_max_length() { const MAX_LENGTH: usize = 7; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n"); assert!(codec.decode(buf).is_err()); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 2", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 3", chunk); // \r\n cause empty chunk let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("chunk 4", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); let chunk = codec.decode(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("", chunk); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); buf.put_slice(b"k"); assert_eq!(None, codec.decode(buf).unwrap()); let chunk = codec.decode_eof(buf).unwrap().unwrap(); assert!( chunk.len() <= MAX_LENGTH, "{:?}.len() <= {:?}", chunk, MAX_LENGTH ); assert_eq!("k", chunk); assert_eq!(None, codec.decode(buf).unwrap()); assert_eq!(None, codec.decode_eof(buf).unwrap()); // Delimiter that's one character too long. This could cause an out of bounds // error if we peek at the next characters using slice indexing. buf.put_slice(b"aaabbbcc"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_underrun() { const MAX_LENGTH: usize = 7; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"ong\n"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"chunk 2"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b","); assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_max_length_underrun_twice() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too very l"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\nshort\n"); assert_eq!("short", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_max_length_bursts() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too l"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"ong\n"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_big_burst() { const MAX_LENGTH: usize = 11; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"chunk "); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b"too long!\n"); assert!(codec.decode(buf).is_err()); } #[test] fn any_delimiter_decoder_max_length_delimiter_between_decodes() { const MAX_LENGTH: usize = 5; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"hello"); assert_eq!(None, codec.decode(buf).unwrap()); buf.put_slice(b",world"); assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); } #[test] fn any_delimiter_decoder_discard_repeat() { const MAX_LENGTH: usize = 1; let mut codec = AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); let buf = &mut BytesMut::new(); buf.reserve(200); buf.put_slice(b"aa"); assert!(codec.decode(buf).is_err()); buf.put_slice(b"a"); assert_eq!(None, codec.decode(buf).unwrap()); } #[test] fn any_delimiter_encoder() { let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec()); let mut buf = BytesMut::new(); codec.encode("chunk 1", &mut buf).unwrap(); assert_eq!("chunk 1;--;", buf); codec.encode("chunk 2", &mut buf).unwrap(); assert_eq!("chunk 1;--;chunk 2;--;", buf); } tokio-util-0.7.10/tests/compat.rs000064400000000000000000000024061046102023000150060ustar 00000000000000#![cfg(all(feature = "compat"))] #![cfg(not(target_os = "wasi"))] // WASI does not support all fs operations #![warn(rust_2018_idioms)] use futures_io::SeekFrom; use futures_util::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tempfile::NamedTempFile; use tokio::fs::OpenOptions; use tokio_util::compat::TokioAsyncWriteCompatExt; #[tokio::test] async fn compat_file_seek() -> futures_util::io::Result<()> { let temp_file = NamedTempFile::new()?; let mut file = OpenOptions::new() .read(true) .write(true) .create(true) .open(temp_file) .await? .compat_write(); file.write_all(&[0, 1, 2, 3, 4, 5]).await?; file.write_all(&[6, 7]).await?; assert_eq!(file.stream_position().await?, 8); // Modify elements at position 2. assert_eq!(file.seek(SeekFrom::Start(2)).await?, 2); file.write_all(&[8, 9]).await?; file.flush().await?; // Verify we still have 8 elements. assert_eq!(file.seek(SeekFrom::End(0)).await?, 8); // Seek back to the start of the file to read and verify contents. file.seek(SeekFrom::Start(0)).await?; let mut buf = Vec::new(); let num_bytes = file.read_to_end(&mut buf).await?; assert_eq!(&buf[..num_bytes], &[0, 1, 8, 9, 4, 5, 6, 7]); Ok(()) } tokio-util-0.7.10/tests/context.rs000064400000000000000000000013411046102023000152040ustar 00000000000000#![cfg(feature = "rt")] #![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads #![warn(rust_2018_idioms)] use tokio::runtime::Builder; use tokio::time::*; use tokio_util::context::RuntimeExt; #[test] fn tokio_context_with_another_runtime() { let rt1 = Builder::new_multi_thread() .worker_threads(1) // no timer! .build() .unwrap(); let rt2 = Builder::new_multi_thread() .worker_threads(1) .enable_all() .build() .unwrap(); // Without the `HandleExt.wrap()` there would be a panic because there is // no timer running, since it would be referencing runtime r1. rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await })); } tokio-util-0.7.10/tests/framed.rs000064400000000000000000000074341046102023000147670ustar 00000000000000#![warn(rust_2018_idioms)] use tokio_stream::StreamExt; use tokio_test::assert_ok; use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; use bytes::{Buf, BufMut, BytesMut}; use std::io::{self, Read}; use std::pin::Pin; use std::task::{Context, Poll}; const INITIAL_CAPACITY: usize = 8 * 1024; /// Encode and decode u32 values. #[derive(Default)] struct U32Codec { read_bytes: usize, } impl Decoder for U32Codec { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 4 { return Ok(None); } let n = buf.split_to(4).get_u32(); self.read_bytes += 4; Ok(Some(n)) } } impl Encoder for U32Codec { type Error = io::Error; fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(4); dst.put_u32(item); Ok(()) } } /// Encode and decode u64 values. #[derive(Default)] struct U64Codec { read_bytes: usize, } impl Decoder for U64Codec { type Item = u64; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 8 { return Ok(None); } let n = buf.split_to(8).get_u64(); self.read_bytes += 8; Ok(Some(n)) } } impl Encoder for U64Codec { type Error = io::Error; fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(8); dst.put_u64(item); Ok(()) } } /// This value should never be used struct DontReadIntoThis; impl Read for DontReadIntoThis { fn read(&mut self, _: &mut [u8]) -> io::Result { Err(io::Error::new( io::ErrorKind::Other, "Read into something you weren't supposed to.", )) } } impl tokio::io::AsyncRead for DontReadIntoThis { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { unreachable!() } } #[tokio::test] async fn can_read_from_existing_buf() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let mut framed = Framed::from_parts(parts); let num = assert_ok!(framed.next().await.unwrap()); assert_eq!(num, 42); assert_eq!(framed.codec().read_bytes, 4); } #[tokio::test] async fn can_read_from_existing_buf_after_codec_changed() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]); let mut framed = Framed::from_parts(parts); let num = assert_ok!(framed.next().await.unwrap()); assert_eq!(num, 42); assert_eq!(framed.codec().read_bytes, 4); let mut framed = framed.map_codec(|codec| U64Codec { read_bytes: codec.read_bytes, }); let num = assert_ok!(framed.next().await.unwrap()); assert_eq!(num, 84); assert_eq!(framed.codec().read_bytes, 12); } #[test] fn external_buf_grows_to_init() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); let framed = Framed::from_parts(parts); let FramedParts { read_buf, .. } = framed.into_parts(); assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); } #[test] fn external_buf_does_not_shrink() { let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); let framed = Framed::from_parts(parts); let FramedParts { read_buf, .. } = framed.into_parts(); assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); } tokio-util-0.7.10/tests/framed_read.rs000064400000000000000000000216401046102023000157550ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::{AsyncRead, ReadBuf}; use tokio_test::assert_ready; use tokio_test::task; use tokio_util::codec::{Decoder, FramedRead}; use bytes::{Buf, BytesMut}; use futures::Stream; use std::collections::VecDeque; use std::io; use std::pin::Pin; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! assert_read { ($e:expr, $n:expr) => {{ let val = assert_ready!($e); assert_eq!(val.unwrap().unwrap(), $n); }}; } macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } struct U32Decoder; impl Decoder for U32Decoder { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 4 { return Ok(None); } let n = buf.split_to(4).get_u32(); Ok(Some(n)) } } struct U64Decoder; impl Decoder for U64Decoder { type Item = u64; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 8 { return Ok(None); } let n = buf.split_to(8).get_u64(); Ok(Some(n)) } } #[test] fn read_multi_frame_in_packet() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_multi_frame_across_packets() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), Ok(b"\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_multi_frame_in_packet_after_codec_changed() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0x04); let mut framed = framed.map_decoder(|_| U64Decoder); assert_read!(pin!(framed).poll_next(cx), 0x08); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_not_ready() { let mut task = task::spawn(()); let mock = mock! { Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_partial_then_not_ready() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_err() { let mut task = task::spawn(()); let mock = mock! { Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn read_partial_then_err() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn read_partial_would_block_then_err() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Err(io::Error::new(io::ErrorKind::Other, "")), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert!(pin!(framed).poll_next(cx).is_pending()); assert_eq!( io::ErrorKind::Other, assert_ready!(pin!(framed).poll_next(cx)) .unwrap() .unwrap_err() .kind() ) }); } #[test] fn huge_size() { let mut task = task::spawn(()); let data = &[0; 32 * 1024][..]; let mut framed = FramedRead::new(data, BigDecoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); struct BigDecoder; impl Decoder for BigDecoder { type Item = u32; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { if buf.len() < 32 * 1024 { return Ok(None); } buf.advance(32 * 1024); Ok(Some(0)) } } } #[test] fn data_remaining_is_error() { let mut task = task::spawn(()); let slice = &[0; 5][..]; let mut framed = FramedRead::new(slice, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); }); } #[test] fn multi_frames_on_eof() { let mut task = task::spawn(()); struct MyDecoder(Vec); impl Decoder for MyDecoder { type Item = u32; type Error = io::Error; fn decode(&mut self, _buf: &mut BytesMut) -> io::Result> { unreachable!(); } fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result> { if self.0.is_empty() { return Ok(None); } Ok(Some(self.0.remove(0))) } } let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 0); assert_read!(pin!(framed).poll_next(cx), 1); assert_read!(pin!(framed).poll_next(cx), 2); assert_read!(pin!(framed).poll_next(cx), 3); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } #[test] fn read_eof_then_resume() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x01".to_vec()), Ok(b"".to_vec()), Ok(b"\x00\x00\x00\x02".to_vec()), Ok(b"".to_vec()), Ok(b"\x00\x00\x00\x03".to_vec()), }; let mut framed = FramedRead::new(mock, U32Decoder); task.enter(|cx, _| { assert_read!(pin!(framed).poll_next(cx), 1); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert_read!(pin!(framed).poll_next(cx), 2); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert_read!(pin!(framed).poll_next(cx), 3); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); }); } // ===== Mock ====== struct Mock { calls: VecDeque>>, } impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { use io::ErrorKind::WouldBlock; match self.calls.pop_front() { Some(Ok(data)) => { debug_assert!(buf.remaining() >= data.len()); buf.put_slice(&data); Ready(Ok(())) } Some(Err(ref e)) if e.kind() == WouldBlock => Pending, Some(Err(e)) => Ready(Err(e)), None => Ready(Ok(())), } } } tokio-util-0.7.10/tests/framed_stream.rs000064400000000000000000000020271046102023000163330ustar 00000000000000use futures_core::stream::Stream; use std::{io, pin::Pin}; use tokio_test::{assert_ready, io::Builder, task}; use tokio_util::codec::{BytesCodec, FramedRead}; macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } macro_rules! assert_read { ($e:expr, $n:expr) => {{ let val = assert_ready!($e); assert_eq!(val.unwrap().unwrap(), $n); }}; } #[tokio::test] async fn return_none_after_error() { let mut io = FramedRead::new( Builder::new() .read(b"abcdef") .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) .read(b"more data") .build(), BytesCodec::new(), ); let mut task = task::spawn(()); task.enter(|cx, _| { assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); }) } tokio-util-0.7.10/tests/framed_write.rs000064400000000000000000000140221046102023000161700ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::AsyncWrite; use tokio_test::{assert_ready, task}; use tokio_util::codec::{Encoder, FramedWrite}; use bytes::{BufMut, BytesMut}; use futures_sink::Sink; use std::collections::VecDeque; use std::io::{self, Write}; use std::pin::Pin; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! pin { ($id:ident) => { Pin::new(&mut $id) }; } struct U32Encoder; impl Encoder for U32Encoder { type Error = io::Error; fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(4); dst.put_u32(item); Ok(()) } } struct U64Encoder; impl Encoder for U64Encoder { type Error = io::Error; fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { // Reserve space dst.reserve(8); dst.put_u64(item); Ok(()) } } #[test] fn write_multi_frame_in_packet() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; let mut framed = FramedWrite::new(mock, U32Encoder); task.enter(|cx, _| { assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(0).is_ok()); assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(1).is_ok()); assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(2).is_ok()); // Nothing written yet assert_eq!(1, framed.get_ref().calls.len()); // Flush the writes assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); assert_eq!(0, framed.get_ref().calls.len()); }); } #[test] fn write_multi_frame_after_codec_changed() { let mut task = task::spawn(()); let mock = mock! { Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), }; let mut framed = FramedWrite::new(mock, U32Encoder); task.enter(|cx, _| { assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(0x04).is_ok()); let mut framed = framed.map_encoder(|_| U64Encoder); assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(0x08).is_ok()); // Nothing written yet assert_eq!(1, framed.get_ref().calls.len()); // Flush the writes assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); assert_eq!(0, framed.get_ref().calls.len()); }); } #[test] fn write_hits_backpressure() { const ITER: usize = 2 * 1024; let mut mock = mock! { // Block the `ITER*2`th write Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), Ok(b"".to_vec()), }; for i in 0..=ITER * 2 { let mut b = BytesMut::with_capacity(4); b.put_u32(i as u32); // Append to the end match mock.calls.back_mut().unwrap() { Ok(ref mut data) => { // Write in 2kb chunks if data.len() < ITER { data.extend_from_slice(&b[..]); continue; } // else fall through and create a new buffer } _ => unreachable!(), } // Push a new chunk mock.calls.push_back(Ok(b[..].to_vec())); } // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer assert_eq!(mock.calls.len(), 10); let mut task = task::spawn(()); let mut framed = FramedWrite::new(mock, U32Encoder); framed.set_backpressure_boundary(ITER * 8); task.enter(|cx, _| { // Send 16KB. This fills up FramedWrite buffer for i in 0..ITER * 2 { assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); assert!(pin!(framed).start_send(i as u32).is_ok()); } // Now we poll_ready which forces a flush. The mock pops the front message // and decides to block. assert!(pin!(framed).poll_ready(cx).is_pending()); // We poll again, forcing another flush, which this time succeeds // The whole 16KB buffer is flushed assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); // Send more data. This matches the final message expected by the mock assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok()); // Flush the rest of the buffer assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); // Ensure the mock is empty assert_eq!(0, framed.get_ref().calls.len()); }) } // // ===== Mock ====== struct Mock { calls: VecDeque>>, } impl Write for Mock { fn write(&mut self, src: &[u8]) -> io::Result { match self.calls.pop_front() { Some(Ok(data)) => { assert!(src.len() >= data.len()); assert_eq!(&data[..], &src[..data.len()]); Ok(data.len()) } Some(Err(e)) => Err(e), None => panic!("unexpected write; {:?}", src), } } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl AsyncWrite for Mock { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match Pin::get_mut(self).write(buf) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, other => Ready(other), } } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { match Pin::get_mut(self).flush() { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, other => Ready(other), } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { unimplemented!() } } tokio-util-0.7.10/tests/io_inspect.rs000064400000000000000000000126571046102023000156700ustar 00000000000000use futures::future::poll_fn; use std::{ io::IoSlice, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_util::io::{InspectReader, InspectWriter}; /// An AsyncRead implementation that works byte-by-byte, to catch out callers /// who don't allow for `buf` being part-filled before the call struct SmallReader { contents: Vec, } impl Unpin for SmallReader {} impl AsyncRead for SmallReader { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if let Some(byte) = self.contents.pop() { buf.put_slice(&[byte]) } Poll::Ready(Ok(())) } } #[tokio::test] async fn read_tee() { let contents = b"This could be really long, you know".to_vec(); let reader = SmallReader { contents: contents.clone(), }; let mut altout: Vec = Vec::new(); let mut teeout = Vec::new(); { let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); tee.read_to_end(&mut teeout).await.unwrap(); } assert_eq!(teeout, altout); assert_eq!(altout.len(), contents.len()); } /// An AsyncWrite implementation that works byte-by-byte for poll_write, and /// that reads the whole of the first buffer plus one byte from the second in /// poll_write_vectored. /// /// This is designed to catch bugs in handling partially written buffers #[derive(Debug)] struct SmallWriter { contents: Vec, } impl Unpin for SmallWriter {} impl AsyncWrite for SmallWriter { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { // Just write one byte at a time if buf.is_empty() { return Poll::Ready(Ok(0)); } self.contents.push(buf[0]); Poll::Ready(Ok(1)) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { Poll::Ready(Ok(())) } fn poll_write_vectored( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { // Write all of the first buffer, then one byte from the second buffer // This should trip up anything that doesn't correctly handle multiple // buffers. if bufs.is_empty() { return Poll::Ready(Ok(0)); } let mut written_len = bufs[0].len(); self.contents.extend_from_slice(&bufs[0]); if bufs.len() > 1 { let buf = bufs[1]; if !buf.is_empty() { written_len += 1; self.contents.push(buf[0]); } } Poll::Ready(Ok(written_len)) } fn is_write_vectored(&self) -> bool { true } } #[tokio::test] async fn write_tee() { let mut altout: Vec = Vec::new(); let mut writeout = SmallWriter { contents: Vec::new(), }; { let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); tee.write_all(b"A testing string, very testing") .await .unwrap(); } assert_eq!(altout, writeout.contents); } // This is inefficient, but works well enough for test use. // If you want something similar for real code, you'll want to avoid all the // fun of manipulating `bufs` - ideally, by the time you read this, // IoSlice::advance_slices will be stable, and you can use that. async fn write_all_vectored( mut writer: W, mut bufs: Vec>, ) -> Result { let mut res = 0; while !bufs.is_empty() { let mut written = poll_fn(|cx| { let bufs: Vec = bufs.iter().map(|v| IoSlice::new(v)).collect(); Pin::new(&mut writer).poll_write_vectored(cx, &bufs) }) .await?; res += written; while written > 0 { let buf_len = bufs[0].len(); if buf_len <= written { bufs.remove(0); written -= buf_len; } else { let buf = &mut bufs[0]; let drain_len = written.min(buf.len()); buf.drain(..drain_len); written -= drain_len; } } } Ok(res) } #[tokio::test] async fn write_tee_vectored() { let mut altout: Vec = Vec::new(); let mut writeout = SmallWriter { contents: Vec::new(), }; let original = b"A very long string split up"; let bufs: Vec> = original .split(|b| b.is_ascii_whitespace()) .map(Vec::from) .collect(); assert!(bufs.len() > 1); let expected: Vec = { let mut out = Vec::new(); for item in &bufs { out.extend_from_slice(item) } out }; { let mut bufcount = 0; let tee = InspectWriter::new(&mut writeout, |bytes| { bufcount += 1; altout.extend(bytes) }); assert!(tee.is_write_vectored()); write_all_vectored(tee, bufs.clone()).await.unwrap(); assert!(bufcount >= bufs.len()); } assert_eq!(altout, writeout.contents); assert_eq!(writeout.contents, expected); } tokio-util-0.7.10/tests/io_reader_stream.rs000064400000000000000000000034041046102023000170260ustar 00000000000000#![warn(rust_2018_idioms)] use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, ReadBuf}; use tokio_stream::StreamExt; /// produces at most `remaining` zeros, that returns error. /// each time it reads at most 31 byte. struct Reader { remaining: usize, } impl AsyncRead for Reader { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let this = Pin::into_inner(self); assert_ne!(buf.remaining(), 0); if this.remaining > 0 { let n = std::cmp::min(this.remaining, buf.remaining()); let n = std::cmp::min(n, 31); for x in &mut buf.initialize_unfilled_to(n)[..n] { *x = 0; } buf.advance(n); this.remaining -= n; Poll::Ready(Ok(())) } else { Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) } } } #[tokio::test] async fn correct_behavior_on_errors() { let reader = Reader { remaining: 8000 }; let mut stream = tokio_util::io::ReaderStream::new(reader); let mut zeros_received = 0; let mut had_error = false; loop { let item = stream.next().await.unwrap(); println!("{:?}", item); match item { Ok(bytes) => { let bytes = &*bytes; for byte in bytes { assert_eq!(*byte, 0); zeros_received += 1; } } Err(_) => { assert!(!had_error); had_error = true; break; } } } assert!(had_error); assert_eq!(zeros_received, 8000); assert!(stream.next().await.is_none()); } tokio-util-0.7.10/tests/io_sink_writer.rs000064400000000000000000000043221046102023000165510ustar 00000000000000#![warn(rust_2018_idioms)] use bytes::Bytes; use futures_util::SinkExt; use std::io::{self, Error, ErrorKind}; use tokio::io::AsyncWriteExt; use tokio_util::codec::{Encoder, FramedWrite}; use tokio_util::io::{CopyToBytes, SinkWriter}; use tokio_util::sync::PollSender; #[tokio::test] async fn test_copied_sink_writer() -> Result<(), Error> { // Construct a channel pair to send data across and wrap a pollable sink. // Note that the sink must mimic a writable object, e.g. have `std::io::Error` // as its error type. // As `PollSender` requires an owned copy of the buffer, we wrap it additionally // with a `CopyToBytes` helper. let (tx, mut rx) = tokio::sync::mpsc::channel::(1); let mut writer = SinkWriter::new(CopyToBytes::new( PollSender::new(tx).sink_map_err(|_| io::Error::from(ErrorKind::BrokenPipe)), )); // Write data to our interface... let data: [u8; 4] = [1, 2, 3, 4]; let _ = writer.write(&data).await; // ... and receive it. assert_eq!(data.to_vec(), rx.recv().await.unwrap().to_vec()); Ok(()) } /// A trivial encoder. struct SliceEncoder; impl SliceEncoder { fn new() -> Self { Self {} } } impl<'a> Encoder<&'a [u8]> for SliceEncoder { type Error = Error; fn encode(&mut self, item: &'a [u8], dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { // This is where we'd write packet headers, lengths, etc. in a real encoder. // For simplicity and demonstration purposes, we just pack a copy of // the slice at the end of a buffer. dst.extend_from_slice(item); Ok(()) } } #[tokio::test] async fn test_direct_sink_writer() -> Result<(), Error> { // We define a framed writer which accepts byte slices // and 'reverse' this construction immediately. let framed_byte_lc = FramedWrite::new(Vec::new(), SliceEncoder::new()); let mut writer = SinkWriter::new(framed_byte_lc); // Write multiple slices to the sink... let _ = writer.write(&[1, 2, 3]).await; let _ = writer.write(&[4, 5, 6]).await; // ... and compare it with the buffer. assert_eq!( writer.into_inner().write_buffer().to_vec().as_slice(), &[1, 2, 3, 4, 5, 6] ); Ok(()) } tokio-util-0.7.10/tests/io_stream_reader.rs000064400000000000000000000016511046102023000170300ustar 00000000000000#![warn(rust_2018_idioms)] use bytes::Bytes; use tokio::io::AsyncReadExt; use tokio_stream::iter; use tokio_util::io::StreamReader; #[tokio::test] async fn test_stream_reader() -> std::io::Result<()> { let stream = iter(vec![ std::io::Result::Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[0, 1, 2, 3])), Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[4, 5, 6, 7])), Ok(Bytes::from_static(&[])), Ok(Bytes::from_static(&[8, 9, 10, 11])), Ok(Bytes::from_static(&[])), ]); let mut read = StreamReader::new(stream); let mut buf = [0; 5]; read.read_exact(&mut buf).await?; assert_eq!(buf, [0, 1, 2, 3, 4]); assert_eq!(read.read(&mut buf).await?, 3); assert_eq!(&buf[..3], [5, 6, 7]); assert_eq!(read.read(&mut buf).await?, 4); assert_eq!(&buf[..4], [8, 9, 10, 11]); assert_eq!(read.read(&mut buf).await?, 0); Ok(()) } tokio-util-0.7.10/tests/io_sync_bridge.rs000064400000000000000000000041161046102023000165020ustar 00000000000000#![cfg(feature = "io-util")] #![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads use std::error::Error; use std::io::{Cursor, Read, Result as IoResult, Write}; use tokio::io::{AsyncRead, AsyncReadExt}; use tokio_util::io::SyncIoBridge; async fn test_reader_len( r: impl AsyncRead + Unpin + Send + 'static, expected_len: usize, ) -> IoResult<()> { let mut r = SyncIoBridge::new(r); let res = tokio::task::spawn_blocking(move || { let mut buf = Vec::new(); r.read_to_end(&mut buf)?; Ok::<_, std::io::Error>(buf) }) .await?; assert_eq!(res?.len(), expected_len); Ok(()) } #[tokio::test] async fn test_async_read_to_sync() -> Result<(), Box> { test_reader_len(tokio::io::empty(), 0).await?; let buf = b"hello world"; test_reader_len(Cursor::new(buf), buf.len()).await?; Ok(()) } #[tokio::test] async fn test_async_write_to_sync() -> Result<(), Box> { let mut dest = Vec::new(); let src = b"hello world"; let dest = tokio::task::spawn_blocking(move || -> Result<_, String> { let mut w = SyncIoBridge::new(Cursor::new(&mut dest)); std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?; Ok(dest) }) .await??; assert_eq!(dest.as_slice(), src); Ok(()) } #[tokio::test] async fn test_into_inner() -> Result<(), Box> { let mut buf = Vec::new(); SyncIoBridge::new(tokio::io::empty()) .into_inner() .read_to_end(&mut buf) .await .unwrap(); assert_eq!(buf.len(), 0); Ok(()) } #[tokio::test] async fn test_shutdown() -> Result<(), Box> { let (s1, mut s2) = tokio::io::duplex(1024); let (_rh, wh) = tokio::io::split(s1); tokio::task::spawn_blocking(move || -> std::io::Result<_> { let mut wh = SyncIoBridge::new(wh); wh.write_all(b"hello")?; wh.shutdown()?; assert!(wh.write_all(b" world").is_err()); Ok(()) }) .await??; let mut buf = vec![]; s2.read_to_end(&mut buf).await?; assert_eq!(buf, b"hello"); Ok(()) } tokio-util-0.7.10/tests/length_delimited.rs000064400000000000000000000471621046102023000170340ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, }; use tokio_util::codec::*; use bytes::{BufMut, Bytes, BytesMut}; use futures::{pin_mut, Sink, Stream}; use std::collections::VecDeque; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ let mut v = VecDeque::new(); v.extend(vec![$($x),*]); Mock { calls: v } }}; } macro_rules! assert_next_eq { ($io:ident, $expect:expr) => {{ task::spawn(()).enter(|cx, _| { let res = assert_ready!($io.as_mut().poll_next(cx)); match res { Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), Some(Err(e)) => panic!("error = {:?}", e), None => panic!("none"), } }); }}; } macro_rules! assert_next_pending { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v), Poll::Ready(Some(Err(e))) => panic!("error = {:?}", e), Poll::Ready(None) => panic!("done"), Poll::Pending => {} }); }}; } macro_rules! assert_next_err { ($io:ident) => {{ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v), Poll::Ready(Some(Err(_))) => {} Poll::Ready(None) => panic!("done"), Poll::Pending => panic!("pending"), }); }}; } macro_rules! assert_done { ($io:ident) => {{ task::spawn(()).enter(|cx, _| { let res = assert_ready!($io.as_mut().poll_next(cx)); match res { Some(Ok(v)) => panic!("value = {:?}", v), Some(Err(e)) => panic!("error = {:?}", e), None => {} } }); }}; } #[test] fn read_empty_io_yields_nothing() { let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); pin_mut!(io); assert_done!(io); } #[test] fn read_single_frame_one_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00\x00\x09abcdefghi"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_frame_one_packet_little_endian() { let io = length_delimited::Builder::new() .little_endian() .new_read(mock! { data(b"\x09\x00\x00\x00abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_frame_one_packet_native_endian() { let d = if cfg!(target_endian = "big") { b"\x00\x00\x00\x09abcdefghi" } else { b"\x09\x00\x00\x00abcdefghi" }; let io = length_delimited::Builder::new() .native_endian() .new_read(mock! { data(d), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); d.extend_from_slice(b"\x00\x00\x00\x03123"); d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); let io = FramedRead::new( mock! { data(&d), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn read_single_frame_multi_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00"), data(b"\x00\x09abc"), data(b"defghi"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_multi_frame_multi_packet() { let io = FramedRead::new( mock! { data(b"\x00\x00"), data(b"\x00\x09abc"), data(b"defghi"), data(b"\x00\x00\x00\x0312"), data(b"3\x00\x00\x00\x0bhello world"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn read_single_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), Poll::Pending, data(b"\x00\x09abc"), Poll::Pending, data(b"defghi"), Poll::Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_pending!(io); assert_done!(io); } #[test] fn read_multi_frame_multi_packet_wait() { let io = FramedRead::new( mock! { data(b"\x00\x00"), Poll::Pending, data(b"\x00\x09abc"), Poll::Pending, data(b"defghi"), Poll::Pending, data(b"\x00\x00\x00\x0312"), Poll::Pending, data(b"3\x00\x00\x00\x0bhello world"), Poll::Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_pending!(io); assert_next_pending!(io); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_next_pending!(io); assert_done!(io); } #[test] fn read_incomplete_head() { let io = FramedRead::new( mock! { data(b"\x00\x00"), }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_err!(io); } #[test] fn read_incomplete_head_multi() { let io = FramedRead::new( mock! { Poll::Pending, data(b"\x00"), Poll::Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_err!(io); } #[test] fn read_incomplete_payload() { let io = FramedRead::new( mock! { data(b"\x00\x00\x00\x09ab"), Poll::Pending, data(b"cd"), Poll::Pending, }, LengthDelimitedCodec::new(), ); pin_mut!(io); assert_next_pending!(io); assert_next_pending!(io); assert_next_err!(io); } #[test] fn read_max_frame_len() { let io = length_delimited::Builder::new() .max_frame_length(5) .new_read(mock! { data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_err!(io); } #[test] fn read_update_max_frame_len_at_rest() { let io = length_delimited::Builder::new().new_read(mock! { data(b"\x00\x00\x00\x09abcdefghi"), data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); io.decoder_mut().set_max_frame_length(5); assert_next_err!(io); } #[test] fn read_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_read(mock! { data(b"\x00\x00\x00\x09abcd"), Poll::Pending, data(b"efghi"), data(b"\x00\x00\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_pending!(io); io.decoder_mut().set_max_frame_length(5); assert_next_eq!(io, b"abcdefghi"); assert_next_err!(io); } #[test] fn read_one_byte_length_field() { let io = length_delimited::Builder::new() .length_field_length(1) .new_read(mock! { data(b"\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_header_offset() { let io = length_delimited::Builder::new() .length_field_length(2) .length_field_offset(4) .new_read(mock! { data(b"zzzz\x00\x09abcdefghi"), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet_skip_none_adjusted() { let mut d: Vec = vec![]; d.extend_from_slice(b"xx\x00\x09abcdefghi"); d.extend_from_slice(b"yy\x00\x03123"); d.extend_from_slice(b"zz\x00\x0bhello world"); let io = length_delimited::Builder::new() .length_field_length(2) .length_field_offset(2) .num_skip(0) .length_adjustment(4) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"xx\x00\x09abcdefghi"); assert_next_eq!(io, b"yy\x00\x03123"); assert_next_eq!(io, b"zz\x00\x0bhello world"); assert_done!(io); } #[test] fn read_single_frame_length_adjusted() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x00\x0b\x0cHello world"); let io = length_delimited::Builder::new() .length_field_offset(0) .length_field_length(3) .length_adjustment(0) .num_skip(4) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"Hello world"); assert_done!(io); } #[test] fn read_single_multi_frame_one_packet_length_includes_head() { let mut d: Vec = vec![]; d.extend_from_slice(b"\x00\x0babcdefghi"); d.extend_from_slice(b"\x00\x05123"); d.extend_from_slice(b"\x00\x0dhello world"); let io = length_delimited::Builder::new() .length_field_length(2) .length_adjustment(-2) .new_read(mock! { data(&d), }); pin_mut!(io); assert_next_eq!(io, b"abcdefghi"); assert_next_eq!(io, b"123"); assert_next_eq!(io, b"hello world"); assert_done!(io); } #[test] fn write_single_frame_length_adjusted() { let io = length_delimited::Builder::new() .length_adjustment(-2) .new_write(mock! { data(b"\x00\x00\x00\x0b"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_nothing_yields_nothing() { let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.poll_flush(cx)); }); } #[test] fn write_single_frame_one_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_multi_frame_one_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), data(b"\x00\x00\x00\x03"), data(b"123"), data(b"\x00\x00\x00\x0b"), data(b"hello world"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("123"))); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_multi_frame_multi_packet() { let io = FramedWrite::new( mock! { data(b"\x00\x00\x00\x09"), data(b"abcdefghi"), flush(), data(b"\x00\x00\x00\x03"), data(b"123"), flush(), data(b"\x00\x00\x00\x0b"), data(b"hello world"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("123"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_would_block() { let io = FramedWrite::new( mock! { Poll::Pending, data(b"\x00\x00"), Poll::Pending, data(b"\x00\x09"), data(b"abcdefghi"), flush(), }, LengthDelimitedCodec::new(), ); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_pending!(io.as_mut().poll_flush(cx)); assert_pending!(io.as_mut().poll_flush(cx)); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_little_endian() { let io = length_delimited::Builder::new() .little_endian() .new_write(mock! { data(b"\x09\x00\x00\x00"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_single_frame_with_short_length_field() { let io = length_delimited::Builder::new() .length_field_length(1) .new_write(mock! { data(b"\x09"), data(b"abcdefghi"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_max_frame_len() { let io = length_delimited::Builder::new() .max_frame_length(5) .new_write(mock! {}); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_update_max_frame_len_at_rest() { let io = length_delimited::Builder::new().new_write(mock! { data(b"\x00\x00\x00\x06"), data(b"abcdef"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_ready_ok!(io.as_mut().poll_flush(cx)); io.encoder_mut().set_max_frame_length(5); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_update_max_frame_len_in_flight() { let io = length_delimited::Builder::new().new_write(mock! { data(b"\x00\x00\x00\x06"), data(b"ab"), Poll::Pending, data(b"cdef"), flush(), }); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_pending!(io.as_mut().poll_flush(cx)); io.encoder_mut().set_max_frame_length(5); assert_ready_ok!(io.as_mut().poll_flush(cx)); assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn write_zero() { let io = length_delimited::Builder::new().new_write(mock! {}); pin_mut!(io); task::spawn(()).enter(|cx, _| { assert_ready_ok!(io.as_mut().poll_ready(cx)); assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); assert_ready_err!(io.as_mut().poll_flush(cx)); assert!(io.get_ref().calls.is_empty()); }); } #[test] fn encode_overflow() { // Test reproducing tokio-rs/tokio#681. let mut codec = length_delimited::Builder::new().new_codec(); let mut buf = BytesMut::with_capacity(1024); // Put some data into the buffer without resizing it to hold more. let some_as = std::iter::repeat(b'a').take(1024).collect::>(); buf.put_slice(&some_as[..]); // Trying to encode the length header should resize the buffer if it won't fit. codec.encode(Bytes::from("hello"), &mut buf).unwrap(); } // ===== Test utils ===== struct Mock { calls: VecDeque>>, } enum Op { Data(Vec), Flush, } impl AsyncRead for Mock { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, dst: &mut ReadBuf<'_>, ) -> Poll> { match self.calls.pop_front() { Some(Poll::Ready(Ok(Op::Data(data)))) => { debug_assert!(dst.remaining() >= data.len()); dst.put_slice(&data); Poll::Ready(Ok(())) } Some(Poll::Ready(Ok(_))) => panic!(), Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), Some(Poll::Pending) => Poll::Pending, None => Poll::Ready(Ok(())), } } } impl AsyncWrite for Mock { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, src: &[u8], ) -> Poll> { match self.calls.pop_front() { Some(Poll::Ready(Ok(Op::Data(data)))) => { let len = data.len(); assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); assert_eq!(&data[..], &src[..len]); Poll::Ready(Ok(len)) } Some(Poll::Ready(Ok(_))) => panic!(), Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), Some(Poll::Pending) => Poll::Pending, None => Poll::Ready(Ok(0)), } } fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { match self.calls.pop_front() { Some(Poll::Ready(Ok(Op::Flush))) => Poll::Ready(Ok(())), Some(Poll::Ready(Ok(_))) => panic!(), Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)), Some(Poll::Pending) => Poll::Pending, None => Poll::Ready(Ok(())), } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } impl<'a> From<&'a [u8]> for Op { fn from(src: &'a [u8]) -> Op { Op::Data(src.into()) } } impl From> for Op { fn from(src: Vec) -> Op { Op::Data(src) } } fn data(bytes: &[u8]) -> Poll> { Poll::Ready(Ok(bytes.into())) } fn flush() -> Poll> { Poll::Ready(Ok(Op::Flush)) } tokio-util-0.7.10/tests/mpsc.rs000064400000000000000000000161211046102023000144640ustar 00000000000000use futures::future::poll_fn; use tokio::sync::mpsc::channel; use tokio_test::task::spawn; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; use tokio_util::sync::PollSender; #[tokio::test] async fn simple() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); for i in 1..=3i32 { let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(i).unwrap(); } let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); assert_eq!(recv.recv().await.unwrap(), 1); assert!(reserve.is_woken()); assert_ready_ok!(reserve.poll()); drop(recv); send.send_item(42).unwrap(); } #[tokio::test] async fn simple_ref() { let v = vec![1, 2, 3i32]; let (send, mut recv) = channel(3); let mut send = PollSender::new(send); for vi in v.iter() { let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(vi).unwrap(); } let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); assert_eq!(*recv.recv().await.unwrap(), 1); assert!(reserve.is_woken()); assert_ready_ok!(reserve.poll()); drop(recv); send.send_item(&42).unwrap(); } #[tokio::test] async fn repeated_poll_reserve() { let (send, mut recv) = channel::(1); let mut send = PollSender::new(send); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); assert_ready_ok!(reserve.poll()); send.send_item(1).unwrap(); assert_eq!(recv.recv().await.unwrap(), 1); } #[tokio::test] async fn abort_send() { let (send, mut recv) = channel(3); let mut send = PollSender::new(send); let send2 = send.get_ref().cloned().unwrap(); for i in 1..=3i32 { let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(i).unwrap(); } let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); assert_eq!(recv.recv().await.unwrap(), 1); assert!(reserve.is_woken()); assert_ready_ok!(reserve.poll()); let mut send2_send = spawn(send2.send(5)); assert_pending!(send2_send.poll()); assert!(send.abort_send()); assert!(send2_send.is_woken()); assert_ready_ok!(send2_send.poll()); assert_eq!(recv.recv().await.unwrap(), 2); assert_eq!(recv.recv().await.unwrap(), 3); assert_eq!(recv.recv().await.unwrap(), 5); } #[tokio::test] async fn close_sender_last() { let (send, mut recv) = channel::(3); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); send.close(); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); } #[tokio::test] async fn close_sender_not_last() { let (send, mut recv) = channel::(3); let mut send = PollSender::new(send); let send2 = send.get_ref().cloned().unwrap(); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); send.close(); assert!(!recv_task.is_woken()); assert_pending!(recv_task.poll()); drop(send2); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); } #[tokio::test] async fn close_sender_before_reserve() { let (send, mut recv) = channel::(3); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); send.close(); assert!(recv_task.is_woken()); assert!(assert_ready!(recv_task.poll()).is_none()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_err!(reserve.poll()); } #[tokio::test] async fn close_sender_after_pending_reserve() { let (send, mut recv) = channel::(1); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(1).unwrap(); assert!(recv_task.is_woken()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); drop(reserve); send.close(); assert!(send.is_closed()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_err!(reserve.poll()); } #[tokio::test] async fn close_sender_after_successful_reserve() { let (send, mut recv) = channel::(3); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); drop(reserve); send.close(); assert!(send.is_closed()); assert!(!recv_task.is_woken()); assert_pending!(recv_task.poll()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); } #[tokio::test] async fn abort_send_after_pending_reserve() { let (send, mut recv) = channel::(1); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(1).unwrap(); assert_eq!(send.get_ref().unwrap().capacity(), 0); assert!(!send.abort_send()); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); assert!(send.abort_send()); assert_eq!(send.get_ref().unwrap().capacity(), 0); } #[tokio::test] async fn abort_send_after_successful_reserve() { let (send, mut recv) = channel::(1); let mut send = PollSender::new(send); let mut recv_task = spawn(recv.recv()); assert_pending!(recv_task.poll()); assert_eq!(send.get_ref().unwrap().capacity(), 1); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); assert_eq!(send.get_ref().unwrap().capacity(), 0); assert!(send.abort_send()); assert_eq!(send.get_ref().unwrap().capacity(), 1); } #[tokio::test] async fn closed_when_receiver_drops() { let (send, _) = channel::(1); let mut send = PollSender::new(send); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_err!(reserve.poll()); } #[should_panic] #[test] fn start_send_panics_when_idle() { let (send, _) = channel::(3); let mut send = PollSender::new(send); send.send_item(1).unwrap(); } #[should_panic] #[test] fn start_send_panics_when_acquiring() { let (send, _) = channel::(1); let mut send = PollSender::new(send); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_ready_ok!(reserve.poll()); send.send_item(1).unwrap(); let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); assert_pending!(reserve.poll()); send.send_item(2).unwrap(); } tokio-util-0.7.10/tests/panic.rs000064400000000000000000000142461046102023000146220ustar 00000000000000#![warn(rust_2018_idioms)] #![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi doesn't support panic recovery use parking_lot::{const_mutex, Mutex}; use std::error::Error; use std::panic; use std::sync::Arc; use tokio::runtime::Runtime; use tokio::sync::mpsc::channel; use tokio::time::{Duration, Instant}; use tokio_test::task; use tokio_util::io::SyncIoBridge; use tokio_util::sync::PollSender; use tokio_util::task::LocalPoolHandle; use tokio_util::time::DelayQueue; // Taken from tokio-util::time::wheel, if that changes then const MAX_DURATION_MS: u64 = (1 << (36)) - 1; fn test_panic(func: Func) -> Option { static PANIC_MUTEX: Mutex<()> = const_mutex(()); { let _guard = PANIC_MUTEX.lock(); let panic_file: Arc>> = Arc::new(Mutex::new(None)); let prev_hook = panic::take_hook(); { let panic_file = panic_file.clone(); panic::set_hook(Box::new(move |panic_info| { let panic_location = panic_info.location().unwrap(); panic_file .lock() .clone_from(&Some(panic_location.file().to_string())); })); } let result = panic::catch_unwind(func); // Return to the previously set panic hook (maybe default) so that we get nice error // messages in the tests. panic::set_hook(prev_hook); if result.is_err() { panic_file.lock().clone() } else { None } } } #[test] fn sync_bridge_new_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let _ = SyncIoBridge::new(tokio::io::empty()); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn poll_sender_send_item_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let (send, _) = channel::(3); let mut send = PollSender::new(send); let _ = send.send_item(42); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn local_pool_handle_new_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let _ = LocalPoolHandle::new(0); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn local_pool_handle_spawn_pinned_by_idx_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let handle = LocalPoolHandle::new(2); handle.spawn_pinned_by_idx(|| async { "test" }, 3); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_insert_at_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::with_capacity(3)); //let st = std::time::Instant::from(SystemTime::UNIX_EPOCH); let _k = queue.insert_at( "1", Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1), ); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_insert_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::with_capacity(3)); let _k = queue.insert("1", Duration::from_millis(MAX_DURATION_MS + 1)); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_remove_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::with_capacity(3)); let key = queue.insert_at("1", Instant::now()); queue.remove(&key); queue.remove(&key); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_reset_at_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::with_capacity(3)); let key = queue.insert_at("1", Instant::now()); queue.reset_at( &key, Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1), ); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_reset_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::with_capacity(3)); let key = queue.insert_at("1", Instant::now()); queue.reset(&key, Duration::from_millis(MAX_DURATION_MS + 1)); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } #[test] fn delay_queue_reserve_panic_caller() -> Result<(), Box> { let panic_location_file = test_panic(|| { let rt = basic(); rt.block_on(async { let mut queue = task::spawn(DelayQueue::::with_capacity(3)); queue.reserve((1 << 30) as usize); }); }); // The panic location should be in this file assert_eq!(&panic_location_file.unwrap(), file!()); Ok(()) } fn basic() -> Runtime { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap() } tokio-util-0.7.10/tests/poll_semaphore.rs000064400000000000000000000051401046102023000165320ustar 00000000000000use std::future::Future; use std::sync::Arc; use std::task::Poll; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::sync::PollSemaphore; type SemRet = Option; fn semaphore_poll( sem: &mut PollSemaphore, ) -> tokio_test::task::Spawn + '_> { let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); tokio_test::task::spawn(fut) } fn semaphore_poll_many( sem: &mut PollSemaphore, permits: u32, ) -> tokio_test::task::Spawn + '_> { let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits)); tokio_test::task::spawn(fut) } #[tokio::test] async fn it_works() { let sem = Arc::new(Semaphore::new(1)); let mut poll_sem = PollSemaphore::new(sem.clone()); let permit = sem.acquire().await.unwrap(); let mut poll = semaphore_poll(&mut poll_sem); assert!(poll.poll().is_pending()); drop(permit); assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); drop(poll); sem.close(); assert!(semaphore_poll(&mut poll_sem).await.is_none()); // Check that it is fused. assert!(semaphore_poll(&mut poll_sem).await.is_none()); assert!(semaphore_poll(&mut poll_sem).await.is_none()); } #[tokio::test] async fn can_acquire_many_permits() { let sem = Arc::new(Semaphore::new(4)); let mut poll_sem = PollSemaphore::new(sem.clone()); let permit1 = semaphore_poll(&mut poll_sem).poll(); assert!(matches!(permit1, Poll::Ready(Some(_)))); let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll(); assert!(matches!(permit2, Poll::Ready(Some(_)))); assert_eq!(sem.available_permits(), 1); drop(permit2); let mut permit4 = semaphore_poll_many(&mut poll_sem, 4); assert!(permit4.poll().is_pending()); drop(permit1); let permit4 = permit4.poll(); assert!(matches!(permit4, Poll::Ready(Some(_)))); assert_eq!(sem.available_permits(), 0); } #[tokio::test] async fn can_poll_different_amounts_of_permits() { let sem = Arc::new(Semaphore::new(4)); let mut poll_sem = PollSemaphore::new(sem.clone()); assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); let permit = sem.acquire_many(4).await.unwrap(); assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending()); drop(permit); assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); } tokio-util-0.7.10/tests/reusable_box.rs000064400000000000000000000045471046102023000162050ustar 00000000000000use futures::future::FutureExt; use std::alloc::Layout; use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use tokio_util::sync::ReusableBoxFuture; #[test] // Clippy false positive; it's useful to be able to test the trait impls for any lifetime #[allow(clippy::extra_unused_lifetimes)] fn traits<'a>() { fn assert_traits() {} // Use a type that is !Unpin assert_traits::>(); // Use a type that is !Send + !Sync assert_traits::>>(); } #[test] fn test_different_futures() { let fut = async move { 10 }; // Not zero sized! assert_eq!(Layout::for_value(&fut).size(), 1); let mut b = ReusableBoxFuture::new(fut); assert_eq!(b.get_pin().now_or_never(), Some(10)); b.try_set(async move { 20 }) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(20)); b.try_set(async move { 30 }) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(30)); } #[test] fn test_different_sizes() { let fut1 = async move { 10 }; let val = [0u32; 1000]; let fut2 = async move { val[0] }; let fut3 = ZeroSizedFuture {}; assert_eq!(Layout::for_value(&fut1).size(), 1); assert_eq!(Layout::for_value(&fut2).size(), 4004); assert_eq!(Layout::for_value(&fut3).size(), 0); let mut b = ReusableBoxFuture::new(fut1); assert_eq!(b.get_pin().now_or_never(), Some(10)); b.set(fut2); assert_eq!(b.get_pin().now_or_never(), Some(0)); b.set(fut3); assert_eq!(b.get_pin().now_or_never(), Some(5)); } struct ZeroSizedFuture {} impl Future for ZeroSizedFuture { type Output = u32; fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { Poll::Ready(5) } } #[test] fn test_zero_sized() { let fut = ZeroSizedFuture {}; // Zero sized! assert_eq!(Layout::for_value(&fut).size(), 0); let mut b = ReusableBoxFuture::new(fut); assert_eq!(b.get_pin().now_or_never(), Some(5)); assert_eq!(b.get_pin().now_or_never(), Some(5)); b.try_set(ZeroSizedFuture {}) .unwrap_or_else(|_| panic!("incorrect size")); assert_eq!(b.get_pin().now_or_never(), Some(5)); assert_eq!(b.get_pin().now_or_never(), Some(5)); } tokio-util-0.7.10/tests/spawn_pinned.rs000064400000000000000000000163341046102023000162150ustar 00000000000000#![warn(rust_2018_idioms)] #![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads use std::rc::Rc; use std::sync::Arc; use tokio::sync::Barrier; use tokio_util::task; /// Simple test of running a !Send future via spawn_pinned #[tokio::test] async fn can_spawn_not_send_future() { let pool = task::LocalPoolHandle::new(1); let output = pool .spawn_pinned(|| { // Rc is !Send + !Sync let local_data = Rc::new("test"); // This future holds an Rc, so it is !Send async move { local_data.to_string() } }) .await .unwrap(); assert_eq!(output, "test"); } /// Dropping the join handle still lets the task execute #[test] fn can_drop_future_and_still_get_output() { let pool = task::LocalPoolHandle::new(1); let (sender, receiver) = std::sync::mpsc::channel(); let _ = pool.spawn_pinned(move || { // Rc is !Send + !Sync let local_data = Rc::new("test"); // This future holds an Rc, so it is !Send async move { let _ = sender.send(local_data.to_string()); } }); assert_eq!(receiver.recv(), Ok("test".to_string())); } #[test] #[should_panic(expected = "assertion failed: pool_size > 0")] fn cannot_create_zero_sized_pool() { let _pool = task::LocalPoolHandle::new(0); } /// We should be able to spawn multiple futures onto the pool at the same time. #[tokio::test] async fn can_spawn_multiple_futures() { let pool = task::LocalPoolHandle::new(2); let join_handle1 = pool.spawn_pinned(|| { let local_data = Rc::new("test1"); async move { local_data.to_string() } }); let join_handle2 = pool.spawn_pinned(|| { let local_data = Rc::new("test2"); async move { local_data.to_string() } }); assert_eq!(join_handle1.await.unwrap(), "test1"); assert_eq!(join_handle2.await.unwrap(), "test2"); } /// A panic in the spawned task causes the join handle to return an error. /// But, you can continue to spawn tasks. #[tokio::test] async fn task_panic_propagates() { let pool = task::LocalPoolHandle::new(1); let join_handle = pool.spawn_pinned(|| async { panic!("Test panic"); }); let result = join_handle.await; assert!(result.is_err()); let error = result.unwrap_err(); assert!(error.is_panic()); let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); assert_eq!(*panic_str, "Test panic"); // Trying again with a "safe" task still works let join_handle = pool.spawn_pinned(|| async { "test" }); let result = join_handle.await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "test"); } /// A panic during task creation causes the join handle to return an error. /// But, you can continue to spawn tasks. #[tokio::test] async fn callback_panic_does_not_kill_worker() { let pool = task::LocalPoolHandle::new(1); let join_handle = pool.spawn_pinned(|| { panic!("Test panic"); #[allow(unreachable_code)] async {} }); let result = join_handle.await; assert!(result.is_err()); let error = result.unwrap_err(); assert!(error.is_panic()); let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); assert_eq!(*panic_str, "Test panic"); // Trying again with a "safe" callback works let join_handle = pool.spawn_pinned(|| async { "test" }); let result = join_handle.await; assert!(result.is_ok()); assert_eq!(result.unwrap(), "test"); } /// Canceling the task via the returned join handle cancels the spawned task /// (which has a different, internal join handle). #[tokio::test] async fn task_cancellation_propagates() { let pool = task::LocalPoolHandle::new(1); let notify_dropped = Arc::new(()); let weak_notify_dropped = Arc::downgrade(¬ify_dropped); let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); let join_handle = pool.spawn_pinned(|| async move { let _drop_sender = drop_sender; // Move the Arc into the task let _notify_dropped = notify_dropped; let _ = start_sender.send(()); // Keep the task running until it gets aborted futures::future::pending::<()>().await; }); // Wait for the task to start let _ = start_receiver.await; join_handle.abort(); // Wait for the inner task to abort, dropping the sender. // The top level join handle aborts quicker than the inner task (the abort // needs to propagate and get processed on the worker thread), so we can't // just await the top level join handle. let _ = drop_receiver.await; // Check that the Arc has been dropped. This verifies that the inner task // was canceled as well. assert!(weak_notify_dropped.upgrade().is_none()); } /// Tasks should be given to the least burdened worker. When spawning two tasks /// on a pool with two empty workers the tasks should be spawned on separate /// workers. #[tokio::test] async fn tasks_are_balanced() { let pool = task::LocalPoolHandle::new(2); // Spawn a task so one thread has a task count of 1 let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); let join_handle1 = pool.spawn_pinned(|| async move { let _ = start_sender1.send(()); let _ = end_receiver1.await; std::thread::current().id() }); // Wait for the first task to start up let _ = start_receiver1.await; // This task should be spawned on the other thread let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); let join_handle2 = pool.spawn_pinned(|| async move { let _ = start_sender2.send(()); std::thread::current().id() }); // Wait for the second task to start up let _ = start_receiver2.await; // Allow the first task to end let _ = end_sender1.send(()); let thread_id1 = join_handle1.await.unwrap(); let thread_id2 = join_handle2.await.unwrap(); // Since the first task was active when the second task spawned, they should // be on separate workers/threads. assert_ne!(thread_id1, thread_id2); } #[tokio::test] async fn spawn_by_idx() { let pool = task::LocalPoolHandle::new(3); let barrier = Arc::new(Barrier::new(4)); let barrier1 = barrier.clone(); let barrier2 = barrier.clone(); let barrier3 = barrier.clone(); let handle1 = pool.spawn_pinned_by_idx( || async move { barrier1.wait().await; std::thread::current().id() }, 0, ); let _ = pool.spawn_pinned_by_idx( || async move { barrier2.wait().await; std::thread::current().id() }, 0, ); let handle2 = pool.spawn_pinned_by_idx( || async move { barrier3.wait().await; std::thread::current().id() }, 1, ); let loads = pool.get_task_loads_for_each_worker(); barrier.wait().await; assert_eq!(loads[0], 2); assert_eq!(loads[1], 1); assert_eq!(loads[2], 0); let thread_id1 = handle1.await.unwrap(); let thread_id2 = handle2.await.unwrap(); assert_ne!(thread_id1, thread_id2); } tokio-util-0.7.10/tests/sync_cancellation_token.rs000064400000000000000000000273121046102023000204160ustar 00000000000000#![warn(rust_2018_idioms)] use tokio::pin; use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; use core::future::Future; use core::task::{Context, Poll}; use futures_test::task::new_count_waker; #[test] fn cancel_token() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); assert!(!token.is_cancelled()); let wait_fut = token.cancelled(); pin!(wait_fut); assert_eq!( Poll::Pending, wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); let wait_fut_2 = token.cancelled(); pin!(wait_fut_2); token.cancel(); assert_eq!(wake_counter, 1); assert!(token.is_cancelled()); assert_eq!( Poll::Ready(()), wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_token_owned() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); assert!(!token.is_cancelled()); let wait_fut = token.clone().cancelled_owned(); pin!(wait_fut); assert_eq!( Poll::Pending, wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); let wait_fut_2 = token.clone().cancelled_owned(); pin!(wait_fut_2); token.cancel(); assert_eq!(wake_counter, 1); assert!(token.is_cancelled()); assert_eq!( Poll::Ready(()), wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_token_owned_drop_test() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let future = token.cancelled_owned(); pin!(future); assert_eq!( Poll::Pending, future.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); // let future be dropped while pinned and under pending state to // find potential memory related bugs. } #[test] fn cancel_child_token_through_parent() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let child_token = token.child_token(); assert!(!child_token.is_cancelled()); let child_fut = child_token.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); token.cancel(); assert_eq!(wake_counter, 2); assert!(token.is_cancelled()); assert!(child_token.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_grandchild_token_through_parent_if_child_was_dropped() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let intermediate_token = token.child_token(); let child_token = intermediate_token.child_token(); drop(intermediate_token); assert!(!child_token.is_cancelled()); let child_fut = child_token.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); token.cancel(); assert_eq!(wake_counter, 2); assert!(token.is_cancelled()); assert!(child_token.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn cancel_child_token_without_parent() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); let child_token_1 = token.child_token(); let child_fut = child_token_1.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Pending, child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); child_token_1.cancel(); assert_eq!(wake_counter, 1); assert!(!token.is_cancelled()); assert!(child_token_1.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); let child_token_2 = token.child_token(); let child_fut_2 = child_token_2.cancelled(); pin!(child_fut_2); assert_eq!( Poll::Pending, child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); token.cancel(); assert_eq!(wake_counter, 3); assert!(token.is_cancelled()); assert!(child_token_2.is_cancelled()); assert_eq!( Poll::Ready(()), child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); } #[test] fn create_child_token_after_parent_was_cancelled() { for drop_child_first in [true, false].iter().cloned() { let (waker, wake_counter) = new_count_waker(); let token = CancellationToken::new(); token.cancel(); let child_token = token.child_token(); assert!(child_token.is_cancelled()); { let child_fut = child_token.cancelled(); pin!(child_fut); let parent_fut = token.cancelled(); pin!(parent_fut); assert_eq!( Poll::Ready(()), child_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); } if drop_child_first { drop(child_token); drop(token); } else { drop(token); drop(child_token); } } } #[test] fn drop_multiple_child_tokens() { for drop_first_child_first in &[true, false] { let token = CancellationToken::new(); let mut child_tokens = [None, None, None]; for child in &mut child_tokens { *child = Some(token.child_token()); } assert!(!token.is_cancelled()); assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); for i in 0..child_tokens.len() { if *drop_first_child_first { child_tokens[i] = None; } else { child_tokens[child_tokens.len() - 1 - i] = None; } assert!(!token.is_cancelled()); } drop(token); } } #[test] fn cancel_only_all_descendants() { // ARRANGE let (waker, wake_counter) = new_count_waker(); let parent_token = CancellationToken::new(); let token = parent_token.child_token(); let sibling_token = parent_token.child_token(); let child1_token = token.child_token(); let child2_token = token.child_token(); let grandchild_token = child1_token.child_token(); let grandchild2_token = child1_token.child_token(); let great_grandchild_token = grandchild_token.child_token(); assert!(!parent_token.is_cancelled()); assert!(!token.is_cancelled()); assert!(!sibling_token.is_cancelled()); assert!(!child1_token.is_cancelled()); assert!(!child2_token.is_cancelled()); assert!(!grandchild_token.is_cancelled()); assert!(!grandchild2_token.is_cancelled()); assert!(!great_grandchild_token.is_cancelled()); let parent_fut = parent_token.cancelled(); let fut = token.cancelled(); let sibling_fut = sibling_token.cancelled(); let child1_fut = child1_token.cancelled(); let child2_fut = child2_token.cancelled(); let grandchild_fut = grandchild_token.cancelled(); let grandchild2_fut = grandchild2_token.cancelled(); let great_grandchild_fut = great_grandchild_token.cancelled(); pin!(parent_fut); pin!(fut); pin!(sibling_fut); pin!(child1_fut); pin!(child2_fut); pin!(grandchild_fut); pin!(grandchild2_fut); pin!(great_grandchild_fut); assert_eq!( Poll::Pending, parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, sibling_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, grandchild_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, grandchild2_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Pending, great_grandchild_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 0); // ACT token.cancel(); // ASSERT assert_eq!(wake_counter, 6); assert!(!parent_token.is_cancelled()); assert!(token.is_cancelled()); assert!(!sibling_token.is_cancelled()); assert!(child1_token.is_cancelled()); assert!(child2_token.is_cancelled()); assert!(grandchild_token.is_cancelled()); assert!(grandchild2_token.is_cancelled()); assert!(great_grandchild_token.is_cancelled()); assert_eq!( Poll::Ready(()), fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), grandchild_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), grandchild2_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!( Poll::Ready(()), great_grandchild_fut .as_mut() .poll(&mut Context::from_waker(&waker)) ); assert_eq!(wake_counter, 6); } #[test] fn drop_parent_before_child_tokens() { let token = CancellationToken::new(); let child1 = token.child_token(); let child2 = token.child_token(); drop(token); assert!(!child1.is_cancelled()); drop(child1); drop(child2); } #[test] fn derives_send_sync() { fn assert_send() {} fn assert_sync() {} assert_send::(); assert_sync::(); assert_send::>(); assert_sync::>(); } tokio-util-0.7.10/tests/task_join_map.rs000064400000000000000000000164711046102023000163500ustar 00000000000000#![warn(rust_2018_idioms)] #![cfg(all(feature = "rt", tokio_unstable))] use tokio::sync::oneshot; use tokio::time::Duration; use tokio_util::task::JoinMap; use futures::future::FutureExt; fn rt() -> tokio::runtime::Runtime { tokio::runtime::Builder::new_current_thread() .build() .unwrap() } #[tokio::test(start_paused = true)] async fn test_with_sleep() { let mut map = JoinMap::new(); for i in 0..10 { map.spawn(i, async move { i }); assert_eq!(map.len(), 1 + i); } map.detach_all(); assert_eq!(map.len(), 0); assert!(matches!(map.join_next().await, None)); for i in 0..10 { map.spawn(i, async move { tokio::time::sleep(Duration::from_secs(i as u64)).await; i }); assert_eq!(map.len(), 1 + i); } let mut seen = [false; 10]; while let Some((k, res)) = map.join_next().await { seen[k] = true; assert_eq!(res.expect("task should have completed successfully"), k); } for was_seen in &seen { assert!(was_seen); } assert!(matches!(map.join_next().await, None)); // Do it again. for i in 0..10 { map.spawn(i, async move { tokio::time::sleep(Duration::from_secs(i as u64)).await; i }); } let mut seen = [false; 10]; while let Some((k, res)) = map.join_next().await { seen[k] = true; assert_eq!(res.expect("task should have completed successfully"), k); } for was_seen in &seen { assert!(was_seen); } assert!(matches!(map.join_next().await, None)); } #[tokio::test] async fn test_abort_on_drop() { let mut map = JoinMap::new(); let mut recvs = Vec::new(); for i in 0..16 { let (send, recv) = oneshot::channel::<()>(); recvs.push(recv); map.spawn(i, async { // This task will never complete on its own. futures::future::pending::<()>().await; drop(send); }); } drop(map); for recv in recvs { // The task is aborted soon and we will receive an error. assert!(recv.await.is_err()); } } #[tokio::test] async fn alternating() { let mut map = JoinMap::new(); assert_eq!(map.len(), 0); map.spawn(1, async {}); assert_eq!(map.len(), 1); map.spawn(2, async {}); assert_eq!(map.len(), 2); for i in 0..16 { let (_, res) = map.join_next().await.unwrap(); assert!(res.is_ok()); assert_eq!(map.len(), 1); map.spawn(i, async {}); assert_eq!(map.len(), 2); } } #[tokio::test] async fn test_keys() { use std::collections::HashSet; let mut map = JoinMap::new(); assert_eq!(map.len(), 0); map.spawn(1, async {}); assert_eq!(map.len(), 1); map.spawn(2, async {}); assert_eq!(map.len(), 2); let keys = map.keys().collect::>(); assert!(keys.contains(&1)); assert!(keys.contains(&2)); let _ = map.join_next().await.unwrap(); let _ = map.join_next().await.unwrap(); assert_eq!(map.len(), 0); let keys = map.keys().collect::>(); assert!(keys.is_empty()); } #[tokio::test(start_paused = true)] async fn abort_by_key() { let mut map = JoinMap::new(); let mut num_canceled = 0; let mut num_completed = 0; for i in 0..16 { map.spawn(i, async move { tokio::time::sleep(Duration::from_secs(i as u64)).await; }); } for i in 0..16 { if i % 2 != 0 { // abort odd-numbered tasks. map.abort(&i); } } while let Some((key, res)) = map.join_next().await { match res { Ok(()) => { num_completed += 1; assert_eq!(key % 2, 0); assert!(!map.contains_key(&key)); } Err(e) => { num_canceled += 1; assert!(e.is_cancelled()); assert_ne!(key % 2, 0); assert!(!map.contains_key(&key)); } } } assert_eq!(num_canceled, 8); assert_eq!(num_completed, 8); } #[tokio::test(start_paused = true)] async fn abort_by_predicate() { let mut map = JoinMap::new(); let mut num_canceled = 0; let mut num_completed = 0; for i in 0..16 { map.spawn(i, async move { tokio::time::sleep(Duration::from_secs(i as u64)).await; }); } // abort odd-numbered tasks. map.abort_matching(|key| key % 2 != 0); while let Some((key, res)) = map.join_next().await { match res { Ok(()) => { num_completed += 1; assert_eq!(key % 2, 0); assert!(!map.contains_key(&key)); } Err(e) => { num_canceled += 1; assert!(e.is_cancelled()); assert_ne!(key % 2, 0); assert!(!map.contains_key(&key)); } } } assert_eq!(num_canceled, 8); assert_eq!(num_completed, 8); } #[test] fn runtime_gone() { let mut map = JoinMap::new(); { let rt = rt(); map.spawn_on("key", async { 1 }, rt.handle()); drop(rt); } let (key, res) = rt().block_on(map.join_next()).unwrap(); assert_eq!(key, "key"); assert!(res.unwrap_err().is_cancelled()); } // This ensures that `join_next` works correctly when the coop budget is // exhausted. #[tokio::test(flavor = "current_thread")] async fn join_map_coop() { // Large enough to trigger coop. const TASK_NUM: u32 = 1000; static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); let mut map = JoinMap::new(); for i in 0..TASK_NUM { map.spawn(i, async move { SEM.add_permits(1); i }); } // Wait for all tasks to complete. // // Since this is a `current_thread` runtime, there's no race condition // between the last permit being added and the task completing. let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); let mut count = 0; let mut coop_count = 0; loop { match map.join_next().now_or_never() { Some(Some((key, Ok(i)))) => assert_eq!(key, i), Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err), None => { coop_count += 1; tokio::task::yield_now().await; continue; } Some(None) => break, } count += 1; } assert!(coop_count >= 1); assert_eq!(count, TASK_NUM); } #[tokio::test(start_paused = true)] async fn abort_all() { let mut map: JoinMap = JoinMap::new(); for i in 0..5 { map.spawn(i, futures::future::pending()); } for i in 5..10 { map.spawn(i, async { tokio::time::sleep(Duration::from_secs(1)).await; }); } // The join map will now have 5 pending tasks and 5 ready tasks. tokio::time::sleep(Duration::from_secs(2)).await; map.abort_all(); assert_eq!(map.len(), 10); let mut count = 0; let mut seen = [false; 10]; while let Some((k, res)) = map.join_next().await { seen[k] = true; if let Err(err) = res { assert!(err.is_cancelled()); } count += 1; } assert_eq!(count, 10); assert_eq!(map.len(), 0); for was_seen in &seen { assert!(was_seen); } } tokio-util-0.7.10/tests/task_tracker.rs000064400000000000000000000074761046102023000162140ustar 00000000000000#![warn(rust_2018_idioms)] use tokio_test::{assert_pending, assert_ready, task}; use tokio_util::task::TaskTracker; #[test] fn open_close() { let tracker = TaskTracker::new(); assert!(!tracker.is_closed()); assert!(tracker.is_empty()); assert_eq!(tracker.len(), 0); tracker.close(); assert!(tracker.is_closed()); assert!(tracker.is_empty()); assert_eq!(tracker.len(), 0); tracker.reopen(); assert!(!tracker.is_closed()); tracker.reopen(); assert!(!tracker.is_closed()); assert!(tracker.is_empty()); assert_eq!(tracker.len(), 0); tracker.close(); assert!(tracker.is_closed()); tracker.close(); assert!(tracker.is_closed()); assert!(tracker.is_empty()); assert_eq!(tracker.len(), 0); } #[test] fn token_len() { let tracker = TaskTracker::new(); let mut tokens = Vec::new(); for i in 0..10 { assert_eq!(tracker.len(), i); tokens.push(tracker.token()); } assert!(!tracker.is_empty()); assert_eq!(tracker.len(), 10); for (i, token) in tokens.into_iter().enumerate() { drop(token); assert_eq!(tracker.len(), 9 - i); } } #[test] fn notify_immediately() { let tracker = TaskTracker::new(); tracker.close(); let mut wait = task::spawn(tracker.wait()); assert_ready!(wait.poll()); } #[test] fn notify_immediately_on_reopen() { let tracker = TaskTracker::new(); tracker.close(); let mut wait = task::spawn(tracker.wait()); tracker.reopen(); assert_ready!(wait.poll()); } #[test] fn notify_on_close() { let tracker = TaskTracker::new(); let mut wait = task::spawn(tracker.wait()); assert_pending!(wait.poll()); tracker.close(); assert_ready!(wait.poll()); } #[test] fn notify_on_close_reopen() { let tracker = TaskTracker::new(); let mut wait = task::spawn(tracker.wait()); assert_pending!(wait.poll()); tracker.close(); tracker.reopen(); assert_ready!(wait.poll()); } #[test] fn notify_on_last_task() { let tracker = TaskTracker::new(); tracker.close(); let token = tracker.token(); let mut wait = task::spawn(tracker.wait()); assert_pending!(wait.poll()); drop(token); assert_ready!(wait.poll()); } #[test] fn notify_on_last_task_respawn() { let tracker = TaskTracker::new(); tracker.close(); let token = tracker.token(); let mut wait = task::spawn(tracker.wait()); assert_pending!(wait.poll()); drop(token); let token2 = tracker.token(); assert_ready!(wait.poll()); drop(token2); } #[test] fn no_notify_on_respawn_if_open() { let tracker = TaskTracker::new(); let token = tracker.token(); let mut wait = task::spawn(tracker.wait()); assert_pending!(wait.poll()); drop(token); let token2 = tracker.token(); assert_pending!(wait.poll()); drop(token2); } #[test] fn close_during_exit() { const ITERS: usize = 5; for close_spot in 0..=ITERS { let tracker = TaskTracker::new(); let tokens: Vec<_> = (0..ITERS).map(|_| tracker.token()).collect(); let mut wait = task::spawn(tracker.wait()); for (i, token) in tokens.into_iter().enumerate() { assert_pending!(wait.poll()); if i == close_spot { tracker.close(); assert_pending!(wait.poll()); } drop(token); } if close_spot == ITERS { assert_pending!(wait.poll()); tracker.close(); } assert_ready!(wait.poll()); } } #[test] fn notify_many() { let tracker = TaskTracker::new(); let mut waits: Vec<_> = (0..10).map(|_| task::spawn(tracker.wait())).collect(); for wait in &mut waits { assert_pending!(wait.poll()); } tracker.close(); for wait in &mut waits { assert_ready!(wait.poll()); } } tokio-util-0.7.10/tests/time_delay_queue.rs000064400000000000000000000522201046102023000170420ustar 00000000000000#![allow(clippy::disallowed_names)] #![warn(rust_2018_idioms)] #![cfg(feature = "full")] use futures::StreamExt; use tokio::time::{self, sleep, sleep_until, Duration, Instant}; use tokio_test::{assert_pending, assert_ready, task}; use tokio_util::time::DelayQueue; macro_rules! poll { ($queue:ident) => { $queue.enter(|cx, mut queue| queue.poll_expired(cx)) }; } macro_rules! assert_ready_some { ($e:expr) => {{ match assert_ready!($e) { Some(v) => v, None => panic!("None"), } }}; } #[tokio::test] async fn single_immediate_delay() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _key = queue.insert_at("foo", Instant::now()); // Advance time by 1ms to handle thee rounding sleep(ms(1)).await; assert_ready_some!(poll!(queue)); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } #[tokio::test] async fn multi_immediate_delays() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _k = queue.insert_at("1", Instant::now()); let _k = queue.insert_at("2", Instant::now()); let _k = queue.insert_at("3", Instant::now()); sleep(ms(1)).await; let mut res = vec![]; while res.len() < 3 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); res.sort_unstable(); assert_eq!("1", res[0]); assert_eq!("2", res[1]); assert_eq!("3", res[2]); } #[tokio::test] async fn single_short_delay() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let _key = queue.insert_at("foo", Instant::now() + ms(5)); assert_pending!(poll!(queue)); sleep(ms(1)).await; assert!(!queue.is_woken()); sleep(ms(5)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)); assert_eq!(*entry.get_ref(), "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn multi_delay_at_start() { time::pause(); let long = 262_144 + 9 * 4096; let delays = &[1000, 2, 234, long, 60, 10]; let mut queue = task::spawn(DelayQueue::new()); // Setup the delays for &i in delays { let _key = queue.insert_at(i, Instant::now() + ms(i)); } assert_pending!(poll!(queue)); assert!(!queue.is_woken()); let start = Instant::now(); for elapsed in 0..1200 { println!("elapsed: {:?}", elapsed); let elapsed = elapsed + 1; tokio::time::sleep_until(start + ms(elapsed)).await; if delays.contains(&elapsed) { assert!(queue.is_woken()); assert_ready!(poll!(queue)); assert_pending!(poll!(queue)); } else if queue.is_woken() { let cascade = &[192, 960]; assert!( cascade.contains(&elapsed), "elapsed={} dt={:?}", elapsed, Instant::now() - start ); assert_pending!(poll!(queue)); } } println!("finished multi_delay_start"); } #[tokio::test] async fn insert_in_past_fires_immediately() { println!("running insert_in_past_fires_immediately"); time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(10)).await; queue.insert_at("foo", now); assert_ready!(poll!(queue)); println!("finished insert_in_past_fires_immediately"); } #[tokio::test] async fn remove_entry() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let key = queue.insert_at("foo", Instant::now() + ms(5)); assert_pending!(poll!(queue)); let entry = queue.remove(&key); assert_eq!(entry.into_inner(), "foo"); sleep(ms(10)).await; let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn reset_entry() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now + ms(5)); assert_pending!(poll!(queue)); sleep(ms(1)).await; queue.reset_at(&key, now + ms(10)); assert_pending!(poll!(queue)); sleep(ms(7)).await; assert!(!queue.is_woken()); assert_pending!(poll!(queue)); sleep(ms(3)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)); assert_eq!(*entry.get_ref(), "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } // Reproduces tokio-rs/tokio#849. #[tokio::test] async fn reset_much_later() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(1)).await; let key = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(3)).await; queue.reset_at(&key, now + ms(10)); sleep(ms(20)).await; assert!(queue.is_woken()); } // Reproduces tokio-rs/tokio#849. #[tokio::test] async fn reset_twice() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); sleep(ms(1)).await; let key = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(3)).await; queue.reset_at(&key, now + ms(50)); sleep(ms(20)).await; queue.reset_at(&key, now + ms(40)); sleep(ms(20)).await; assert!(queue.is_woken()); } /// Regression test: Given an entry inserted with a deadline in the past, so /// that it is placed directly on the expired queue, reset the entry to a /// deadline in the future. Validate that this leaves the entry and queue in an /// internally consistent state by running an additional reset on the entry /// before polling it to completion. #[tokio::test] async fn repeatedly_reset_entry_inserted_as_expired() { time::pause(); // Instants before the start of the test seem to break in wasm. time::sleep(ms(1000)).await; let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now - ms(100)); queue.reset_at(&key, now + ms(100)); queue.reset_at(&key, now + ms(50)); assert_pending!(poll!(queue)); time::sleep_until(now + ms(60)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test] async fn remove_expired_item() { time::pause(); let mut queue = DelayQueue::new(); let now = Instant::now(); sleep(ms(10)).await; let key = queue.insert_at("foo", now); let entry = queue.remove(&key); assert_eq!(entry.into_inner(), "foo"); } /// Regression test: it should be possible to remove entries which fall in the /// 0th slot of the internal timer wheel — that is, entries whose expiration /// (a) falls at the beginning of one of the wheel's hierarchical levels and (b) /// is equal to the wheel's current elapsed time. #[tokio::test] async fn remove_at_timer_wheel_threshold() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key1 = queue.insert_at("foo", now + ms(64)); let key2 = queue.insert_at("bar", now + ms(64)); sleep(ms(80)).await; let entry = assert_ready_some!(poll!(queue)).into_inner(); match entry { "foo" => { let entry = queue.remove(&key2).into_inner(); assert_eq!(entry, "bar"); } "bar" => { let entry = queue.remove(&key1).into_inner(); assert_eq!(entry, "foo"); } other => panic!("other: {:?}", other), } } #[tokio::test] async fn expires_before_last_insert() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo", now + ms(10_000)); // Delay should be set to 8.192s here. assert_pending!(poll!(queue)); // Delay should be set to the delay of the new item here queue.insert_at("bar", now + ms(600)); assert_pending!(poll!(queue)); sleep(ms(600)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "bar"); } #[tokio::test] async fn multi_reset() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); let two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(300)); queue.reset_at(&two, now + ms(350)); queue.reset_at(&one, now + ms(400)); sleep(ms(310)).await; assert_pending!(poll!(queue)); sleep(ms(50)).await; let entry = assert_ready_some!(poll!(queue)); assert_eq!(*entry.get_ref(), "two"); assert_pending!(poll!(queue)); sleep(ms(50)).await; let entry = assert_ready_some!(poll!(queue)); assert_eq!(*entry.get_ref(), "one"); let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()) } #[tokio::test] async fn expire_first_key_when_reset_to_expire_earlier() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(100)); sleep(ms(100)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "one"); } #[tokio::test] async fn expire_second_key_when_reset_to_expire_earlier() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("one", now + ms(200)); let two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&two, now + ms(100)); sleep(ms(100)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn reset_first_expiring_item_to_expire_later() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let one = queue.insert_at("one", now + ms(200)); let _two = queue.insert_at("two", now + ms(250)); assert_pending!(poll!(queue)); queue.reset_at(&one, now + ms(300)); sleep(ms(250)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn insert_before_first_after_poll() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let _one = queue.insert_at("one", now + ms(200)); assert_pending!(poll!(queue)); let _two = queue.insert_at("two", now + ms(100)); sleep(ms(99)).await; assert_pending!(poll!(queue)); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "two"); } #[tokio::test] async fn insert_after_ready_poll() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("1", now + ms(100)); queue.insert_at("2", now + ms(100)); queue.insert_at("3", now + ms(100)); assert_pending!(poll!(queue)); sleep(ms(100)).await; assert!(queue.is_woken()); let mut res = vec![]; while res.len() < 3 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); queue.insert_at("foo", now + ms(500)); } res.sort_unstable(); assert_eq!("1", res[0]); assert_eq!("2", res[1]); assert_eq!("3", res[2]); } #[tokio::test] async fn reset_later_after_slot_starts() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let foo = queue.insert_at("foo", now + ms(100)); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(80)).await; assert!(!queue.is_woken()); // At this point the queue hasn't been polled, so `elapsed` on the wheel // for the queue is still at 0 and hence the 1ms resolution slots cover // [0-64). Resetting the time on the entry to 120 causes it to get put in // the [64-128) slot. As the queue knows that the first entry is within // that slot, but doesn't know when, it must wake immediately to advance // the wheel. queue.reset_at(&foo, now + ms(120)); assert!(queue.is_woken()); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(119)).await; assert!(!queue.is_woken()); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); } #[tokio::test] async fn reset_inserted_expired() { time::pause(); // Instants before the start of the test seem to break in wasm. time::sleep(ms(1000)).await; let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now - ms(100)); // this causes the panic described in #2473 queue.reset_at(&key, now + ms(100)); assert_eq!(1, queue.len()); sleep(ms(200)).await; let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); assert_eq!(queue.len(), 0); } #[tokio::test] async fn reset_earlier_after_slot_starts() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let foo = queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(80)).await; assert!(!queue.is_woken()); // At this point the queue hasn't been polled, so `elapsed` on the wheel // for the queue is still at 0 and hence the 1ms resolution slots cover // [0-64). Resetting the time on the entry to 120 causes it to get put in // the [64-128) slot. As the queue knows that the first entry is within // that slot, but doesn't know when, it must wake immediately to advance // the wheel. queue.reset_at(&foo, now + ms(120)); assert!(queue.is_woken()); assert_pending!(poll!(queue)); sleep_until(now + Duration::from_millis(119)).await; assert!(!queue.is_woken()); sleep(ms(1)).await; assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "foo"); } #[tokio::test] async fn insert_in_past_after_poll_fires_immediately() { time::pause(); let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo", now + ms(200)); assert_pending!(poll!(queue)); sleep(ms(80)).await; assert!(!queue.is_woken()); queue.insert_at("bar", now + ms(40)); assert!(queue.is_woken()); let entry = assert_ready_some!(poll!(queue)).into_inner(); assert_eq!(entry, "bar"); } #[tokio::test] async fn delay_queue_poll_expired_when_empty() { let mut delay_queue = task::spawn(DelayQueue::new()); let key = delay_queue.insert(0, std::time::Duration::from_secs(10)); assert_pending!(poll!(delay_queue)); delay_queue.remove(&key); assert!(assert_ready!(poll!(delay_queue)).is_none()); } #[tokio::test(start_paused = true)] async fn compact_expire_empty() { let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo1", now + ms(10)); queue.insert_at("foo2", now + ms(10)); sleep(ms(10)).await; let mut res = vec![]; while res.len() < 2 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } queue.compact(); assert_eq!(queue.len(), 0); assert_eq!(queue.capacity(), 0); } #[tokio::test(start_paused = true)] async fn compact_remove_empty() { let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key1 = queue.insert_at("foo1", now + ms(10)); let key2 = queue.insert_at("foo2", now + ms(10)); queue.remove(&key1); queue.remove(&key2); queue.compact(); assert_eq!(queue.len(), 0); assert_eq!(queue.capacity(), 0); } #[tokio::test(start_paused = true)] // Trigger a re-mapping of keys in the slab due to a `compact` call and // test removal of re-mapped keys async fn compact_remove_remapped_keys() { let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); queue.insert_at("foo1", now + ms(10)); queue.insert_at("foo2", now + ms(10)); // should be assigned indices 3 and 4 let key3 = queue.insert_at("foo3", now + ms(20)); let key4 = queue.insert_at("foo4", now + ms(20)); sleep(ms(10)).await; let mut res = vec![]; while res.len() < 2 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } // items corresponding to `foo3` and `foo4` will be assigned // new indices here queue.compact(); queue.insert_at("foo5", now + ms(10)); // test removal of re-mapped keys let expired3 = queue.remove(&key3); let expired4 = queue.remove(&key4); assert_eq!(expired3.into_inner(), "foo3"); assert_eq!(expired4.into_inner(), "foo4"); queue.compact(); assert_eq!(queue.len(), 1); assert_eq!(queue.capacity(), 1); } #[tokio::test(start_paused = true)] async fn compact_change_deadline() { let mut queue = task::spawn(DelayQueue::new()); let mut now = Instant::now(); queue.insert_at("foo1", now + ms(10)); queue.insert_at("foo2", now + ms(10)); // should be assigned indices 3 and 4 queue.insert_at("foo3", now + ms(20)); let key4 = queue.insert_at("foo4", now + ms(20)); sleep(ms(10)).await; let mut res = vec![]; while res.len() < 2 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } // items corresponding to `foo3` and `foo4` should be assigned // new indices queue.compact(); now = Instant::now(); queue.insert_at("foo5", now + ms(10)); let key6 = queue.insert_at("foo6", now + ms(10)); queue.reset_at(&key4, now + ms(20)); queue.reset_at(&key6, now + ms(20)); // foo3 and foo5 will expire sleep(ms(10)).await; while res.len() < 4 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } sleep(ms(10)).await; while res.len() < 6 { let entry = assert_ready_some!(poll!(queue)); res.push(entry.into_inner()); } let entry = assert_ready!(poll!(queue)); assert!(entry.is_none()); } #[tokio::test(start_paused = true)] async fn item_expiry_greater_than_wheel() { // This function tests that a delay queue that has existed for at least 2^36 milliseconds won't panic when a new item is inserted. let mut queue = DelayQueue::new(); for _ in 0..2 { tokio::time::advance(Duration::from_millis(1 << 35)).await; queue.insert(0, Duration::from_millis(0)); queue.next().await; } // This should not panic let no_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { queue.insert(1, Duration::from_millis(1)); })); assert!(no_panic.is_ok()); } #[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")] #[tokio::test(start_paused = true)] async fn remove_after_compact() { let now = Instant::now(); let mut queue = DelayQueue::new(); let foo_key = queue.insert_at("foo", now + ms(10)); queue.insert_at("bar", now + ms(20)); queue.remove(&foo_key); queue.compact(); let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { queue.remove(&foo_key); })); assert!(panic.is_err()); } #[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")] #[tokio::test(start_paused = true)] async fn remove_after_compact_poll() { let now = Instant::now(); let mut queue = task::spawn(DelayQueue::new()); let foo_key = queue.insert_at("foo", now + ms(10)); queue.insert_at("bar", now + ms(20)); sleep(ms(10)).await; assert_eq!(assert_ready_some!(poll!(queue)).key(), foo_key); queue.compact(); let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { queue.remove(&foo_key); })); assert!(panic.is_err()); } #[tokio::test(start_paused = true)] async fn peek() { let mut queue = task::spawn(DelayQueue::new()); let now = Instant::now(); let key = queue.insert_at("foo", now + ms(5)); let key2 = queue.insert_at("bar", now); let key3 = queue.insert_at("baz", now + ms(10)); assert_eq!(queue.peek(), Some(key2)); sleep(ms(6)).await; assert_eq!(queue.peek(), Some(key2)); let entry = assert_ready_some!(poll!(queue)); assert_eq!(entry.get_ref(), &"bar"); assert_eq!(queue.peek(), Some(key)); let entry = assert_ready_some!(poll!(queue)); assert_eq!(entry.get_ref(), &"foo"); assert_eq!(queue.peek(), Some(key3)); assert_pending!(poll!(queue)); sleep(ms(5)).await; assert_eq!(queue.peek(), Some(key3)); let entry = assert_ready_some!(poll!(queue)); assert_eq!(entry.get_ref(), &"baz"); assert!(queue.peek().is_none()); } fn ms(n: u64) -> Duration { Duration::from_millis(n) } tokio-util-0.7.10/tests/udp.rs000064400000000000000000000076721046102023000143250ustar 00000000000000#![warn(rust_2018_idioms)] #![cfg(not(target_os = "wasi"))] // Wasi doesn't support UDP use tokio::net::UdpSocket; use tokio_stream::StreamExt; use tokio_util::codec::{Decoder, Encoder, LinesCodec}; use tokio_util::udp::UdpFramed; use bytes::{BufMut, BytesMut}; use futures::future::try_join; use futures::future::FutureExt; use futures::sink::SinkExt; use std::io; use std::sync::Arc; #[cfg_attr( any(target_os = "macos", target_os = "ios", target_os = "tvos"), allow(unused_assignments) )] #[tokio::test] async fn send_framed_byte_codec() -> std::io::Result<()> { let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; // test sending & receiving bytes { let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, ByteCodec); let msg = b"4567"; let send = a.send((msg, b_addr)); let recv = b.next().map(|e| e.unwrap()); let (_, received) = try_join(send, recv).await.unwrap(); let (data, addr) = received; assert_eq!(msg, &*data); assert_eq!(a_addr, addr); a_soc = a.into_inner(); b_soc = b.into_inner(); } #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))] // test sending & receiving an empty message { let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, ByteCodec); let msg = b""; let send = a.send((msg, b_addr)); let recv = b.next().map(|e| e.unwrap()); let (_, received) = try_join(send, recv).await.unwrap(); let (data, addr) = received; assert_eq!(msg, &*data); assert_eq!(a_addr, addr); } Ok(()) } pub struct ByteCodec; impl Decoder for ByteCodec { type Item = Vec; type Error = io::Error; fn decode(&mut self, buf: &mut BytesMut) -> Result>, io::Error> { let len = buf.len(); Ok(Some(buf.split_to(len).to_vec())) } } impl Encoder<&[u8]> for ByteCodec { type Error = io::Error; fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { buf.reserve(data.len()); buf.put_slice(data); Ok(()) } } #[tokio::test] async fn send_framed_lines_codec() -> std::io::Result<()> { let a_soc = UdpSocket::bind("127.0.0.1:0").await?; let b_soc = UdpSocket::bind("127.0.0.1:0").await?; let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, LinesCodec::new()); let msg = b"1\r\n2\r\n3\r\n".to_vec(); a.send((&msg, b_addr)).await?; assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); Ok(()) } #[tokio::test] async fn framed_half() -> std::io::Result<()> { let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); let b_soc = a_soc.clone(); let a_addr = a_soc.local_addr()?; let b_addr = b_soc.local_addr()?; let mut a = UdpFramed::new(a_soc, ByteCodec); let mut b = UdpFramed::new(b_soc, LinesCodec::new()); let msg = b"1\r\n2\r\n3\r\n".to_vec(); a.send((&msg, b_addr)).await?; let msg = b"4\r\n5\r\n6\r\n".to_vec(); a.send((&msg, b_addr)).await?; assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); Ok(()) }