tower-http-0.4.4/.cargo_vcs_info.json0000644000000001500000000000100131520ustar { "git": { "sha1": "466f0e0a55a0981e8a3b263a51b6f1e33162a28b" }, "path_in_vcs": "tower-http" }tower-http-0.4.4/CHANGELOG.md000064400000000000000000000345221046102023000135650ustar 00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). # Unreleased ## Added - None. ## Changed - None. ## Removed - None. ## Fixed - None. # 0.4.4 (September 1, 2023) ## Added - **trace**: Default implementations for trace bodies. # 0.4.3 (July 20, 2023) ## Fixed - **compression:** Fix accidental breaking change in 0.4.2. # 0.4.2 (July 19, 2023) ## Added - **cors:** Add support for private network preflights ([#373]) - **compression:** Implement `Default` for `DecompressionBody` ([#370]) ## Changed - **compression:** Update to async-compression 0.4 ([#371]) ## Fixed - **compression:** Override default brotli compression level 11 -> 4 ([#356]) - **trace:** Simplify dynamic tracing level application ([#380]) - **normalize_path:** Fix path normalization for preceding slashes ([#359]) [#356]: https://github.com/tower-rs/tower-http/pull/356 [#359]: https://github.com/tower-rs/tower-http/pull/359 [#370]: https://github.com/tower-rs/tower-http/pull/370 [#371]: https://github.com/tower-rs/tower-http/pull/371 [#373]: https://github.com/tower-rs/tower-http/pull/373 [#380]: https://github.com/tower-rs/tower-http/pull/380 # 0.4.1 (June 20, 2023) ## Added - **request_id:** Derive `Default` for `MakeRequestUuid` ([#335]) - **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336]) - **compression:** Expose compression quality on the CompressionLayer ([#333]) ## Fixed - **compression:** Improve parsing of `Accept-Encoding` request header ([#220]) - **normalize_path:** Fix path normalization of index route ([#347]) - **decompression:** Enable `multiple_members` for `GzipDecoder` ([#354]) [#347]: https://github.com/tower-rs/tower-http/pull/347 [#333]: https://github.com/tower-rs/tower-http/pull/333 [#220]: https://github.com/tower-rs/tower-http/pull/220 [#335]: https://github.com/tower-rs/tower-http/pull/335 [#336]: https://github.com/tower-rs/tower-http/pull/336 [#354]: https://github.com/tower-rs/tower-http/pull/354 # 0.4.0 (February 24, 2023) ## Added - **decompression:** Add `RequestDecompression` middleware ([#282]) - **compression:** Implement `Default` for `CompressionBody` ([#323]) - **compression, decompression:** Support zstd (de)compression ([#322]) ## Changed - **serve_dir:** `ServeDir` and `ServeFile`'s error types are now `Infallible` and any IO errors will be converted into responses. Use `try_call` to generate error responses manually (BREAKING) ([#283]) - **serve_dir:** `ServeDir::fallback` and `ServeDir::not_found_service` now requires the fallback service to use `Infallible` as its error type (BREAKING) ([#283]) - **compression, decompression:** Tweak prefered compression encodings ([#325]) ## Removed - Removed `RequireAuthorization` in favor of `ValidateRequest` (BREAKING) ([#290]) ## Fixed - **serve_dir:** Don't include identity in Content-Encoding header ([#317]) - **compression:** Do compress SVGs ([#321]) - **serve_dir:** In `ServeDir`, convert `io::ErrorKind::NotADirectory` to `404 Not Found` ([#331]) [#282]: https://github.com/tower-rs/tower-http/pull/282 [#283]: https://github.com/tower-rs/tower-http/pull/283 [#290]: https://github.com/tower-rs/tower-http/pull/290 [#317]: https://github.com/tower-rs/tower-http/pull/317 [#321]: https://github.com/tower-rs/tower-http/pull/321 [#322]: https://github.com/tower-rs/tower-http/pull/322 [#323]: https://github.com/tower-rs/tower-http/pull/323 [#325]: https://github.com/tower-rs/tower-http/pull/325 [#331]: https://github.com/tower-rs/tower-http/pull/331 # 0.3.5 (December 02, 2022) ## Added - Add `NormalizePath` middleware ([#275]) - Add `ValidateRequest` middleware ([#289]) - Add `RequestBodyTimeout` middleware ([#303]) ## Changed - Bump Minimum Supported Rust Version to 1.60 ([#299]) ## Fixed - **trace:** Correctly identify gRPC requests in default `on_response` callback ([#278]) - **cors:** Panic if a wildcard (`*`) is passed to `AllowOrigin::list`. Use `AllowOrigin::any()` instead ([#285]) - **serve_dir:** Call the fallback on non-uft8 request paths ([#310]) [#275]: https://github.com/tower-rs/tower-http/pull/275 [#278]: https://github.com/tower-rs/tower-http/pull/278 [#285]: https://github.com/tower-rs/tower-http/pull/285 [#289]: https://github.com/tower-rs/tower-http/pull/289 [#299]: https://github.com/tower-rs/tower-http/pull/299 [#303]: https://github.com/tower-rs/tower-http/pull/303 [#310]: https://github.com/tower-rs/tower-http/pull/310 # 0.3.4 (June 06, 2022) ## Added - Add `Timeout` middleware ([#270]) - Add `RequestBodyLimit` middleware ([#271]) [#270]: https://github.com/tower-rs/tower-http/pull/270 [#271]: https://github.com/tower-rs/tower-http/pull/271 # 0.3.3 (May 08, 2022) ## Added - **serve_dir:** Add `ServeDir::call_fallback_on_method_not_allowed` to allow calling the fallback for requests that aren't `GET` or `HEAD` ([#264]) - **request_id:** Add `MakeRequestUuid` for generating request ids using UUIDs ([#266]) [#264]: https://github.com/tower-rs/tower-http/pull/264 [#266]: https://github.com/tower-rs/tower-http/pull/266 ## Fixed - **serve_dir:** Include `Allow` header for `405 Method Not Allowed` responses ([#263]) [#263]: https://github.com/tower-rs/tower-http/pull/263 # 0.3.2 (April 29, 2022) ## Fixed - **serve_dir**: Fix empty request parts being passed to `ServeDir`'s fallback instead of the actual ones ([#258]) [#258]: https://github.com/tower-rs/tower-http/pull/258 # 0.3.1 (April 28, 2022) ## Fixed - **cors**: Only send a single origin in `Access-Control-Allow-Origin` header when a list of allowed origins is configured (the previous behavior of sending a comma-separated list like for allowed methods and allowed headers is not allowed by any standard) # 0.3.0 (April 25, 2022) ## Added - **fs**: Add `ServeDir::{fallback, not_found_service}` for calling another service if the file cannot be found ([#243]) - **fs**: Add `SetStatus` to override status codes ([#248]) - `ServeDir` and `ServeFile` now respond with `405 Method Not Allowed` to requests where the method isn't `GET` or `HEAD` ([#249]) - **cors**: Added `CorsLayer::very_permissive` which is like `CorsLayer::permissive` except it (truly) allows credentials. This is made possible by mirroring the request's origin as well as method and headers back as CORS-whitelisted ones ([#237]) - **cors**: Allow customizing the value(s) for the `Vary` header ([#237]) ## Changed - **cors**: Removed `allow-credentials: true` from `CorsLayer::permissive`. It never actually took effect in compliant browsers because it is mutually exclusive with the `*` wildcard (`Any`) on origins, methods and headers ([#237]) - **cors**: Rewrote the CORS middleware. Almost all existing usage patterns will continue to work. (BREAKING) ([#237]) - **cors**: The CORS middleware will now panic if you try to use `Any` in combination with `.allow_credentials(true)`. This configuration worked before, but resulted in browsers ignoring the `allow-credentials` header, which defeats the purpose of setting it and can be very annoying to debug ([#237]) ## Fixed - **fs**: Fix content-length calculation on range requests ([#228]) [#228]: https://github.com/tower-rs/tower-http/pull/228 [#237]: https://github.com/tower-rs/tower-http/pull/237 [#243]: https://github.com/tower-rs/tower-http/pull/243 [#248]: https://github.com/tower-rs/tower-http/pull/248 [#249]: https://github.com/tower-rs/tower-http/pull/249 # 0.2.4 (March 5, 2022) ## Added - Added `CatchPanic` middleware which catches panics and converts them into `500 Internal Server` responses ([#214]) ## Fixed - Make parsing of `Accept-Encoding` more robust ([#220]) [#214]: https://github.com/tower-rs/tower-http/pull/214 [#220]: https://github.com/tower-rs/tower-http/pull/220 # 0.2.3 (February 18, 2022) ## Changed - Update to tokio-util 0.7 ([#221]) ## Fixed - The CORS layer / service methods `allow_headers`, `allow_methods`, `allow_origin` and `expose_headers` now do nothing if given an empty `Vec`, instead of sending the respective header with an empty value ([#218]) [#218]: https://github.com/tower-rs/tower-http/pull/218 [#221]: https://github.com/tower-rs/tower-http/pull/221 # 0.2.2 (February 8, 2022) ## Fixed - Add `Vary` headers for CORS preflight responses ([#216]) [#216]: https://github.com/tower-rs/tower-http/pull/216 # 0.2.1 (January 21, 2022) ## Added - Support `Last-Modified` (and friends) headers in `ServeDir` and `ServeFile` ([#145]) - Add `AsyncRequireAuthorization::layer` ([#195]) ## Fixed - Fix build error for certain feature sets ([#209]) - `Cors`: Set `Vary` header ([#199]) - `ServeDir` and `ServeFile`: Fix potential directory traversal attack due to improper path validation on Windows ([#204]) [#145]: https://github.com/tower-rs/tower-http/pull/145 [#195]: https://github.com/tower-rs/tower-http/pull/195 [#199]: https://github.com/tower-rs/tower-http/pull/199 [#204]: https://github.com/tower-rs/tower-http/pull/204 [#209]: https://github.com/tower-rs/tower-http/pull/209 # 0.2.0 (December 1, 2021) ## Added - **builder**: Add `ServiceBuilderExt` which adds methods to `tower::ServiceBuilder` for adding middleware from tower-http ([#106]) - **request_id**: Add `SetRequestId` and `PropagateRequestId` middleware ([#150]) - **trace**: Add `DefaultMakeSpan::level` to make log level of tracing spans easily configurable ([#124]) - **trace**: Add `LatencyUnit::Seconds` for formatting latencies as seconds ([#179]) - **trace**: Support customizing which status codes are considered failures by `GrpcErrorsAsFailures` ([#189]) - **compression**: Support specifying predicates to choose when responses should be compressed. This can be used to disable compression of small responses, responses with a certain `content-type`, or something user defined ([#172]) - **fs**: Ability to serve precompressed files ([#156]) - **fs**: Support `Range` requests ([#173]) - **fs**: Properly support HEAD requests which return no body and have the `Content-Length` header set ([#169]) ## Changed - `AddAuthorization`, `InFlightRequests`, `SetRequestHeader`, `SetResponseHeader`, `AddExtension`, `MapRequestBody` and `MapResponseBody` now requires underlying service to use `http::Request` and `http::Response` as request and responses ([#182]) (BREAKING) - **set_header**: Remove unnecessary generic parameter from `SetRequestHeaderLayer` and `SetResponseHeaderLayer`. This removes the need (and possibility) to specify a body type for these layers ([#148]) (BREAKING) - **compression, decompression**: Change the response body error type to `Box`. This makes them usable if the body they're wrapping uses `Box` as its error type which they previously weren't ([#166]) (BREAKING) - **fs**: Change response body type of `ServeDir` and `ServeFile` to `ServeFileSystemResponseBody` and `ServeFileSystemResponseFuture` ([#187]) (BREAKING) - **auth**: Change `AuthorizeRequest` and `AsyncAuthorizeRequest` traits to be simpler ([#192]) (BREAKING) ## Removed - **compression, decompression**: Remove `BodyOrIoError`. Its been replaced with `Box` ([#166]) (BREAKING) - **compression, decompression**: Remove the `compression` and `decompression` feature. They were unnecessary and `compression-full`/`decompression-full` can be used to get full compression/decompression support. For more granular control, `[compression|decompression]-gzip`, `[compression|decompression]-br` and `[compression|decompression]-deflate` may be used instead ([#170]) (BREAKING) [#106]: https://github.com/tower-rs/tower-http/pull/106 [#124]: https://github.com/tower-rs/tower-http/pull/124 [#148]: https://github.com/tower-rs/tower-http/pull/148 [#150]: https://github.com/tower-rs/tower-http/pull/150 [#156]: https://github.com/tower-rs/tower-http/pull/156 [#166]: https://github.com/tower-rs/tower-http/pull/166 [#169]: https://github.com/tower-rs/tower-http/pull/169 [#170]: https://github.com/tower-rs/tower-http/pull/170 [#172]: https://github.com/tower-rs/tower-http/pull/172 [#173]: https://github.com/tower-rs/tower-http/pull/173 [#179]: https://github.com/tower-rs/tower-http/pull/179 [#182]: https://github.com/tower-rs/tower-http/pull/182 [#187]: https://github.com/tower-rs/tower-http/pull/187 [#189]: https://github.com/tower-rs/tower-http/pull/189 [#192]: https://github.com/tower-rs/tower-http/pull/192 # 0.1.2 (November 13, 2021) - New middleware: Add `Cors` for setting [CORS] headers ([#112]) - New middleware: Add `AsyncRequireAuthorization` ([#118]) - `Compression`: Don't recompress HTTP responses ([#140]) - `Compression` and `Decompression`: Pass configuration from layer into middleware ([#132]) - `ServeDir` and `ServeFile`: Improve performance ([#137]) - `Compression`: Remove needless `ResBody::Error: Into` bounds ([#117]) - `ServeDir`: Percent decode path segments ([#129]) - `ServeDir`: Use correct redirection status ([#130]) - `ServeDir`: Return `404 Not Found` on requests to directories if `append_index_html_on_directories` is set to `false` ([#122]) [#112]: https://github.com/tower-rs/tower-http/pull/112 [#118]: https://github.com/tower-rs/tower-http/pull/118 [#140]: https://github.com/tower-rs/tower-http/pull/140 [#132]: https://github.com/tower-rs/tower-http/pull/132 [#137]: https://github.com/tower-rs/tower-http/pull/137 [#117]: https://github.com/tower-rs/tower-http/pull/117 [#129]: https://github.com/tower-rs/tower-http/pull/129 [#130]: https://github.com/tower-rs/tower-http/pull/130 [#122]: https://github.com/tower-rs/tower-http/pull/122 # 0.1.1 (July 2, 2021) - Add example of using `SharedClassifier`. - Add `StatusInRangeAsFailures` which is a response classifier that considers responses with status code in a certain range as failures. Useful for HTTP clients where both server errors (5xx) and client errors (4xx) are considered failures. - Implement `Debug` for `NeverClassifyEos`. - Update iri-string to 0.4. - Add `ClassifyResponse::map_failure_class` and `ClassifyEos::map_failure_class` for transforming the failure classification using a function. - Clarify exactly when each `Trace` callback is called. - Add `AddAuthorizationLayer` for setting the `Authorization` header on requests. # 0.1.0 (May 27, 2021) - Initial release. [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS tower-http-0.4.4/Cargo.toml0000644000000125430000000000100111610ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2018" rust-version = "1.60" name = "tower-http" version = "0.4.4" authors = ["Tower Maintainers "] description = "Tower middleware and utilities for HTTP clients and servers" homepage = "https://github.com/tower-rs/tower-http" readme = "README.md" keywords = [ "io", "async", "futures", "service", "http", ] categories = [ "asynchronous", "network-programming", "web-programming", ] license = "MIT" repository = "https://github.com/tower-rs/tower-http" [package.metadata.cargo-public-api-crates] allowed = [ "bytes", "http", "http_body", "mime", "tokio", "tower", "tower_layer", "tower_service", "tracing", "tracing_core", ] [package.metadata.docs.rs] all-features = true rustdoc-args = [ "--cfg", "docsrs", ] [package.metadata.playground] features = ["full"] [dependencies.async-compression] version = "0.4" features = ["tokio"] optional = true [dependencies.base64] version = "0.21" optional = true [dependencies.bitflags] version = "2.0.2" [dependencies.bytes] version = "1" [dependencies.futures-core] version = "0.3" [dependencies.futures-util] version = "0.3.14" features = [] default_features = false [dependencies.http] version = "0.2.7" [dependencies.http-body] version = "0.4.5" [dependencies.http-range-header] version = "0.3.0" [dependencies.httpdate] version = "1.0" optional = true [dependencies.iri-string] version = "0.7.0" optional = true [dependencies.mime] version = "0.3.17" optional = true default_features = false [dependencies.mime_guess] version = "2" optional = true default_features = false [dependencies.percent-encoding] version = "2.1.0" optional = true [dependencies.pin-project-lite] version = "0.2.7" [dependencies.tokio] version = "1.6" optional = true default_features = false [dependencies.tokio-util] version = "0.7" features = ["io"] optional = true default_features = false [dependencies.tower] version = "0.4.1" optional = true [dependencies.tower-layer] version = "0.3" [dependencies.tower-service] version = "0.3" [dependencies.tracing] version = "0.1" optional = true default_features = false [dependencies.uuid] version = "1.0" features = ["v4"] optional = true [dev-dependencies.brotli] version = "3" [dev-dependencies.bytes] version = "1" [dev-dependencies.flate2] version = "1.0" [dev-dependencies.futures] version = "0.3" [dev-dependencies.hyper] version = "0.14" features = ["full"] [dev-dependencies.once_cell] version = "1" [dev-dependencies.serde_json] version = "1.0" [dev-dependencies.tokio] version = "1" features = ["full"] [dev-dependencies.tower] version = "0.4.10" features = [ "buffer", "util", "retry", "make", "timeout", ] [dev-dependencies.tracing-subscriber] version = "0.3" [dev-dependencies.uuid] version = "1.0" features = ["v4"] [dev-dependencies.zstd] version = "0.12" [features] add-extension = [] auth = [ "base64", "validate-request", ] catch-panic = [ "tracing", "futures-util/std", ] compression-br = [ "async-compression/brotli", "tokio-util", "tokio", ] compression-deflate = [ "async-compression/zlib", "tokio-util", "tokio", ] compression-full = [ "compression-br", "compression-deflate", "compression-gzip", "compression-zstd", ] compression-gzip = [ "async-compression/gzip", "tokio-util", "tokio", ] compression-zstd = [ "async-compression/zstd", "tokio-util", "tokio", ] cors = [] decompression-br = [ "async-compression/brotli", "tokio-util", "tokio", ] decompression-deflate = [ "async-compression/zlib", "tokio-util", "tokio", ] decompression-full = [ "decompression-br", "decompression-deflate", "decompression-gzip", "decompression-zstd", ] decompression-gzip = [ "async-compression/gzip", "tokio-util", "tokio", ] decompression-zstd = [ "async-compression/zstd", "tokio-util", "tokio", ] default = [] follow-redirect = [ "iri-string", "tower/util", ] fs = [ "tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing", ] full = [ "add-extension", "auth", "catch-panic", "compression-full", "cors", "decompression-full", "follow-redirect", "fs", "limit", "map-request-body", "map-response-body", "metrics", "normalize-path", "propagate-header", "redirect", "request-id", "sensitive-headers", "set-header", "set-status", "timeout", "trace", "util", "validate-request", ] limit = [] map-request-body = [] map-response-body = [] metrics = ["tokio/time"] normalize-path = [] propagate-header = [] redirect = [] request-id = ["uuid"] sensitive-headers = [] set-header = [] set-status = [] timeout = ["tokio/time"] trace = ["tracing"] util = ["tower"] validate-request = ["mime"] tower-http-0.4.4/Cargo.toml.orig000064400000000000000000000101111046102023000146270ustar 00000000000000[package] name = "tower-http" description = "Tower middleware and utilities for HTTP clients and servers" version = "0.4.4" authors = ["Tower Maintainers "] edition = "2018" license = "MIT" readme = "../README.md" repository = "https://github.com/tower-rs/tower-http" homepage = "https://github.com/tower-rs/tower-http" categories = ["asynchronous", "network-programming", "web-programming"] keywords = ["io", "async", "futures", "service", "http"] rust-version = "1.60" [dependencies] bitflags = "2.0.2" bytes = "1" futures-core = "0.3" futures-util = { version = "0.3.14", default_features = false, features = [] } http = "0.2.7" http-body = "0.4.5" pin-project-lite = "0.2.7" tower-layer = "0.3" tower-service = "0.3" # optional dependencies async-compression = { version = "0.4", optional = true, features = ["tokio"] } base64 = { version = "0.21", optional = true } http-range-header = "0.3.0" iri-string = { version = "0.7.0", optional = true } mime = { version = "0.3.17", optional = true, default_features = false } mime_guess = { version = "2", optional = true, default_features = false } percent-encoding = { version = "2.1.0", optional = true } tokio = { version = "1.6", optional = true, default_features = false } tokio-util = { version = "0.7", optional = true, default_features = false, features = ["io"] } tower = { version = "0.4.1", optional = true } tracing = { version = "0.1", default_features = false, optional = true } httpdate = { version = "1.0", optional = true } uuid = { version = "1.0", features = ["v4"], optional = true } [dev-dependencies] bytes = "1" flate2 = "1.0" brotli = "3" futures = "0.3" hyper = { version = "0.14", features = ["full"] } once_cell = "1" tokio = { version = "1", features = ["full"] } tower = { version = "0.4.10", features = ["buffer", "util", "retry", "make", "timeout"] } tracing-subscriber = "0.3" uuid = { version = "1.0", features = ["v4"] } serde_json = "1.0" zstd = "0.12" [features] default = [] full = [ "add-extension", "auth", "catch-panic", "compression-full", "cors", "decompression-full", "follow-redirect", "fs", "limit", "map-request-body", "map-response-body", "metrics", "normalize-path", "propagate-header", "redirect", "request-id", "sensitive-headers", "set-header", "set-status", "timeout", "trace", "util", "validate-request", ] add-extension = [] auth = ["base64", "validate-request"] catch-panic = ["tracing", "futures-util/std"] cors = [] follow-redirect = ["iri-string", "tower/util"] fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc", "tracing"] limit = [] map-request-body = [] map-response-body = [] metrics = ["tokio/time"] normalize-path = [] propagate-header = [] redirect = [] request-id = ["uuid"] sensitive-headers = [] set-header = [] set-status = [] timeout = ["tokio/time"] trace = ["tracing"] util = ["tower"] validate-request = ["mime"] compression-br = ["async-compression/brotli", "tokio-util", "tokio"] compression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] compression-full = ["compression-br", "compression-deflate", "compression-gzip", "compression-zstd"] compression-gzip = ["async-compression/gzip", "tokio-util", "tokio"] compression-zstd = ["async-compression/zstd", "tokio-util", "tokio"] decompression-br = ["async-compression/brotli", "tokio-util", "tokio"] decompression-deflate = ["async-compression/zlib", "tokio-util", "tokio"] decompression-full = ["decompression-br", "decompression-deflate", "decompression-gzip", "decompression-zstd"] decompression-gzip = ["async-compression/gzip", "tokio-util", "tokio"] decompression-zstd = ["async-compression/zstd", "tokio-util", "tokio"] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [package.metadata.playground] features = ["full"] [package.metadata.cargo-public-api-crates] allowed = [ "bytes", "http", "http_body", "mime", "tokio", "tower", "tower_layer", "tower_service", "tracing", "tracing_core", ] tower-http-0.4.4/LICENSE000064400000000000000000000020531046102023000127530ustar 00000000000000Copyright (c) 2019-2021 Tower Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. tower-http-0.4.4/README.md000064400000000000000000000063231046102023000132310ustar 00000000000000# Tower HTTP Tower middleware and utilities for HTTP clients and servers. [![Build status](https://github.com/tower-rs/tower-http/workflows/CI/badge.svg)](https://github.com/tower-rs/tower-http/actions) [![Crates.io](https://img.shields.io/crates/v/tower-http)](https://crates.io/crates/tower-http) [![Documentation](https://docs.rs/tower-http/badge.svg)](https://docs.rs/tower-http) [![Crates.io](https://img.shields.io/crates/l/tower-http)](tower-http/LICENSE) More information about this crate can be found in the [crate documentation][docs]. ## Middleware Tower HTTP contains lots of middleware that are generally useful when building HTTP servers and clients. Some of the highlights are: - `Trace` adds high level logging of requests and responses. Supports both regular HTTP requests as well as gRPC. - `Compression` and `Decompression` to compress/decompress response bodies. - `FollowRedirect` to automatically follow redirection responses. See the [docs] for the complete list of middleware. Middleware uses the [http] crate as the HTTP interface so they're compatible with any library or framework that also uses [http]. For example [hyper]. The middleware were originally extracted from one of [@EmbarkStudios] internal projects. ## Examples The [examples] folder contains various examples of how to use Tower HTTP: - [warp-key-value-store]: A key/value store with an HTTP API built with warp. - [tonic-key-value-store]: A key/value store with a gRPC API and client built with tonic. - [axum-key-value-store]: A key/value store with an HTTP API built with axum. ## Minimum supported Rust version tower-http's MSRV is 1.60. ## Getting Help If you're new to tower its [guides] might help. In the tower-http repo we also have a [number of examples][examples] showing how to put everything together. You're also welcome to ask in the [`#tower` Discord channel][chat] or open an [issue] with your question. ## Contributing :balloon: Thanks for your help improving the project! We are so happy to have you! We have a [contributing guide][guide] to help you get involved in the Tower HTTP project. [guide]: CONTRIBUTING.md ## License This project is licensed under the [MIT license](tower-http/LICENSE). ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in Tower HTTP by you, shall be licensed as MIT, without any additional terms or conditions. [@EmbarkStudios]: https://github.com/EmbarkStudios [examples]: https://github.com/tower-rs/tower-http/tree/master/examples [http]: https://crates.io/crates/http [tonic-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/tonic-key-value-store [warp-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/warp-key-value-store [axum-key-value-store]: https://github.com/tower-rs/tower-http/tree/master/examples/axum-key-value-store [chat]: https://discord.gg/tokio [docs]: https://docs.rs/tower-http [hyper]: https://github.com/hyperium/hyper [issue]: https://github.com/tower-rs/tower-http/issues/new [milestone]: https://github.com/tower-rs/tower-http/milestones [examples]: https://github.com/tower-rs/tower-http/tree/master/examples [guides]: https://github.com/tower-rs/tower/tree/master/guides tower-http-0.4.4/src/add_extension.rs000064400000000000000000000106541046102023000157350ustar 00000000000000//! Middleware that clones a value into each request's [extensions]. //! //! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html //! //! # Example //! //! ``` //! use tower_http::add_extension::AddExtensionLayer; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use http::{Request, Response}; //! use hyper::Body; //! use std::{sync::Arc, convert::Infallible}; //! //! # struct DatabaseConnectionPool; //! # impl DatabaseConnectionPool { //! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool } //! # } //! # //! // Shared state across all request handlers --- in this case, a pool of database connections. //! struct State { //! pool: DatabaseConnectionPool, //! } //! //! async fn handle(req: Request) -> Result, Infallible> { //! // Grab the state from the request extensions. //! let state = req.extensions().get::>().unwrap(); //! //! Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! // Construct the shared state. //! let state = State { //! pool: DatabaseConnectionPool::new(), //! }; //! //! let mut service = ServiceBuilder::new() //! // Share an `Arc` with all requests. //! .layer(AddExtensionLayer::new(Arc::new(state))) //! .service_fn(handle); //! //! // Call the service. //! let response = service //! .ready() //! .await? //! .call(Request::new(Body::empty())) //! .await?; //! # Ok(()) //! # } //! ``` use http::{Request, Response}; use std::task::{Context, Poll}; use tower_layer::Layer; use tower_service::Service; /// [`Layer`] for adding some shareable value to [request extensions]. /// /// See the [module docs](crate::add_extension) for more details. /// /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[derive(Clone, Copy, Debug)] pub struct AddExtensionLayer { value: T, } impl AddExtensionLayer { /// Create a new [`AddExtensionLayer`]. pub fn new(value: T) -> Self { AddExtensionLayer { value } } } impl Layer for AddExtensionLayer where T: Clone, { type Service = AddExtension; fn layer(&self, inner: S) -> Self::Service { AddExtension { inner, value: self.value.clone(), } } } /// Middleware for adding some shareable value to [request extensions]. /// /// See the [module docs](crate::add_extension) for more details. /// /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[derive(Clone, Copy, Debug)] pub struct AddExtension { inner: S, value: T, } impl AddExtension { /// Create a new [`AddExtension`]. pub fn new(inner: S, value: T) -> Self { Self { inner, value } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `AddExtension` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(value: T) -> AddExtensionLayer { AddExtensionLayer::new(value) } } impl Service> for AddExtension where S: Service, Response = Response>, T: Clone + Send + Sync + 'static, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { req.extensions_mut().insert(self.value.clone()); self.inner.call(req) } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::Response; use hyper::Body; use std::{convert::Infallible, sync::Arc}; use tower::{service_fn, ServiceBuilder, ServiceExt}; struct State(i32); #[tokio::test] async fn basic() { let state = Arc::new(State(1)); let svc = ServiceBuilder::new() .layer(AddExtensionLayer::new(state)) .service(service_fn(|req: Request| async move { let state = req.extensions().get::>().unwrap(); Ok::<_, Infallible>(Response::new(state.0)) })); let res = svc .oneshot(Request::new(Body::empty())) .await .unwrap() .into_body(); assert_eq!(1, res); } } tower-http-0.4.4/src/auth/add_authorization.rs000064400000000000000000000211371046102023000175600ustar 00000000000000//! Add authorization to requests using the [`Authorization`] header. //! //! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization //! //! # Example //! //! ``` //! use tower_http::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer}; //! use tower_http::auth::AddAuthorizationLayer; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::AUTHORIZATION}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! # async fn handle(request: Request) -> Result, Error> { //! # Ok(Response::new(Body::empty())) //! # } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let service_that_requires_auth = ValidateRequestHeader::basic( //! # tower::service_fn(handle), //! # "username", //! # "password", //! # ); //! let mut client = ServiceBuilder::new() //! // Use basic auth with the given username and password //! .layer(AddAuthorizationLayer::basic("username", "password")) //! .service(service_that_requires_auth); //! //! // Make a request, we don't have to add the `Authorization` header manually //! let response = client //! .ready() //! .await? //! .call(Request::new(Body::empty())) //! .await?; //! //! assert_eq!(StatusCode::OK, response.status()); //! # Ok(()) //! # } //! ``` use base64::Engine as _; use http::{HeaderValue, Request, Response}; use std::{ convert::TryFrom, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; /// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the /// [`Authorization`] header. /// /// See the [module docs](crate::auth::add_authorization) for an example. /// /// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this /// middleware. /// /// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization /// [`SetRequestHeader`]: crate::set_header::SetRequestHeader #[derive(Debug, Clone)] pub struct AddAuthorizationLayer { value: HeaderValue, } impl AddAuthorizationLayer { /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is /// `base64_encode("{username}:{password}")`. /// /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS /// with this method. However use of HTTPS/TLS is not enforced by this middleware. pub fn basic(username: &str, password: &str) -> Self { let encoded = BASE64.encode(format!("{}:{}", username, password)); let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap(); Self { value } } /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header will be set to `Bearer {token}`. /// /// # Panics /// /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(token: &str) -> Self { let value = HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header"); Self { value } } /// Mark the header as [sensitive]. /// /// This can for example be used to hide the header value from logs. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { self.value.set_sensitive(sensitive); self } } impl Layer for AddAuthorizationLayer { type Service = AddAuthorization; fn layer(&self, inner: S) -> Self::Service { AddAuthorization { inner, value: self.value.clone(), } } } /// Middleware that adds authorization all requests using the [`Authorization`] header. /// /// See the [module docs](crate::auth::add_authorization) for an example. /// /// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this /// middleware. /// /// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization /// [`SetRequestHeader`]: crate::set_header::SetRequestHeader #[derive(Debug, Clone)] pub struct AddAuthorization { inner: S, value: HeaderValue, } impl AddAuthorization { /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is /// `base64_encode("{username}:{password}")`. /// /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS /// with this method. However use of HTTPS/TLS is not enforced by this middleware. pub fn basic(inner: S, username: &str, password: &str) -> Self { AddAuthorizationLayer::basic(username, password).layer(inner) } /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header will be set to `Bearer {token}`. /// /// # Panics /// /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(inner: S, token: &str) -> Self { AddAuthorizationLayer::bearer(token).layer(inner) } define_inner_service_accessors!(); /// Mark the header as [sensitive]. /// /// This can for example be used to hide the header value from logs. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { self.value.set_sensitive(sensitive); self } } impl Service> for AddAuthorization where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { req.headers_mut() .insert(http::header::AUTHORIZATION, self.value.clone()); self.inner.call(req) } } #[cfg(test)] mod tests { use crate::validate_request::ValidateRequestHeaderLayer; #[allow(unused_imports)] use super::*; use http::{Response, StatusCode}; use hyper::Body; use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; #[tokio::test] async fn basic() { // service that requires auth for all requests let svc = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) .service_fn(echo); // make a client that adds auth let mut client = AddAuthorization::basic(svc, "foo", "bar"); let res = client .ready() .await .unwrap() .call(Request::new(Body::empty())) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn token() { // service that requires auth for all requests let svc = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foo")) .service_fn(echo); // make a client that adds auth let mut client = AddAuthorization::bearer(svc, "foo"); let res = client .ready() .await .unwrap() .call(Request::new(Body::empty())) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn making_header_sensitive() { let svc = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foo")) .service_fn(|request: Request| async move { let auth = request.headers().get(http::header::AUTHORIZATION).unwrap(); assert!(auth.is_sensitive()); Ok::<_, hyper::Error>(Response::new(Body::empty())) }); let mut client = AddAuthorization::bearer(svc, "foo").as_sensitive(true); let res = client .ready() .await .unwrap() .call(Request::new(Body::empty())) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } } tower-http-0.4.4/src/auth/async_require_authorization.rs000064400000000000000000000271121046102023000217000ustar 00000000000000//! Authorize requests using the [`Authorization`] header asynchronously. //! //! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization //! //! # Example //! //! ``` //! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::AUTHORIZATION}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use futures_util::future::BoxFuture; //! //! #[derive(Clone, Copy)] //! struct MyAuth; //! //! impl AsyncAuthorizeRequest for MyAuth //! where //! B: Send + Sync + 'static, //! { //! type RequestBody = B; //! type ResponseBody = Body; //! type Future = BoxFuture<'static, Result, Response>>; //! //! fn authorize(&mut self, mut request: Request) -> Self::Future { //! Box::pin(async { //! if let Some(user_id) = check_auth(&request).await { //! // Set `user_id` as a request extension so it can be accessed by other //! // services down the stack. //! request.extensions_mut().insert(user_id); //! //! Ok(request) //! } else { //! let unauthorized_response = Response::builder() //! .status(StatusCode::UNAUTHORIZED) //! .body(Body::empty()) //! .unwrap(); //! //! Err(unauthorized_response) //! } //! }) //! } //! } //! //! async fn check_auth(request: &Request) -> Option { //! // ... //! # None //! } //! //! #[derive(Debug)] //! struct UserId(String); //! //! async fn handle(request: Request) -> Result, Error> { //! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the //! // request was authorized and `UserId` will be present. //! let user_id = request //! .extensions() //! .get::() //! .expect("UserId will be there if request was authorized"); //! //! println!("request from {:?}", user_id); //! //! Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() //! // Authorize requests using `MyAuth` //! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! Or using a closure: //! //! ``` //! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; //! use hyper::{Request, Response, Body, Error}; //! use http::StatusCode; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use futures_util::future::BoxFuture; //! //! async fn check_auth(request: &Request) -> Option { //! // ... //! # None //! } //! //! #[derive(Debug)] //! struct UserId(String); //! //! async fn handle(request: Request) -> Result, Error> { //! # todo!(); //! // ... //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() //! .layer(AsyncRequireAuthorizationLayer::new(|request: Request| async move { //! if let Some(user_id) = check_auth(&request).await { //! Ok(request) //! } else { //! let unauthorized_response = Response::builder() //! .status(StatusCode::UNAUTHORIZED) //! .body(Body::empty()) //! .unwrap(); //! //! Err(unauthorized_response) //! } //! })) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` use futures_core::ready; use http::{Request, Response}; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the /// [`Authorization`] header. /// /// See the [module docs](crate::auth::async_require_authorization) for an example. /// /// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization #[derive(Debug, Clone)] pub struct AsyncRequireAuthorizationLayer { auth: T, } impl AsyncRequireAuthorizationLayer { /// Authorize requests using a custom scheme. pub fn new(auth: T) -> AsyncRequireAuthorizationLayer { Self { auth } } } impl Layer for AsyncRequireAuthorizationLayer where T: Clone, { type Service = AsyncRequireAuthorization; fn layer(&self, inner: S) -> Self::Service { AsyncRequireAuthorization::new(inner, self.auth.clone()) } } /// Middleware that authorizes all requests using the [`Authorization`] header. /// /// See the [module docs](crate::auth::async_require_authorization) for an example. /// /// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization #[derive(Clone, Debug)] pub struct AsyncRequireAuthorization { inner: S, auth: T, } impl AsyncRequireAuthorization { define_inner_service_accessors!(); } impl AsyncRequireAuthorization { /// Authorize requests using a custom scheme. /// /// The `Authorization` header is required to have the value provided. pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization { Self { inner, auth } } /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`] /// middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer { AsyncRequireAuthorizationLayer::new(auth) } } impl Service> for AsyncRequireAuthorization where Auth: AsyncAuthorizeRequest, S: Service, Response = Response> + Clone, { type Response = Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let inner = self.inner.clone(); let authorize = self.auth.authorize(req); ResponseFuture { state: State::Authorize { authorize }, service: inner, } } } pin_project! { /// Response future for [`AsyncRequireAuthorization`]. pub struct ResponseFuture where Auth: AsyncAuthorizeRequest, S: Service>, { #[pin] state: State, service: S, } } pin_project! { #[project = StateProj] enum State { Authorize { #[pin] authorize: A, }, Authorized { #[pin] fut: SFut, }, } } impl Future for ResponseFuture where Auth: AsyncAuthorizeRequest, S: Service, Response = Response>, { type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); loop { match this.state.as_mut().project() { StateProj::Authorize { authorize } => { let auth = ready!(authorize.poll(cx)); match auth { Ok(req) => { let fut = this.service.call(req); this.state.set(State::Authorized { fut }) } Err(res) => { return Poll::Ready(Ok(res)); } }; } StateProj::Authorized { fut } => { return fut.poll(cx); } } } } } /// Trait for authorizing requests. pub trait AsyncAuthorizeRequest { /// The type of request body returned by `authorize`. /// /// Set this to `B` unless you need to change the request body type. type RequestBody; /// The body type used for responses to unauthorized requests. type ResponseBody; /// The Future type returned by `authorize` type Future: Future, Response>>; /// Authorize the request. /// /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. fn authorize(&mut self, request: Request) -> Self::Future; } impl AsyncAuthorizeRequest for F where F: FnMut(Request) -> Fut, Fut: Future, Response>>, { type RequestBody = ReqBody; type ResponseBody = ResBody; type Future = Fut; fn authorize(&mut self, request: Request) -> Self::Future { self(request) } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use futures_util::future::BoxFuture; use http::{header, StatusCode}; use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; #[derive(Clone, Copy)] struct MyAuth; impl AsyncAuthorizeRequest for MyAuth where B: Send + 'static, { type RequestBody = B; type ResponseBody = Body; type Future = BoxFuture<'static, Result, Response>>; fn authorize(&mut self, mut request: Request) -> Self::Future { Box::pin(async move { let authorized = request .headers() .get(header::AUTHORIZATION) .and_then(|it| it.to_str().ok()) .and_then(|it| it.strip_prefix("Bearer ")) .map(|it| it == "69420") .unwrap_or(false); if authorized { let user_id = UserId("6969".to_owned()); request.extensions_mut().insert(user_id); Ok(request) } else { Err(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::empty()) .unwrap()) } }) } } #[derive(Debug)] struct UserId(String); #[tokio::test] async fn require_async_auth_works() { let mut service = ServiceBuilder::new() .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "Bearer 69420") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn require_async_auth_401() { let mut service = ServiceBuilder::new() .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "Bearer deez") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } } tower-http-0.4.4/src/auth/mod.rs000064400000000000000000000005571046102023000146320ustar 00000000000000//! Authorization related middleware. pub mod add_authorization; pub mod async_require_authorization; pub mod require_authorization; #[doc(inline)] pub use self::{ add_authorization::{AddAuthorization, AddAuthorizationLayer}, async_require_authorization::{ AsyncAuthorizeRequest, AsyncRequireAuthorization, AsyncRequireAuthorizationLayer, }, }; tower-http-0.4.4/src/auth/require_authorization.rs000064400000000000000000000301161046102023000205010ustar 00000000000000//! Authorize requests using [`ValidateRequest`]. //! //! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization //! //! # Example //! //! ``` //! use tower_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::AUTHORIZATION}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! //! async fn handle(request: Request) -> Result, Error> { //! Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut service = ServiceBuilder::new() //! // Require the `Authorization` header to be `Bearer passwordlol` //! .layer(ValidateRequestHeaderLayer::bearer("passwordlol")) //! .service_fn(handle); //! //! // Requests with the correct token are allowed through //! let request = Request::builder() //! .header(AUTHORIZATION, "Bearer passwordlol") //! .body(Body::empty()) //! .unwrap(); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!(StatusCode::OK, response.status()); //! //! // Requests with an invalid token get a `401 Unauthorized` response //! let request = Request::builder() //! .body(Body::empty()) //! .unwrap(); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!(StatusCode::UNAUTHORIZED, response.status()); //! # Ok(()) //! # } //! ``` //! //! Custom validation can be made by implementing [`ValidateRequest`]. use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; use base64::Engine as _; use http::{ header::{self, HeaderValue}, Request, Response, StatusCode, }; use http_body::Body; use std::{fmt, marker::PhantomData}; const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; impl ValidateRequestHeader> { /// Authorize requests using a username and password pair. /// /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is /// `base64_encode("{username}:{password}")`. /// /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS /// with this method. However use of HTTPS/TLS is not enforced by this middleware. pub fn basic(inner: S, username: &str, value: &str) -> Self where ResBody: Body + Default, { Self::custom(inner, Basic::new(username, value)) } } impl ValidateRequestHeaderLayer> { /// Authorize requests using a username and password pair. /// /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is /// `base64_encode("{username}:{password}")`. /// /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS /// with this method. However use of HTTPS/TLS is not enforced by this middleware. pub fn basic(username: &str, password: &str) -> Self where ResBody: Body + Default, { Self::custom(Basic::new(username, password)) } } impl ValidateRequestHeader> { /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header is required to be `Bearer {token}`. /// /// # Panics /// /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(inner: S, token: &str) -> Self where ResBody: Body + Default, { Self::custom(inner, Bearer::new(token)) } } impl ValidateRequestHeaderLayer> { /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header is required to be `Bearer {token}`. /// /// # Panics /// /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(token: &str) -> Self where ResBody: Body + Default, { Self::custom(Bearer::new(token)) } } /// Type that performs "bearer token" authorization. /// /// See [`ValidateRequestHeader::bearer`] for more details. pub struct Bearer { header_value: HeaderValue, _ty: PhantomData ResBody>, } impl Bearer { fn new(token: &str) -> Self where ResBody: Body + Default, { Self { header_value: format!("Bearer {}", token) .parse() .expect("token is not a valid header value"), _ty: PhantomData, } } } impl Clone for Bearer { fn clone(&self) -> Self { Self { header_value: self.header_value.clone(), _ty: PhantomData, } } } impl fmt::Debug for Bearer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Bearer") .field("header_value", &self.header_value) .finish() } } impl ValidateRequest for Bearer where ResBody: Body + Default, { type ResponseBody = ResBody; fn validate(&mut self, request: &mut Request) -> Result<(), Response> { match request.headers().get(header::AUTHORIZATION) { Some(actual) if actual == self.header_value => Ok(()), _ => { let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; Err(res) } } } } /// Type that performs basic authorization. /// /// See [`ValidateRequestHeader::basic`] for more details. pub struct Basic { header_value: HeaderValue, _ty: PhantomData ResBody>, } impl Basic { fn new(username: &str, password: &str) -> Self where ResBody: Body + Default, { let encoded = BASE64.encode(format!("{}:{}", username, password)); let header_value = format!("Basic {}", encoded).parse().unwrap(); Self { header_value, _ty: PhantomData, } } } impl Clone for Basic { fn clone(&self) -> Self { Self { header_value: self.header_value.clone(), _ty: PhantomData, } } } impl fmt::Debug for Basic { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Basic") .field("header_value", &self.header_value) .finish() } } impl ValidateRequest for Basic where ResBody: Body + Default, { type ResponseBody = ResBody; fn validate(&mut self, request: &mut Request) -> Result<(), Response> { match request.headers().get(header::AUTHORIZATION) { Some(actual) if actual == self.header_value => Ok(()), _ => { let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; res.headers_mut() .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap()); Err(res) } } } } #[cfg(test)] mod tests { use crate::validate_request::ValidateRequestHeaderLayer; #[allow(unused_imports)] use super::*; use http::header; use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; use tower_service::Service; #[tokio::test] async fn valid_basic_token() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) .service_fn(echo); let request = Request::get("/") .header( header::AUTHORIZATION, format!("Basic {}", BASE64.encode("foo:bar")), ) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn invalid_basic_token() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) .service_fn(echo); let request = Request::get("/") .header( header::AUTHORIZATION, format!("Basic {}", BASE64.encode("wrong:credentials")), ) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap(); assert_eq!(www_authenticate, "Basic"); } #[tokio::test] async fn valid_bearer_token() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foobar")) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "Bearer foobar") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn basic_auth_is_case_sensitive_in_prefix() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) .service_fn(echo); let request = Request::get("/") .header( header::AUTHORIZATION, format!("basic {}", BASE64.encode("foo:bar")), ) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn basic_auth_is_case_sensitive_in_value() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::basic("foo", "bar")) .service_fn(echo); let request = Request::get("/") .header( header::AUTHORIZATION, format!("Basic {}", BASE64.encode("Foo:bar")), ) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn invalid_bearer_token() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foobar")) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "Bearer wat") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn bearer_token_is_case_sensitive_in_prefix() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foobar")) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "bearer foobar") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn bearer_token_is_case_sensitive_in_token() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::bearer("foobar")) .service_fn(echo); let request = Request::get("/") .header(header::AUTHORIZATION, "Bearer Foobar") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } } tower-http-0.4.4/src/builder.rs000064400000000000000000000475701046102023000145460ustar 00000000000000use tower::ServiceBuilder; #[cfg(feature = "trace")] use crate::classify::{GrpcErrorsAsFailures, ServerErrorsAsFailures, SharedClassifier}; #[allow(unused_imports)] use http::header::HeaderName; #[allow(unused_imports)] use tower_layer::Stack; /// Extension trait that adds methods to [`tower::ServiceBuilder`] for adding middleware from /// tower-http. /// /// [`Service`]: tower::Service /// /// # Example /// /// ```rust /// use http::{Request, Response, header::HeaderName}; /// use hyper::Body; /// use std::{time::Duration, convert::Infallible}; /// use tower::{ServiceBuilder, ServiceExt, Service}; /// use tower_http::ServiceBuilderExt; /// /// async fn handle(request: Request) -> Result, Infallible> { /// Ok(Response::new(Body::empty())) /// } /// /// # #[tokio::main] /// # async fn main() { /// let service = ServiceBuilder::new() /// // Methods from tower /// .timeout(Duration::from_secs(30)) /// // Methods from tower-http /// .trace_for_http() /// .compression() /// .propagate_header(HeaderName::from_static("x-request-id")) /// .service_fn(handle); /// # let mut service = service; /// # service.ready().await.unwrap().call(Request::new(Body::empty())).await.unwrap(); /// # } /// ``` #[cfg(feature = "util")] // ^ work around rustdoc not inferring doc(cfg)s for cfg's from surrounding scopes pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { /// Propagate a header from the request to the response. /// /// See [`tower_http::propagate_header`] for more details. /// /// [`tower_http::propagate_header`]: crate::propagate_header #[cfg(feature = "propagate-header")] fn propagate_header( self, header: HeaderName, ) -> ServiceBuilder>; /// Add some shareable value to [request extensions]. /// /// See [`tower_http::add_extension`] for more details. /// /// [`tower_http::add_extension`]: crate::add_extension /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[cfg(feature = "add-extension")] fn add_extension( self, value: T, ) -> ServiceBuilder, L>>; /// Apply a transformation to the request body. /// /// See [`tower_http::map_request_body`] for more details. /// /// [`tower_http::map_request_body`]: crate::map_request_body #[cfg(feature = "map-request-body")] fn map_request_body( self, f: F, ) -> ServiceBuilder, L>>; /// Apply a transformation to the response body. /// /// See [`tower_http::map_response_body`] for more details. /// /// [`tower_http::map_response_body`]: crate::map_response_body #[cfg(feature = "map-response-body")] fn map_response_body( self, f: F, ) -> ServiceBuilder, L>>; /// Compresses response bodies. /// /// See [`tower_http::compression`] for more details. /// /// [`tower_http::compression`]: crate::compression #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", ))] fn compression(self) -> ServiceBuilder>; /// Decompress response bodies. /// /// See [`tower_http::decompression`] for more details. /// /// [`tower_http::decompression`]: crate::decompression #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", ))] fn decompression(self) -> ServiceBuilder>; /// High level tracing that classifies responses using HTTP status codes. /// /// This method does not support customizing the output, to do that use [`TraceLayer`] /// instead. /// /// See [`tower_http::trace`] for more details. /// /// [`tower_http::trace`]: crate::trace /// [`TraceLayer`]: crate::trace::TraceLayer #[cfg(feature = "trace")] fn trace_for_http( self, ) -> ServiceBuilder>, L>>; /// High level tracing that classifies responses using gRPC headers. /// /// This method does not support customizing the output, to do that use [`TraceLayer`] /// instead. /// /// See [`tower_http::trace`] for more details. /// /// [`tower_http::trace`]: crate::trace /// [`TraceLayer`]: crate::trace::TraceLayer #[cfg(feature = "trace")] fn trace_for_grpc( self, ) -> ServiceBuilder>, L>>; /// Follow redirect resposes using the [`Standard`] policy. /// /// See [`tower_http::follow_redirect`] for more details. /// /// [`tower_http::follow_redirect`]: crate::follow_redirect /// [`Standard`]: crate::follow_redirect::policy::Standard #[cfg(feature = "follow-redirect")] fn follow_redirects( self, ) -> ServiceBuilder< Stack< crate::follow_redirect::FollowRedirectLayer, L, >, >; /// Mark headers as [sensitive] on both requests and responses. /// /// See [`tower_http::sensitive_headers`] for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive /// [`tower_http::sensitive_headers`]: crate::sensitive_headers #[cfg(feature = "sensitive-headers")] fn sensitive_headers( self, headers: I, ) -> ServiceBuilder> where I: IntoIterator; /// Mark headers as [sensitive] on both requests. /// /// See [`tower_http::sensitive_headers`] for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive /// [`tower_http::sensitive_headers`]: crate::sensitive_headers #[cfg(feature = "sensitive-headers")] fn sensitive_request_headers( self, headers: std::sync::Arc<[HeaderName]>, ) -> ServiceBuilder>; /// Mark headers as [sensitive] on both responses. /// /// See [`tower_http::sensitive_headers`] for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive /// [`tower_http::sensitive_headers`]: crate::sensitive_headers #[cfg(feature = "sensitive-headers")] fn sensitive_response_headers( self, headers: std::sync::Arc<[HeaderName]>, ) -> ServiceBuilder>; /// Insert a header into the request. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn override_request_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Append a header into the request. /// /// If previous values exist, the header will have multiple values. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn append_request_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Insert a header into the request, if the header is not already present. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn insert_request_header_if_not_present( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Insert a header into the response. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn override_response_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Append a header into the response. /// /// If previous values exist, the header will have multiple values. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn append_response_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Insert a header into the response, if the header is not already present. /// /// See [`tower_http::set_header`] for more details. /// /// [`tower_http::set_header`]: crate::set_header #[cfg(feature = "set-header")] fn insert_response_header_if_not_present( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>>; /// Add request id header and extension. /// /// See [`tower_http::request_id`] for more details. /// /// [`tower_http::request_id`]: crate::request_id #[cfg(feature = "request-id")] fn set_request_id( self, header_name: HeaderName, make_request_id: M, ) -> ServiceBuilder, L>> where M: crate::request_id::MakeRequestId; /// Add request id header and extension, using `x-request-id` as the header name. /// /// See [`tower_http::request_id`] for more details. /// /// [`tower_http::request_id`]: crate::request_id #[cfg(feature = "request-id")] fn set_x_request_id( self, make_request_id: M, ) -> ServiceBuilder, L>> where M: crate::request_id::MakeRequestId, { self.set_request_id( HeaderName::from_static(crate::request_id::X_REQUEST_ID), make_request_id, ) } /// Propgate request ids from requests to responses. /// /// See [`tower_http::request_id`] for more details. /// /// [`tower_http::request_id`]: crate::request_id #[cfg(feature = "request-id")] fn propagate_request_id( self, header_name: HeaderName, ) -> ServiceBuilder>; /// Propgate request ids from requests to responses, using `x-request-id` as the header name. /// /// See [`tower_http::request_id`] for more details. /// /// [`tower_http::request_id`]: crate::request_id #[cfg(feature = "request-id")] fn propagate_x_request_id( self, ) -> ServiceBuilder> { self.propagate_request_id(HeaderName::from_static(crate::request_id::X_REQUEST_ID)) } /// Catch panics and convert them into `500 Internal Server` responses. /// /// See [`tower_http::catch_panic`] for more details. /// /// [`tower_http::catch_panic`]: crate::catch_panic #[cfg(feature = "catch-panic")] fn catch_panic( self, ) -> ServiceBuilder< Stack, L>, >; /// Intercept requests with over-sized payloads and convert them into /// `413 Payload Too Large` responses. /// /// See [`tower_http::limit`] for more details. /// /// [`tower_http::limit`]: crate::limit #[cfg(feature = "limit")] fn request_body_limit( self, limit: usize, ) -> ServiceBuilder>; /// Remove trailing slashes from paths. /// /// See [`tower_http::normalize_path`] for more details. /// /// [`tower_http::normalize_path`]: crate::normalize_path #[cfg(feature = "normalize-path")] fn trim_trailing_slash( self, ) -> ServiceBuilder>; } impl crate::sealed::Sealed for ServiceBuilder {} impl ServiceBuilderExt for ServiceBuilder { #[cfg(feature = "propagate-header")] fn propagate_header( self, header: HeaderName, ) -> ServiceBuilder> { self.layer(crate::propagate_header::PropagateHeaderLayer::new(header)) } #[cfg(feature = "add-extension")] fn add_extension( self, value: T, ) -> ServiceBuilder, L>> { self.layer(crate::add_extension::AddExtensionLayer::new(value)) } #[cfg(feature = "map-request-body")] fn map_request_body( self, f: F, ) -> ServiceBuilder, L>> { self.layer(crate::map_request_body::MapRequestBodyLayer::new(f)) } #[cfg(feature = "map-response-body")] fn map_response_body( self, f: F, ) -> ServiceBuilder, L>> { self.layer(crate::map_response_body::MapResponseBodyLayer::new(f)) } #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", ))] fn compression(self) -> ServiceBuilder> { self.layer(crate::compression::CompressionLayer::new()) } #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", ))] fn decompression(self) -> ServiceBuilder> { self.layer(crate::decompression::DecompressionLayer::new()) } #[cfg(feature = "trace")] fn trace_for_http( self, ) -> ServiceBuilder>, L>> { self.layer(crate::trace::TraceLayer::new_for_http()) } #[cfg(feature = "trace")] fn trace_for_grpc( self, ) -> ServiceBuilder>, L>> { self.layer(crate::trace::TraceLayer::new_for_grpc()) } #[cfg(feature = "follow-redirect")] fn follow_redirects( self, ) -> ServiceBuilder< Stack< crate::follow_redirect::FollowRedirectLayer, L, >, > { self.layer(crate::follow_redirect::FollowRedirectLayer::new()) } #[cfg(feature = "sensitive-headers")] fn sensitive_headers( self, headers: I, ) -> ServiceBuilder> where I: IntoIterator, { self.layer(crate::sensitive_headers::SetSensitiveHeadersLayer::new( headers, )) } #[cfg(feature = "sensitive-headers")] fn sensitive_request_headers( self, headers: std::sync::Arc<[HeaderName]>, ) -> ServiceBuilder> { self.layer(crate::sensitive_headers::SetSensitiveRequestHeadersLayer::from_shared(headers)) } #[cfg(feature = "sensitive-headers")] fn sensitive_response_headers( self, headers: std::sync::Arc<[HeaderName]>, ) -> ServiceBuilder> { self.layer(crate::sensitive_headers::SetSensitiveResponseHeadersLayer::from_shared(headers)) } #[cfg(feature = "set-header")] fn override_request_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetRequestHeaderLayer::overriding( header_name, make, )) } #[cfg(feature = "set-header")] fn append_request_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetRequestHeaderLayer::appending( header_name, make, )) } #[cfg(feature = "set-header")] fn insert_request_header_if_not_present( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetRequestHeaderLayer::if_not_present( header_name, make, )) } #[cfg(feature = "set-header")] fn override_response_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetResponseHeaderLayer::overriding( header_name, make, )) } #[cfg(feature = "set-header")] fn append_response_header( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetResponseHeaderLayer::appending( header_name, make, )) } #[cfg(feature = "set-header")] fn insert_response_header_if_not_present( self, header_name: HeaderName, make: M, ) -> ServiceBuilder, L>> { self.layer(crate::set_header::SetResponseHeaderLayer::if_not_present( header_name, make, )) } #[cfg(feature = "request-id")] fn set_request_id( self, header_name: HeaderName, make_request_id: M, ) -> ServiceBuilder, L>> where M: crate::request_id::MakeRequestId, { self.layer(crate::request_id::SetRequestIdLayer::new( header_name, make_request_id, )) } #[cfg(feature = "request-id")] fn propagate_request_id( self, header_name: HeaderName, ) -> ServiceBuilder> { self.layer(crate::request_id::PropagateRequestIdLayer::new(header_name)) } #[cfg(feature = "catch-panic")] fn catch_panic( self, ) -> ServiceBuilder< Stack, L>, > { self.layer(crate::catch_panic::CatchPanicLayer::new()) } #[cfg(feature = "limit")] fn request_body_limit( self, limit: usize, ) -> ServiceBuilder> { self.layer(crate::limit::RequestBodyLimitLayer::new(limit)) } #[cfg(feature = "normalize-path")] fn trim_trailing_slash( self, ) -> ServiceBuilder> { self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash()) } } tower-http-0.4.4/src/catch_panic.rs000064400000000000000000000272061046102023000153460ustar 00000000000000//! Convert panics into responses. //! //! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result` //! whenever possible. //! //! # Example //! //! ```rust //! use http::{Request, Response, header::HeaderName}; //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request) -> Result, Infallible> { //! panic!("something went wrong...") //! } //! //! let mut svc = ServiceBuilder::new() //! // Catch panics and convert them into responses. //! .layer(CatchPanicLayer::new()) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), 500); //! # //! # Ok(()) //! # } //! ``` //! //! Using a custom panic handler: //! //! ```rust //! use http::{Request, StatusCode, Response, header::{self, HeaderName}}; //! use std::{any::Any, convert::Infallible}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request) -> Result, Infallible> { //! panic!("something went wrong...") //! } //! //! fn handle_panic(err: Box) -> Response { //! let details = if let Some(s) = err.downcast_ref::() { //! s.clone() //! } else if let Some(s) = err.downcast_ref::<&str>() { //! s.to_string() //! } else { //! "Unknown panic message".to_string() //! }; //! //! let body = serde_json::json!({ //! "error": { //! "kind": "panic", //! "details": details, //! } //! }); //! let body = serde_json::to_string(&body).unwrap(); //! //! Response::builder() //! .status(StatusCode::INTERNAL_SERVER_ERROR) //! .header(header::CONTENT_TYPE, "application/json") //! .body(Body::from(body)) //! .unwrap() //! } //! //! let svc = ServiceBuilder::new() //! // Use `handle_panic` to create the response. //! .layer(CatchPanicLayer::custom(handle_panic)) //! .service_fn(handle); //! # //! # Ok(()) //! # } //! ``` use bytes::Bytes; use futures_core::ready; use futures_util::future::{CatchUnwind, FutureExt}; use http::{HeaderValue, Request, Response, StatusCode}; use http_body::{combinators::UnsyncBoxBody, Body, Full}; use pin_project_lite::pin_project; use std::{ any::Any, future::Future, panic::AssertUnwindSafe, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; use crate::BoxError; /// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into /// `500 Internal Server` responses. /// /// See the [module docs](self) for an example. #[derive(Debug, Clone, Copy, Default)] pub struct CatchPanicLayer { panic_handler: T, } impl CatchPanicLayer { /// Create a new `CatchPanicLayer` with the default panic handler. pub fn new() -> Self { CatchPanicLayer { panic_handler: DefaultResponseForPanic, } } } impl CatchPanicLayer { /// Create a new `CatchPanicLayer` with a custom panic handler. pub fn custom(panic_handler: T) -> Self where T: ResponseForPanic, { Self { panic_handler } } } impl Layer for CatchPanicLayer where T: Clone, { type Service = CatchPanic; fn layer(&self, inner: S) -> Self::Service { CatchPanic { inner, panic_handler: self.panic_handler.clone(), } } } /// Middleware that catches panics and converts them into `500 Internal Server` responses. /// /// See the [module docs](self) for an example. #[derive(Debug, Clone, Copy)] pub struct CatchPanic { inner: S, panic_handler: T, } impl CatchPanic { /// Create a new `CatchPanic` with the default panic handler. pub fn new(inner: S) -> Self { Self { inner, panic_handler: DefaultResponseForPanic, } } } impl CatchPanic { define_inner_service_accessors!(); /// Create a new `CatchPanic` with a custom panic handler. pub fn custom(inner: S, panic_handler: T) -> Self where T: ResponseForPanic, { Self { inner, panic_handler, } } } impl Service> for CatchPanic where S: Service, Response = Response>, ResBody: Body + Send + 'static, ResBody::Error: Into, T: ResponseForPanic + Clone, T::ResponseBody: Body + Send + 'static, ::Error: Into, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) { Ok(future) => ResponseFuture { kind: Kind::Future { future: AssertUnwindSafe(future).catch_unwind(), panic_handler: Some(self.panic_handler.clone()), }, }, Err(panic_err) => ResponseFuture { kind: Kind::Panicked { panic_err: Some(panic_err), panic_handler: Some(self.panic_handler.clone()), }, }, } } } pin_project! { /// Response future for [`CatchPanic`]. pub struct ResponseFuture { #[pin] kind: Kind, } } pin_project! { #[project = KindProj] enum Kind { Panicked { panic_err: Option>, panic_handler: Option, }, Future { #[pin] future: CatchUnwind>, panic_handler: Option, } } } impl Future for ResponseFuture where F: Future, E>>, ResBody: Body + Send + 'static, ResBody::Error: Into, T: ResponseForPanic, T::ResponseBody: Body + Send + 'static, ::Error: Into, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { KindProj::Panicked { panic_err, panic_handler, } => { let panic_handler = panic_handler .take() .expect("future polled after completion"); let panic_err = panic_err.take().expect("future polled after completion"); Poll::Ready(Ok(response_for_panic(panic_handler, panic_err))) } KindProj::Future { future, panic_handler, } => match ready!(future.poll(cx)) { Ok(Ok(res)) => { Poll::Ready(Ok(res.map(|body| body.map_err(Into::into).boxed_unsync()))) } Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)), Err(panic_err) => Poll::Ready(Ok(response_for_panic( panic_handler .take() .expect("future polled after completion"), panic_err, ))), }, } } } fn response_for_panic( mut panic_handler: T, err: Box, ) -> Response> where T: ResponseForPanic, T::ResponseBody: Body + Send + 'static, ::Error: Into, { panic_handler .response_for_panic(err) .map(|body| body.map_err(Into::into).boxed_unsync()) } /// Trait for creating responses from panics. pub trait ResponseForPanic: Clone { /// The body type used for responses to panics. type ResponseBody; /// Create a response from the panic error. fn response_for_panic( &mut self, err: Box, ) -> Response; } impl ResponseForPanic for F where F: FnMut(Box) -> Response + Clone, { type ResponseBody = B; fn response_for_panic( &mut self, err: Box, ) -> Response { self(err) } } /// The default `ResponseForPanic` used by `CatchPanic`. /// /// It will log the panic message and return a `500 Internal Server` error response with an empty /// body. #[derive(Debug, Default, Clone, Copy)] #[non_exhaustive] pub struct DefaultResponseForPanic; impl ResponseForPanic for DefaultResponseForPanic { type ResponseBody = Full; fn response_for_panic( &mut self, err: Box, ) -> Response { if let Some(s) = err.downcast_ref::() { tracing::error!("Service panicked: {}", s); } else if let Some(s) = err.downcast_ref::<&str>() { tracing::error!("Service panicked: {}", s); } else { tracing::error!( "Service panicked but `CatchPanic` was unable to downcast the panic info" ); }; let mut res = Response::new(Full::from("Service panicked")); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; #[allow(clippy::declare_interior_mutable_const)] const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); res.headers_mut() .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); res } } #[cfg(test)] mod tests { #![allow(unreachable_code)] use super::*; use hyper::{Body, Response}; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn panic_before_returning_future() { let svc = ServiceBuilder::new() .layer(CatchPanicLayer::new()) .service_fn(|_: Request| { panic!("service panic"); async { Ok::<_, Infallible>(Response::new(Body::empty())) } }); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); let body = hyper::body::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } #[tokio::test] async fn panic_in_future() { let svc = ServiceBuilder::new() .layer(CatchPanicLayer::new()) .service_fn(|_: Request| async { panic!("future panic"); Ok::<_, Infallible>(Response::new(Body::empty())) }); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); let body = hyper::body::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } } tower-http-0.4.4/src/classify/grpc_errors_as_failures.rs000064400000000000000000000301131046102023000216220ustar 00000000000000use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse, SharedClassifier}; use bitflags::bitflags; use http::{HeaderMap, Response}; use std::{fmt, num::NonZeroI32}; /// gRPC status codes. Used in [`GrpcErrorsAsFailures`]. /// /// These variants match the [gRPC status codes]. /// /// [gRPC status codes]: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc #[derive(Clone, Copy, Debug)] pub enum GrpcCode { /// The operation completed successfully. Ok, /// The operation was cancelled. Cancelled, /// Unknown error. Unknown, /// Client specified an invalid argument. InvalidArgument, /// Deadline expired before operation could complete. DeadlineExceeded, /// Some requested entity was not found. NotFound, /// Some entity that we attempted to create already exists. AlreadyExists, /// The caller does not have permission to execute the specified operation. PermissionDenied, /// Some resource has been exhausted. ResourceExhausted, /// The system is not in a state required for the operation's execution. FailedPrecondition, /// The operation was aborted. Aborted, /// Operation was attempted past the valid range. OutOfRange, /// Operation is not implemented or not supported. Unimplemented, /// Internal error. Internal, /// The service is currently unavailable. Unavailable, /// Unrecoverable data loss or corruption. DataLoss, /// The request does not have valid authentication credentials Unauthenticated, } impl GrpcCode { pub(crate) fn into_bitmask(self) -> GrpcCodeBitmask { match self { Self::Ok => GrpcCodeBitmask::OK, Self::Cancelled => GrpcCodeBitmask::CANCELLED, Self::Unknown => GrpcCodeBitmask::UNKNOWN, Self::InvalidArgument => GrpcCodeBitmask::INVALID_ARGUMENT, Self::DeadlineExceeded => GrpcCodeBitmask::DEADLINE_EXCEEDED, Self::NotFound => GrpcCodeBitmask::NOT_FOUND, Self::AlreadyExists => GrpcCodeBitmask::ALREADY_EXISTS, Self::PermissionDenied => GrpcCodeBitmask::PERMISSION_DENIED, Self::ResourceExhausted => GrpcCodeBitmask::RESOURCE_EXHAUSTED, Self::FailedPrecondition => GrpcCodeBitmask::FAILED_PRECONDITION, Self::Aborted => GrpcCodeBitmask::ABORTED, Self::OutOfRange => GrpcCodeBitmask::OUT_OF_RANGE, Self::Unimplemented => GrpcCodeBitmask::UNIMPLEMENTED, Self::Internal => GrpcCodeBitmask::INTERNAL, Self::Unavailable => GrpcCodeBitmask::UNAVAILABLE, Self::DataLoss => GrpcCodeBitmask::DATA_LOSS, Self::Unauthenticated => GrpcCodeBitmask::UNAUTHENTICATED, } } } bitflags! { #[derive(Debug, Clone, Copy)] pub(crate) struct GrpcCodeBitmask: u32 { const OK = 0b00000000000000001; const CANCELLED = 0b00000000000000010; const UNKNOWN = 0b00000000000000100; const INVALID_ARGUMENT = 0b00000000000001000; const DEADLINE_EXCEEDED = 0b00000000000010000; const NOT_FOUND = 0b00000000000100000; const ALREADY_EXISTS = 0b00000000001000000; const PERMISSION_DENIED = 0b00000000010000000; const RESOURCE_EXHAUSTED = 0b00000000100000000; const FAILED_PRECONDITION = 0b00000001000000000; const ABORTED = 0b00000010000000000; const OUT_OF_RANGE = 0b00000100000000000; const UNIMPLEMENTED = 0b00001000000000000; const INTERNAL = 0b00010000000000000; const UNAVAILABLE = 0b00100000000000000; const DATA_LOSS = 0b01000000000000000; const UNAUTHENTICATED = 0b10000000000000000; } } impl GrpcCodeBitmask { fn try_from_u32(code: u32) -> Option { match code { 0 => Some(Self::OK), 1 => Some(Self::CANCELLED), 2 => Some(Self::UNKNOWN), 3 => Some(Self::INVALID_ARGUMENT), 4 => Some(Self::DEADLINE_EXCEEDED), 5 => Some(Self::NOT_FOUND), 6 => Some(Self::ALREADY_EXISTS), 7 => Some(Self::PERMISSION_DENIED), 8 => Some(Self::RESOURCE_EXHAUSTED), 9 => Some(Self::FAILED_PRECONDITION), 10 => Some(Self::ABORTED), 11 => Some(Self::OUT_OF_RANGE), 12 => Some(Self::UNIMPLEMENTED), 13 => Some(Self::INTERNAL), 14 => Some(Self::UNAVAILABLE), 15 => Some(Self::DATA_LOSS), 16 => Some(Self::UNAUTHENTICATED), _ => None, } } } /// Response classifier for gRPC responses. /// /// gRPC doesn't use normal HTTP statuses for indicating success or failure but instead a special /// header that might appear in a trailer. /// /// Responses are considered successful if /// /// - `grpc-status` header value matches [`GrpcErrorsAsFailures`] (only `Ok` by /// default). /// - `grpc-status` header is missing. /// - `grpc-status` header value isn't a valid `String`. /// - `grpc-status` header value can't parsed into an `i32`. /// /// All others are considered failures. #[derive(Debug, Clone)] pub struct GrpcErrorsAsFailures { success_codes: GrpcCodeBitmask, } impl Default for GrpcErrorsAsFailures { fn default() -> Self { Self::new() } } impl GrpcErrorsAsFailures { /// Create a new [`GrpcErrorsAsFailures`]. pub fn new() -> Self { Self { success_codes: GrpcCodeBitmask::OK, } } /// Change which gRPC codes are considered success. /// /// Defaults to only considering `Ok` as success. /// /// `Ok` will always be considered a success. /// /// # Example /// /// Servers might not want to consider `Invalid Argument` or `Not Found` as failures since /// thats likely the clients fault: /// /// ```rust /// use tower_http::classify::{GrpcErrorsAsFailures, GrpcCode}; /// /// let classifier = GrpcErrorsAsFailures::new() /// .with_success(GrpcCode::InvalidArgument) /// .with_success(GrpcCode::NotFound); /// ``` pub fn with_success(mut self, code: GrpcCode) -> Self { self.success_codes |= code.into_bitmask(); self } /// Returns a [`MakeClassifier`](super::MakeClassifier) that produces `GrpcErrorsAsFailures`. /// /// This is a convenience function that simply calls `SharedClassifier::new`. pub fn make_classifier() -> SharedClassifier { SharedClassifier::new(Self::new()) } } impl ClassifyResponse for GrpcErrorsAsFailures { type FailureClass = GrpcFailureClass; type ClassifyEos = GrpcEosErrorsAsFailures; fn classify_response( self, res: &Response, ) -> ClassifiedResponse { match classify_grpc_metadata(res.headers(), self.success_codes) { ParsedGrpcStatus::Success | ParsedGrpcStatus::HeaderNotString | ParsedGrpcStatus::HeaderNotInt => ClassifiedResponse::Ready(Ok(())), ParsedGrpcStatus::NonSuccess(status) => { ClassifiedResponse::Ready(Err(GrpcFailureClass::Code(status))) } ParsedGrpcStatus::GrpcStatusHeaderMissing => { ClassifiedResponse::RequiresEos(GrpcEosErrorsAsFailures { success_codes: self.success_codes, }) } } } fn classify_error(self, error: &E) -> Self::FailureClass where E: fmt::Display + 'static, { GrpcFailureClass::Error(error.to_string()) } } /// The [`ClassifyEos`] for [`GrpcErrorsAsFailures`]. #[derive(Debug, Clone)] pub struct GrpcEosErrorsAsFailures { success_codes: GrpcCodeBitmask, } impl ClassifyEos for GrpcEosErrorsAsFailures { type FailureClass = GrpcFailureClass; fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { if let Some(trailers) = trailers { match classify_grpc_metadata(trailers, self.success_codes) { ParsedGrpcStatus::Success | ParsedGrpcStatus::GrpcStatusHeaderMissing | ParsedGrpcStatus::HeaderNotString | ParsedGrpcStatus::HeaderNotInt => Ok(()), ParsedGrpcStatus::NonSuccess(status) => Err(GrpcFailureClass::Code(status)), } } else { Ok(()) } } fn classify_error(self, error: &E) -> Self::FailureClass where E: fmt::Display + 'static, { GrpcFailureClass::Error(error.to_string()) } } impl Default for GrpcEosErrorsAsFailures { fn default() -> Self { Self { success_codes: GrpcCodeBitmask::OK, } } } /// The failure class for [`GrpcErrorsAsFailures`]. #[derive(Debug)] pub enum GrpcFailureClass { /// A gRPC response was classified as a failure with the corresponding status. Code(std::num::NonZeroI32), /// A gRPC response was classified as an error with the corresponding error description. Error(String), } impl fmt::Display for GrpcFailureClass { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Code(code) => write!(f, "Code: {}", code), Self::Error(error) => write!(f, "Error: {}", error), } } } #[allow(clippy::if_let_some_result)] pub(crate) fn classify_grpc_metadata( headers: &HeaderMap, success_codes: GrpcCodeBitmask, ) -> ParsedGrpcStatus { macro_rules! or_else { ($expr:expr, $other:ident) => { if let Some(value) = $expr { value } else { return ParsedGrpcStatus::$other; } }; } let status = or_else!(headers.get("grpc-status"), GrpcStatusHeaderMissing); let status = or_else!(status.to_str().ok(), HeaderNotString); let status = or_else!(status.parse::().ok(), HeaderNotInt); if GrpcCodeBitmask::try_from_u32(status as _) .filter(|code| success_codes.contains(*code)) .is_some() { ParsedGrpcStatus::Success } else { ParsedGrpcStatus::NonSuccess(NonZeroI32::new(status).unwrap()) } } #[derive(Debug, PartialEq, Eq)] pub(crate) enum ParsedGrpcStatus { Success, NonSuccess(NonZeroI32), GrpcStatusHeaderMissing, // these two are treated as `Success` but kept separate for clarity HeaderNotString, HeaderNotInt, } #[cfg(test)] mod tests { use super::*; macro_rules! classify_grpc_metadata_test { ( name: $name:ident, status: $status:expr, success_flags: $success_flags:expr, expected: $expected:expr, ) => { #[test] fn $name() { let mut headers = HeaderMap::new(); headers.insert("grpc-status", $status.parse().unwrap()); let status = classify_grpc_metadata(&headers, $success_flags); assert_eq!(status, $expected); } }; } classify_grpc_metadata_test! { name: basic_ok, status: "0", success_flags: GrpcCodeBitmask::OK, expected: ParsedGrpcStatus::Success, } classify_grpc_metadata_test! { name: basic_error, status: "1", success_flags: GrpcCodeBitmask::OK, expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(1).unwrap()), } classify_grpc_metadata_test! { name: two_success_codes_first_matches, status: "0", success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, expected: ParsedGrpcStatus::Success, } classify_grpc_metadata_test! { name: two_success_codes_second_matches, status: "3", success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, expected: ParsedGrpcStatus::Success, } classify_grpc_metadata_test! { name: two_success_codes_none_matches, status: "16", success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT, expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(16).unwrap()), } } tower-http-0.4.4/src/classify/map_failure_class.rs000064400000000000000000000043741046102023000204010ustar 00000000000000use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse}; use http::{HeaderMap, Response}; use std::fmt; /// Response classifier that transforms the failure class of some other /// classifier. /// /// Created with [`ClassifyResponse::map_failure_class`] or /// [`ClassifyEos::map_failure_class`]. #[derive(Clone, Copy)] pub struct MapFailureClass { inner: C, f: F, } impl MapFailureClass { pub(super) fn new(classify: C, f: F) -> Self { Self { inner: classify, f } } } impl fmt::Debug for MapFailureClass where C: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapFailureClass") .field("inner", &self.inner) .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl ClassifyResponse for MapFailureClass where C: ClassifyResponse, F: FnOnce(C::FailureClass) -> NewClass, { type FailureClass = NewClass; type ClassifyEos = MapFailureClass; fn classify_response( self, res: &Response, ) -> ClassifiedResponse { match self.inner.classify_response(res) { ClassifiedResponse::Ready(result) => ClassifiedResponse::Ready(result.map_err(self.f)), ClassifiedResponse::RequiresEos(classify_eos) => { let mapped_classify_eos = MapFailureClass::new(classify_eos, self.f); ClassifiedResponse::RequiresEos(mapped_classify_eos) } } } fn classify_error(self, error: &E) -> Self::FailureClass where E: std::fmt::Display + 'static, { (self.f)(self.inner.classify_error(error)) } } impl ClassifyEos for MapFailureClass where C: ClassifyEos, F: FnOnce(C::FailureClass) -> NewClass, { type FailureClass = NewClass; fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { self.inner.classify_eos(trailers).map_err(self.f) } fn classify_error(self, error: &E) -> Self::FailureClass where E: std::fmt::Display + 'static, { (self.f)(self.inner.classify_error(error)) } } tower-http-0.4.4/src/classify/mod.rs000064400000000000000000000345661046102023000155150ustar 00000000000000//! Tools for classifying responses as either success or failure. use http::{HeaderMap, Request, Response, StatusCode}; use std::{convert::Infallible, fmt, marker::PhantomData}; pub(crate) mod grpc_errors_as_failures; mod map_failure_class; mod status_in_range_is_error; pub use self::{ grpc_errors_as_failures::{ GrpcCode, GrpcEosErrorsAsFailures, GrpcErrorsAsFailures, GrpcFailureClass, }, map_failure_class::MapFailureClass, status_in_range_is_error::{StatusInRangeAsFailures, StatusInRangeFailureClass}, }; /// Trait for producing response classifiers from a request. /// /// This is useful when a classifier depends on data from the request. For example, this could /// include the URI or HTTP method. /// /// This trait is generic over the [`Error` type] of the `Service`s used with the classifier. /// This is necessary for [`ClassifyResponse::classify_error`]. /// /// [`Error` type]: https://docs.rs/tower/latest/tower/trait.Service.html#associatedtype.Error pub trait MakeClassifier { /// The response classifier produced. type Classifier: ClassifyResponse< FailureClass = Self::FailureClass, ClassifyEos = Self::ClassifyEos, >; /// The type of failure classifications. /// /// This might include additional information about the error, such as /// whether it was a client or server error, or whether or not it should /// be considered retryable. type FailureClass; /// The type used to classify the response end of stream (EOS). type ClassifyEos: ClassifyEos; /// Returns a response classifier for this request fn make_classifier(&self, req: &Request) -> Self::Classifier; } /// A [`MakeClassifier`] that produces new classifiers by cloning an inner classifier. /// /// When a type implementing [`ClassifyResponse`] doesn't depend on information /// from the request, [`SharedClassifier`] can be used to turn an instance of that type /// into a [`MakeClassifier`]. /// /// # Example /// /// ``` /// use std::fmt; /// use tower_http::classify::{ /// ClassifyResponse, ClassifiedResponse, NeverClassifyEos, /// SharedClassifier, MakeClassifier, /// }; /// use http::Response; /// /// // A response classifier that only considers errors to be failures. /// #[derive(Clone, Copy)] /// struct MyClassifier; /// /// impl ClassifyResponse for MyClassifier { /// type FailureClass = String; /// type ClassifyEos = NeverClassifyEos; /// /// fn classify_response( /// self, /// _res: &Response, /// ) -> ClassifiedResponse { /// ClassifiedResponse::Ready(Ok(())) /// } /// /// fn classify_error(self, error: &E) -> Self::FailureClass /// where /// E: fmt::Display + 'static, /// { /// error.to_string() /// } /// } /// /// // Some function that requires a `MakeClassifier` /// fn use_make_classifier(make: M) { /// // ... /// } /// /// // `MyClassifier` doesn't implement `MakeClassifier` but since it doesn't /// // care about the incoming request we can make `MyClassifier`s by cloning. /// // That is what `SharedClassifier` does. /// let make_classifier = SharedClassifier::new(MyClassifier); /// /// // We now have a `MakeClassifier`! /// use_make_classifier(make_classifier); /// ``` #[derive(Debug, Clone)] pub struct SharedClassifier { classifier: C, } impl SharedClassifier { /// Create a new `SharedClassifier` from the given classifier. pub fn new(classifier: C) -> Self where C: ClassifyResponse + Clone, { Self { classifier } } } impl MakeClassifier for SharedClassifier where C: ClassifyResponse + Clone, { type FailureClass = C::FailureClass; type ClassifyEos = C::ClassifyEos; type Classifier = C; fn make_classifier(&self, _req: &Request) -> Self::Classifier { self.classifier.clone() } } /// Trait for classifying responses as either success or failure. Designed to support both unary /// requests (single request for a single response) as well as streaming responses. /// /// Response classifiers are used in cases where middleware needs to determine /// whether a response completed successfully or failed. For example, they may /// be used by logging or metrics middleware to record failures differently /// from successes. /// /// Furthermore, when a response fails, a response classifier may provide /// additional information about the failure. This can, for example, be used to /// build [retry policies] by indicating whether or not a particular failure is /// retryable. /// /// [retry policies]: https://docs.rs/tower/latest/tower/retry/trait.Policy.html pub trait ClassifyResponse { /// The type returned when a response is classified as a failure. /// /// Depending on the classifier, this may simply indicate that the /// request failed, or it may contain additional information about /// the failure, such as whether or not it is retryable. type FailureClass; /// The type used to classify the response end of stream (EOS). type ClassifyEos: ClassifyEos; /// Attempt to classify the beginning of a response. /// /// In some cases, the response can be classified immediately, without /// waiting for a body to complete. This may include: /// /// - When the response has an error status code. /// - When a successful response does not have a streaming body. /// - When the classifier does not care about streaming bodies. /// /// When the response can be classified immediately, `classify_response` /// returns a [`ClassifiedResponse::Ready`] which indicates whether the /// response succeeded or failed. /// /// In other cases, however, the classifier may need to wait until the /// response body stream completes before it can classify the response. /// For example, gRPC indicates RPC failures using the `grpc-status` /// trailer. In this case, `classify_response` returns a /// [`ClassifiedResponse::RequiresEos`] containing a type which will /// be used to classify the response when the body stream ends. fn classify_response( self, res: &Response, ) -> ClassifiedResponse; /// Classify an error. /// /// Errors are always errors (doh) but sometimes it might be useful to have multiple classes of /// errors. A retry policy might allow retrying some errors and not others. fn classify_error(self, error: &E) -> Self::FailureClass where E: fmt::Display + 'static; /// Transform the failure classification using a function. /// /// # Example /// /// ``` /// use tower_http::classify::{ /// ServerErrorsAsFailures, ServerErrorsFailureClass, /// ClassifyResponse, ClassifiedResponse /// }; /// use http::{Response, StatusCode}; /// use http_body::Empty; /// use bytes::Bytes; /// /// fn transform_failure_class(class: ServerErrorsFailureClass) -> NewFailureClass { /// match class { /// // Convert status codes into u16 /// ServerErrorsFailureClass::StatusCode(status) => { /// NewFailureClass::Status(status.as_u16()) /// } /// // Don't change errors. /// ServerErrorsFailureClass::Error(error) => { /// NewFailureClass::Error(error) /// } /// } /// } /// /// enum NewFailureClass { /// Status(u16), /// Error(String), /// } /// /// // Create a classifier who's failure class will be transformed by `transform_failure_class` /// let classifier = ServerErrorsAsFailures::new().map_failure_class(transform_failure_class); /// /// let response = Response::builder() /// .status(StatusCode::INTERNAL_SERVER_ERROR) /// .body(Empty::::new()) /// .unwrap(); /// /// let classification = classifier.classify_response(&response); /// /// assert!(matches!( /// classification, /// ClassifiedResponse::Ready(Err(NewFailureClass::Status(500))) /// )); /// ``` fn map_failure_class(self, f: F) -> MapFailureClass where Self: Sized, F: FnOnce(Self::FailureClass) -> NewClass, { MapFailureClass::new(self, f) } } /// Trait for classifying end of streams (EOS) as either success or failure. pub trait ClassifyEos { /// The type of failure classifications. type FailureClass; /// Perform the classification from response trailers. fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass>; /// Classify an error. /// /// Errors are always errors (doh) but sometimes it might be useful to have multiple classes of /// errors. A retry policy might allow retrying some errors and not others. fn classify_error(self, error: &E) -> Self::FailureClass where E: fmt::Display + 'static; /// Transform the failure classification using a function. /// /// See [`ClassifyResponse::map_failure_class`] for more details. fn map_failure_class(self, f: F) -> MapFailureClass where Self: Sized, F: FnOnce(Self::FailureClass) -> NewClass, { MapFailureClass::new(self, f) } } /// Result of doing a classification. #[derive(Debug)] pub enum ClassifiedResponse { /// The response was able to be classified immediately. Ready(Result<(), FailureClass>), /// We have to wait until the end of a streaming response to classify it. RequiresEos(ClassifyEos), } /// A [`ClassifyEos`] type that can be used in [`ClassifyResponse`] implementations that never have /// to classify streaming responses. /// /// `NeverClassifyEos` exists only as type. `NeverClassifyEos` values cannot be constructed. pub struct NeverClassifyEos { _output_ty: PhantomData T>, _never: Infallible, } impl ClassifyEos for NeverClassifyEos { type FailureClass = T; fn classify_eos(self, _trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> { // `NeverClassifyEos` contains an `Infallible` so it can never be constructed unreachable!() } fn classify_error(self, _error: &E) -> Self::FailureClass where E: fmt::Display + 'static, { // `NeverClassifyEos` contains an `Infallible` so it can never be constructed unreachable!() } } impl fmt::Debug for NeverClassifyEos { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("NeverClassifyEos").finish() } } /// The default classifier used for normal HTTP responses. /// /// Responses with a `5xx` status code are considered failures, all others are considered /// successes. #[derive(Clone, Debug, Default)] pub struct ServerErrorsAsFailures { _priv: (), } impl ServerErrorsAsFailures { /// Create a new [`ServerErrorsAsFailures`]. pub fn new() -> Self { Self::default() } /// Returns a [`MakeClassifier`] that produces `ServerErrorsAsFailures`. /// /// This is a convenience function that simply calls `SharedClassifier::new`. pub fn make_classifier() -> SharedClassifier { SharedClassifier::new(Self::new()) } } impl ClassifyResponse for ServerErrorsAsFailures { type FailureClass = ServerErrorsFailureClass; type ClassifyEos = NeverClassifyEos; fn classify_response( self, res: &Response, ) -> ClassifiedResponse { if res.status().is_server_error() { ClassifiedResponse::Ready(Err(ServerErrorsFailureClass::StatusCode(res.status()))) } else { ClassifiedResponse::Ready(Ok(())) } } fn classify_error(self, error: &E) -> Self::FailureClass where E: fmt::Display + 'static, { ServerErrorsFailureClass::Error(error.to_string()) } } /// The failure class for [`ServerErrorsAsFailures`]. #[derive(Debug)] pub enum ServerErrorsFailureClass { /// A response was classified as a failure with the corresponding status. StatusCode(StatusCode), /// A response was classified as an error with the corresponding error description. Error(String), } impl fmt::Display for ServerErrorsFailureClass { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::StatusCode(code) => write!(f, "Status code: {}", code), Self::Error(error) => write!(f, "Error: {}", error), } } } // Just verify that we can actually use this response classifier to determine retries as well #[cfg(test)] mod usable_for_retries { #[allow(unused_imports)] use super::*; use hyper::{Request, Response}; use tower::retry::Policy; trait IsRetryable { fn is_retryable(&self) -> bool; } #[derive(Clone)] struct RetryBasedOnClassification { classifier: C, // ... } impl Policy, Response, E> for RetryBasedOnClassification where C: ClassifyResponse + Clone, E: fmt::Display + 'static, C::FailureClass: IsRetryable, ResB: http_body::Body, Request: Clone, E: std::error::Error + 'static, { type Future = futures::future::Ready>; fn retry( &self, _req: &Request, res: Result<&Response, &E>, ) -> Option { match res { Ok(res) => { if let ClassifiedResponse::Ready(class) = self.classifier.clone().classify_response(res) { if class.err()?.is_retryable() { return Some(futures::future::ready(self.clone())); } } None } Err(err) => self .classifier .clone() .classify_error(err) .is_retryable() .then(|| futures::future::ready(self.clone())), } } fn clone_request(&self, req: &Request) -> Option> { Some(req.clone()) } } } tower-http-0.4.4/src/classify/status_in_range_is_error.rs000064400000000000000000000114461046102023000220170ustar 00000000000000use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier}; use http::StatusCode; use std::{fmt, ops::RangeInclusive}; /// Response classifier that considers responses with a status code within some range to be /// failures. /// /// # Example /// /// A client with tracing where server errors _and_ client errors are considered failures. /// /// ```no_run /// use tower_http::{trace::TraceLayer, classify::StatusInRangeAsFailures}; /// use tower::{ServiceBuilder, Service, ServiceExt}; /// use hyper::{Client, Body}; /// use http::{Request, Method}; /// /// # async fn foo() -> Result<(), tower::BoxError> { /// let classifier = StatusInRangeAsFailures::new(400..=599); /// /// let mut client = ServiceBuilder::new() /// .layer(TraceLayer::new(classifier.into_make_classifier())) /// .service(Client::new()); /// /// let request = Request::builder() /// .method(Method::GET) /// .uri("https://example.com") /// .body(Body::empty()) /// .unwrap(); /// /// let response = client.ready().await?.call(request).await?; /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] pub struct StatusInRangeAsFailures { range: RangeInclusive, } impl StatusInRangeAsFailures { /// Creates a new `StatusInRangeAsFailures`. /// /// # Panics /// /// Panics if the start or end of `range` aren't valid status codes as determined by /// [`StatusCode::from_u16`]. /// /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16 pub fn new(range: RangeInclusive) -> Self { assert!( StatusCode::from_u16(*range.start()).is_ok(), "range start isn't a valid status code" ); assert!( StatusCode::from_u16(*range.end()).is_ok(), "range end isn't a valid status code" ); Self { range } } /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as /// failures. /// /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`. pub fn new_for_client_and_server_errors() -> Self { Self::new(400..=599) } /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`]. /// /// [`MakeClassifier`]: super::MakeClassifier pub fn into_make_classifier(self) -> SharedClassifier { SharedClassifier::new(self) } } impl ClassifyResponse for StatusInRangeAsFailures { type FailureClass = StatusInRangeFailureClass; type ClassifyEos = NeverClassifyEos; fn classify_response( self, res: &http::Response, ) -> ClassifiedResponse { if self.range.contains(&res.status().as_u16()) { let class = StatusInRangeFailureClass::StatusCode(res.status()); ClassifiedResponse::Ready(Err(class)) } else { ClassifiedResponse::Ready(Ok(())) } } fn classify_error(self, error: &E) -> Self::FailureClass where E: std::fmt::Display + 'static, { StatusInRangeFailureClass::Error(error.to_string()) } } /// The failure class for [`StatusInRangeAsFailures`]. #[derive(Debug)] pub enum StatusInRangeFailureClass { /// A response was classified as a failure with the corresponding status. StatusCode(StatusCode), /// A response was classified as an error with the corresponding error description. Error(String), } impl fmt::Display for StatusInRangeFailureClass { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::StatusCode(code) => write!(f, "Status code: {}", code), Self::Error(error) => write!(f, "Error: {}", error), } } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::Response; #[test] fn basic() { let classifier = StatusInRangeAsFailures::new(400..=599); assert!(matches!( dbg!(classifier .clone() .classify_response(&response_with_status(200))), ClassifiedResponse::Ready(Ok(())), )); assert!(matches!( dbg!(classifier .clone() .classify_response(&response_with_status(400))), ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( StatusCode::BAD_REQUEST ))), )); assert!(matches!( dbg!(classifier.classify_response(&response_with_status(500))), ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( StatusCode::INTERNAL_SERVER_ERROR ))), )); } fn response_with_status(status: u16) -> Response<()> { Response::builder().status(status).body(()).unwrap() } } tower-http-0.4.4/src/compression/body.rs000064400000000000000000000263051046102023000164070ustar 00000000000000#![allow(unused_imports)] use crate::compression::CompressionLevel; use crate::{ compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody}, BoxError, }; #[cfg(feature = "compression-br")] use async_compression::tokio::bufread::BrotliEncoder; #[cfg(feature = "compression-gzip")] use async_compression::tokio::bufread::GzipEncoder; #[cfg(feature = "compression-deflate")] use async_compression::tokio::bufread::ZlibEncoder; #[cfg(feature = "compression-zstd")] use async_compression::tokio::bufread::ZstdEncoder; use bytes::{Buf, Bytes}; use futures_util::ready; use http::HeaderMap; use http_body::Body; use pin_project_lite::pin_project; use std::{ io, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tokio_util::io::StreamReader; use super::pin_project_cfg::pin_project_cfg; pin_project! { /// Response body of [`Compression`]. /// /// [`Compression`]: super::Compression pub struct CompressionBody where B: Body, { #[pin] pub(crate) inner: BodyInner, } } impl Default for CompressionBody where B: Body + Default, { fn default() -> Self { Self { inner: BodyInner::Identity { inner: B::default(), }, } } } impl CompressionBody where B: Body, { pub(crate) fn new(inner: BodyInner) -> Self { Self { inner } } /// Get a reference to the inner body pub fn get_ref(&self) -> &B { match &self.inner { #[cfg(feature = "compression-gzip")] BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "compression-deflate")] BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "compression-br")] BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "compression-zstd")] BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), BodyInner::Identity { inner } => inner, } } /// Get a mutable reference to the inner body pub fn get_mut(&mut self) -> &mut B { match &mut self.inner { #[cfg(feature = "compression-gzip")] BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "compression-deflate")] BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "compression-br")] BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "compression-zstd")] BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), BodyInner::Identity { inner } => inner, } } /// Get a pinned mutable reference to the inner body pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { match self.project().inner.project() { #[cfg(feature = "compression-gzip")] BodyInnerProj::Gzip { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "compression-deflate")] BodyInnerProj::Deflate { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "compression-br")] BodyInnerProj::Brotli { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "compression-zstd")] BodyInnerProj::Zstd { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), BodyInnerProj::Identity { inner } => inner, } } /// Consume `self`, returning the inner body pub fn into_inner(self) -> B { match self.inner { #[cfg(feature = "compression-gzip")] BodyInner::Gzip { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "compression-deflate")] BodyInner::Deflate { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "compression-br")] BodyInner::Brotli { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "compression-zstd")] BodyInner::Zstd { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), BodyInner::Identity { inner } => inner, } } } #[cfg(feature = "compression-gzip")] type GzipBody = WrapBody>; #[cfg(feature = "compression-deflate")] type DeflateBody = WrapBody>; #[cfg(feature = "compression-br")] type BrotliBody = WrapBody>; #[cfg(feature = "compression-zstd")] type ZstdBody = WrapBody>; pin_project_cfg! { #[project = BodyInnerProj] pub(crate) enum BodyInner where B: Body, { #[cfg(feature = "compression-gzip")] Gzip { #[pin] inner: GzipBody, }, #[cfg(feature = "compression-deflate")] Deflate { #[pin] inner: DeflateBody, }, #[cfg(feature = "compression-br")] Brotli { #[pin] inner: BrotliBody, }, #[cfg(feature = "compression-zstd")] Zstd { #[pin] inner: ZstdBody, }, Identity { #[pin] inner: B, }, } } impl BodyInner { #[cfg(feature = "compression-gzip")] pub(crate) fn gzip(inner: WrapBody>) -> Self { Self::Gzip { inner } } #[cfg(feature = "compression-deflate")] pub(crate) fn deflate(inner: WrapBody>) -> Self { Self::Deflate { inner } } #[cfg(feature = "compression-br")] pub(crate) fn brotli(inner: WrapBody>) -> Self { Self::Brotli { inner } } #[cfg(feature = "compression-zstd")] pub(crate) fn zstd(inner: WrapBody>) -> Self { Self::Zstd { inner } } pub(crate) fn identity(inner: B) -> Self { Self::Identity { inner } } } impl Body for CompressionBody where B: Body, B::Error: Into, { type Data = Bytes; type Error = BoxError; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { match self.project().inner.project() { #[cfg(feature = "compression-gzip")] BodyInnerProj::Gzip { inner } => inner.poll_data(cx), #[cfg(feature = "compression-deflate")] BodyInnerProj::Deflate { inner } => inner.poll_data(cx), #[cfg(feature = "compression-br")] BodyInnerProj::Brotli { inner } => inner.poll_data(cx), #[cfg(feature = "compression-zstd")] BodyInnerProj::Zstd { inner } => inner.poll_data(cx), BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { Some(Ok(mut buf)) => { let bytes = buf.copy_to_bytes(buf.remaining()); Poll::Ready(Some(Ok(bytes))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), }, } } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { match self.project().inner.project() { #[cfg(feature = "compression-gzip")] BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), #[cfg(feature = "compression-deflate")] BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), #[cfg(feature = "compression-br")] BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), #[cfg(feature = "compression-zstd")] BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), } } } #[cfg(feature = "compression-gzip")] impl DecorateAsyncRead for GzipEncoder where B: Body, { type Input = AsyncReadBody; type Output = GzipEncoder; fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { GzipEncoder::with_quality(input, quality.into_async_compression()) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "compression-deflate")] impl DecorateAsyncRead for ZlibEncoder where B: Body, { type Input = AsyncReadBody; type Output = ZlibEncoder; fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { ZlibEncoder::with_quality(input, quality.into_async_compression()) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "compression-br")] impl DecorateAsyncRead for BrotliEncoder where B: Body, { type Input = AsyncReadBody; type Output = BrotliEncoder; fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { // The brotli crate used under the hood here has a default compression level of 11, // which is the max for brotli. This causes extremely slow compression times, so we // manually set a default of 4 here. // // This is the same default used by NGINX for on-the-fly brotli compression. let level = match quality { CompressionLevel::Default => async_compression::Level::Precise(4), other => other.into_async_compression(), }; BrotliEncoder::with_quality(input, level) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "compression-zstd")] impl DecorateAsyncRead for ZstdEncoder where B: Body, { type Input = AsyncReadBody; type Output = ZstdEncoder; fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { ZstdEncoder::with_quality(input, quality.into_async_compression()) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } tower-http-0.4.4/src/compression/future.rs000064400000000000000000000101661046102023000167620ustar 00000000000000#![allow(unused_imports)] use super::{body::BodyInner, CompressionBody}; use crate::compression::predicate::Predicate; use crate::compression::CompressionLevel; use crate::compression_utils::WrapBody; use crate::content_encoding::Encoding; use futures_util::ready; use http::{header, HeaderMap, HeaderValue, Response}; use http_body::Body; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; pin_project! { /// Response future of [`Compression`]. /// /// [`Compression`]: super::Compression #[derive(Debug)] pub struct ResponseFuture { #[pin] pub(crate) inner: F, pub(crate) encoding: Encoding, pub(crate) predicate: P, pub(crate) quality: CompressionLevel, } } impl Future for ResponseFuture where F: Future, E>>, B: Body, P: Predicate, { type Output = Result>, E>; #[allow(unreachable_code, unused_mut, unused_variables)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.as_mut().project().inner.poll(cx)?); // never recompress responses that are already compressed let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING) && self.predicate.should_compress(&res); let (mut parts, body) = res.into_parts(); let body = match (should_compress, self.encoding) { // if compression is _not_ support or the client doesn't accept it (false, _) | (_, Encoding::Identity) => { return Poll::Ready(Ok(Response::from_parts( parts, CompressionBody::new(BodyInner::identity(body)), ))) } #[cfg(feature = "compression-gzip")] (_, Encoding::Gzip) => { CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality))) } #[cfg(feature = "compression-deflate")] (_, Encoding::Deflate) => { CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality))) } #[cfg(feature = "compression-br")] (_, Encoding::Brotli) => { CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality))) } #[cfg(feature = "compression-zstd")] (_, Encoding::Zstd) => { CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality))) } #[cfg(feature = "fs")] (true, _) => { // This should never happen because the `AcceptEncoding` struct which is used to determine // `self.encoding` will only enable the different compression algorithms if the // corresponding crate feature has been enabled. This means // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the // features enabled. // // The match arm is still required though because the `fs` feature uses the // Encoding struct independently and requires no compression logic to be enabled. // This means a combination of an individual compression feature and `fs` will fail // to compile without this branch even though it will never be reached. // // To safeguard against refactors that changes this relationship or other bugs the // server will return an uncompressed response instead of panicking since that could // become a ddos attack vector. return Poll::Ready(Ok(Response::from_parts( parts, CompressionBody::new(BodyInner::identity(body)), ))); } }; parts.headers.remove(header::CONTENT_LENGTH); parts .headers .insert(header::CONTENT_ENCODING, self.encoding.into_header_value()); let res = Response::from_parts(parts, body); Poll::Ready(Ok(res)) } } tower-http-0.4.4/src/compression/layer.rs000064400000000000000000000144141046102023000165640ustar 00000000000000use super::{Compression, Predicate}; use crate::compression::predicate::DefaultPredicate; use crate::compression::CompressionLevel; use crate::compression_utils::AcceptEncoding; use tower_layer::Layer; /// Compress response bodies of the underlying service. /// /// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the /// `Content-Encoding` header to responses. /// /// See the [module docs](crate::compression) for more details. #[derive(Clone, Debug, Default)] pub struct CompressionLayer

{ accept: AcceptEncoding, predicate: P, quality: CompressionLevel, } impl Layer for CompressionLayer

where P: Predicate, { type Service = Compression; fn layer(&self, inner: S) -> Self::Service { Compression { inner, accept: self.accept, predicate: self.predicate.clone(), quality: self.quality, } } } impl CompressionLayer { /// Create a new [`CompressionLayer`] pub fn new() -> Self { Self::default() } /// Sets whether to enable the gzip encoding. #[cfg(feature = "compression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to enable the Deflate encoding. #[cfg(feature = "compression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to enable the Brotli encoding. #[cfg(feature = "compression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to enable the Zstd encoding. #[cfg(feature = "compression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Sets the compression quality. pub fn quality(mut self, quality: CompressionLevel) -> Self { self.quality = quality; self } /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables the Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables the Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables the Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } /// Replace the current compression predicate. /// /// See [`Compression::compress_when`] for more details. pub fn compress_when(self, predicate: C) -> CompressionLayer where C: Predicate, { CompressionLayer { accept: self.accept, predicate, quality: self.quality, } } } #[cfg(test)] mod tests { use super::*; use http::{header::ACCEPT_ENCODING, Request, Response}; use http_body::Body as _; use hyper::Body; use tokio::fs::File; // for Body::data use bytes::{Bytes, BytesMut}; use std::convert::Infallible; use tokio_util::io::ReaderStream; use tower::{Service, ServiceBuilder, ServiceExt}; async fn handle(_req: Request) -> Result, Infallible> { // Open the file. let file = File::open("Cargo.toml").await.expect("file missing"); // Convert the file into a `Stream`. let stream = ReaderStream::new(file); // Convert the `Stream` into a `Body`. let body = Body::wrap_stream(stream); // Create response. Ok(Response::new(body)) } #[tokio::test] async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> { let deflate_only_layer = CompressionLayer::new() .quality(CompressionLevel::Best) .no_br() .no_gzip(); let mut service = ServiceBuilder::new() // Compress responses based on the `Accept-Encoding` header. .layer(deflate_only_layer) .service_fn(handle); // Call the service with the deflate only layer let request = Request::builder() .header(ACCEPT_ENCODING, "gzip, deflate, br") .body(Body::empty())?; let response = service.ready().await?.call(request).await?; assert_eq!(response.headers()["content-encoding"], "deflate"); // Read the body let mut body = response.into_body(); let mut bytes = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk?; bytes.extend_from_slice(&chunk[..]); } let bytes: Bytes = bytes.freeze(); let deflate_bytes_len = bytes.len(); let br_only_layer = CompressionLayer::new() .quality(CompressionLevel::Best) .no_gzip() .no_deflate(); let mut service = ServiceBuilder::new() // Compress responses based on the `Accept-Encoding` header. .layer(br_only_layer) .service_fn(handle); // Call the service with the br only layer let request = Request::builder() .header(ACCEPT_ENCODING, "gzip, deflate, br") .body(Body::empty())?; let response = service.ready().await?.call(request).await?; assert_eq!(response.headers()["content-encoding"], "br"); // Read the body let mut body = response.into_body(); let mut bytes = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk?; bytes.extend_from_slice(&chunk[..]); } let bytes: Bytes = bytes.freeze(); let br_byte_length = bytes.len(); // check the corresponding algorithms are actually used // br should compresses better than deflate assert!(br_byte_length < deflate_bytes_len * 9 / 10); Ok(()) } } tower-http-0.4.4/src/compression/mod.rs000064400000000000000000000324151046102023000162300ustar 00000000000000//! Middleware that compresses response bodies. //! //! # Example //! //! Example showing how to respond with the compressed contents of a file. //! //! ```rust //! use bytes::{Bytes, BytesMut}; //! use http::{Request, Response, header::ACCEPT_ENCODING}; //! use http_body::Body as _; // for Body::data //! use hyper::Body; //! use std::convert::Infallible; //! use tokio::fs::{self, File}; //! use tokio_util::io::ReaderStream; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::{compression::CompressionLayer, BoxError}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), BoxError> { //! async fn handle(req: Request) -> Result, Infallible> { //! // Open the file. //! let file = File::open("Cargo.toml").await.expect("file missing"); //! // Convert the file into a `Stream`. //! let stream = ReaderStream::new(file); //! // Convert the `Stream` into a `Body`. //! let body = Body::wrap_stream(stream); //! // Create response. //! Ok(Response::new(body)) //! } //! //! let mut service = ServiceBuilder::new() //! // Compress responses based on the `Accept-Encoding` header. //! .layer(CompressionLayer::new()) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::builder() //! .header(ACCEPT_ENCODING, "gzip") //! .body(Body::empty())?; //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!(response.headers()["content-encoding"], "gzip"); //! //! // Read the body //! let mut body = response.into_body(); //! let mut bytes = BytesMut::new(); //! while let Some(chunk) = body.data().await { //! let chunk = chunk?; //! bytes.extend_from_slice(&chunk[..]); //! } //! let bytes: Bytes = bytes.freeze(); //! //! // The compressed body should be smaller 🤞 //! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len(); //! assert!(bytes.len() < uncompressed_len); //! # //! # Ok(()) //! # } //! ``` //! pub mod predicate; mod body; mod future; mod layer; mod pin_project_cfg; mod service; #[doc(inline)] pub use self::{ body::CompressionBody, future::ResponseFuture, layer::CompressionLayer, predicate::{DefaultPredicate, Predicate}, service::Compression, }; pub use crate::compression_utils::CompressionLevel; #[cfg(test)] mod tests { use crate::compression::predicate::SizeAbove; use super::*; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; use bytes::BytesMut; use flate2::read::GzDecoder; use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE}; use http_body::Body as _; use hyper::{Body, Error, Request, Response, Server}; use std::sync::{Arc, RwLock}; use std::{io::Read, net::SocketAddr}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::StreamReader; use tower::{make::Shared, service_fn, Service, ServiceExt}; // Compression filter allows every other request to be compressed #[derive(Clone)] struct Always; impl Predicate for Always { fn should_compress(&self, _: &http::Response) -> bool where B: http_body::Body, { true } } #[tokio::test] async fn gzip_works() { let svc = service_fn(handle); let mut svc = Compression::new(svc).compress_when(Always); // call the service let req = Request::builder() .header("accept-encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let compressed_data = data.freeze().to_vec(); // decompress the body // doing this with flate2 as that is much easier than async-compression and blocking during // tests is fine let mut decoder = GzDecoder::new(&compressed_data[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert_eq!(decompressed, "Hello, World!"); } #[tokio::test] async fn zstd_works() { let svc = service_fn(handle); let mut svc = Compression::new(svc).compress_when(Always); // call the service let req = Request::builder() .header("accept-encoding", "zstd") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let compressed_data = data.freeze().to_vec(); // decompress the body let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap(); let decompressed = String::from_utf8(decompressed).unwrap(); assert_eq!(decompressed, "Hello, World!"); } #[allow(dead_code)] async fn is_compatible_with_hyper() { let svc = service_fn(handle); let svc = Compression::new(svc); let make_service = Shared::new(svc); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let server = Server::bind(&addr).serve(make_service); server.await.unwrap(); } #[tokio::test] async fn no_recompress() { const DATA: &str = "Hello, World! I'm already compressed with br!"; let svc = service_fn(|_| async { let buf = { let mut buf = Vec::new(); let mut enc = BrotliEncoder::new(&mut buf); enc.write_all(DATA.as_bytes()).await?; enc.flush().await?; buf }; let resp = Response::builder() .header("content-encoding", "br") .body(Body::from(buf)) .unwrap(); Ok::<_, std::io::Error>(resp) }); let mut svc = Compression::new(svc); // call the service // // note: the accept-encoding doesn't match the content-encoding above, so that // we're able to see if the compression layer triggered or not let req = Request::builder() .header("accept-encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // check we didn't recompress assert_eq!( res.headers() .get("content-encoding") .and_then(|h| h.to_str().ok()) .unwrap_or_default(), "br", ); // read the compressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } // decompress the body let data = { let mut output_buf = Vec::new(); let mut decoder = BrotliDecoder::new(&mut output_buf); decoder .write_all(&data) .await .expect("couldn't brotli-decode"); decoder.flush().await.expect("couldn't flush"); output_buf }; assert_eq!(data, DATA.as_bytes()); } async fn handle(_req: Request) -> Result, Error> { Ok(Response::new(Body::from("Hello, World!"))) } #[tokio::test] async fn will_not_compress_if_filtered_out() { use predicate::Predicate; const DATA: &str = "Hello world uncompressed"; let svc_fn = service_fn(|_| async { let resp = Response::builder() // .header("content-encoding", "br") .body(Body::from(DATA.as_bytes())) .unwrap(); Ok::<_, std::io::Error>(resp) }); // Compression filter allows every other request to be compressed #[derive(Default, Clone)] struct EveryOtherResponse(Arc>); impl Predicate for EveryOtherResponse { fn should_compress(&self, _: &http::Response) -> bool where B: http_body::Body, { let mut guard = self.0.write().unwrap(); let should_compress = *guard % 2 != 0; *guard += 1; dbg!(should_compress) } } let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default()); let req = Request::builder() .header("accept-encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the uncompressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let still_uncompressed = String::from_utf8(data.to_vec()).unwrap(); assert_eq!(DATA, &still_uncompressed); // Compression filter will compress the next body let req = Request::builder() .header("accept-encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } assert!(String::from_utf8(data.to_vec()).is_err()); } #[tokio::test] async fn doesnt_compress_images() { async fn handle(_req: Request) -> Result, Error> { let mut res = Response::new(Body::from( "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), )); res.headers_mut() .insert(CONTENT_TYPE, "image/png".parse().unwrap()); Ok(res) } let svc = Compression::new(service_fn(handle)); let res = svc .oneshot( Request::builder() .header(ACCEPT_ENCODING, "gzip") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert!(res.headers().get(CONTENT_ENCODING).is_none()); } #[tokio::test] async fn does_compress_svg() { async fn handle(_req: Request) -> Result, Error> { let mut res = Response::new(Body::from( "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), )); res.headers_mut() .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap()); Ok(res) } let svc = Compression::new(service_fn(handle)); let res = svc .oneshot( Request::builder() .header(ACCEPT_ENCODING, "gzip") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(res.headers()[CONTENT_ENCODING], "gzip"); } #[tokio::test] async fn compress_with_quality() { const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!"; let level = CompressionLevel::Best; let svc = service_fn(|_| async { let resp = Response::builder() .body(Body::from(DATA.as_bytes())) .unwrap(); Ok::<_, std::io::Error>(resp) }); let mut svc = Compression::new(svc).quality(level); // call the service let req = Request::builder() .header("accept-encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let compressed_data = data.freeze().to_vec(); // build the compressed body with the same quality level let compressed_with_level = { use async_compression::tokio::bufread::BrotliEncoder; let stream = Box::pin(futures::stream::once(async move { Ok::<_, std::io::Error>(DATA.as_bytes()) })); let reader = StreamReader::new(stream); let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression()); let mut buf = Vec::new(); enc.read_to_end(&mut buf).await.unwrap(); buf }; assert_eq!( compressed_data.as_slice(), compressed_with_level.as_slice(), "Compression level is not respected" ); } } tower-http-0.4.4/src/compression/pin_project_cfg.rs000064400000000000000000000071311046102023000206010ustar 00000000000000// Full credit to @tesaguri who posted this gist under CC0 1.0 Universal licence // https://gist.github.com/tesaguri/2a1c0790a48bbda3dd7f71c26d02a793 macro_rules! pin_project_cfg { ($(#[$($attr:tt)*])* $vis:vis enum $($rest:tt)+) => { pin_project_cfg! { @outer [$(#[$($attr)*])* $vis enum] $($rest)+ } }; // Accumulate type parameters and `where` clause. (@outer [$($accum:tt)*] $tt:tt $($rest:tt)+) => { pin_project_cfg! { @outer [$($accum)* $tt] $($rest)+ } }; (@outer [$($accum:tt)*] { $($body:tt)* }) => { pin_project_cfg! { @body #[cfg(all())] [$($accum)*] {} $($body)* } }; // Process a variant with `cfg`. ( @body #[cfg(all($($pred_accum:tt)*))] $outer:tt { $($accum:tt)* } #[cfg($($pred:tt)*)] $(#[$($attr:tt)*])* $variant:ident { $($body:tt)* }, $($rest:tt)* ) => { // Create two versions of the enum with `cfg($pred)` and `cfg(not($pred))`. pin_project_cfg! { @variant_body { $($body)* } {} #[cfg(all($($pred_accum)* $($pred)*,))] $outer { $($accum)* $(#[$($attr)*])* $variant } $($rest)* } pin_project_cfg! { @body #[cfg(all($($pred_accum)* not($($pred)*),))] $outer { $($accum)* } $($rest)* } }; // Process a variant without `cfg`. ( @body #[cfg(all($($pred_accum:tt)*))] $outer:tt { $($accum:tt)* } $(#[$($attr:tt)*])* $variant:ident { $($body:tt)* }, $($rest:tt)* ) => { pin_project_cfg! { @variant_body { $($body)* } {} #[cfg(all($($pred_accum)*))] $outer { $($accum)* $(#[$($attr)*])* $variant } $($rest)* } }; // Process a variant field with `cfg`. ( @variant_body { #[cfg($($pred:tt)*)] $(#[$($attr:tt)*])* $field:ident: $ty:ty, $($rest:tt)* } { $($accum:tt)* } #[cfg(all($($pred_accum:tt)*))] $($outer:tt)* ) => { pin_project_cfg! { @variant_body {$($rest)*} { $($accum)* $(#[$($attr)*])* $field: $ty, } #[cfg(all($($pred_accum)* $($pred)*,))] $($outer)* } pin_project_cfg! { @variant_body { $($rest)* } { $($accum)* } #[cfg(all($($pred_accum)* not($($pred)*),))] $($outer)* } }; // Process a variant field without `cfg`. ( @variant_body { $(#[$($attr:tt)*])* $field:ident: $ty:ty, $($rest:tt)* } { $($accum:tt)* } $($outer:tt)* ) => { pin_project_cfg! { @variant_body {$($rest)*} { $($accum)* $(#[$($attr)*])* $field: $ty, } $($outer)* } }; ( @variant_body {} $body:tt #[cfg(all($($pred_accum:tt)*))] $outer:tt { $($accum:tt)* } $($rest:tt)* ) => { pin_project_cfg! { @body #[cfg(all($($pred_accum)*))] $outer { $($accum)* $body, } $($rest)* } }; ( @body #[$cfg:meta] [$($outer:tt)*] $body:tt ) => { #[$cfg] pin_project_lite::pin_project! { $($outer)* $body } }; } pub(crate) use pin_project_cfg; tower-http-0.4.4/src/compression/predicate.rs000064400000000000000000000157551046102023000174210ustar 00000000000000//! Predicates for disabling compression of responses. //! //! Predicates are applied with [`Compression::compress_when`] or //! [`CompressionLayer::compress_when`]. //! //! [`Compression::compress_when`]: super::Compression::compress_when //! [`CompressionLayer::compress_when`]: super::CompressionLayer::compress_when use http::{header, Extensions, HeaderMap, StatusCode, Version}; use http_body::Body; use std::{fmt, sync::Arc}; /// Predicate used to determine if a response should be compressed or not. pub trait Predicate: Clone { /// Should this response be compressed or not? fn should_compress(&self, response: &http::Response) -> bool where B: Body; /// Combine two predicates into one. /// /// The resulting predicate enables compression if both inner predicates do. fn and(self, other: Other) -> And where Self: Sized, Other: Predicate, { And { lhs: self, rhs: other, } } } impl Predicate for F where F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone, { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { let status = response.status(); let version = response.version(); let headers = response.headers(); let extensions = response.extensions(); self(status, version, headers, extensions) } } impl Predicate for Option where T: Predicate, { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { self.as_ref() .map(|inner| inner.should_compress(response)) .unwrap_or(true) } } /// Two predicates combined into one. /// /// Created with [`Predicate::and`] #[derive(Debug, Clone, Default, Copy)] pub struct And { lhs: Lhs, rhs: Rhs, } impl Predicate for And where Lhs: Predicate, Rhs: Predicate, { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { self.lhs.should_compress(response) && self.rhs.should_compress(response) } } /// The default predicate used by [`Compression`] and [`CompressionLayer`]. /// /// This will compress responses unless: /// /// - They're gRPC, which has its own protocol specific compression scheme. /// - It's an image as determined by the `content-type` starting with `image/`. /// - The response is less than 32 bytes. /// /// # Configuring the defaults /// /// `DefaultPredicate` doesn't support any configuration. Instead you can build your own predicate /// by combining types in this module: /// /// ```rust /// use tower_http::compression::predicate::{SizeAbove, NotForContentType, Predicate}; /// /// // slightly large min size than the default 32 /// let predicate = SizeAbove::new(256) /// // still don't compress gRPC /// .and(NotForContentType::GRPC) /// // still don't compress images /// .and(NotForContentType::IMAGES) /// // also don't compress JSON /// .and(NotForContentType::const_new("application/json")); /// ``` /// /// [`Compression`]: super::Compression /// [`CompressionLayer`]: super::CompressionLayer #[derive(Clone)] pub struct DefaultPredicate(And, NotForContentType>); impl DefaultPredicate { /// Create a new `DefaultPredicate`. pub fn new() -> Self { let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE) .and(NotForContentType::GRPC) .and(NotForContentType::IMAGES); Self(inner) } } impl Default for DefaultPredicate { fn default() -> Self { Self::new() } } impl Predicate for DefaultPredicate { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { self.0.should_compress(response) } } /// [`Predicate`] that will only allow compression of responses above a certain size. #[derive(Clone, Copy, Debug)] pub struct SizeAbove(u16); impl SizeAbove { pub(crate) const DEFAULT_MIN_SIZE: u16 = 32; /// Create a new `SizeAbove` predicate that will only compress responses larger than /// `min_size_bytes`. /// /// The response will be compressed if the exact size cannot be determined through either the /// `content-length` header or [`Body::size_hint`]. pub const fn new(min_size_bytes: u16) -> Self { Self(min_size_bytes) } } impl Default for SizeAbove { fn default() -> Self { Self(Self::DEFAULT_MIN_SIZE) } } impl Predicate for SizeAbove { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { let content_size = response.body().size_hint().exact().or_else(|| { response .headers() .get(header::CONTENT_LENGTH) .and_then(|h| h.to_str().ok()) .and_then(|val| val.parse().ok()) }); match content_size { Some(size) => size >= (self.0 as u64), _ => true, } } } /// Predicate that wont allow responses with a specific `content-type` to be compressed. #[derive(Clone, Debug)] pub struct NotForContentType { content_type: Str, exception: Option, } impl NotForContentType { /// Predicate that wont compress gRPC responses. pub const GRPC: Self = Self::const_new("application/grpc"); /// Predicate that wont compress images. pub const IMAGES: Self = Self { content_type: Str::Static("image/"), exception: Some(Str::Static("image/svg+xml")), }; /// Create a new `NotForContentType`. pub fn new(content_type: &str) -> Self { Self { content_type: Str::Shared(content_type.into()), exception: None, } } /// Create a new `NotForContentType` from a static string. pub const fn const_new(content_type: &'static str) -> Self { Self { content_type: Str::Static(content_type), exception: None, } } } impl Predicate for NotForContentType { fn should_compress(&self, response: &http::Response) -> bool where B: Body, { if let Some(except) = &self.exception { if content_type(response) == except.as_str() { return true; } } !content_type(response).starts_with(self.content_type.as_str()) } } #[derive(Clone)] enum Str { Static(&'static str), Shared(Arc), } impl Str { fn as_str(&self) -> &str { match self { Str::Static(s) => s, Str::Shared(s) => s, } } } impl fmt::Debug for Str { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Static(inner) => inner.fmt(f), Self::Shared(inner) => inner.fmt(f), } } } fn content_type(response: &http::Response) -> &str { response .headers() .get(header::CONTENT_TYPE) .and_then(|h| h.to_str().ok()) .unwrap_or_default() } tower-http-0.4.4/src/compression/service.rs000064400000000000000000000133231046102023000171060ustar 00000000000000use super::{CompressionBody, CompressionLayer, ResponseFuture}; use crate::compression::predicate::{DefaultPredicate, Predicate}; use crate::compression::CompressionLevel; use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding}; use http::{Request, Response}; use http_body::Body; use std::task::{Context, Poll}; use tower_service::Service; /// Compress response bodies of the underlying service. /// /// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the /// `Content-Encoding` header to responses. /// /// See the [module docs](crate::compression) for more details. #[derive(Clone, Copy)] pub struct Compression { pub(crate) inner: S, pub(crate) accept: AcceptEncoding, pub(crate) predicate: P, pub(crate) quality: CompressionLevel, } impl Compression { /// Creates a new `Compression` wrapping the `service`. pub fn new(service: S) -> Compression { Self { inner: service, accept: AcceptEncoding::default(), predicate: DefaultPredicate::default(), quality: CompressionLevel::default(), } } } impl Compression { define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `Compression` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer() -> CompressionLayer { CompressionLayer::new() } /// Sets whether to enable the gzip encoding. #[cfg(feature = "compression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to enable the Deflate encoding. #[cfg(feature = "compression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to enable the Brotli encoding. #[cfg(feature = "compression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to enable the Zstd encoding. #[cfg(feature = "compression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Sets the compression quality. pub fn quality(mut self, quality: CompressionLevel) -> Self { self.quality = quality; self } /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables the Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables the Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables the Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } /// Replace the current compression predicate. /// /// Predicates are used to determine whether a response should be compressed or not. /// /// The default predicate is [`DefaultPredicate`]. See its documentation for more /// details on which responses it wont compress. /// /// # Changing the compression predicate /// /// ``` /// use tower_http::compression::{ /// Compression, /// predicate::{Predicate, NotForContentType, DefaultPredicate}, /// }; /// use tower::util::service_fn; /// /// // Placeholder service_fn /// let service = service_fn(|_: ()| async { /// Ok::<_, std::io::Error>(http::Response::new(())) /// }); /// /// // build our custom compression predicate /// // its recommended to still include `DefaultPredicate` as part of /// // custom predicates /// let predicate = DefaultPredicate::new() /// // don't compress responses who's `content-type` starts with `application/json` /// .and(NotForContentType::new("application/json")); /// /// let service = Compression::new(service).compress_when(predicate); /// ``` /// /// See [`predicate`](super::predicate) for more utilities for building compression predicates. /// /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be /// recompressed, regardless what they predicate says. pub fn compress_when(self, predicate: C) -> Compression where C: Predicate, { Compression { inner: self.inner, accept: self.accept, predicate, quality: self.quality, } } } impl Service> for Compression where S: Service, Response = Response>, ResBody: Body, P: Predicate, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let encoding = Encoding::from_headers(req.headers(), self.accept); ResponseFuture { inner: self.inner.call(req), encoding, predicate: self.predicate.clone(), quality: self.quality, } } } tower-http-0.4.4/src/compression_utils.rs000064400000000000000000000260121046102023000166650ustar 00000000000000//! Types used by compression and decompression middleware. use crate::{content_encoding::SupportedEncodings, BoxError}; use bytes::{Bytes, BytesMut}; use futures_core::Stream; use futures_util::ready; use http::HeaderValue; use http_body::Body; use pin_project_lite::pin_project; use std::{ io, pin::Pin, task::{Context, Poll}, }; use tokio::io::AsyncRead; use tokio_util::io::{poll_read_buf, StreamReader}; #[derive(Debug, Clone, Copy)] pub(crate) struct AcceptEncoding { pub(crate) gzip: bool, pub(crate) deflate: bool, pub(crate) br: bool, pub(crate) zstd: bool, } impl AcceptEncoding { #[allow(dead_code)] pub(crate) fn to_header_value(self) -> Option { let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { (true, true, true, false) => "gzip,deflate,br", (true, true, false, false) => "gzip,deflate", (true, false, true, false) => "gzip,br", (true, false, false, false) => "gzip", (false, true, true, false) => "deflate,br", (false, true, false, false) => "deflate", (false, false, true, false) => "br", (true, true, true, true) => "zstd,gzip,deflate,br", (true, true, false, true) => "zstd,gzip,deflate", (true, false, true, true) => "zstd,gzip,br", (true, false, false, true) => "zstd,gzip", (false, true, true, true) => "zstd,deflate,br", (false, true, false, true) => "zstd,deflate", (false, false, true, true) => "zstd,br", (false, false, false, true) => "zstd", (false, false, false, false) => return None, }; Some(HeaderValue::from_static(accept)) } #[allow(dead_code)] pub(crate) fn set_gzip(&mut self, enable: bool) { self.gzip = enable; } #[allow(dead_code)] pub(crate) fn set_deflate(&mut self, enable: bool) { self.deflate = enable; } #[allow(dead_code)] pub(crate) fn set_br(&mut self, enable: bool) { self.br = enable; } #[allow(dead_code)] pub(crate) fn set_zstd(&mut self, enable: bool) { self.zstd = enable; } } impl SupportedEncodings for AcceptEncoding { #[allow(dead_code)] fn gzip(&self) -> bool { #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))] { self.gzip } #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))] { false } } #[allow(dead_code)] fn deflate(&self) -> bool { #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))] { self.deflate } #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))] { false } } #[allow(dead_code)] fn br(&self) -> bool { #[cfg(any(feature = "decompression-br", feature = "compression-br"))] { self.br } #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))] { false } } #[allow(dead_code)] fn zstd(&self) -> bool { #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))] { self.zstd } #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))] { false } } } impl Default for AcceptEncoding { fn default() -> Self { AcceptEncoding { gzip: true, deflate: true, br: true, zstd: true, } } } /// A `Body` that has been converted into an `AsyncRead`. pub(crate) type AsyncReadBody = StreamReader, ::Error>, ::Data>; /// Trait for applying some decorator to an `AsyncRead` pub(crate) trait DecorateAsyncRead { type Input: AsyncRead; type Output: AsyncRead; /// Apply the decorator fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output; /// Get a pinned mutable reference to the original input. /// /// This is necessary to implement `Body::poll_trailers`. fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>; } pin_project! { /// `Body` that has been decorated by an `AsyncRead` pub(crate) struct WrapBody { #[pin] pub(crate) read: M::Output, } } impl WrapBody { #[allow(dead_code)] pub(crate) fn new(body: B, quality: CompressionLevel) -> Self where B: Body, M: DecorateAsyncRead>, { // convert `Body` into a `Stream` let stream = BodyIntoStream::new(body); // an adapter that converts the error type into `io::Error` while storing the actual error // `StreamReader` requires the error type is `io::Error` let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream); // convert `Stream` into an `AsyncRead` let read = StreamReader::new(stream); // apply decorator to `AsyncRead` yielding another `AsyncRead` let read = M::apply(read, quality); Self { read } } } impl Body for WrapBody where B: Body, B::Error: Into, M: DecorateAsyncRead>, { type Data = Bytes; type Error = BoxError; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let mut this = self.project(); let mut buf = BytesMut::new(); let read = match ready!(poll_read_buf(this.read.as_mut(), cx, &mut buf)) { Ok(read) => read, Err(err) => { let body_error: Option = M::get_pin_mut(this.read) .get_pin_mut() .project() .error .take(); if let Some(body_error) = body_error { return Poll::Ready(Some(Err(body_error.into()))); } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { // SENTINEL_ERROR_CODE only gets used when storing an underlying body error unreachable!() } else { return Poll::Ready(Some(Err(err.into()))); } } }; if read == 0 { Poll::Ready(None) } else { Poll::Ready(Some(Ok(buf.freeze()))) } } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { let this = self.project(); let body = M::get_pin_mut(this.read) .get_pin_mut() .get_pin_mut() .get_pin_mut(); body.poll_trailers(cx).map_err(Into::into) } } pin_project! { // When https://github.com/hyperium/http-body/pull/36 is merged we can remove this pub(crate) struct BodyIntoStream { #[pin] body: B, } } #[allow(dead_code)] impl BodyIntoStream { pub(crate) fn new(body: B) -> Self { Self { body } } /// Get a reference to the inner body pub(crate) fn get_ref(&self) -> &B { &self.body } /// Get a mutable reference to the inner body pub(crate) fn get_mut(&mut self) -> &mut B { &mut self.body } /// Get a pinned mutable reference to the inner body pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { self.project().body } /// Consume `self`, returning the inner body pub(crate) fn into_inner(self) -> B { self.body } } impl Stream for BodyIntoStream where B: Body, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().body.poll_data(cx) } } pin_project! { pub(crate) struct StreamErrorIntoIoError { #[pin] inner: S, error: Option, } } impl StreamErrorIntoIoError { pub(crate) fn new(inner: S) -> Self { Self { inner, error: None } } /// Get a reference to the inner body pub(crate) fn get_ref(&self) -> &S { &self.inner } /// Get a mutable reference to the inner inner pub(crate) fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Get a pinned mutable reference to the inner inner pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { self.project().inner } /// Consume `self`, returning the inner inner pub(crate) fn into_inner(self) -> S { self.inner } } impl Stream for StreamErrorIntoIoError where S: Stream>, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); match ready!(this.inner.poll_next(cx)) { None => Poll::Ready(None), Some(Ok(value)) => Poll::Ready(Some(Ok(value))), Some(Err(err)) => { *this.error = Some(err); Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE)))) } } } } pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418; /// Level of compression data should be compressed with. #[non_exhaustive] #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum CompressionLevel { /// Fastest quality of compression, usually produces bigger size. Fastest, /// Best quality of compression, usually produces the smallest size. Best, /// Default quality of compression defined by the selected compression algorithm. Default, /// Precise quality based on the underlying compression algorithms' /// qualities. The interpretation of this depends on the algorithm chosen /// and the specific implementation backing it. /// Qualities are implicitly clamped to the algorithm's maximum. Precise(u32), } impl Default for CompressionLevel { fn default() -> Self { CompressionLevel::Default } } #[cfg(any( feature = "compression-br", feature = "compression-gzip", feature = "compression-deflate", feature = "compression-zstd" ))] use async_compression::Level as AsyncCompressionLevel; #[cfg(any( feature = "compression-br", feature = "compression-gzip", feature = "compression-deflate", feature = "compression-zstd" ))] impl CompressionLevel { pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel { use std::convert::TryInto; match self { CompressionLevel::Fastest => AsyncCompressionLevel::Fastest, CompressionLevel::Best => AsyncCompressionLevel::Best, CompressionLevel::Default => AsyncCompressionLevel::Default, CompressionLevel::Precise(quality) => { AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX)) } } } } tower-http-0.4.4/src/content_encoding.rs000064400000000000000000000476621046102023000164420ustar 00000000000000pub(crate) trait SupportedEncodings: Copy { fn gzip(&self) -> bool; fn deflate(&self) -> bool; fn br(&self) -> bool; fn zstd(&self) -> bool; } // This enum's variants are ordered from least to most preferred. #[derive(Copy, Clone, Debug, Ord, PartialOrd, PartialEq, Eq)] pub(crate) enum Encoding { #[allow(dead_code)] Identity, #[cfg(any(feature = "fs", feature = "compression-deflate"))] Deflate, #[cfg(any(feature = "fs", feature = "compression-gzip"))] Gzip, #[cfg(any(feature = "fs", feature = "compression-br"))] Brotli, #[cfg(any(feature = "fs", feature = "compression-zstd"))] Zstd, } impl Encoding { #[allow(dead_code)] fn to_str(self) -> &'static str { match self { #[cfg(any(feature = "fs", feature = "compression-gzip"))] Encoding::Gzip => "gzip", #[cfg(any(feature = "fs", feature = "compression-deflate"))] Encoding::Deflate => "deflate", #[cfg(any(feature = "fs", feature = "compression-br"))] Encoding::Brotli => "br", #[cfg(any(feature = "fs", feature = "compression-zstd"))] Encoding::Zstd => "zstd", Encoding::Identity => "identity", } } #[cfg(feature = "fs")] pub(crate) fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { match self { Encoding::Gzip => Some(std::ffi::OsStr::new(".gz")), Encoding::Deflate => Some(std::ffi::OsStr::new(".zz")), Encoding::Brotli => Some(std::ffi::OsStr::new(".br")), Encoding::Zstd => Some(std::ffi::OsStr::new(".zst")), Encoding::Identity => None, } } #[allow(dead_code)] pub(crate) fn into_header_value(self) -> http::HeaderValue { http::HeaderValue::from_static(self.to_str()) } #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-deflate", feature = "compression-zstd", feature = "fs", ))] fn parse(s: &str, _supported_encoding: impl SupportedEncodings) -> Option { #[cfg(any(feature = "fs", feature = "compression-gzip"))] if s.eq_ignore_ascii_case("gzip") && _supported_encoding.gzip() { return Some(Encoding::Gzip); } #[cfg(any(feature = "fs", feature = "compression-deflate"))] if s.eq_ignore_ascii_case("deflate") && _supported_encoding.deflate() { return Some(Encoding::Deflate); } #[cfg(any(feature = "fs", feature = "compression-br"))] if s.eq_ignore_ascii_case("br") && _supported_encoding.br() { return Some(Encoding::Brotli); } #[cfg(any(feature = "fs", feature = "compression-zstd"))] if s.eq_ignore_ascii_case("zstd") && _supported_encoding.zstd() { return Some(Encoding::Zstd); } if s.eq_ignore_ascii_case("identity") { return Some(Encoding::Identity); } None } #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-zstd", feature = "compression-deflate", ))] // based on https://github.com/http-rs/accept-encoding pub(crate) fn from_headers( headers: &http::HeaderMap, supported_encoding: impl SupportedEncodings, ) -> Self { Encoding::preferred_encoding(&encodings(headers, supported_encoding)) .unwrap_or(Encoding::Identity) } #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] pub(crate) fn preferred_encoding(accepted_encodings: &[(Encoding, QValue)]) -> Option { accepted_encodings .iter() .filter(|(_, qvalue)| qvalue.0 > 0) .max_by_key(|(encoding, qvalue)| (qvalue, encoding)) .map(|(encoding, _)| *encoding) } } // Allowed q-values are numbers between 0 and 1 with at most 3 digits in the fractional part. They // are presented here as an unsigned integer between 0 and 1000. #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub(crate) struct QValue(u16); #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] impl QValue { #[inline] fn one() -> Self { Self(1000) } // Parse a q-value as specified in RFC 7231 section 5.3.1. fn parse(s: &str) -> Option { let mut c = s.chars(); // Parse "q=" (case-insensitively). match c.next() { Some('q') | Some('Q') => (), _ => return None, }; match c.next() { Some('=') => (), _ => return None, }; // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" // are allowed. let mut value = match c.next() { Some('0') => 0, Some('1') => 1000, _ => return None, }; // Parse optional decimal point. match c.next() { Some('.') => (), None => return Some(Self(value)), _ => return None, }; // Parse optional fractional digits. The value of each digit is multiplied by `factor`. // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for // the first digit, `10` for the next, and `1` for the digit after that. let mut factor = 100; loop { match c.next() { Some(n @ '0'..='9') => { // If `factor` is less than `1`, three digits have already been parsed. A // q-value having more than 3 fractional digits is invalid. if factor < 1 { return None; } // Add the digit's value multiplied by `factor` to `value`. value += factor * (n as u16 - '0' as u16); } None => { // No more characters to parse. Check that the value representing the q-value is // in the valid range. return if value <= 1000 { Some(Self(value)) } else { None }; } _ => return None, }; factor /= 10; } } } #[cfg(any( feature = "compression-gzip", feature = "compression-br", feature = "compression-zstd", feature = "compression-deflate", feature = "fs", ))] // based on https://github.com/http-rs/accept-encoding pub(crate) fn encodings( headers: &http::HeaderMap, supported_encoding: impl SupportedEncodings, ) -> Vec<(Encoding, QValue)> { headers .get_all(http::header::ACCEPT_ENCODING) .iter() .filter_map(|hval| hval.to_str().ok()) .flat_map(|s| s.split(',')) .filter_map(|v| { let mut v = v.splitn(2, ';'); let encoding = match Encoding::parse(v.next().unwrap().trim(), supported_encoding) { Some(encoding) => encoding, None => return None, // ignore unknown encodings }; let qval = if let Some(qval) = v.next() { QValue::parse(qval.trim())? } else { QValue::one() }; Some((encoding, qval)) }) .collect::>() } #[cfg(all( test, feature = "compression-gzip", feature = "compression-deflate", feature = "compression-br", feature = "compression-zstd", ))] mod tests { use super::*; #[derive(Copy, Clone, Default)] struct SupportedEncodingsAll; impl SupportedEncodings for SupportedEncodingsAll { fn gzip(&self) -> bool { true } fn deflate(&self) -> bool { true } fn br(&self) -> bool { true } fn zstd(&self) -> bool { true } } #[test] fn no_accept_encoding_header() { let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); } #[test] fn accept_encoding_header_single_encoding() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Gzip, encoding); } #[test] fn accept_encoding_header_two_encodings() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_three_encodings() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,deflate,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_two_encodings_with_one_qvalue() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_three_encodings_with_one_qvalue() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn two_accept_encoding_headers_with_one_qvalue() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5"), ); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn two_accept_encoding_headers_three_encodings_with_one_qvalue() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate"), ); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn three_accept_encoding_headers_with_one_qvalue() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5"), ); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("deflate"), ); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_two_encodings_with_two_qvalues() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_three_encodings_with_three_qvalues() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Deflate, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_invalid_encdoing() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("invalid,gzip"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Gzip, encoding); } #[test] fn accept_encoding_header_with_qvalue_zero() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0."), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_with_uppercase_letters() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gZiP"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_with_allowed_spaces() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Brotli, encoding); } #[test] fn accept_encoding_header_with_invalid_spaces() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q =0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q= 0.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); } #[test] fn accept_encoding_header_with_invalid_quvalues() { let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=-0.1"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=00.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5000"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=.5"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.01"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); headers.append( http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.001"), ); let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); assert_eq!(Encoding::Identity, encoding); } } tower-http-0.4.4/src/cors/allow_credentials.rs000064400000000000000000000055421046102023000175520ustar 00000000000000use std::{fmt, sync::Arc}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; /// Holds configuration for how to set the [`Access-Control-Allow-Credentials`][mdn] header. /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials #[derive(Clone, Default)] #[must_use] pub struct AllowCredentials(AllowCredentialsInner); impl AllowCredentials { /// Allow credentials for all requests /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials pub fn yes() -> Self { Self(AllowCredentialsInner::Yes) } /// Allow credentials for some requests, based on a given predicate /// /// The first argument to the predicate is the request origin. /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials pub fn predicate(f: F) -> Self where F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, { Self(AllowCredentialsInner::Predicate(Arc::new(f))) } pub(super) fn is_true(&self) -> bool { matches!(&self.0, AllowCredentialsInner::Yes) } pub(super) fn to_header( &self, origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { #[allow(clippy::declare_interior_mutable_const)] const TRUE: HeaderValue = HeaderValue::from_static("true"); let allow_creds = match &self.0 { AllowCredentialsInner::Yes => true, AllowCredentialsInner::No => false, AllowCredentialsInner::Predicate(c) => c(origin?, parts), }; allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) } } impl From for AllowCredentials { fn from(v: bool) -> Self { match v { true => Self(AllowCredentialsInner::Yes), false => Self(AllowCredentialsInner::No), } } } impl fmt::Debug for AllowCredentials { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { AllowCredentialsInner::Yes => f.debug_tuple("Yes").finish(), AllowCredentialsInner::No => f.debug_tuple("No").finish(), AllowCredentialsInner::Predicate(_) => f.debug_tuple("Predicate").finish(), } } } #[derive(Clone)] enum AllowCredentialsInner { Yes, No, Predicate( Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, ), } impl Default for AllowCredentialsInner { fn default() -> Self { Self::No } } tower-http-0.4.4/src/cors/allow_headers.rs000064400000000000000000000064131046102023000166660ustar 00000000000000use std::{array, fmt}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; use super::{separated_by_commas, Any, WILDCARD}; /// Holds configuration for how to set the [`Access-Control-Allow-Headers`][mdn] header. /// /// See [`CorsLayer::allow_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers #[derive(Clone, Default)] #[must_use] pub struct AllowHeaders(AllowHeadersInner); impl AllowHeaders { /// Allow any headers by sending a wildcard (`*`) /// /// See [`CorsLayer::allow_headers`] for more details. /// /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers pub fn any() -> Self { Self(AllowHeadersInner::Const(Some(WILDCARD))) } /// Set multiple allowed headers /// /// See [`CorsLayer::allow_headers`] for more details. /// /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers pub fn list(headers: I) -> Self where I: IntoIterator, { Self(AllowHeadersInner::Const(separated_by_commas( headers.into_iter().map(Into::into), ))) } /// Allow any headers, by mirroring the preflight [`Access-Control-Request-Headers`][mdn] /// header. /// /// See [`CorsLayer::allow_headers`] for more details. /// /// [`CorsLayer::allow_headers`]: super::CorsLayer::allow_headers /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers pub fn mirror_request() -> Self { Self(AllowHeadersInner::MirrorRequest) } #[allow(clippy::borrow_interior_mutable_const)] pub(super) fn is_wildcard(&self) -> bool { matches!(&self.0, AllowHeadersInner::Const(Some(v)) if v == WILDCARD) } pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { let allow_headers = match &self.0 { AllowHeadersInner::Const(v) => v.clone()?, AllowHeadersInner::MirrorRequest => parts .headers .get(header::ACCESS_CONTROL_REQUEST_HEADERS)? .clone(), }; Some((header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers)) } } impl fmt::Debug for AllowHeaders { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { AllowHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), AllowHeadersInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), } } } impl From for AllowHeaders { fn from(_: Any) -> Self { Self::any() } } impl From<[HeaderName; N]> for AllowHeaders { fn from(arr: [HeaderName; N]) -> Self { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 Self::list(array::IntoIter::new(arr)) } } impl From> for AllowHeaders { fn from(vec: Vec) -> Self { Self::list(vec) } } #[derive(Clone)] enum AllowHeadersInner { Const(Option), MirrorRequest, } impl Default for AllowHeadersInner { fn default() -> Self { Self::Const(None) } } tower-http-0.4.4/src/cors/allow_methods.rs000064400000000000000000000074201046102023000167150ustar 00000000000000use std::{array, fmt}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, Method, }; use super::{separated_by_commas, Any, WILDCARD}; /// Holds configuration for how to set the [`Access-Control-Allow-Methods`][mdn] header. /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods #[derive(Clone, Default)] #[must_use] pub struct AllowMethods(AllowMethodsInner); impl AllowMethods { /// Allow any method by sending a wildcard (`*`) /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods pub fn any() -> Self { Self(AllowMethodsInner::Const(Some(WILDCARD))) } /// Set a single allowed method /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods pub fn exact(method: Method) -> Self { Self(AllowMethodsInner::Const(Some( HeaderValue::from_str(method.as_str()).unwrap(), ))) } /// Set multiple allowed methods /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods pub fn list(methods: I) -> Self where I: IntoIterator, { Self(AllowMethodsInner::Const(separated_by_commas( methods .into_iter() .map(|m| HeaderValue::from_str(m.as_str()).unwrap()), ))) } /// Allow any method, by mirroring the preflight [`Access-Control-Request-Method`][mdn] /// header. /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [`CorsLayer::allow_methods`]: super::CorsLayer::allow_methods /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Method pub fn mirror_request() -> Self { Self(AllowMethodsInner::MirrorRequest) } #[allow(clippy::borrow_interior_mutable_const)] pub(super) fn is_wildcard(&self) -> bool { matches!(&self.0, AllowMethodsInner::Const(Some(v)) if v == WILDCARD) } pub(super) fn to_header(&self, parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { let allow_methods = match &self.0 { AllowMethodsInner::Const(v) => v.clone()?, AllowMethodsInner::MirrorRequest => parts .headers .get(header::ACCESS_CONTROL_REQUEST_METHOD)? .clone(), }; Some((header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods)) } } impl fmt::Debug for AllowMethods { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { AllowMethodsInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), AllowMethodsInner::MirrorRequest => f.debug_tuple("MirrorRequest").finish(), } } } impl From for AllowMethods { fn from(_: Any) -> Self { Self::any() } } impl From for AllowMethods { fn from(method: Method) -> Self { Self::exact(method) } } impl From<[Method; N]> for AllowMethods { fn from(arr: [Method; N]) -> Self { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 Self::list(array::IntoIter::new(arr)) } } impl From> for AllowMethods { fn from(vec: Vec) -> Self { Self::list(vec) } } #[derive(Clone)] enum AllowMethodsInner { Const(Option), MirrorRequest, } impl Default for AllowMethodsInner { fn default() -> Self { Self::Const(None) } } tower-http-0.4.4/src/cors/allow_origin.rs000064400000000000000000000106701046102023000165420ustar 00000000000000use std::{array, fmt, sync::Arc}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; use super::{Any, WILDCARD}; /// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header. /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin #[derive(Clone, Default)] #[must_use] pub struct AllowOrigin(OriginInner); impl AllowOrigin { /// Allow any origin by sending a wildcard (`*`) /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin pub fn any() -> Self { Self(OriginInner::Const(WILDCARD)) } /// Set a single allowed origin /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin pub fn exact(origin: HeaderValue) -> Self { Self(OriginInner::Const(origin)) } /// Set multiple allowed origins /// /// See [`CorsLayer::allow_origin`] for more details. /// /// # Panics /// /// If the iterator contains a wildcard (`*`). /// /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin #[allow(clippy::borrow_interior_mutable_const)] pub fn list(origins: I) -> Self where I: IntoIterator, { let origins = origins.into_iter().collect::>(); if origins.iter().any(|o| o == WILDCARD) { panic!("Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. Use `AllowOrigin::any()` instead"); } else { Self(OriginInner::List(origins)) } } /// Set the allowed origins from a predicate /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin pub fn predicate(f: F) -> Self where F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, { Self(OriginInner::Predicate(Arc::new(f))) } /// Allow any origin, by mirroring the request origin /// /// This is equivalent to /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate]. /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin pub fn mirror_request() -> Self { Self::predicate(|_, _| true) } #[allow(clippy::borrow_interior_mutable_const)] pub(super) fn is_wildcard(&self) -> bool { matches!(&self.0, OriginInner::Const(v) if v == WILDCARD) } pub(super) fn to_header( &self, origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { let allow_origin = match &self.0 { OriginInner::Const(v) => v.clone(), OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(), OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(), }; Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin)) } } impl fmt::Debug for AllowOrigin { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(), OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(), } } } impl From for AllowOrigin { fn from(_: Any) -> Self { Self::any() } } impl From for AllowOrigin { fn from(val: HeaderValue) -> Self { Self::exact(val) } } impl From<[HeaderValue; N]> for AllowOrigin { fn from(arr: [HeaderValue; N]) -> Self { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 Self::list(array::IntoIter::new(arr)) } } impl From> for AllowOrigin { fn from(vec: Vec) -> Self { Self::list(vec) } } #[derive(Clone)] enum OriginInner { Const(HeaderValue), List(Vec), Predicate( Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, ), } impl Default for OriginInner { fn default() -> Self { Self::List(Vec::new()) } } tower-http-0.4.4/src/cors/allow_private_network.rs000064400000000000000000000151411046102023000204740ustar 00000000000000use std::{fmt, sync::Arc}; use http::{ header::{HeaderName, HeaderValue}, request::Parts as RequestParts, }; /// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header. /// /// See [`CorsLayer::allow_private_network`] for more details. /// /// [wicg]: https://wicg.github.io/private-network-access/ /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network #[derive(Clone, Default)] #[must_use] pub struct AllowPrivateNetwork(AllowPrivateNetworkInner); impl AllowPrivateNetwork { /// Allow requests via a more private network than the one used to access the origin /// /// See [`CorsLayer::allow_private_network`] for more details. /// /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network pub fn yes() -> Self { Self(AllowPrivateNetworkInner::Yes) } /// Allow requests via private network for some requests, based on a given predicate /// /// The first argument to the predicate is the request origin. /// /// See [`CorsLayer::allow_private_network`] for more details. /// /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network pub fn predicate(f: F) -> Self where F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, { Self(AllowPrivateNetworkInner::Predicate(Arc::new(f))) } pub(super) fn to_header( &self, origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { #[allow(clippy::declare_interior_mutable_const)] const REQUEST_PRIVATE_NETWORK: HeaderName = HeaderName::from_static("access-control-request-private-network"); #[allow(clippy::declare_interior_mutable_const)] const ALLOW_PRIVATE_NETWORK: HeaderName = HeaderName::from_static("access-control-allow-private-network"); const TRUE: HeaderValue = HeaderValue::from_static("true"); // Cheapest fallback: allow_private_network hasn't been set if let AllowPrivateNetworkInner::No = &self.0 { return None; } // Access-Control-Allow-Private-Network is only relevant if the request // has the Access-Control-Request-Private-Network header set, else skip if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) { return None; } let allow_private_network = match &self.0 { AllowPrivateNetworkInner::Yes => true, AllowPrivateNetworkInner::No => false, // unreachable, but not harmful AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts), }; allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE)) } } impl From for AllowPrivateNetwork { fn from(v: bool) -> Self { match v { true => Self(AllowPrivateNetworkInner::Yes), false => Self(AllowPrivateNetworkInner::No), } } } impl fmt::Debug for AllowPrivateNetwork { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0 { AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(), AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(), AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(), } } } #[derive(Clone)] enum AllowPrivateNetworkInner { Yes, No, Predicate( Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, ), } impl Default for AllowPrivateNetworkInner { fn default() -> Self { Self::No } } #[cfg(test)] mod tests { use super::AllowPrivateNetwork; use crate::cors::CorsLayer; use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response}; use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; use tower_service::Service; const REQUEST_PRIVATE_NETWORK: HeaderName = HeaderName::from_static("access-control-request-private-network"); const ALLOW_PRIVATE_NETWORK: HeaderName = HeaderName::from_static("access-control-allow-private-network"); const TRUE: HeaderValue = HeaderValue::from_static("true"); #[tokio::test] async fn cors_private_network_header_is_added_correctly() { let mut service = ServiceBuilder::new() .layer(CorsLayer::new().allow_private_network(true)) .service_fn(echo); let req = Request::builder() .header(REQUEST_PRIVATE_NETWORK, TRUE) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); let req = Request::builder().body(Body::empty()).unwrap(); let res = service.ready().await.unwrap().call(req).await.unwrap(); assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); } #[tokio::test] async fn cors_private_network_header_is_added_correctly_with_predicate() { let allow_private_network = AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| { parts.uri.path() == "/allow-private" && origin == "localhost" }); let mut service = ServiceBuilder::new() .layer(CorsLayer::new().allow_private_network(allow_private_network)) .service_fn(echo); let req = Request::builder() .header(ORIGIN, "localhost") .header(REQUEST_PRIVATE_NETWORK, TRUE) .uri("/allow-private") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); let req = Request::builder() .header(ORIGIN, "localhost") .header(REQUEST_PRIVATE_NETWORK, TRUE) .uri("/other") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(req).await.unwrap(); assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); let req = Request::builder() .header(ORIGIN, "not-localhost") .header(REQUEST_PRIVATE_NETWORK, TRUE) .uri("/allow-private") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(req).await.unwrap(); assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } } tower-http-0.4.4/src/cors/expose_headers.rs000064400000000000000000000051441046102023000170530ustar 00000000000000use std::{array, fmt}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; use super::{separated_by_commas, Any, WILDCARD}; /// Holds configuration for how to set the [`Access-Control-Expose-Headers`][mdn] header. /// /// See [`CorsLayer::expose_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers #[derive(Clone, Default)] #[must_use] pub struct ExposeHeaders(ExposeHeadersInner); impl ExposeHeaders { /// Expose any / all headers by sending a wildcard (`*`) /// /// See [`CorsLayer::expose_headers`] for more details. /// /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers pub fn any() -> Self { Self(ExposeHeadersInner::Const(Some(WILDCARD))) } /// Set multiple exposed header names /// /// See [`CorsLayer::expose_headers`] for more details. /// /// [`CorsLayer::expose_headers`]: super::CorsLayer::expose_headers pub fn list(headers: I) -> Self where I: IntoIterator, { Self(ExposeHeadersInner::Const(separated_by_commas( headers.into_iter().map(Into::into), ))) } #[allow(clippy::borrow_interior_mutable_const)] pub(super) fn is_wildcard(&self) -> bool { matches!(&self.0, ExposeHeadersInner::Const(Some(v)) if v == WILDCARD) } pub(super) fn to_header(&self, _parts: &RequestParts) -> Option<(HeaderName, HeaderValue)> { let expose_headers = match &self.0 { ExposeHeadersInner::Const(v) => v.clone()?, }; Some((header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers)) } } impl fmt::Debug for ExposeHeaders { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { ExposeHeadersInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(), } } } impl From for ExposeHeaders { fn from(_: Any) -> Self { Self::any() } } impl From<[HeaderName; N]> for ExposeHeaders { fn from(arr: [HeaderName; N]) -> Self { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 Self::list(array::IntoIter::new(arr)) } } impl From> for ExposeHeaders { fn from(vec: Vec) -> Self { Self::list(vec) } } #[derive(Clone)] enum ExposeHeadersInner { Const(Option), } impl Default for ExposeHeadersInner { fn default() -> Self { ExposeHeadersInner::Const(None) } } tower-http-0.4.4/src/cors/max_age.rs000064400000000000000000000040751046102023000154600ustar 00000000000000use std::{fmt, sync::Arc, time::Duration}; use http::{ header::{self, HeaderName, HeaderValue}, request::Parts as RequestParts, }; /// Holds configuration for how to set the [`Access-Control-Max-Age`][mdn] header. /// /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age #[derive(Clone, Default)] #[must_use] pub struct MaxAge(MaxAgeInner); impl MaxAge { /// Set a static max-age value /// /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. pub fn exact(max_age: Duration) -> Self { Self(MaxAgeInner::Exact(Some(max_age.as_secs().into()))) } /// Set the max-age based on the preflight request parts /// /// See [`CorsLayer::max_age`][super::CorsLayer::max_age] for more details. pub fn dynamic(f: F) -> Self where F: Fn(&HeaderValue, &RequestParts) -> Duration + Send + Sync + 'static, { Self(MaxAgeInner::Fn(Arc::new(f))) } pub(super) fn to_header( &self, origin: Option<&HeaderValue>, parts: &RequestParts, ) -> Option<(HeaderName, HeaderValue)> { let max_age = match &self.0 { MaxAgeInner::Exact(v) => v.clone()?, MaxAgeInner::Fn(c) => c(origin?, parts).as_secs().into(), }; Some((header::ACCESS_CONTROL_MAX_AGE, max_age)) } } impl fmt::Debug for MaxAge { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0 { MaxAgeInner::Exact(inner) => f.debug_tuple("Exact").field(inner).finish(), MaxAgeInner::Fn(_) => f.debug_tuple("Fn").finish(), } } } impl From for MaxAge { fn from(max_age: Duration) -> Self { Self::exact(max_age) } } #[derive(Clone)] enum MaxAgeInner { Exact(Option), Fn(Arc Fn(&'a HeaderValue, &'a RequestParts) -> Duration + Send + Sync + 'static>), } impl Default for MaxAgeInner { fn default() -> Self { Self::Exact(None) } } tower-http-0.4.4/src/cors/mod.rs000064400000000000000000000547311046102023000146420ustar 00000000000000//! Middleware which adds headers for [CORS][mdn]. //! //! # Example //! //! ``` //! use http::{Request, Response, Method, header}; //! use hyper::Body; //! use tower::{ServiceBuilder, ServiceExt, Service}; //! use tower_http::cors::{Any, CorsLayer}; //! use std::convert::Infallible; //! //! async fn handle(request: Request) -> Result, Infallible> { //! Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let cors = CorsLayer::new() //! // allow `GET` and `POST` when accessing the resource //! .allow_methods([Method::GET, Method::POST]) //! // allow requests from any origin //! .allow_origin(Any); //! //! let mut service = ServiceBuilder::new() //! .layer(cors) //! .service_fn(handle); //! //! let request = Request::builder() //! .header(header::ORIGIN, "https://example.com") //! .body(Body::empty()) //! .unwrap(); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!( //! response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(), //! "*", //! ); //! # Ok(()) //! # } //! ``` //! //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #![allow(clippy::enum_variant_names)] use bytes::{BufMut, BytesMut}; use futures_core::ready; use http::{ header::{self, HeaderName}, HeaderMap, HeaderValue, Method, Request, Response, }; use pin_project_lite::pin_project; use std::{ array, future::Future, mem, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; mod allow_credentials; mod allow_headers; mod allow_methods; mod allow_origin; mod allow_private_network; mod expose_headers; mod max_age; mod vary; pub use self::{ allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, }; /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. /// /// See the [module docs](crate::cors) for an example. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] #[must_use] pub struct CorsLayer { allow_credentials: AllowCredentials, allow_headers: AllowHeaders, allow_methods: AllowMethods, allow_origin: AllowOrigin, allow_private_network: AllowPrivateNetwork, expose_headers: ExposeHeaders, max_age: MaxAge, vary: Vary, } #[allow(clippy::declare_interior_mutable_const)] const WILDCARD: HeaderValue = HeaderValue::from_static("*"); impl CorsLayer { /// Create a new `CorsLayer`. /// /// No headers are sent by default. Use the builder methods to customize /// the behavior. /// /// You need to set at least an allowed origin for browsers to make /// successful cross-origin requests to your service. pub fn new() -> Self { Self { allow_credentials: Default::default(), allow_headers: Default::default(), allow_methods: Default::default(), allow_origin: Default::default(), allow_private_network: Default::default(), expose_headers: Default::default(), max_age: Default::default(), vary: Default::default(), } } /// A permissive configuration: /// /// - All request headers allowed. /// - All methods allowed. /// - All origins allowed. /// - All headers exposed. pub fn permissive() -> Self { Self::new() .allow_headers(Any) .allow_methods(Any) .allow_origin(Any) .expose_headers(Any) } /// A very permissive configuration: /// /// - **Credentials allowed.** /// - The method received in `Access-Control-Request-Method` is sent back /// as an allowed method. /// - The origin of the preflight request is sent back as an allowed origin. /// - The header names received in `Access-Control-Request-Headers` are sent /// back as allowed headers. /// - No headers are currently exposed, but this may change in the future. pub fn very_permissive() -> Self { Self::new() .allow_credentials(true) .allow_headers(AllowHeaders::mirror_request()) .allow_methods(AllowMethods::mirror_request()) .allow_origin(AllowOrigin::mirror_request()) } /// Set the [`Access-Control-Allow-Credentials`][mdn] header. /// /// ``` /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().allow_credentials(true); /// ``` /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials pub fn allow_credentials(mut self, allow_credentials: T) -> Self where T: Into, { self.allow_credentials = allow_credentials.into(); self } /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header. /// /// ``` /// use tower_http::cors::CorsLayer; /// use http::header::{AUTHORIZATION, ACCEPT}; /// /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]); /// ``` /// /// All headers can be allowed with /// /// ``` /// use tower_http::cors::{Any, CorsLayer}; /// /// let layer = CorsLayer::new().allow_headers(Any); /// ``` /// /// Note that multiple calls to this method will override any previous /// calls. /// /// Also note that `Access-Control-Allow-Headers` is required for requests that have /// `Access-Control-Request-Headers`. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers pub fn allow_headers(mut self, headers: T) -> Self where T: Into, { self.allow_headers = headers.into(); self } /// Set the value of the [`Access-Control-Max-Age`][mdn] header. /// /// ``` /// use std::time::Duration; /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10); /// ``` /// /// By default the header will not be set which disables caching and will /// require a preflight call for all requests. /// /// Note that each browser has a maximum internal value that takes /// precedence when the Access-Control-Max-Age is greater. For more details /// see [mdn]. /// /// If you need more flexibility, you can use supply a function which can /// dynamically decide the max-age based on the origin and other parts of /// each preflight request: /// /// ``` /// # struct MyServerConfig { cors_max_age: Duration } /// use std::time::Duration; /// /// use http::{request::Parts as RequestParts, HeaderValue}; /// use tower_http::cors::{CorsLayer, MaxAge}; /// /// let layer = CorsLayer::new().max_age(MaxAge::dynamic( /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration { /// // Let's say you want to be able to reload your config at /// // runtime and have another middleware that always inserts /// // the current config into the request extensions /// let config = parts.extensions.get::().unwrap(); /// config.cors_max_age /// }, /// )); /// ``` /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age pub fn max_age(mut self, max_age: T) -> Self where T: Into, { self.max_age = max_age.into(); self } /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header. /// /// ``` /// use tower_http::cors::CorsLayer; /// use http::Method; /// /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]); /// ``` /// /// All methods can be allowed with /// /// ``` /// use tower_http::cors::{Any, CorsLayer}; /// /// let layer = CorsLayer::new().allow_methods(Any); /// ``` /// /// Note that multiple calls to this method will override any previous /// calls. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(mut self, methods: T) -> Self where T: Into, { self.allow_methods = methods.into(); self } /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. /// /// ``` /// use http::HeaderValue; /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().allow_origin( /// "http://example.com".parse::().unwrap(), /// ); /// ``` /// /// Multiple origins can be allowed with /// /// ``` /// use tower_http::cors::CorsLayer; /// /// let origins = [ /// "http://example.com".parse().unwrap(), /// "http://api.example.com".parse().unwrap(), /// ]; /// /// let layer = CorsLayer::new().allow_origin(origins); /// ``` /// /// All origins can be allowed with /// /// ``` /// use tower_http::cors::{Any, CorsLayer}; /// /// let layer = CorsLayer::new().allow_origin(Any); /// ``` /// /// You can also use a closure /// /// ``` /// use tower_http::cors::{CorsLayer, AllowOrigin}; /// use http::{request::Parts as RequestParts, HeaderValue}; /// /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate( /// |origin: &HeaderValue, _request_parts: &RequestParts| { /// origin.as_bytes().ends_with(b".rust-lang.org") /// }, /// )); /// ``` /// /// Note that multiple calls to this method will override any previous /// calls. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(mut self, origin: T) -> Self where T: Into, { self.allow_origin = origin.into(); self } /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header. /// /// ``` /// use tower_http::cors::CorsLayer; /// use http::header::CONTENT_ENCODING; /// /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]); /// ``` /// /// All headers can be allowed with /// /// ``` /// use tower_http::cors::{Any, CorsLayer}; /// /// let layer = CorsLayer::new().expose_headers(Any); /// ``` /// /// Note that multiple calls to this method will override any previous /// calls. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers pub fn expose_headers(mut self, headers: T) -> Self where T: Into, { self.expose_headers = headers.into(); self } /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. /// /// ``` /// use tower_http::cors::CorsLayer; /// /// let layer = CorsLayer::new().allow_private_network(true); /// ``` /// /// [wicg]: https://wicg.github.io/private-network-access/ pub fn allow_private_network(mut self, allow_private_network: T) -> Self where T: Into, { self.allow_private_network = allow_private_network.into(); self } /// Set the value(s) of the [`Vary`][mdn] header. /// /// In contrast to the other headers, this one has a non-empty default of /// [`preflight_request_headers()`]. /// /// You only need to set this is you want to remove some of these defaults, /// or if you use a closure for one of the other headers and want to add a /// vary header accordingly. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary pub fn vary(mut self, headers: T) -> Self where T: Into, { self.vary = headers.into(); self } } /// Represents a wildcard value (`*`) used with some CORS headers such as /// [`CorsLayer::allow_methods`]. #[derive(Debug, Clone, Copy)] #[must_use] pub struct Any; /// Represents a wildcard value (`*`) used with some CORS headers such as /// [`CorsLayer::allow_methods`]. #[deprecated = "Use Any as a unit struct literal instead"] pub fn any() -> Any { Any } fn separated_by_commas(mut iter: I) -> Option where I: Iterator, { match iter.next() { Some(fst) => { let mut result = BytesMut::from(fst.as_bytes()); for val in iter { result.reserve(val.len() + 1); result.put_u8(b','); result.extend_from_slice(val.as_bytes()); } Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap()) } None => None, } } impl Default for CorsLayer { fn default() -> Self { Self::new() } } impl Layer for CorsLayer { type Service = Cors; fn layer(&self, inner: S) -> Self::Service { ensure_usable_cors_rules(self); Cors { inner, layer: self.clone(), } } } /// Middleware which adds headers for [CORS][mdn]. /// /// See the [module docs](crate::cors) for an example. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS #[derive(Debug, Clone)] #[must_use] pub struct Cors { inner: S, layer: CorsLayer, } impl Cors { /// Create a new `Cors`. /// /// See [`CorsLayer::new`] for more details. pub fn new(inner: S) -> Self { Self { inner, layer: CorsLayer::new(), } } /// A permissive configuration. /// /// See [`CorsLayer::permissive`] for more details. pub fn permissive(inner: S) -> Self { Self { inner, layer: CorsLayer::permissive(), } } /// A very permissive configuration. /// /// See [`CorsLayer::very_permissive`] for more details. pub fn very_permissive(inner: S) -> Self { Self { inner, layer: CorsLayer::very_permissive(), } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer() -> CorsLayer { CorsLayer::new() } /// Set the [`Access-Control-Allow-Credentials`][mdn] header. /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials pub fn allow_credentials(self, allow_credentials: T) -> Self where T: Into, { self.map_layer(|layer| layer.allow_credentials(allow_credentials)) } /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header. /// /// See [`CorsLayer::allow_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers pub fn allow_headers(self, headers: T) -> Self where T: Into, { self.map_layer(|layer| layer.allow_headers(headers)) } /// Set the value of the [`Access-Control-Max-Age`][mdn] header. /// /// See [`CorsLayer::max_age`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age pub fn max_age(self, max_age: T) -> Self where T: Into, { self.map_layer(|layer| layer.max_age(max_age)) } /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header. /// /// See [`CorsLayer::allow_methods`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods pub fn allow_methods(self, methods: T) -> Self where T: Into, { self.map_layer(|layer| layer.allow_methods(methods)) } /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header. /// /// See [`CorsLayer::allow_origin`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin pub fn allow_origin(self, origin: T) -> Self where T: Into, { self.map_layer(|layer| layer.allow_origin(origin)) } /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header. /// /// See [`CorsLayer::expose_headers`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers pub fn expose_headers(self, headers: T) -> Self where T: Into, { self.map_layer(|layer| layer.expose_headers(headers)) } /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. /// /// See [`CorsLayer::allow_private_network`] for more details. /// /// [wicg]: https://wicg.github.io/private-network-access/ pub fn allow_private_network(self, allow_private_network: T) -> Self where T: Into, { self.map_layer(|layer| layer.allow_private_network(allow_private_network)) } fn map_layer(mut self, f: F) -> Self where F: FnOnce(CorsLayer) -> CorsLayer, { self.layer = f(self.layer); self } } impl Service> for Cors where S: Service, Response = Response>, ResBody: Default, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { ensure_usable_cors_rules(&self.layer); self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let (parts, body) = req.into_parts(); let origin = parts.headers.get(&header::ORIGIN); let mut headers = HeaderMap::new(); // These headers are applied to both preflight and subsequent regular CORS requests: // https://fetch.spec.whatwg.org/#http-responses headers.extend(self.layer.allow_origin.to_header(origin, &parts)); headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); headers.extend(self.layer.allow_private_network.to_header(origin, &parts)); let mut vary_headers = self.layer.vary.values(); if let Some(first) = vary_headers.next() { let mut header = match headers.entry(header::VARY) { header::Entry::Occupied(_) => { unreachable!("no vary header inserted up to this point") } header::Entry::Vacant(v) => v.insert_entry(first), }; for val in vary_headers { header.append(val); } } // Return results immediately upon preflight request if parts.method == Method::OPTIONS { // These headers are applied only to preflight requests headers.extend(self.layer.allow_methods.to_header(&parts)); headers.extend(self.layer.allow_headers.to_header(&parts)); headers.extend(self.layer.max_age.to_header(origin, &parts)); ResponseFuture { inner: Kind::PreflightCall { headers }, } } else { // This header is applied only to non-preflight requests headers.extend(self.layer.expose_headers.to_header(&parts)); let req = Request::from_parts(parts, body); ResponseFuture { inner: Kind::CorsCall { future: self.inner.call(req), headers, }, } } } } pin_project! { /// Response future for [`Cors`]. pub struct ResponseFuture { #[pin] inner: Kind, } } pin_project! { #[project = KindProj] enum Kind { CorsCall { #[pin] future: F, headers: HeaderMap, }, PreflightCall { headers: HeaderMap, }, } } impl Future for ResponseFuture where F: Future, E>>, B: Default, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().inner.project() { KindProj::CorsCall { future, headers } => { let mut response: Response = ready!(future.poll(cx))?; response.headers_mut().extend(headers.drain()); Poll::Ready(Ok(response)) } KindProj::PreflightCall { headers } => { let mut response = Response::new(B::default()); mem::swap(response.headers_mut(), headers); Poll::Ready(Ok(response)) } } } } fn ensure_usable_cors_rules(layer: &CorsLayer) { if layer.allow_credentials.is_true() { assert!( !layer.allow_headers.is_wildcard(), "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ with `Access-Control-Allow-Headers: *`" ); assert!( !layer.allow_methods.is_wildcard(), "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ with `Access-Control-Allow-Methods: *`" ); assert!( !layer.allow_origin.is_wildcard(), "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ with `Access-Control-Allow-Origin: *`" ); assert!( !layer.expose_headers.is_wildcard(), "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \ with `Access-Control-Expose-Headers: *`" ); } } /// Returns an iterator over the three request headers that may be involved in a CORS preflight request. /// /// This is the default set of header names returned in the `vary` header pub fn preflight_request_headers() -> impl Iterator { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 array::IntoIter::new([ header::ORIGIN, header::ACCESS_CONTROL_REQUEST_METHOD, header::ACCESS_CONTROL_REQUEST_HEADERS, ]) } tower-http-0.4.4/src/cors/vary.rs000064400000000000000000000024541046102023000150370ustar 00000000000000use std::array; use http::{header::HeaderName, HeaderValue}; use super::preflight_request_headers; /// Holds configuration for how to set the [`Vary`][mdn] header. /// /// See [`CorsLayer::vary`] for more details. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary /// [`CorsLayer::vary`]: super::CorsLayer::vary #[derive(Clone, Debug)] pub struct Vary(Vec); impl Vary { /// Set the list of header names to return as vary header values /// /// See [`CorsLayer::vary`] for more details. /// /// [`CorsLayer::vary`]: super::CorsLayer::vary pub fn list(headers: I) -> Self where I: IntoIterator, { Self(headers.into_iter().map(Into::into).collect()) } pub(super) fn values(&self) -> impl Iterator + '_ { self.0.iter().cloned() } } impl Default for Vary { fn default() -> Self { Self::list(preflight_request_headers()) } } impl From<[HeaderName; N]> for Vary { fn from(arr: [HeaderName; N]) -> Self { #[allow(deprecated)] // Can be changed when MSRV >= 1.53 Self::list(array::IntoIter::new(arr)) } } impl From> for Vary { fn from(vec: Vec) -> Self { Self::list(vec) } } tower-http-0.4.4/src/decompression/body.rs000064400000000000000000000340171046102023000167170ustar 00000000000000#![allow(unused_imports)] use crate::compression_utils::CompressionLevel; use crate::{ compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody}, BoxError, }; #[cfg(feature = "decompression-br")] use async_compression::tokio::bufread::BrotliDecoder; #[cfg(feature = "decompression-gzip")] use async_compression::tokio::bufread::GzipDecoder; #[cfg(feature = "decompression-deflate")] use async_compression::tokio::bufread::ZlibDecoder; #[cfg(feature = "decompression-zstd")] use async_compression::tokio::bufread::ZstdDecoder; use bytes::{Buf, Bytes}; use futures_util::ready; use http::HeaderMap; use http_body::Body; use pin_project_lite::pin_project; use std::task::Context; use std::{io, marker::PhantomData, pin::Pin, task::Poll}; use tokio_util::io::StreamReader; pin_project! { /// Response body of [`RequestDecompression`] and [`Decompression`]. /// /// [`RequestDecompression`]: super::RequestDecompression /// [`Decompression`]: super::Decompression pub struct DecompressionBody where B: Body { #[pin] pub(crate) inner: BodyInner, } } impl Default for DecompressionBody where B: Body + Default, { fn default() -> Self { Self { inner: BodyInner::Identity { inner: B::default(), }, } } } impl DecompressionBody where B: Body, { pub(crate) fn new(inner: BodyInner) -> Self { Self { inner } } /// Get a reference to the inner body pub fn get_ref(&self) -> &B { match &self.inner { #[cfg(feature = "decompression-gzip")] BodyInner::Gzip { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "decompression-deflate")] BodyInner::Deflate { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "decompression-br")] BodyInner::Brotli { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), #[cfg(feature = "decompression-zstd")] BodyInner::Zstd { inner } => inner.read.get_ref().get_ref().get_ref().get_ref(), BodyInner::Identity { inner } => inner, // FIXME: Remove once possible; see https://github.com/rust-lang/rust/issues/51085 #[cfg(not(feature = "decompression-gzip"))] BodyInner::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInner::Zstd { inner } => match inner.0 {}, } } /// Get a mutable reference to the inner body pub fn get_mut(&mut self) -> &mut B { match &mut self.inner { #[cfg(feature = "decompression-gzip")] BodyInner::Gzip { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "decompression-deflate")] BodyInner::Deflate { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "decompression-br")] BodyInner::Brotli { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), #[cfg(feature = "decompression-zstd")] BodyInner::Zstd { inner } => inner.read.get_mut().get_mut().get_mut().get_mut(), BodyInner::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] BodyInner::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInner::Zstd { inner } => match inner.0 {}, } } /// Get a pinned mutable reference to the inner body pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> { match self.project().inner.project() { #[cfg(feature = "decompression-gzip")] BodyInnerProj::Gzip { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "decompression-deflate")] BodyInnerProj::Deflate { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "decompression-br")] BodyInnerProj::Brotli { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), #[cfg(feature = "decompression-zstd")] BodyInnerProj::Zstd { inner } => inner .project() .read .get_pin_mut() .get_pin_mut() .get_pin_mut() .get_pin_mut(), BodyInnerProj::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] BodyInnerProj::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInnerProj::Zstd { inner } => match inner.0 {}, } } /// Consume `self`, returning the inner body pub fn into_inner(self) -> B { match self.inner { #[cfg(feature = "decompression-gzip")] BodyInner::Gzip { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "decompression-deflate")] BodyInner::Deflate { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "decompression-br")] BodyInner::Brotli { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), #[cfg(feature = "decompression-zstd")] BodyInner::Zstd { inner } => inner .read .into_inner() .into_inner() .into_inner() .into_inner(), BodyInner::Identity { inner } => inner, #[cfg(not(feature = "decompression-gzip"))] BodyInner::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInner::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInner::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInner::Zstd { inner } => match inner.0 {}, } } } #[cfg(any( not(feature = "decompression-gzip"), not(feature = "decompression-deflate"), not(feature = "decompression-br"), not(feature = "decompression-zstd") ))] pub(crate) enum Never {} #[cfg(feature = "decompression-gzip")] type GzipBody = WrapBody>; #[cfg(not(feature = "decompression-gzip"))] type GzipBody = (Never, PhantomData); #[cfg(feature = "decompression-deflate")] type DeflateBody = WrapBody>; #[cfg(not(feature = "decompression-deflate"))] type DeflateBody = (Never, PhantomData); #[cfg(feature = "decompression-br")] type BrotliBody = WrapBody>; #[cfg(not(feature = "decompression-br"))] type BrotliBody = (Never, PhantomData); #[cfg(feature = "decompression-zstd")] type ZstdBody = WrapBody>; #[cfg(not(feature = "decompression-zstd"))] type ZstdBody = (Never, PhantomData); pin_project! { #[project = BodyInnerProj] pub(crate) enum BodyInner where B: Body, { Gzip { #[pin] inner: GzipBody, }, Deflate { #[pin] inner: DeflateBody, }, Brotli { #[pin] inner: BrotliBody, }, Zstd { #[pin] inner: ZstdBody, }, Identity { #[pin] inner: B, }, } } impl BodyInner { #[cfg(feature = "decompression-gzip")] pub(crate) fn gzip(inner: WrapBody>) -> Self { Self::Gzip { inner } } #[cfg(feature = "decompression-deflate")] pub(crate) fn deflate(inner: WrapBody>) -> Self { Self::Deflate { inner } } #[cfg(feature = "decompression-br")] pub(crate) fn brotli(inner: WrapBody>) -> Self { Self::Brotli { inner } } #[cfg(feature = "decompression-zstd")] pub(crate) fn zstd(inner: WrapBody>) -> Self { Self::Zstd { inner } } pub(crate) fn identity(inner: B) -> Self { Self::Identity { inner } } } impl Body for DecompressionBody where B: Body, B::Error: Into, { type Data = Bytes; type Error = BoxError; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { match self.project().inner.project() { #[cfg(feature = "decompression-gzip")] BodyInnerProj::Gzip { inner } => inner.poll_data(cx), #[cfg(feature = "decompression-deflate")] BodyInnerProj::Deflate { inner } => inner.poll_data(cx), #[cfg(feature = "decompression-br")] BodyInnerProj::Brotli { inner } => inner.poll_data(cx), #[cfg(feature = "decompression-zstd")] BodyInnerProj::Zstd { inner } => inner.poll_data(cx), BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { Some(Ok(mut buf)) => { let bytes = buf.copy_to_bytes(buf.remaining()); Poll::Ready(Some(Ok(bytes))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), }, #[cfg(not(feature = "decompression-gzip"))] BodyInnerProj::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInnerProj::Zstd { inner } => match inner.0 {}, } } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { match self.project().inner.project() { #[cfg(feature = "decompression-gzip")] BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), #[cfg(feature = "decompression-deflate")] BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), #[cfg(feature = "decompression-br")] BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), #[cfg(feature = "decompression-zstd")] BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), #[cfg(not(feature = "decompression-gzip"))] BodyInnerProj::Gzip { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-deflate"))] BodyInnerProj::Deflate { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-br"))] BodyInnerProj::Brotli { inner } => match inner.0 {}, #[cfg(not(feature = "decompression-zstd"))] BodyInnerProj::Zstd { inner } => match inner.0 {}, } } } #[cfg(feature = "decompression-gzip")] impl DecorateAsyncRead for GzipDecoder where B: Body, { type Input = AsyncReadBody; type Output = GzipDecoder; fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { let mut decoder = GzipDecoder::new(input); decoder.multiple_members(true); decoder } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "decompression-deflate")] impl DecorateAsyncRead for ZlibDecoder where B: Body, { type Input = AsyncReadBody; type Output = ZlibDecoder; fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { ZlibDecoder::new(input) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "decompression-br")] impl DecorateAsyncRead for BrotliDecoder where B: Body, { type Input = AsyncReadBody; type Output = BrotliDecoder; fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { BrotliDecoder::new(input) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } #[cfg(feature = "decompression-zstd")] impl DecorateAsyncRead for ZstdDecoder where B: Body, { type Input = AsyncReadBody; type Output = ZstdDecoder; fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { ZstdDecoder::new(input) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { pinned.get_pin_mut() } } tower-http-0.4.4/src/decompression/future.rs000064400000000000000000000054731046102023000173000ustar 00000000000000#![allow(unused_imports)] use super::{body::BodyInner, DecompressionBody}; use crate::compression_utils::{AcceptEncoding, CompressionLevel, WrapBody}; use crate::content_encoding::SupportedEncodings; use futures_util::ready; use http::{header, Response}; use http_body::Body; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; pin_project! { /// Response future of [`Decompression`]. /// /// [`Decompression`]: super::Decompression #[derive(Debug)] pub struct ResponseFuture { #[pin] pub(crate) inner: F, pub(crate) accept: AcceptEncoding, } } impl Future for ResponseFuture where F: Future, E>>, B: Body, { type Output = Result>, E>; #[allow(unreachable_code, unused_mut, unused_variables)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.as_mut().project().inner.poll(cx)?); let (mut parts, body) = res.into_parts(); let res = if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { let body = match entry.get().as_bytes() { #[cfg(feature = "decompression-gzip")] b"gzip" if self.accept.gzip() => DecompressionBody::new(BodyInner::gzip( WrapBody::new(body, CompressionLevel::default()), )), #[cfg(feature = "decompression-deflate")] b"deflate" if self.accept.deflate() => DecompressionBody::new( BodyInner::deflate(WrapBody::new(body, CompressionLevel::default())), ), #[cfg(feature = "decompression-br")] b"br" if self.accept.br() => DecompressionBody::new(BodyInner::brotli( WrapBody::new(body, CompressionLevel::default()), )), #[cfg(feature = "decompression-zstd")] b"zstd" if self.accept.zstd() => DecompressionBody::new(BodyInner::zstd( WrapBody::new(body, CompressionLevel::default()), )), _ => { return Poll::Ready(Ok(Response::from_parts( parts, DecompressionBody::new(BodyInner::identity(body)), ))) } }; entry.remove(); parts.headers.remove(header::CONTENT_LENGTH); Response::from_parts(parts, body) } else { Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body))) }; Poll::Ready(Ok(res)) } } tower-http-0.4.4/src/decompression/layer.rs000064400000000000000000000047541046102023000171030ustar 00000000000000use super::Decompression; use crate::compression_utils::AcceptEncoding; use tower_layer::Layer; /// Decompresses response bodies of the underlying service. /// /// This adds the `Accept-Encoding` header to requests and transparently decompresses response /// bodies based on the `Content-Encoding` header. /// /// See the [module docs](crate::decompression) for more details. #[derive(Debug, Default, Clone)] pub struct DecompressionLayer { accept: AcceptEncoding, } impl Layer for DecompressionLayer { type Service = Decompression; fn layer(&self, service: S) -> Self::Service { Decompression { inner: service, accept: self.accept, } } } impl DecompressionLayer { /// Creates a new `DecompressionLayer`. pub fn new() -> Self { Default::default() } /// Sets whether to request the gzip encoding. #[cfg(feature = "decompression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to request the Deflate encoding. #[cfg(feature = "decompression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to request the Brotli encoding. #[cfg(feature = "decompression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to request the Zstd encoding. #[cfg(feature = "decompression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables the Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables the Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables the Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } } tower-http-0.4.4/src/decompression/mod.rs000064400000000000000000000146751046102023000165510ustar 00000000000000//! Middleware that decompresses request and response bodies. //! //! # Examples //! //! #### Request //! ```rust //! use bytes::BytesMut; //! use flate2::{write::GzEncoder, Compression}; //! use http::{header, HeaderValue, Request, Response}; //! use http_body::Body as _; // for Body::data //! use hyper::Body; //! use std::{error::Error, io::Write}; //! use tower::{Service, ServiceBuilder, service_fn, ServiceExt}; //! use tower_http::{BoxError, decompression::{DecompressionBody, RequestDecompressionLayer}}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), BoxError> { //! // A request encoded with gzip coming from some HTTP client. //! let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); //! encoder.write_all(b"Hello?")?; //! let request = Request::builder() //! .header(header::CONTENT_ENCODING, "gzip") //! .body(Body::from(encoder.finish()?))?; //! //! // Our HTTP server //! let mut server = ServiceBuilder::new() //! // Automatically decompress request bodies. //! .layer(RequestDecompressionLayer::new()) //! .service(service_fn(handler)); //! //! // Send the request, with the gzip encoded body, to our server. //! let _response = server.ready().await?.call(request).await?; //! //! // Handler receives request whose body is decoded when read //! async fn handler(mut req: Request>) -> Result, BoxError>{ //! let mut data = BytesMut::new(); //! while let Some(chunk) = req.body_mut().data().await { //! let chunk = chunk?; //! data.extend_from_slice(&chunk[..]); //! } //! assert_eq!(data.freeze().to_vec(), b"Hello?"); //! Ok(Response::new(Body::from("Hello, World!"))) //! } //! # Ok(()) //! # } //! ``` //! //! #### Response //! ```rust //! use bytes::BytesMut; //! use http::{Request, Response}; //! use http_body::Body as _; // for Body::data //! use hyper::Body; //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::{compression::Compression, decompression::DecompressionLayer, BoxError}; //! # //! # #[tokio::main] //! # async fn main() -> Result<(), tower_http::BoxError> { //! # async fn handle(req: Request) -> Result, Infallible> { //! # let body = Body::from("Hello, World!"); //! # Ok(Response::new(body)) //! # } //! //! // Some opaque service that applies compression. //! let service = Compression::new(service_fn(handle)); //! //! // Our HTTP client. //! let mut client = ServiceBuilder::new() //! // Automatically decompress response bodies. //! .layer(DecompressionLayer::new()) //! .service(service); //! //! // Call the service. //! // //! // `DecompressionLayer` takes care of setting `Accept-Encoding`. //! let request = Request::new(Body::empty()); //! //! let response = client //! .ready() //! .await? //! .call(request) //! .await?; //! //! // Read the body //! let mut body = response.into_body(); //! let mut bytes = BytesMut::new(); //! while let Some(chunk) = body.data().await { //! let chunk = chunk?; //! bytes.extend_from_slice(&chunk[..]); //! } //! let body = String::from_utf8(bytes.to_vec()).map_err(Into::::into)?; //! //! assert_eq!(body, "Hello, World!"); //! # //! # Ok(()) //! # } //! ``` mod request; mod body; mod future; mod layer; mod service; pub use self::{ body::DecompressionBody, future::ResponseFuture, layer::DecompressionLayer, service::Decompression, }; pub use self::request::future::RequestDecompressionFuture; pub use self::request::layer::RequestDecompressionLayer; pub use self::request::service::RequestDecompression; #[cfg(test)] mod tests { use std::io::Write; use super::*; use crate::compression::Compression; use bytes::BytesMut; use flate2::write::GzEncoder; use http::Response; use http_body::Body as _; use hyper::{Body, Client, Error, Request}; use tower::{service_fn, Service, ServiceExt}; #[tokio::test] async fn works() { let mut client = Decompression::new(Compression::new(service_fn(handle))); let req = Request::builder() .header("accept-encoding", "gzip") .body(Body::empty()) .unwrap(); let res = client.ready().await.unwrap().call(req).await.unwrap(); // read the body, it will be decompressed automatically let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let decompressed_data = String::from_utf8(data.freeze().to_vec()).unwrap(); assert_eq!(decompressed_data, "Hello, World!"); } #[tokio::test] async fn decompress_multi_gz() { let mut client = Decompression::new(service_fn(handle_multi_gz)); let req = Request::builder() .header("accept-encoding", "gzip") .body(Body::empty()) .unwrap(); let res = client.ready().await.unwrap().call(req).await.unwrap(); // read the body, it will be decompressed automatically let mut body = res.into_body(); let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } let decompressed_data = String::from_utf8(data.freeze().to_vec()).unwrap(); assert_eq!(decompressed_data, "Hello, World!"); } async fn handle(_req: Request) -> Result, Error> { Ok(Response::new(Body::from("Hello, World!"))) } async fn handle_multi_gz(_req: Request) -> Result, Error> { let mut buf = Vec::new(); let mut enc1 = GzEncoder::new(&mut buf, Default::default()); enc1.write_all(b"Hello, ").unwrap(); enc1.finish().unwrap(); let mut enc2 = GzEncoder::new(&mut buf, Default::default()); enc2.write_all(b"World!").unwrap(); enc2.finish().unwrap(); let mut res = Response::new(Body::from(buf)); res.headers_mut() .insert("content-encoding", "gzip".parse().unwrap()); Ok(res) } #[allow(dead_code)] async fn is_compatible_with_hyper() { let mut client = Decompression::new(Client::new()); let req = Request::new(Body::empty()); let _: Response> = client.ready().await.unwrap().call(req).await.unwrap(); } } tower-http-0.4.4/src/decompression/request/future.rs000064400000000000000000000050231046102023000207570ustar 00000000000000use crate::compression_utils::AcceptEncoding; use crate::BoxError; use bytes::Buf; use http::{header, HeaderValue, Response, StatusCode}; use http_body::{combinators::UnsyncBoxBody, Body, Empty}; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; use std::task::Context; use std::task::Poll; pin_project! { #[derive(Debug)] /// Response future of [`RequestDecompression`] pub struct RequestDecompressionFuture where F: Future, E>>, B: Body { #[pin] kind: Kind, } } pin_project! { #[derive(Debug)] #[project = StateProj] enum Kind where F: Future, E>>, B: Body { Inner { #[pin] fut: F }, Unsupported { #[pin] accept: AcceptEncoding }, } } impl RequestDecompressionFuture where F: Future, E>>, B: Body, { #[must_use] pub(super) fn unsupported_encoding(accept: AcceptEncoding) -> Self { Self { kind: Kind::Unsupported { accept }, } } #[must_use] pub(super) fn inner(fut: F) -> Self { Self { kind: Kind::Inner { fut }, } } } impl Future for RequestDecompressionFuture where F: Future, E>>, B: Body + Send + 'static, B::Data: Buf + 'static, B::Error: Into + 'static, E: Into, { type Output = Result>, BoxError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { StateProj::Inner { fut } => fut .poll(cx) .map_ok(|res| res.map(|body| body.map_err(Into::into).boxed_unsync())) .map_err(Into::into), StateProj::Unsupported { accept } => { let res = Response::builder() .header( header::ACCEPT_ENCODING, accept .to_header_value() .unwrap_or(HeaderValue::from_static("identity")), ) .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) .body(Empty::new().map_err(Into::into).boxed_unsync()) .unwrap(); Poll::Ready(Ok(res)) } } } } tower-http-0.4.4/src/decompression/request/layer.rs000064400000000000000000000063641046102023000205720ustar 00000000000000use super::service::RequestDecompression; use crate::compression_utils::AcceptEncoding; use tower_layer::Layer; /// Decompresses request bodies and calls its underlying service. /// /// Transparently decompresses request bodies based on the `Content-Encoding` header. /// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` /// status code will be returned with the accepted encodings in the `Accept-Encoding` header. /// /// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type`. But /// will call the underlying service with the unmodified request if the encoding is not supported. /// This is disabled by default. /// /// See the [module docs](crate::decompression) for more details. #[derive(Debug, Default, Clone)] pub struct RequestDecompressionLayer { accept: AcceptEncoding, pass_through_unaccepted: bool, } impl Layer for RequestDecompressionLayer { type Service = RequestDecompression; fn layer(&self, service: S) -> Self::Service { RequestDecompression { inner: service, accept: self.accept, pass_through_unaccepted: self.pass_through_unaccepted, } } } impl RequestDecompressionLayer { /// Creates a new `RequestDecompressionLayer`. pub fn new() -> Self { Default::default() } /// Sets whether to support gzip encoding. #[cfg(feature = "decompression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to support Deflate encoding. #[cfg(feature = "decompression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to support Brotli encoding. #[cfg(feature = "decompression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to support Zstd encoding. #[cfg(feature = "decompression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Disables support for gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables support for Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables support for Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables support for Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } /// Sets whether to pass through the request even when the encoding is not supported. pub fn pass_through_unaccepted(mut self, enable: bool) -> Self { self.pass_through_unaccepted = enable; self } } tower-http-0.4.4/src/decompression/request/mod.rs000064400000000000000000000073601046102023000202320ustar 00000000000000pub(super) mod future; pub(super) mod layer; pub(super) mod service; #[cfg(test)] mod tests { use super::service::RequestDecompression; use crate::decompression::DecompressionBody; use bytes::BytesMut; use flate2::{write::GzEncoder, Compression}; use http::{header, Response, StatusCode}; use http_body::Body as _; use hyper::{Body, Error, Request, Server}; use std::io::Write; use std::net::SocketAddr; use tower::{make::Shared, service_fn, Service, ServiceExt}; #[tokio::test] async fn decompress_accepted_encoding() { let req = request_gzip(); let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); let _ = svc.ready().await.unwrap().call(req).await.unwrap(); } #[tokio::test] async fn support_unencoded_body() { let req = Request::builder().body(Body::from("Hello?")).unwrap(); let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); let _ = svc.ready().await.unwrap().call(req).await.unwrap(); } #[tokio::test] async fn unaccepted_content_encoding_returns_unsupported_media_type() { let req = request_gzip(); let mut svc = RequestDecompression::new(service_fn(should_not_be_called)).gzip(false); let res = svc.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, res.status()); } #[tokio::test] async fn pass_through_unsupported_encoding_when_enabled() { let req = request_gzip(); let mut svc = RequestDecompression::new(service_fn(assert_request_is_passed_through)) .pass_through_unaccepted(true) .gzip(false); let _ = svc.ready().await.unwrap().call(req).await.unwrap(); } async fn assert_request_is_decompressed( req: Request>, ) -> Result, Error> { let (parts, mut body) = req.into_parts(); let body = read_body(&mut body).await; assert_eq!(body, b"Hello?"); assert!(!parts.headers.contains_key(header::CONTENT_ENCODING)); Ok(Response::new(Body::from("Hello, World!"))) } async fn assert_request_is_passed_through( req: Request>, ) -> Result, Error> { let (parts, mut body) = req.into_parts(); let body = read_body(&mut body).await; assert_ne!(body, b"Hello?"); assert!(parts.headers.contains_key(header::CONTENT_ENCODING)); Ok(Response::new(Body::empty())) } async fn should_not_be_called( _: Request>, ) -> Result, Error> { panic!("Inner service should not be called"); } fn request_gzip() -> Request { let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); encoder.write_all(b"Hello?").unwrap(); let body = encoder.finish().unwrap(); Request::builder() .header(header::CONTENT_ENCODING, "gzip") .body(Body::from(body)) .unwrap() } async fn read_body(body: &mut DecompressionBody) -> Vec { let mut data = BytesMut::new(); while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); data.extend_from_slice(&chunk[..]); } data.freeze().to_vec() } #[allow(dead_code)] async fn is_compatible_with_hyper() { let svc = service_fn(assert_request_is_decompressed); let svc = RequestDecompression::new(svc); let make_service = Shared::new(svc); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let server = Server::bind(&addr).serve(make_service); server.await.unwrap(); } } tower-http-0.4.4/src/decompression/request/service.rs000064400000000000000000000161101046102023000211040ustar 00000000000000use super::future::RequestDecompressionFuture as ResponseFuture; use super::layer::RequestDecompressionLayer; use crate::compression_utils::CompressionLevel; use crate::{ compression_utils::AcceptEncoding, decompression::body::BodyInner, decompression::DecompressionBody, BoxError, }; use bytes::Buf; use http::{header, Request, Response}; use http_body::{combinators::UnsyncBoxBody, Body}; use std::task::{Context, Poll}; use tower_service::Service; #[cfg(any( feature = "decompression-gzip", feature = "decompression-deflate", feature = "decompression-br", feature = "decompression-zstd", ))] use crate::content_encoding::SupportedEncodings; /// Decompresses request bodies and calls its underlying service. /// /// Transparently decompresses request bodies based on the `Content-Encoding` header. /// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` /// status code will be returned with the accepted encodings in the `Accept-Encoding` header. /// /// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type` but /// will call the underlying service with the unmodified request if the encoding is not supported. /// This is disabled by default. /// /// See the [module docs](crate::decompression) for more details. #[derive(Debug, Clone)] pub struct RequestDecompression { pub(super) inner: S, pub(super) accept: AcceptEncoding, pub(super) pass_through_unaccepted: bool, } impl Service> for RequestDecompression where S: Service>, Response = Response>, ReqBody: Body, ResBody: Body + Send + 'static, S::Error: Into, ::Error: Into, D: Buf + 'static, { type Response = Response>; type Error = BoxError; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx).map_err(Into::into) } fn call(&mut self, req: Request) -> Self::Future { let (mut parts, body) = req.into_parts(); let body = if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { match entry.get().as_bytes() { #[cfg(feature = "decompression-gzip")] b"gzip" if self.accept.gzip() => { entry.remove(); parts.headers.remove(header::CONTENT_LENGTH); BodyInner::gzip(crate::compression_utils::WrapBody::new( body, CompressionLevel::default(), )) } #[cfg(feature = "decompression-deflate")] b"deflate" if self.accept.deflate() => { entry.remove(); parts.headers.remove(header::CONTENT_LENGTH); BodyInner::deflate(crate::compression_utils::WrapBody::new( body, CompressionLevel::default(), )) } #[cfg(feature = "decompression-br")] b"br" if self.accept.br() => { entry.remove(); parts.headers.remove(header::CONTENT_LENGTH); BodyInner::brotli(crate::compression_utils::WrapBody::new( body, CompressionLevel::default(), )) } #[cfg(feature = "decompression-zstd")] b"zstd" if self.accept.zstd() => { entry.remove(); parts.headers.remove(header::CONTENT_LENGTH); BodyInner::zstd(crate::compression_utils::WrapBody::new( body, CompressionLevel::default(), )) } b"identity" => BodyInner::identity(body), _ if self.pass_through_unaccepted => BodyInner::identity(body), _ => return ResponseFuture::unsupported_encoding(self.accept), } } else { BodyInner::identity(body) }; let body = DecompressionBody::new(body); let req = Request::from_parts(parts, body); ResponseFuture::inner(self.inner.call(req)) } } impl RequestDecompression { /// Creates a new `RequestDecompression` wrapping the `service`. pub fn new(service: S) -> Self { Self { inner: service, accept: AcceptEncoding::default(), pass_through_unaccepted: false, } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `RequestDecompression` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer() -> RequestDecompressionLayer { RequestDecompressionLayer::new() } /// Passes through the request even when the encoding is not supported. /// /// By default pass-through is disabled. pub fn pass_through_unaccepted(mut self, enabled: bool) -> Self { self.pass_through_unaccepted = enabled; self } /// Sets whether to support gzip encoding. #[cfg(feature = "decompression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to support Deflate encoding. #[cfg(feature = "decompression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to support Brotli encoding. #[cfg(feature = "decompression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to support Zstd encoding. #[cfg(feature = "decompression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Disables support for gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables support for Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables support for Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables support for Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } } tower-http-0.4.4/src/decompression/service.rs000064400000000000000000000071511046102023000174210ustar 00000000000000use super::{DecompressionBody, DecompressionLayer, ResponseFuture}; use crate::compression_utils::AcceptEncoding; use http::{ header::{self, ACCEPT_ENCODING}, Request, Response, }; use http_body::Body; use std::task::{Context, Poll}; use tower_service::Service; /// Decompresses response bodies of the underlying service. /// /// This adds the `Accept-Encoding` header to requests and transparently decompresses response /// bodies based on the `Content-Encoding` header. /// /// See the [module docs](crate::decompression) for more details. #[derive(Debug, Clone)] pub struct Decompression { pub(crate) inner: S, pub(crate) accept: AcceptEncoding, } impl Decompression { /// Creates a new `Decompression` wrapping the `service`. pub fn new(service: S) -> Self { Self { inner: service, accept: AcceptEncoding::default(), } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `Decompression` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer() -> DecompressionLayer { DecompressionLayer::new() } /// Sets whether to request the gzip encoding. #[cfg(feature = "decompression-gzip")] pub fn gzip(mut self, enable: bool) -> Self { self.accept.set_gzip(enable); self } /// Sets whether to request the Deflate encoding. #[cfg(feature = "decompression-deflate")] pub fn deflate(mut self, enable: bool) -> Self { self.accept.set_deflate(enable); self } /// Sets whether to request the Brotli encoding. #[cfg(feature = "decompression-br")] pub fn br(mut self, enable: bool) -> Self { self.accept.set_br(enable); self } /// Sets whether to request the Zstd encoding. #[cfg(feature = "decompression-zstd")] pub fn zstd(mut self, enable: bool) -> Self { self.accept.set_zstd(enable); self } /// Disables the gzip encoding. /// /// This method is available even if the `gzip` crate feature is disabled. pub fn no_gzip(mut self) -> Self { self.accept.set_gzip(false); self } /// Disables the Deflate encoding. /// /// This method is available even if the `deflate` crate feature is disabled. pub fn no_deflate(mut self) -> Self { self.accept.set_deflate(false); self } /// Disables the Brotli encoding. /// /// This method is available even if the `br` crate feature is disabled. pub fn no_br(mut self) -> Self { self.accept.set_br(false); self } /// Disables the Zstd encoding. /// /// This method is available even if the `zstd` crate feature is disabled. pub fn no_zstd(mut self) -> Self { self.accept.set_zstd(false); self } } impl Service> for Decompression where S: Service, Response = Response>, ResBody: Body, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) { if let Some(accept) = self.accept.to_header_value() { entry.insert(accept); } } ResponseFuture { inner: self.inner.call(req), accept: self.accept, } } } tower-http-0.4.4/src/follow_redirect/mod.rs000064400000000000000000000341341046102023000170520ustar 00000000000000//! Middleware for following redirections. //! //! # Overview //! //! The [`FollowRedirect`] middleware retries requests with the inner [`Service`] to follow HTTP //! redirections. //! //! The middleware tries to clone the original [`Request`] when making a redirected request. //! However, since [`Extensions`][http::Extensions] are `!Clone`, any extensions set by outer //! middleware will be discarded. Also, the request body cannot always be cloned. When the //! original body is known to be empty by [`Body::size_hint`], the middleware uses `Default` //! implementation of the body type to create a new request body. If you know that the body can be //! cloned in some way, you can tell the middleware to clone it by configuring a [`policy`]. //! //! # Examples //! //! ## Basic usage //! //! ``` //! use http::{Request, Response}; //! use hyper::Body; //! use tower::{Service, ServiceBuilder, ServiceExt}; //! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), std::convert::Infallible> { //! # let http_client = tower::service_fn(|req: Request<_>| async move { //! # let dest = "https://www.rust-lang.org/"; //! # let mut res = http::Response::builder(); //! # if req.uri() != dest { //! # res = res //! # .status(http::StatusCode::MOVED_PERMANENTLY) //! # .header(http::header::LOCATION, dest); //! # } //! # Ok::<_, std::convert::Infallible>(res.body(Body::empty()).unwrap()) //! # }); //! let mut client = ServiceBuilder::new() //! .layer(FollowRedirectLayer::new()) //! .service(http_client); //! //! let request = Request::builder() //! .uri("https://rust-lang.org/") //! .body(Body::empty()) //! .unwrap(); //! //! let response = client.ready().await?.call(request).await?; //! // Get the final request URI. //! assert_eq!(response.extensions().get::().unwrap().0, "https://www.rust-lang.org/"); //! # Ok(()) //! # } //! ``` //! //! ## Customizing the `Policy` //! //! You can use a [`Policy`] value to customize how the middleware handles redirections. //! //! ``` //! use http::{Request, Response}; //! use hyper::Body; //! use tower::{Service, ServiceBuilder, ServiceExt}; //! use tower_http::follow_redirect::{ //! policy::{self, PolicyExt}, //! FollowRedirectLayer, //! }; //! //! #[derive(Debug)] //! enum MyError { //! Hyper(hyper::Error), //! TooManyRedirects, //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), MyError> { //! # let http_client = //! # tower::service_fn(|_: Request| async { Ok(Response::new(Body::empty())) }); //! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10. //! // Return an error when the limit was reached. //! .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects))) //! // Do not follow cross-origin redirections, and return the redirection responses as-is. //! .and::<_, (), _>(policy::SameOrigin::new()); //! //! let mut client = ServiceBuilder::new() //! .layer(FollowRedirectLayer::with_policy(policy)) //! .map_err(MyError::Hyper) //! .service(http_client); //! //! // ... //! # let _ = client.ready().await?.call(Request::default()).await?; //! # Ok(()) //! # } //! ``` pub mod policy; use self::policy::{Action, Attempt, Policy, Standard}; use futures_core::ready; use futures_util::future::Either; use http::{ header::LOCATION, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, Version, }; use http_body::Body; use iri_string::types::{UriAbsoluteString, UriReferenceStr}; use pin_project_lite::pin_project; use std::{ convert::TryFrom, future::Future, mem, pin::Pin, str, task::{Context, Poll}, }; use tower::util::Oneshot; use tower_layer::Layer; use tower_service::Service; /// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses. /// /// See the [module docs](self) for more details. #[derive(Clone, Copy, Debug, Default)] pub struct FollowRedirectLayer

{ policy: P, } impl FollowRedirectLayer { /// Create a new [`FollowRedirectLayer`] with a [`Standard`] redirection policy. pub fn new() -> Self { Self::default() } } impl

FollowRedirectLayer

{ /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`]. pub fn with_policy(policy: P) -> Self { FollowRedirectLayer { policy } } } impl Layer for FollowRedirectLayer

where S: Clone, P: Clone, { type Service = FollowRedirect; fn layer(&self, inner: S) -> Self::Service { FollowRedirect::with_policy(inner, self.policy.clone()) } } /// Middleware that retries requests with a [`Service`] to follow redirection responses. /// /// See the [module docs](self) for more details. #[derive(Clone, Copy, Debug)] pub struct FollowRedirect { inner: S, policy: P, } impl FollowRedirect { /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy. pub fn new(inner: S) -> Self { Self::with_policy(inner, Standard::default()) } /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer() -> FollowRedirectLayer { FollowRedirectLayer::new() } } impl FollowRedirect where P: Clone, { /// Create a new [`FollowRedirect`] with the given redirection [`Policy`]. pub fn with_policy(inner: S, policy: P) -> Self { FollowRedirect { inner, policy } } /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware /// with the given redirection [`Policy`]. /// /// [`Layer`]: tower_layer::Layer pub fn layer_with_policy(policy: P) -> FollowRedirectLayer

{ FollowRedirectLayer::with_policy(policy) } define_inner_service_accessors!(); } impl Service> for FollowRedirect where S: Service, Response = Response> + Clone, ReqBody: Body + Default, P: Policy + Clone, { type Response = Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let service = self.inner.clone(); let mut service = mem::replace(&mut self.inner, service); let mut policy = self.policy.clone(); let mut body = BodyRepr::None; body.try_clone_from(req.body(), &policy); policy.on_request(&mut req); ResponseFuture { method: req.method().clone(), uri: req.uri().clone(), version: req.version(), headers: req.headers().clone(), body, future: Either::Left(service.call(req)), service, policy, } } } pin_project! { /// Response future for [`FollowRedirect`]. #[derive(Debug)] pub struct ResponseFuture where S: Service>, { #[pin] future: Either>>, service: S, policy: P, method: Method, uri: Uri, version: Version, headers: HeaderMap, body: BodyRepr, } } impl Future for ResponseFuture where S: Service, Response = Response> + Clone, ReqBody: Body + Default, P: Policy, { type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); let mut res = ready!(this.future.as_mut().poll(cx)?); res.extensions_mut().insert(RequestUri(this.uri.clone())); match res.status() { StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => { // User agents MAY change the request method from POST to GET // (RFC 7231 section 6.4.2. and 6.4.3.). if *this.method == Method::POST { *this.method = Method::GET; *this.body = BodyRepr::Empty; } } StatusCode::SEE_OTHER => { // A user agent can perform a GET or HEAD request (RFC 7231 section 6.4.4.). if *this.method != Method::HEAD { *this.method = Method::GET; } *this.body = BodyRepr::Empty; } StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {} _ => return Poll::Ready(Ok(res)), }; let body = if let Some(body) = this.body.take() { body } else { return Poll::Ready(Ok(res)); }; let location = res .headers() .get(&LOCATION) .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri)); let location = if let Some(loc) = location { loc } else { return Poll::Ready(Ok(res)); }; let attempt = Attempt { status: res.status(), location: &location, previous: this.uri, }; match this.policy.redirect(&attempt)? { Action::Follow => { *this.uri = location; this.body.try_clone_from(&body, &this.policy); let mut req = Request::new(body); *req.uri_mut() = this.uri.clone(); *req.method_mut() = this.method.clone(); *req.version_mut() = *this.version; *req.headers_mut() = this.headers.clone(); this.policy.on_request(&mut req); this.future .set(Either::Right(Oneshot::new(this.service.clone(), req))); cx.waker().wake_by_ref(); Poll::Pending } Action::Stop => Poll::Ready(Ok(res)), } } } /// Response [`Extensions`][http::Extensions] value that represents the effective request URI of /// a response returned by a [`FollowRedirect`] middleware. /// /// The value differs from the original request's effective URI if the middleware has followed /// redirections. pub struct RequestUri(pub Uri); #[derive(Debug)] enum BodyRepr { Some(B), Empty, None, } impl BodyRepr where B: Body + Default, { fn take(&mut self) -> Option { match mem::replace(self, BodyRepr::None) { BodyRepr::Some(body) => Some(body), BodyRepr::Empty => { *self = BodyRepr::Empty; Some(B::default()) } BodyRepr::None => None, } } fn try_clone_from(&mut self, body: &B, policy: &P) where P: Policy, { match self { BodyRepr::Some(_) | BodyRepr::Empty => {} BodyRepr::None => { if let Some(body) = clone_body(policy, body) { *self = BodyRepr::Some(body); } } } } } fn clone_body(policy: &P, body: &B) -> Option where P: Policy, B: Body + Default, { if body.size_hint().exact() == Some(0) { Some(B::default()) } else { policy.clone_body(body) } } /// Try to resolve a URI reference `relative` against a base URI `base`. fn resolve_uri(relative: &str, base: &Uri) -> Option { let relative = UriReferenceStr::new(relative).ok()?; let base = UriAbsoluteString::try_from(base.to_string()).ok()?; let uri = relative.resolve_against(&base).to_string(); Uri::try_from(uri).ok() } #[cfg(test)] mod tests { use super::{policy::*, *}; use hyper::{header::LOCATION, Body}; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn follows() { let svc = ServiceBuilder::new() .layer(FollowRedirectLayer::with_policy(Action::Follow)) .buffer(1) .service_fn(handle); let req = Request::builder() .uri("http://example.com/42") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(*res.body(), 0); assert_eq!( res.extensions().get::().unwrap().0, "http://example.com/0" ); } #[tokio::test] async fn stops() { let svc = ServiceBuilder::new() .layer(FollowRedirectLayer::with_policy(Action::Stop)) .buffer(1) .service_fn(handle); let req = Request::builder() .uri("http://example.com/42") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(*res.body(), 42); assert_eq!( res.extensions().get::().unwrap().0, "http://example.com/42" ); } #[tokio::test] async fn limited() { let svc = ServiceBuilder::new() .layer(FollowRedirectLayer::with_policy(Limited::new(10))) .buffer(1) .service_fn(handle); let req = Request::builder() .uri("http://example.com/42") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(*res.body(), 42 - 10); assert_eq!( res.extensions().get::().unwrap().0, "http://example.com/32" ); } /// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero, /// returning `n` as the response body. async fn handle(req: Request) -> Result, Infallible> { let n: u64 = req.uri().path()[1..].parse().unwrap(); let mut res = Response::builder(); if n > 0 { res = res .status(StatusCode::MOVED_PERMANENTLY) .header(LOCATION, format!("/{}", n - 1)); } Ok::<_, Infallible>(res.body(n).unwrap()) } } tower-http-0.4.4/src/follow_redirect/policy/and.rs000064400000000000000000000060251046102023000203320ustar 00000000000000use super::{Action, Attempt, Policy}; use http::Request; /// A redirection [`Policy`] that combines the results of two `Policy`s. /// /// See [`PolicyExt::and`][super::PolicyExt::and] for more details. #[derive(Clone, Copy, Debug, Default)] pub struct And { a: A, b: B, } impl And { pub(crate) fn new(a: A, b: B) -> Self where A: Policy, B: Policy, { And { a, b } } } impl Policy for And where A: Policy, B: Policy, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { match self.a.redirect(attempt) { Ok(Action::Follow) => self.b.redirect(attempt), a => a, } } fn on_request(&mut self, request: &mut Request) { self.a.on_request(request); self.b.on_request(request); } fn clone_body(&self, body: &Bd) -> Option { self.a.clone_body(body).or_else(|| self.b.clone_body(body)) } } #[cfg(test)] mod tests { use super::*; use http::Uri; struct Taint

{ policy: P, used: bool, } impl

Taint

{ fn new(policy: P) -> Self { Taint { policy, used: false, } } } impl Policy for Taint

where P: Policy, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { self.used = true; self.policy.redirect(attempt) } } #[test] fn redirect() { let attempt = Attempt { status: Default::default(), location: &Uri::from_static("*"), previous: &Uri::from_static("*"), }; let mut a = Taint::new(Action::Follow); let mut b = Taint::new(Action::Follow); let mut policy = And::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); assert!(a.used); assert!(b.used); let mut a = Taint::new(Action::Stop); let mut b = Taint::new(Action::Follow); let mut policy = And::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); assert!(a.used); assert!(!b.used); // short-circuiting let mut a = Taint::new(Action::Follow); let mut b = Taint::new(Action::Stop); let mut policy = And::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); assert!(a.used); assert!(b.used); let mut a = Taint::new(Action::Stop); let mut b = Taint::new(Action::Stop); let mut policy = And::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); assert!(a.used); assert!(!b.used); } } tower-http-0.4.4/src/follow_redirect/policy/clone_body_fn.rs000064400000000000000000000021011046102023000223570ustar 00000000000000use super::{Action, Attempt, Policy}; use std::fmt; /// A redirection [`Policy`] created from a closure. /// /// See [`clone_body_fn`] for more details. #[derive(Clone, Copy)] pub struct CloneBodyFn { f: F, } impl fmt::Debug for CloneBodyFn { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CloneBodyFn") .field("f", &std::any::type_name::()) .finish() } } impl Policy for CloneBodyFn where F: Fn(&B) -> Option, { fn redirect(&mut self, _: &Attempt<'_>) -> Result { Ok(Action::Follow) } fn clone_body(&self, body: &B) -> Option { (self.f)(body) } } /// Create a new redirection [`Policy`] from a closure `F: Fn(&B) -> Option`. /// /// [`clone_body`][Policy::clone_body] method of the returned `Policy` delegates to the wrapped /// closure and [`redirect`][Policy::redirect] method always returns [`Action::Follow`]. pub fn clone_body_fn(f: F) -> CloneBodyFn where F: Fn(&B) -> Option, { CloneBodyFn { f } } tower-http-0.4.4/src/follow_redirect/policy/filter_credentials.rs000064400000000000000000000111131046102023000234240ustar 00000000000000use super::{eq_origin, Action, Attempt, Policy}; use http::{ header::{self, HeaderName}, Request, }; /// A redirection [`Policy`] that removes credentials from requests in redirections. #[derive(Clone, Debug)] pub struct FilterCredentials { block_cross_origin: bool, block_any: bool, remove_blocklisted: bool, remove_all: bool, blocked: bool, } const BLOCKLIST: &[HeaderName] = &[ header::AUTHORIZATION, header::COOKIE, header::PROXY_AUTHORIZATION, ]; impl FilterCredentials { /// Create a new [`FilterCredentials`] that removes blocklisted request headers in cross-origin /// redirections. pub fn new() -> Self { FilterCredentials { block_cross_origin: true, block_any: false, remove_blocklisted: true, remove_all: false, blocked: false, } } /// Configure `self` to mark cross-origin redirections as "blocked". pub fn block_cross_origin(mut self, enable: bool) -> Self { self.block_cross_origin = enable; self } /// Configure `self` to mark every redirection as "blocked". pub fn block_any(mut self) -> Self { self.block_any = true; self } /// Configure `self` to mark no redirections as "blocked". pub fn block_none(mut self) -> Self { self.block_any = false; self.block_cross_origin(false) } /// Configure `self` to remove blocklisted headers in "blocked" redirections. /// /// The blocklist includes the following headers: /// /// - `Authorization` /// - `Cookie` /// - `Proxy-Authorization` pub fn remove_blocklisted(mut self, enable: bool) -> Self { self.remove_blocklisted = enable; self } /// Configure `self` to remove all headers in "blocked" redirections. pub fn remove_all(mut self) -> Self { self.remove_all = true; self } /// Configure `self` to remove no headers in "blocked" redirections. pub fn remove_none(mut self) -> Self { self.remove_all = false; self.remove_blocklisted(false) } } impl Default for FilterCredentials { fn default() -> Self { Self::new() } } impl Policy for FilterCredentials { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { self.blocked = self.block_any || (self.block_cross_origin && !eq_origin(attempt.previous(), attempt.location())); Ok(Action::Follow) } fn on_request(&mut self, request: &mut Request) { if self.blocked { let headers = request.headers_mut(); if self.remove_all { headers.clear(); } else if self.remove_blocklisted { for key in BLOCKLIST { headers.remove(key); } } } } } #[cfg(test)] mod tests { use super::*; use http::Uri; #[test] fn works() { let mut policy = FilterCredentials::default(); let initial = Uri::from_static("http://example.com/old"); let same_origin = Uri::from_static("http://example.com/new"); let cross_origin = Uri::from_static("https://example.com/new"); let mut request = Request::builder() .uri(initial) .header(header::COOKIE, "42") .body(()) .unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); assert!(request.headers().contains_key(header::COOKIE)); let attempt = Attempt { status: Default::default(), location: &same_origin, previous: request.uri(), }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); let mut request = Request::builder() .uri(same_origin) .header(header::COOKIE, "42") .body(()) .unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); assert!(request.headers().contains_key(header::COOKIE)); let attempt = Attempt { status: Default::default(), location: &cross_origin, previous: request.uri(), }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); let mut request = Request::builder() .uri(cross_origin) .header(header::COOKIE, "42") .body(()) .unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); assert!(!request.headers().contains_key(header::COOKIE)); } } tower-http-0.4.4/src/follow_redirect/policy/limited.rs000064400000000000000000000041341046102023000212160ustar 00000000000000use super::{Action, Attempt, Policy}; /// A redirection [`Policy`] that limits the number of successive redirections. #[derive(Clone, Copy, Debug)] pub struct Limited { remaining: usize, } impl Limited { /// Create a new [`Limited`] with a limit of `max` redirections. pub fn new(max: usize) -> Self { Limited { remaining: max } } } impl Default for Limited { /// Returns the default [`Limited`] with a limit of `20` redirections. fn default() -> Self { // This is the (default) limit of Firefox and the Fetch API. // https://hg.mozilla.org/mozilla-central/file/6264f13d54a1caa4f5b60303617a819efd91b8ee/modules/libpref/init/all.js#l1371 // https://fetch.spec.whatwg.org/#http-redirect-fetch Limited::new(20) } } impl Policy for Limited { fn redirect(&mut self, _: &Attempt<'_>) -> Result { if self.remaining > 0 { self.remaining -= 1; Ok(Action::Follow) } else { Ok(Action::Stop) } } } #[cfg(test)] mod tests { use http::{Request, Uri}; use super::*; #[test] fn works() { let uri = Uri::from_static("https://example.com/"); let mut policy = Limited::new(2); for _ in 0..2 { let mut request = Request::builder().uri(uri.clone()).body(()).unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); let attempt = Attempt { status: Default::default(), location: &uri, previous: &uri, }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); } let mut request = Request::builder().uri(uri.clone()).body(()).unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); let attempt = Attempt { status: Default::default(), location: &uri, previous: &uri, }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); } } tower-http-0.4.4/src/follow_redirect/policy/mod.rs000064400000000000000000000207761046102023000203600ustar 00000000000000//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware. mod and; mod clone_body_fn; mod filter_credentials; mod limited; mod or; mod redirect_fn; mod same_origin; pub use self::{ and::And, clone_body_fn::{clone_body_fn, CloneBodyFn}, filter_credentials::FilterCredentials, limited::Limited, or::Or, redirect_fn::{redirect_fn, RedirectFn}, same_origin::SameOrigin, }; use http::{uri::Scheme, Request, StatusCode, Uri}; /// Trait for the policy on handling redirection responses. /// /// # Example /// /// Detecting a cyclic redirection: /// /// ``` /// use http::{Request, Uri}; /// use std::collections::HashSet; /// use tower_http::follow_redirect::policy::{Action, Attempt, Policy}; /// /// #[derive(Clone)] /// pub struct DetectCycle { /// uris: HashSet, /// } /// /// impl Policy for DetectCycle { /// fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { /// if self.uris.contains(attempt.location()) { /// Ok(Action::Stop) /// } else { /// self.uris.insert(attempt.previous().clone()); /// Ok(Action::Follow) /// } /// } /// } /// ``` pub trait Policy { /// Invoked when the service received a response with a redirection status code (`3xx`). /// /// This method returns an [`Action`] which indicates whether the service should follow /// the redirection. fn redirect(&mut self, attempt: &Attempt<'_>) -> Result; /// Invoked right before the service makes a request, regardless of whether it is redirected /// or not. /// /// This can for example be used to remove sensitive headers from the request /// or prepare the request in other ways. /// /// The default implementation does nothing. fn on_request(&mut self, _request: &mut Request) {} /// Try to clone a request body before the service makes a redirected request. /// /// If the request body cannot be cloned, return `None`. /// /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero, /// in which case `B::default()` will be used to create a new request body. /// /// The default implementation returns `None`. fn clone_body(&self, _body: &B) -> Option { None } } impl Policy for &mut P where P: Policy + ?Sized, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { (**self).redirect(attempt) } fn on_request(&mut self, request: &mut Request) { (**self).on_request(request) } fn clone_body(&self, body: &B) -> Option { (**self).clone_body(body) } } impl Policy for Box

where P: Policy + ?Sized, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { (**self).redirect(attempt) } fn on_request(&mut self, request: &mut Request) { (**self).on_request(request) } fn clone_body(&self, body: &B) -> Option { (**self).clone_body(body) } } /// An extension trait for `Policy` that provides additional adapters. pub trait PolicyExt { /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return /// `Action::Follow`. /// /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body /// with both policies. /// /// # Example /// /// ``` /// use bytes::Bytes; /// use hyper::Body; /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt}; /// /// enum MyBody { /// Bytes(Bytes), /// Hyper(Body), /// } /// /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| { /// if let MyBody::Bytes(buf) = body { /// Some(MyBody::Bytes(buf.clone())) /// } else { /// None /// } /// })); /// ``` fn and(self, other: P) -> And where Self: Policy + Sized, P: Policy; /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns /// `Action::Follow`. /// /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body /// with both policies. /// /// # Example /// /// ``` /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt}; /// /// #[derive(Clone)] /// enum MyError { /// TooManyRedirects, /// // ... /// } /// /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects)); /// ``` fn or(self, other: P) -> Or where Self: Policy + Sized, P: Policy; } impl PolicyExt for T where T: ?Sized, { fn and(self, other: P) -> And where Self: Policy + Sized, P: Policy, { And::new(self, other) } fn or(self, other: P) -> Or where Self: Policy + Sized, P: Policy, { Or::new(self, other) } } /// A redirection [`Policy`] with a reasonable set of standard behavior. /// /// This policy limits the number of successive redirections ([`Limited`]) /// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]). pub type Standard = And; /// A type that holds information on a redirection attempt. pub struct Attempt<'a> { pub(crate) status: StatusCode, pub(crate) location: &'a Uri, pub(crate) previous: &'a Uri, } impl<'a> Attempt<'a> { /// Returns the redirection response. pub fn status(&self) -> StatusCode { self.status } /// Returns the destination URI of the redirection. pub fn location(&self) -> &'a Uri { self.location } /// Returns the URI of the original request. pub fn previous(&self) -> &'a Uri { self.previous } } /// A value returned by [`Policy::redirect`] which indicates the action /// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response. #[derive(Clone, Copy, Debug)] pub enum Action { /// Follow the redirection. Follow, /// Do not follow the redirection, and return the redirection response as-is. Stop, } impl Action { /// Returns `true` if the `Action` is a `Follow` value. pub fn is_follow(&self) -> bool { if let Action::Follow = self { true } else { false } } /// Returns `true` if the `Action` is a `Stop` value. pub fn is_stop(&self) -> bool { if let Action::Stop = self { true } else { false } } } impl Policy for Action { fn redirect(&mut self, _: &Attempt<'_>) -> Result { Ok(*self) } } impl Policy for Result where E: Clone, { fn redirect(&mut self, _: &Attempt<'_>) -> Result { self.clone() } } /// Compares the origins of two URIs as per RFC 6454 sections 4. through 5. fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool { let default_port = match (lhs.scheme(), rhs.scheme()) { (Some(l), Some(r)) if l == r => { if l == &Scheme::HTTP { 80 } else if l == &Scheme::HTTPS { 443 } else { return false; } } _ => return false, }; match (lhs.host(), rhs.host()) { (Some(l), Some(r)) if l == r => {} _ => return false, } lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port) } #[cfg(test)] mod tests { use super::*; #[test] fn eq_origin_works() { assert!(eq_origin( &Uri::from_static("https://example.com/1"), &Uri::from_static("https://example.com/2") )); assert!(eq_origin( &Uri::from_static("https://example.com:443/"), &Uri::from_static("https://example.com/") )); assert!(eq_origin( &Uri::from_static("https://example.com/"), &Uri::from_static("https://user@example.com/") )); assert!(!eq_origin( &Uri::from_static("https://example.com/"), &Uri::from_static("https://www.example.com/") )); assert!(!eq_origin( &Uri::from_static("https://example.com/"), &Uri::from_static("http://example.com/") )); } } tower-http-0.4.4/src/follow_redirect/policy/or.rs000064400000000000000000000060261046102023000202110ustar 00000000000000use super::{Action, Attempt, Policy}; use http::Request; /// A redirection [`Policy`] that combines the results of two `Policy`s. /// /// See [`PolicyExt::or`][super::PolicyExt::or] for more details. #[derive(Clone, Copy, Debug, Default)] pub struct Or { a: A, b: B, } impl Or { pub(crate) fn new(a: A, b: B) -> Self where A: Policy, B: Policy, { Or { a, b } } } impl Policy for Or where A: Policy, B: Policy, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { match self.a.redirect(attempt) { Ok(Action::Stop) | Err(_) => self.b.redirect(attempt), a => a, } } fn on_request(&mut self, request: &mut Request) { self.a.on_request(request); self.b.on_request(request); } fn clone_body(&self, body: &Bd) -> Option { self.a.clone_body(body).or_else(|| self.b.clone_body(body)) } } #[cfg(test)] mod tests { use super::*; use http::Uri; struct Taint

{ policy: P, used: bool, } impl

Taint

{ fn new(policy: P) -> Self { Taint { policy, used: false, } } } impl Policy for Taint

where P: Policy, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { self.used = true; self.policy.redirect(attempt) } } #[test] fn redirect() { let attempt = Attempt { status: Default::default(), location: &Uri::from_static("*"), previous: &Uri::from_static("*"), }; let mut a = Taint::new(Action::Follow); let mut b = Taint::new(Action::Follow); let mut policy = Or::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); assert!(a.used); assert!(!b.used); // short-circuiting let mut a = Taint::new(Action::Stop); let mut b = Taint::new(Action::Follow); let mut policy = Or::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); assert!(a.used); assert!(b.used); let mut a = Taint::new(Action::Follow); let mut b = Taint::new(Action::Stop); let mut policy = Or::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); assert!(a.used); assert!(!b.used); let mut a = Taint::new(Action::Stop); let mut b = Taint::new(Action::Stop); let mut policy = Or::new::<(), ()>(&mut a, &mut b); assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); assert!(a.used); assert!(b.used); } } tower-http-0.4.4/src/follow_redirect/policy/redirect_fn.rs000064400000000000000000000017371046102023000220610ustar 00000000000000use super::{Action, Attempt, Policy}; use std::fmt; /// A redirection [`Policy`] created from a closure. /// /// See [`redirect_fn`] for more details. #[derive(Clone, Copy)] pub struct RedirectFn { f: F, } impl fmt::Debug for RedirectFn { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RedirectFn") .field("f", &std::any::type_name::()) .finish() } } impl Policy for RedirectFn where F: FnMut(&Attempt<'_>) -> Result, { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { (self.f)(attempt) } } /// Create a new redirection [`Policy`] from a closure /// `F: FnMut(&Attempt<'_>) -> Result`. /// /// [`redirect`][Policy::redirect] method of the returned `Policy` delegates to /// the wrapped closure. pub fn redirect_fn(f: F) -> RedirectFn where F: FnMut(&Attempt<'_>) -> Result, { RedirectFn { f } } tower-http-0.4.4/src/follow_redirect/policy/same_origin.rs000064400000000000000000000036311046102023000220640ustar 00000000000000use super::{eq_origin, Action, Attempt, Policy}; use std::fmt; /// A redirection [`Policy`] that stops cross-origin redirections. #[derive(Clone, Copy, Default)] pub struct SameOrigin { _priv: (), } impl SameOrigin { /// Create a new [`SameOrigin`]. pub fn new() -> Self { Self::default() } } impl fmt::Debug for SameOrigin { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SameOrigin").finish() } } impl Policy for SameOrigin { fn redirect(&mut self, attempt: &Attempt<'_>) -> Result { if eq_origin(attempt.previous(), attempt.location()) { Ok(Action::Follow) } else { Ok(Action::Stop) } } } #[cfg(test)] mod tests { use super::*; use http::{Request, Uri}; #[test] fn works() { let mut policy = SameOrigin::default(); let initial = Uri::from_static("http://example.com/old"); let same_origin = Uri::from_static("http://example.com/new"); let cross_origin = Uri::from_static("https://example.com/new"); let mut request = Request::builder().uri(initial).body(()).unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); let attempt = Attempt { status: Default::default(), location: &same_origin, previous: request.uri(), }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_follow()); let mut request = Request::builder().uri(same_origin).body(()).unwrap(); Policy::<(), ()>::on_request(&mut policy, &mut request); let attempt = Attempt { status: Default::default(), location: &cross_origin, previous: request.uri(), }; assert!(Policy::<(), ()>::redirect(&mut policy, &attempt) .unwrap() .is_stop()); } } tower-http-0.4.4/src/lib.rs000064400000000000000000000260661046102023000136630ustar 00000000000000//! `async fn(HttpRequest) -> Result` //! //! # Overview //! //! tower-http is a library that provides HTTP-specific middleware and utilities built on top of //! [tower]. //! //! All middleware uses the [http] and [http-body] crates as the HTTP abstractions. That means //! they're compatible with any library or framework that also uses those crates, such as //! [hyper], [tonic], and [warp]. //! //! # Example server //! //! This example shows how to apply middleware from tower-http to a [`Service`] and then run //! that service using [hyper]. //! //! ```rust,no_run //! use tower_http::{ //! add_extension::AddExtensionLayer, //! compression::CompressionLayer, //! propagate_header::PropagateHeaderLayer, //! sensitive_headers::SetSensitiveRequestHeadersLayer, //! set_header::SetResponseHeaderLayer, //! trace::TraceLayer, //! validate_request::ValidateRequestHeaderLayer, //! }; //! use tower::{ServiceBuilder, service_fn, make::Shared}; //! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}}; //! use hyper::{Body, Error, server::Server, service::make_service_fn}; //! use std::{sync::Arc, net::SocketAddr, convert::Infallible, iter::once}; //! # struct DatabaseConnectionPool; //! # impl DatabaseConnectionPool { //! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool } //! # } //! # fn content_length_from_response(_: &http::Response) -> Option { None } //! # async fn update_in_flight_requests_metric(count: usize) {} //! //! // Our request handler. This is where we would implement the application logic //! // for responding to HTTP requests... //! async fn handler(request: Request) -> Result, Error> { //! // ... //! # todo!() //! } //! //! // Shared state across all request handlers --- in this case, a pool of database connections. //! struct State { //! pool: DatabaseConnectionPool, //! } //! //! #[tokio::main] //! async fn main() { //! // Construct the shared state. //! let state = State { //! pool: DatabaseConnectionPool::new(), //! }; //! //! // Use tower's `ServiceBuilder` API to build a stack of tower middleware //! // wrapping our request handler. //! let service = ServiceBuilder::new() //! // Mark the `Authorization` request header as sensitive so it doesn't show in logs //! .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION))) //! // High level logging of requests and responses //! .layer(TraceLayer::new_for_http()) //! // Share an `Arc` with all requests //! .layer(AddExtensionLayer::new(Arc::new(state))) //! // Compress responses //! .layer(CompressionLayer::new()) //! // Propagate `X-Request-Id`s from requests to responses //! .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id"))) //! // If the response has a known size set the `Content-Length` header //! .layer(SetResponseHeaderLayer::overriding(CONTENT_TYPE, content_length_from_response)) //! // Authorize requests using a token //! .layer(ValidateRequestHeaderLayer::bearer("passwordlol")) //! // Accept only application/json, application/* and */* in a request's ACCEPT header //! .layer(ValidateRequestHeaderLayer::accept("application/json")) //! // Wrap a `Service` in our middleware stack //! .service_fn(handler); //! //! // And run our service using `hyper` //! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); //! Server::bind(&addr) //! .serve(Shared::new(service)) //! .await //! .expect("server error"); //! } //! ``` //! //! Keep in mind that while this example uses [hyper], tower-http supports any HTTP //! client/server implementation that uses the [http] and [http-body] crates. //! //! # Example client //! //! tower-http middleware can also be applied to HTTP clients: //! //! ```rust,no_run //! use tower_http::{ //! decompression::DecompressionLayer, //! set_header::SetRequestHeaderLayer, //! trace::TraceLayer, //! classify::StatusInRangeAsFailures, //! }; //! use tower::{ServiceBuilder, Service, ServiceExt}; //! use hyper::Body; //! use http::{Request, HeaderValue, header::USER_AGENT}; //! //! #[tokio::main] //! async fn main() { //! let mut client = ServiceBuilder::new() //! // Add tracing and consider server errors and client //! // errors as failures. //! .layer(TraceLayer::new( //! StatusInRangeAsFailures::new(400..=599).into_make_classifier() //! )) //! // Set a `User-Agent` header on all requests. //! .layer(SetRequestHeaderLayer::overriding( //! USER_AGENT, //! HeaderValue::from_static("tower-http demo") //! )) //! // Decompress response bodies //! .layer(DecompressionLayer::new()) //! // Wrap a `hyper::Client` in our middleware stack. //! // This is possible because `hyper::Client` implements //! // `tower::Service`. //! .service(hyper::Client::new()); //! //! // Make a request //! let request = Request::builder() //! .uri("http://example.com") //! .body(Body::empty()) //! .unwrap(); //! //! let response = client //! .ready() //! .await //! .unwrap() //! .call(request) //! .await //! .unwrap(); //! } //! ``` //! //! # Feature Flags //! //! All middleware are disabled by default and can be enabled using [cargo features]. //! //! For example, to enable the [`Trace`] middleware, add the "trace" feature flag in //! your `Cargo.toml`: //! //! ```toml //! tower-http = { version = "0.1", features = ["trace"] } //! ``` //! //! You can use `"full"` to enable everything: //! //! ```toml //! tower-http = { version = "0.1", features = ["full"] } //! ``` //! //! # Getting Help //! //! If you're new to tower its [guides] might help. In the tower-http repo we also have a [number //! of examples][examples] showing how to put everything together. You're also welcome to ask in //! the [`#tower` Discord channel][chat] or open an [issue] with your question. //! //! [tower]: https://crates.io/crates/tower //! [http]: https://crates.io/crates/http //! [http-body]: https://crates.io/crates/http-body //! [hyper]: https://crates.io/crates/hyper //! [guides]: https://github.com/tower-rs/tower/tree/master/guides //! [tonic]: https://crates.io/crates/tonic //! [warp]: https://crates.io/crates/warp //! [cargo features]: https://doc.rust-lang.org/cargo/reference/features.html //! [`AddExtension`]: crate::add_extension::AddExtension //! [`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html //! [chat]: https://discord.gg/tokio //! [issue]: https://github.com/tower-rs/tower-http/issues/new //! [`Trace`]: crate::trace::Trace //! [examples]: https://github.com/tower-rs/tower-http/tree/master/examples #![warn( clippy::all, clippy::dbg_macro, clippy::todo, clippy::empty_enum, clippy::enum_glob_use, clippy::pub_enum_variant_names, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, clippy::needless_continue, clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, clippy::rest_pat_in_fully_bound_structs, clippy::fn_params_excessive_bools, clippy::exit, clippy::inefficient_to_string, clippy::linkedlist, clippy::macro_use_imports, clippy::option_option, clippy::verbose_file_reads, clippy::unnested_or_patterns, rust_2018_idioms, future_incompatible, nonstandard_style, missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow( elided_lifetimes_in_paths, // TODO: Remove this once the MSRV bumps to 1.42.0 or above. clippy::match_like_matches_macro, clippy::type_complexity )] #![forbid(unsafe_code)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] #[macro_use] pub(crate) mod macros; #[cfg(feature = "auth")] pub mod auth; #[cfg(feature = "set-header")] pub mod set_header; #[cfg(feature = "propagate-header")] pub mod propagate_header; #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", ))] pub mod compression; #[cfg(feature = "add-extension")] pub mod add_extension; #[cfg(feature = "sensitive-headers")] pub mod sensitive_headers; #[cfg(any( feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", ))] pub mod decompression; #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", feature = "fs" // Used for serving precompressed static files as well ))] mod content_encoding; #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", ))] mod compression_utils; #[cfg(any( feature = "compression-br", feature = "compression-deflate", feature = "compression-gzip", feature = "compression-zstd", feature = "decompression-br", feature = "decompression-deflate", feature = "decompression-gzip", feature = "decompression-zstd", ))] pub use compression_utils::CompressionLevel; #[cfg(feature = "map-response-body")] pub mod map_response_body; #[cfg(feature = "map-request-body")] pub mod map_request_body; #[cfg(feature = "trace")] pub mod trace; #[cfg(feature = "follow-redirect")] pub mod follow_redirect; #[cfg(feature = "limit")] pub mod limit; #[cfg(feature = "metrics")] pub mod metrics; #[cfg(feature = "cors")] pub mod cors; #[cfg(feature = "request-id")] pub mod request_id; #[cfg(feature = "catch-panic")] pub mod catch_panic; #[cfg(feature = "set-status")] pub mod set_status; #[cfg(feature = "timeout")] pub mod timeout; #[cfg(feature = "normalize-path")] pub mod normalize_path; pub mod classify; pub mod services; #[cfg(feature = "util")] mod builder; #[cfg(feature = "util")] #[doc(inline)] pub use self::builder::ServiceBuilderExt; #[cfg(feature = "validate-request")] pub mod validate_request; /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] pub enum LatencyUnit { /// Use seconds. Seconds, /// Use milliseconds. Millis, /// Use microseconds. Micros, /// Use nanoseconds. Nanos, } /// Alias for a type-erased error type. pub type BoxError = Box; mod sealed { #[allow(unreachable_pub)] pub trait Sealed {} } tower-http-0.4.4/src/limit/body.rs000064400000000000000000000054021046102023000151570ustar 00000000000000use bytes::Bytes; use http::{HeaderMap, HeaderValue, Response, StatusCode}; use http_body::{Body, Full, SizeHint}; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// Response body for [`RequestBodyLimit`]. /// /// [`RequestBodyLimit`]: super::RequestBodyLimit pub struct ResponseBody { #[pin] inner: ResponseBodyInner } } impl ResponseBody { fn payload_too_large() -> Self { Self { inner: ResponseBodyInner::PayloadTooLarge { body: Full::from(BODY), }, } } pub(crate) fn new(body: B) -> Self { Self { inner: ResponseBodyInner::Body { body }, } } } pin_project! { #[project = BodyProj] enum ResponseBodyInner { PayloadTooLarge { #[pin] body: Full, }, Body { #[pin] body: B } } } impl Body for ResponseBody where B: Body, { type Data = Bytes; type Error = B::Error; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { match self.project().inner.project() { BodyProj::PayloadTooLarge { body } => body.poll_data(cx).map_err(|err| match err {}), BodyProj::Body { body } => body.poll_data(cx), } } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { match self.project().inner.project() { BodyProj::PayloadTooLarge { body } => { body.poll_trailers(cx).map_err(|err| match err {}) } BodyProj::Body { body } => body.poll_trailers(cx), } } fn is_end_stream(&self) -> bool { match &self.inner { ResponseBodyInner::PayloadTooLarge { body } => body.is_end_stream(), ResponseBodyInner::Body { body } => body.is_end_stream(), } } fn size_hint(&self) -> SizeHint { match &self.inner { ResponseBodyInner::PayloadTooLarge { body } => body.size_hint(), ResponseBodyInner::Body { body } => body.size_hint(), } } } const BODY: &[u8] = b"length limit exceeded"; pub(crate) fn create_error_response() -> Response> where B: Body, { let mut res = Response::new(ResponseBody::payload_too_large()); *res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; #[allow(clippy::declare_interior_mutable_const)] const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); res.headers_mut() .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); res } tower-http-0.4.4/src/limit/future.rs000064400000000000000000000027221046102023000155360ustar 00000000000000use super::body::create_error_response; use super::ResponseBody; use futures_core::ready; use http::Response; use http_body::Body; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { /// Response future for [`RequestBodyLimit`]. /// /// [`RequestBodyLimit`]: super::RequestBodyLimit pub struct ResponseFuture { #[pin] inner: ResponseFutureInner, } } impl ResponseFuture { pub(crate) fn payload_too_large() -> Self { Self { inner: ResponseFutureInner::PayloadTooLarge, } } pub(crate) fn new(future: F) -> Self { Self { inner: ResponseFutureInner::Future { future }, } } } pin_project! { #[project = ResFutProj] enum ResponseFutureInner { PayloadTooLarge, Future { #[pin] future: F, } } } impl Future for ResponseFuture where ResBody: Body, F: Future, E>>, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = match self.project().inner.project() { ResFutProj::PayloadTooLarge => create_error_response(), ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new), }; Poll::Ready(Ok(res)) } } tower-http-0.4.4/src/limit/layer.rs000064400000000000000000000015411046102023000153360ustar 00000000000000use super::RequestBodyLimit; use tower_layer::Layer; /// Layer that applies the [`RequestBodyLimit`] middleware that intercepts requests /// with body lengths greater than the configured limit and converts them into /// `413 Payload Too Large` responses. /// /// See the [module docs](crate::limit) for an example. /// /// [`RequestBodyLimit`]: super::RequestBodyLimit #[derive(Clone, Copy, Debug)] pub struct RequestBodyLimitLayer { limit: usize, } impl RequestBodyLimitLayer { /// Create a new `RequestBodyLimitLayer` with the given body length limit. pub fn new(limit: usize) -> Self { Self { limit } } } impl Layer for RequestBodyLimitLayer { type Service = RequestBodyLimit; fn layer(&self, inner: S) -> Self::Service { RequestBodyLimit { inner, limit: self.limit, } } } tower-http-0.4.4/src/limit/mod.rs000064400000000000000000000121421046102023000150000ustar 00000000000000//! Middleware for limiting request bodies. //! //! This layer will also intercept requests with a `Content-Length` header //! larger than the allowable limit and return an immediate error response //! before reading any of the body. //! //! Note that payload length errors can be used by adversaries in an attempt //! to smuggle requests. When an incoming stream is dropped due to an //! over-sized payload, servers should close the connection or resynchronize //! by optimistically consuming some data in an attempt to reach the end of //! the current HTTP frame. If the incoming stream cannot be resynchronized, //! then the connection should be closed. If you're using [hyper] this is //! automatically handled for you. //! //! # Examples //! //! ## Limiting based on `Content-Length` //! //! If a `Content-Length` header is present and indicates a payload that is //! larger than the acceptable limit, then the underlying service will not //! be called and a `413 Payload Too Large` response will be generated. //! //! ```rust //! use bytes::Bytes; //! use std::convert::Infallible; //! use http::{Request, Response, StatusCode, HeaderValue, header::CONTENT_LENGTH}; //! use http_body::{Limited, LengthLimitError}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::limit::RequestBodyLimitLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request>) -> Result, Infallible> { //! panic!("This will not be hit") //! } //! //! let mut svc = ServiceBuilder::new() //! // Limit incoming requests to 4096 bytes. //! .layer(RequestBodyLimitLayer::new(4096)) //! .service_fn(handle); //! //! // Call the service with a header that indicates the body is too large. //! let mut request = Request::builder() //! .header(CONTENT_LENGTH, HeaderValue::from_static("5000")) //! .body(Body::empty()) //! .unwrap(); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); //! # //! # Ok(()) //! # } //! ``` //! //! ## Limiting without known `Content-Length` //! //! If a `Content-Length` header is not present, then the body will be read //! until the configured limit has been reached. If the payload is larger than //! the limit, the [`http_body::Limited`] body will return an error. This //! error can be inspected to determine if it is a [`http_body::LengthLimitError`] //! and return an appropriate response in such case. //! //! Note that no error will be generated if the body is never read. Similarly, //! if the body _would be_ to large, but is never consumed beyond the length //! limit, then no error is generated, and handling of the remaining incoming //! data stream is left to the server implementation as described above. //! //! ```rust //! # use bytes::Bytes; //! # use std::convert::Infallible; //! # use http::{Request, Response, StatusCode}; //! # use http_body::{Limited, LengthLimitError}; //! # use tower::{Service, ServiceExt, ServiceBuilder, BoxError}; //! # use tower_http::limit::RequestBodyLimitLayer; //! # use hyper::Body; //! # //! # #[tokio::main] //! # async fn main() -> Result<(), BoxError> { //! async fn handle(req: Request>) -> Result, BoxError> { //! let data = match hyper::body::to_bytes(req.into_body()).await { //! Ok(data) => data, //! Err(err) => { //! if let Some(_) = err.downcast_ref::() { //! let mut resp = Response::new(Body::empty()); //! *resp.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; //! return Ok(resp); //! } else { //! return Err(err); //! } //! } //! }; //! //! Ok(Response::new(Body::empty())) //! } //! //! let mut svc = ServiceBuilder::new() //! // Limit incoming requests to 4096 bytes. //! .layer(RequestBodyLimitLayer::new(4096)) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::OK); //! //! // Call the service with a body that is too large. //! let request = Request::new(Body::from(Bytes::from(vec![0u8; 4097]))); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); //! # //! # Ok(()) //! # } //! ``` //! //! ## Limiting without `Content-Length` //! //! If enforcement of body size limits is desired without preemptively //! handling requests with a `Content-Length` header indicating an over-sized //! request, consider using [`MapRequestBody`] to wrap the request body with //! [`http_body::Limited`] and checking for [`http_body::LengthLimitError`] //! like in the previous example. //! //! [`MapRequestBody`]: crate::map_request_body //! [hyper]: https://crates.io/crates/hyper mod body; mod future; mod layer; mod service; pub use body::ResponseBody; pub use future::ResponseFuture; pub use layer::RequestBodyLimitLayer; pub use service::RequestBodyLimit; tower-http-0.4.4/src/limit/service.rs000064400000000000000000000037651046102023000156740ustar 00000000000000use super::{RequestBodyLimitLayer, ResponseBody, ResponseFuture}; use http::{Request, Response}; use http_body::{Body, Limited}; use std::task::{Context, Poll}; use tower_service::Service; /// Middleware that intercepts requests with body lengths greater than the /// configured limit and converts them into `413 Payload Too Large` responses. /// /// See the [module docs](crate::limit) for an example. #[derive(Clone, Copy, Debug)] pub struct RequestBodyLimit { pub(crate) inner: S, pub(crate) limit: usize, } impl RequestBodyLimit { /// Create a new `RequestBodyLimit` with the given body length limit. pub fn new(inner: S, limit: usize) -> Self { Self { inner, limit } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `RequestBodyLimit` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(limit: usize) -> RequestBodyLimitLayer { RequestBodyLimitLayer::new(limit) } } impl Service> for RequestBodyLimit where ResBody: Body, S: Service>, Response = Response>, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let content_length = req .headers() .get(http::header::CONTENT_LENGTH) .and_then(|value| value.to_str().ok()?.parse::().ok()); let body_limit = match content_length { Some(len) if len > self.limit => return ResponseFuture::payload_too_large(), Some(len) => self.limit.min(len), None => self.limit, }; let req = req.map(|body| Limited::new(body, body_limit)); ResponseFuture::new(self.inner.call(req)) } } tower-http-0.4.4/src/macros.rs000064400000000000000000000065451046102023000144010ustar 00000000000000#[allow(unused_macros)] macro_rules! define_inner_service_accessors { () => { /// Gets a reference to the underlying service. pub fn get_ref(&self) -> &S { &self.inner } /// Gets a mutable reference to the underlying service. pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Consumes `self`, returning the underlying service. pub fn into_inner(self) -> S { self.inner } }; } #[allow(unused_macros)] macro_rules! opaque_body { ($(#[$m:meta])* pub type $name:ident = $actual:ty;) => { opaque_body! { $(#[$m])* pub type $name<> = $actual; } }; ($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => { pin_project_lite::pin_project! { $(#[$m])* pub struct $name<$($param),*> { #[pin] pub(crate) inner: $actual } } impl<$($param),*> $name<$($param),*> { pub(crate) fn new(inner: $actual) -> Self { Self { inner } } } impl<$($param),*> http_body::Body for $name<$($param),*> { type Data = <$actual as http_body::Body>::Data; type Error = <$actual as http_body::Body>::Error; #[inline] fn poll_data( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll>> { self.project().inner.poll_data(cx) } #[inline] fn poll_trailers( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll, Self::Error>> { self.project().inner.poll_trailers(cx) } #[inline] fn is_end_stream(&self) -> bool { http_body::Body::is_end_stream(&self.inner) } #[inline] fn size_hint(&self) -> http_body::SizeHint { http_body::Body::size_hint(&self.inner) } } }; } #[allow(unused_macros)] macro_rules! opaque_future { ($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => { pin_project_lite::pin_project! { $(#[$m])* pub struct $name<$($param),+> { #[pin] inner: $actual } } impl<$($param),+> $name<$($param),+> { pub(crate) fn new(inner: $actual) -> Self { Self { inner } } } impl<$($param),+> std::fmt::Debug for $name<$($param),+> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple(stringify!($name)).field(&format_args!("...")).finish() } } impl<$($param),+> std::future::Future for $name<$($param),+> where $actual: std::future::Future, { type Output = <$actual as std::future::Future>::Output; #[inline] fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { self.project().inner.poll(cx) } } } } tower-http-0.4.4/src/map_request_body.rs000064400000000000000000000113051046102023000164450ustar 00000000000000//! Apply a transformation to the request body. //! //! # Example //! //! ``` //! use bytes::Bytes; //! use http::{Request, Response}; //! use hyper::Body; //! use std::convert::Infallible; //! use std::{pin::Pin, task::{Context, Poll}}; //! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; //! use tower_http::map_request_body::MapRequestBodyLayer; //! use futures::ready; //! //! // A wrapper for a `hyper::Body` that prints the size of data chunks //! struct PrintChunkSizesBody { //! inner: Body, //! } //! //! impl PrintChunkSizesBody { //! fn new(inner: Body) -> Self { //! Self { inner } //! } //! } //! //! impl http_body::Body for PrintChunkSizesBody { //! type Data = Bytes; //! type Error = hyper::Error; //! //! fn poll_data( //! mut self: Pin<&mut Self>, //! cx: &mut Context<'_>, //! ) -> Poll>> { //! if let Some(chunk) = ready!(Pin::new(&mut self.inner).poll_data(cx)?) { //! println!("chunk size = {}", chunk.len()); //! Poll::Ready(Some(Ok(chunk))) //! } else { //! Poll::Ready(None) //! } //! } //! //! fn poll_trailers( //! mut self: Pin<&mut Self>, //! cx: &mut Context<'_>, //! ) -> Poll, Self::Error>> { //! Pin::new(&mut self.inner).poll_trailers(cx) //! } //! //! fn is_end_stream(&self) -> bool { //! self.inner.is_end_stream() //! } //! //! fn size_hint(&self) -> http_body::SizeHint { //! self.inner.size_hint() //! } //! } //! //! async fn handle(_: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut svc = ServiceBuilder::new() //! // Wrap response bodies in `PrintChunkSizesBody` //! .layer(MapRequestBodyLayer::new(PrintChunkSizesBody::new)) //! .service_fn(handle); //! //! // Call the service //! let request = Request::new(Body::empty()); //! //! svc.ready().await?.call(request).await?; //! # Ok(()) //! # } //! ``` use http::{Request, Response}; use std::{ fmt, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Apply a transformation to the request body. /// /// See the [module docs](crate::map_request_body) for an example. #[derive(Clone)] pub struct MapRequestBodyLayer { f: F, } impl MapRequestBodyLayer { /// Create a new [`MapRequestBodyLayer`]. /// /// `F` is expected to be a function that takes a body and returns another body. pub fn new(f: F) -> Self { Self { f } } } impl Layer for MapRequestBodyLayer where F: Clone, { type Service = MapRequestBody; fn layer(&self, inner: S) -> Self::Service { MapRequestBody::new(inner, self.f.clone()) } } impl fmt::Debug for MapRequestBodyLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequestBodyLayer") .field("f", &std::any::type_name::()) .finish() } } /// Apply a transformation to the request body. /// /// See the [module docs](crate::map_request_body) for an example. #[derive(Clone)] pub struct MapRequestBody { inner: S, f: F, } impl MapRequestBody { /// Create a new [`MapRequestBody`]. /// /// `F` is expected to be a function that takes a body and returns another body. pub fn new(service: S, f: F) -> Self { Self { inner: service, f } } /// Returns a new [`Layer`] that wraps services with a `MapRequestBodyLayer` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(f: F) -> MapRequestBodyLayer { MapRequestBodyLayer::new(f) } define_inner_service_accessors!(); } impl Service> for MapRequestBody where S: Service, Response = Response>, F: FnMut(ReqBody) -> NewReqBody, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let req = req.map(&mut self.f); self.inner.call(req) } } impl fmt::Debug for MapRequestBody where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequestBody") .field("inner", &self.inner) .field("f", &std::any::type_name::()) .finish() } } tower-http-0.4.4/src/map_response_body.rs000064400000000000000000000127461046102023000166250ustar 00000000000000//! Apply a transformation to the response body. //! //! # Example //! //! ``` //! use bytes::Bytes; //! use http::{Request, Response}; //! use hyper::Body; //! use std::convert::Infallible; //! use std::{pin::Pin, task::{Context, Poll}}; //! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; //! use tower_http::map_response_body::MapResponseBodyLayer; //! use futures::ready; //! //! // A wrapper for a `hyper::Body` that prints the size of data chunks //! struct PrintChunkSizesBody { //! inner: Body, //! } //! //! impl PrintChunkSizesBody { //! fn new(inner: Body) -> Self { //! Self { inner } //! } //! } //! //! impl http_body::Body for PrintChunkSizesBody { //! type Data = Bytes; //! type Error = hyper::Error; //! //! fn poll_data( //! mut self: Pin<&mut Self>, //! cx: &mut Context<'_>, //! ) -> Poll>> { //! if let Some(chunk) = ready!(Pin::new(&mut self.inner).poll_data(cx)?) { //! println!("chunk size = {}", chunk.len()); //! Poll::Ready(Some(Ok(chunk))) //! } else { //! Poll::Ready(None) //! } //! } //! //! fn poll_trailers( //! mut self: Pin<&mut Self>, //! cx: &mut Context<'_>, //! ) -> Poll, Self::Error>> { //! Pin::new(&mut self.inner).poll_trailers(cx) //! } //! //! fn is_end_stream(&self) -> bool { //! self.inner.is_end_stream() //! } //! //! fn size_hint(&self) -> http_body::SizeHint { //! self.inner.size_hint() //! } //! } //! //! async fn handle(_: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut svc = ServiceBuilder::new() //! // Wrap response bodies in `PrintChunkSizesBody` //! .layer(MapResponseBodyLayer::new(PrintChunkSizesBody::new)) //! .service_fn(handle); //! //! // Call the service //! let request = Request::new(Body::from("foobar")); //! //! svc.ready().await?.call(request).await?; //! # Ok(()) //! # } //! ``` use futures_core::ready; use http::{Request, Response}; use pin_project_lite::pin_project; use std::future::Future; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Apply a transformation to the response body. /// /// See the [module docs](crate::map_response_body) for an example. #[derive(Clone)] pub struct MapResponseBodyLayer { f: F, } impl MapResponseBodyLayer { /// Create a new [`MapResponseBodyLayer`]. /// /// `F` is expected to be a function that takes a body and returns another body. pub fn new(f: F) -> Self { Self { f } } } impl Layer for MapResponseBodyLayer where F: Clone, { type Service = MapResponseBody; fn layer(&self, inner: S) -> Self::Service { MapResponseBody::new(inner, self.f.clone()) } } impl fmt::Debug for MapResponseBodyLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapResponseBodyLayer") .field("f", &std::any::type_name::()) .finish() } } /// Apply a transformation to the response body. /// /// See the [module docs](crate::map_response_body) for an example. #[derive(Clone)] pub struct MapResponseBody { inner: S, f: F, } impl MapResponseBody { /// Create a new [`MapResponseBody`]. /// /// `F` is expected to be a function that takes a body and returns another body. pub fn new(service: S, f: F) -> Self { Self { inner: service, f } } /// Returns a new [`Layer`] that wraps services with a `MapResponseBodyLayer` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(f: F) -> MapResponseBodyLayer { MapResponseBodyLayer::new(f) } define_inner_service_accessors!(); } impl Service> for MapResponseBody where S: Service, Response = Response>, F: FnMut(ResBody) -> NewResBody + Clone, { type Response = Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseFuture { inner: self.inner.call(req), f: self.f.clone(), } } } impl fmt::Debug for MapResponseBody where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapResponseBody") .field("inner", &self.inner) .field("f", &std::any::type_name::()) .finish() } } pin_project! { /// Response future for [`MapResponseBody`]. pub struct ResponseFuture { #[pin] inner: Fut, f: F, } } impl Future for ResponseFuture where Fut: Future, E>>, F: FnMut(ResBody) -> NewResBody, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let res = ready!(this.inner.poll(cx)?); Poll::Ready(Ok(res.map(this.f))) } } tower-http-0.4.4/src/metrics/in_flight_requests.rs000064400000000000000000000225641046102023000204600ustar 00000000000000//! Measure the number of in-flight requests. //! //! In-flight requests is the number of requests a service is currently processing. The processing //! of a request starts when it is received by the service (`tower::Service::call` is called) and //! is considered complete when the response body is consumed, dropped, or an error happens. //! //! # Example //! //! ``` //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::metrics::InFlightRequestsLayer; //! use http::{Request, Response}; //! use hyper::Body; //! use std::{time::Duration, convert::Infallible}; //! //! async fn handle(req: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! async fn update_in_flight_requests_metric(count: usize) { //! // ... //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! // Create a `Layer` with an associated counter. //! let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair(); //! //! // Spawn a task that will receive the number of in-flight requests every 10 seconds. //! tokio::spawn( //! counter.run_emitter(Duration::from_secs(10), |count| async move { //! update_in_flight_requests_metric(count).await; //! }), //! ); //! //! let mut service = ServiceBuilder::new() //! // Keep track of the number of in-flight requests. This will increment and decrement //! // `counter` automatically. //! .layer(in_flight_requests_layer) //! .service_fn(handle); //! //! // Call the service. //! let response = service //! .ready() //! .await? //! .call(Request::new(Body::empty())) //! .await?; //! # Ok(()) //! # } //! ``` use futures_util::ready; use http::{Request, Response}; use http_body::Body; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, task::{Context, Poll}, time::Duration, }; use tower_layer::Layer; use tower_service::Service; /// Layer for applying [`InFlightRequests`] which counts the number of in-flight requests. /// /// See the [module docs](crate::metrics::in_flight_requests) for more details. #[derive(Clone, Debug)] pub struct InFlightRequestsLayer { counter: InFlightRequestsCounter, } impl InFlightRequestsLayer { /// Create a new `InFlightRequestsLayer` and its associated counter. pub fn pair() -> (Self, InFlightRequestsCounter) { let counter = InFlightRequestsCounter::new(); let layer = Self::new(counter.clone()); (layer, counter) } /// Create a new `InFlightRequestsLayer` that will update the given counter. pub fn new(counter: InFlightRequestsCounter) -> Self { Self { counter } } } impl Layer for InFlightRequestsLayer { type Service = InFlightRequests; fn layer(&self, inner: S) -> Self::Service { InFlightRequests { inner, counter: self.counter.clone(), } } } /// Middleware that counts the number of in-flight requests. /// /// See the [module docs](crate::metrics::in_flight_requests) for more details. #[derive(Clone, Debug)] pub struct InFlightRequests { inner: S, counter: InFlightRequestsCounter, } impl InFlightRequests { /// Create a new `InFlightRequests` and its associated counter. pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) { let counter = InFlightRequestsCounter::new(); let service = Self::new(inner, counter.clone()); (service, counter) } /// Create a new `InFlightRequests` that will update the given counter. pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self { Self { inner, counter } } define_inner_service_accessors!(); } /// An atomic counter that keeps track of the number of in-flight requests. /// /// This will normally combined with [`InFlightRequestsLayer`] or [`InFlightRequests`] which will /// update the counter as requests arrive. #[derive(Debug, Clone, Default)] pub struct InFlightRequestsCounter { count: Arc, } impl InFlightRequestsCounter { /// Create a new `InFlightRequestsCounter`. pub fn new() -> Self { Self::default() } /// Get the current number of in-flight requests. pub fn get(&self) -> usize { self.count.load(Ordering::Relaxed) } fn increment(&self) -> IncrementGuard { self.count.fetch_add(1, Ordering::Relaxed); IncrementGuard { count: self.count.clone(), } } /// Run a future every `interval` which receives the current number of in-flight requests. /// /// This can be used to send the current count to your metrics system. /// /// This function will loop forever so normally it is called with [`tokio::spawn`]: /// /// ```rust,no_run /// use tower_http::metrics::in_flight_requests::InFlightRequestsCounter; /// use std::time::Duration; /// /// let counter = InFlightRequestsCounter::new(); /// /// tokio::spawn( /// counter.run_emitter(Duration::from_secs(10), |count: usize| async move { /// // Send `count` to metrics system. /// }), /// ); /// ``` pub async fn run_emitter(mut self, interval: Duration, mut emit: F) where F: FnMut(usize) -> Fut + Send + 'static, Fut: Future + Send, { let mut interval = tokio::time::interval(interval); loop { // if all producers have gone away we don't need to emit anymore match Arc::try_unwrap(self.count) { Ok(_) => return, Err(shared_count) => { self = Self { count: shared_count, } } } interval.tick().await; emit(self.get()).await; } } } struct IncrementGuard { count: Arc, } impl Drop for IncrementGuard { fn drop(&mut self) { self.count.fetch_sub(1, Ordering::Relaxed); } } impl Service> for InFlightRequests where S: Service, Response = Response>, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let guard = self.counter.increment(); ResponseFuture { inner: self.inner.call(req), guard: Some(guard), } } } pin_project! { /// Response future for [`InFlightRequests`]. pub struct ResponseFuture { #[pin] inner: F, guard: Option, } } impl Future for ResponseFuture where F: Future, E>>, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let response = ready!(this.inner.poll(cx))?; let guard = this.guard.take().unwrap(); let response = response.map(move |body| ResponseBody { inner: body, guard }); Poll::Ready(Ok(response)) } } pin_project! { /// Response body for [`InFlightRequests`]. pub struct ResponseBody { #[pin] inner: B, guard: IncrementGuard, } } impl Body for ResponseBody where B: Body, { type Data = B::Data; type Error = B::Error; #[inline] fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { self.project().inner.poll_data(cx) } #[inline] fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { self.project().inner.poll_trailers(cx) } #[inline] fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } #[inline] fn size_hint(&self) -> http_body::SizeHint { self.inner.size_hint() } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::Request; use hyper::Body; use tower::{BoxError, ServiceBuilder}; #[tokio::test] async fn basic() { let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair(); let mut service = ServiceBuilder::new() .layer(in_flight_requests_layer) .service_fn(echo); assert_eq!(counter.get(), 0); // driving service to ready shouldn't increment the counter futures::future::poll_fn(|cx| service.poll_ready(cx)) .await .unwrap(); assert_eq!(counter.get(), 0); // creating the response future should increment the count let response_future = service.call(Request::new(Body::empty())); assert_eq!(counter.get(), 1); // count shouldn't decrement until the full body has been comsumed let response = response_future.await.unwrap(); assert_eq!(counter.get(), 1); let body = response.into_body(); hyper::body::to_bytes(body).await.unwrap(); assert_eq!(counter.get(), 0); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } } tower-http-0.4.4/src/metrics/mod.rs000064400000000000000000000005321046102023000153300ustar 00000000000000//! Middlewares for adding metrics to services. //! //! Supported metrics: //! //! - [In-flight requests][]: Measure the number of requests a service is currently processing. //! //! [In-flight requests]: in_flight_requests pub mod in_flight_requests; #[doc(inline)] pub use self::in_flight_requests::{InFlightRequests, InFlightRequestsLayer}; tower-http-0.4.4/src/normalize_path.rs000064400000000000000000000143561046102023000161300ustar 00000000000000//! Middleware that normalizes paths. //! //! Any trailing slashes from request paths will be removed. For example, a request with `/foo/` //! will be changed to `/foo` before reaching the inner service. //! //! # Example //! //! ``` //! use tower_http::normalize_path::NormalizePathLayer; //! use http::{Request, Response, StatusCode}; //! use hyper::Body; //! use std::{iter::once, convert::Infallible}; //! use tower::{ServiceBuilder, Service, ServiceExt}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request) -> Result, Infallible> { //! // `req.uri().path()` will not have trailing slashes //! # Ok(Response::new(Body::empty())) //! } //! //! let mut service = ServiceBuilder::new() //! // trim trailing slashes from paths //! .layer(NormalizePathLayer::trim_trailing_slash()) //! .service_fn(handle); //! //! // call the service //! let request = Request::builder() //! // `handle` will see `/foo` //! .uri("/foo/") //! .body(Body::empty())?; //! //! service.ready().await?.call(request).await?; //! # //! # Ok(()) //! # } //! ``` use http::{Request, Response, Uri}; use std::{ borrow::Cow, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`NormalizePath`] which normalizes paths. /// /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] pub struct NormalizePathLayer {} impl NormalizePathLayer { /// Create a new [`NormalizePathLayer`]. /// /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` /// will be changed to `/foo` before reaching the inner service. pub fn trim_trailing_slash() -> Self { NormalizePathLayer {} } } impl Layer for NormalizePathLayer { type Service = NormalizePath; fn layer(&self, inner: S) -> Self::Service { NormalizePath::trim_trailing_slash(inner) } } /// Middleware that normalizes paths. /// /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] pub struct NormalizePath { inner: S, } impl NormalizePath { /// Create a new [`NormalizePath`]. /// /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` /// will be changed to `/foo` before reaching the inner service. pub fn trim_trailing_slash(inner: S) -> Self { Self { inner } } define_inner_service_accessors!(); } impl Service> for NormalizePath where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { normalize_trailing_slash(req.uri_mut()); self.inner.call(req) } } fn normalize_trailing_slash(uri: &mut Uri) { if !uri.path().ends_with('/') && !uri.path().starts_with("//") { return; } let new_path = format!("/{}", uri.path().trim_matches('/')); let mut parts = uri.clone().into_parts(); let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { let new_path_and_query = if let Some(query) = path_and_query.query() { Cow::Owned(format!("{}?{}", new_path, query)) } else { new_path.into() } .parse() .unwrap(); Some(new_path_and_query) } else { None }; parts.path_and_query = new_path_and_query; if let Ok(new_uri) = Uri::from_parts(parts) { *uri = new_uri; } } #[cfg(test)] mod tests { use super::*; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn works() { async fn handle(request: Request<()>) -> Result, Infallible> { Ok(Response::new(request.uri().to_string())) } let mut svc = ServiceBuilder::new() .layer(NormalizePathLayer::trim_trailing_slash()) .service_fn(handle); let body = svc .ready() .await .unwrap() .call(Request::builder().uri("/foo/").body(()).unwrap()) .await .unwrap() .into_body(); assert_eq!(body, "/foo"); } #[test] fn is_noop_if_no_trailing_slash() { let mut uri = "/foo".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn maintains_query() { let mut uri = "/foo/?a=a".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_trailing_slashes() { let mut uri = "/foo////".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn removes_multiple_trailing_slashes_even_with_query() { let mut uri = "/foo////?a=a".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn is_noop_on_index() { let mut uri = "/".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index() { let mut uri = "////".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index_even_with_query() { let mut uri = "////?a=a".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/?a=a"); } #[test] fn removes_multiple_preceding_slashes_even_with_query() { let mut uri = "///foo//?a=a".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_preceding_slashes() { let mut uri = "///foo".parse::().unwrap(); normalize_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } } tower-http-0.4.4/src/propagate_header.rs000064400000000000000000000105331046102023000163770ustar 00000000000000//! Propagate a header from the request to the response. //! //! # Example //! //! ```rust //! use http::{Request, Response, header::HeaderName}; //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::propagate_header::PropagateHeaderLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! async fn handle(req: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! let mut svc = ServiceBuilder::new() //! // This will copy `x-request-id` headers from requests onto responses. //! .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id"))) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::builder() //! .header("x-request-id", "1337") //! .body(Body::empty())?; //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["x-request-id"], "1337"); //! # //! # Ok(()) //! # } //! ``` use futures_util::ready; use http::{header::HeaderName, HeaderValue, Request, Response}; use pin_project_lite::pin_project; use std::future::Future; use std::{ pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`PropagateHeader`] which propagates headers from requests to responses. /// /// If the header is present on the request it'll be applied to the response as well. This could /// for example be used to propagate headers such as `X-Request-Id`. /// /// See the [module docs](crate::propagate_header) for more details. #[derive(Clone, Debug)] pub struct PropagateHeaderLayer { header: HeaderName, } impl PropagateHeaderLayer { /// Create a new [`PropagateHeaderLayer`]. pub fn new(header: HeaderName) -> Self { Self { header } } } impl Layer for PropagateHeaderLayer { type Service = PropagateHeader; fn layer(&self, inner: S) -> Self::Service { PropagateHeader { inner, header: self.header.clone(), } } } /// Middleware that propagates headers from requests to responses. /// /// If the header is present on the request it'll be applied to the response as well. This could /// for example be used to propagate headers such as `X-Request-Id`. /// /// See the [module docs](crate::propagate_header) for more details. #[derive(Clone, Debug)] pub struct PropagateHeader { inner: S, header: HeaderName, } impl PropagateHeader { /// Create a new [`PropagateHeader`] that propagates the given header. pub fn new(inner: S, header: HeaderName) -> Self { Self { inner, header } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `PropagateHeader` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(header: HeaderName) -> PropagateHeaderLayer { PropagateHeaderLayer::new(header) } } impl Service> for PropagateHeader where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let value = req.headers().get(&self.header).cloned(); ResponseFuture { future: self.inner.call(req), header_and_value: Some(self.header.clone()).zip(value), } } } pin_project! { /// Response future for [`PropagateHeader`]. #[derive(Debug)] pub struct ResponseFuture { #[pin] future: F, header_and_value: Option<(HeaderName, HeaderValue)>, } } impl Future for ResponseFuture where F: Future, E>>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.future.poll(cx)?); if let Some((header, value)) = this.header_and_value.take() { res.headers_mut().insert(header, value); } Poll::Ready(Ok(res)) } } tower-http-0.4.4/src/request_id.rs000064400000000000000000000454531046102023000152620ustar 00000000000000//! Set and propagate request ids. //! //! # Example //! //! ``` //! use http::{Request, Response, header::HeaderName}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::request_id::{ //! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! }; //! use hyper::Body; //! use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let handler = tower::service_fn(|request: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # //! // A `MakeRequestId` that increments an atomic counter //! #[derive(Clone, Default)] //! struct MyMakeRequestId { //! counter: Arc, //! } //! //! impl MakeRequestId for MyMakeRequestId { //! fn make_request_id(&mut self, request: &Request) -> Option { //! let request_id = self.counter //! .fetch_add(1, Ordering::SeqCst) //! .to_string() //! .parse() //! .unwrap(); //! //! Some(RequestId::new(request_id)) //! } //! } //! //! let x_request_id = HeaderName::from_static("x-request-id"); //! //! let mut svc = ServiceBuilder::new() //! // set `x-request-id` header on all requests //! .layer(SetRequestIdLayer::new( //! x_request_id.clone(), //! MyMakeRequestId::default(), //! )) //! // propagate `x-request-id` headers from request to response //! .layer(PropagateRequestIdLayer::new(x_request_id)) //! .service(handler); //! //! let request = Request::new(Body::empty()); //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["x-request-id"], "0"); //! # //! # Ok(()) //! # } //! ``` //! //! Additional convenience methods are available on [`ServiceBuilderExt`]: //! //! ``` //! use tower_http::ServiceBuilderExt; //! # use http::{Request, Response, header::HeaderName}; //! # use tower::{Service, ServiceExt, ServiceBuilder}; //! # use tower_http::request_id::{ //! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! # }; //! # use hyper::Body; //! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let handler = tower::service_fn(|request: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # #[derive(Clone, Default)] //! # struct MyMakeRequestId { //! # counter: Arc, //! # } //! # impl MakeRequestId for MyMakeRequestId { //! # fn make_request_id(&mut self, request: &Request) -> Option { //! # let request_id = self.counter //! # .fetch_add(1, Ordering::SeqCst) //! # .to_string() //! # .parse() //! # .unwrap(); //! # Some(RequestId::new(request_id)) //! # } //! # } //! //! let mut svc = ServiceBuilder::new() //! .set_x_request_id(MyMakeRequestId::default()) //! .propagate_x_request_id() //! .service(handler); //! //! let request = Request::new(Body::empty()); //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["x-request-id"], "0"); //! # //! # Ok(()) //! # } //! ``` //! //! See [`SetRequestId`] and [`PropagateRequestId`] for more details. //! //! # Using `Trace` //! //! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers //! in this order: //! //! ``` //! use tower_http::{ //! ServiceBuilderExt, //! trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse}, //! }; //! # use http::{Request, Response, header::HeaderName}; //! # use tower::{Service, ServiceExt, ServiceBuilder}; //! # use tower_http::request_id::{ //! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! # }; //! # use hyper::Body; //! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let handler = tower::service_fn(|request: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # #[derive(Clone, Default)] //! # struct MyMakeRequestId { //! # counter: Arc, //! # } //! # impl MakeRequestId for MyMakeRequestId { //! # fn make_request_id(&mut self, request: &Request) -> Option { //! # let request_id = self.counter //! # .fetch_add(1, Ordering::SeqCst) //! # .to_string() //! # .parse() //! # .unwrap(); //! # Some(RequestId::new(request_id)) //! # } //! # } //! //! let svc = ServiceBuilder::new() //! // make sure to set request ids before the request reaches `TraceLayer` //! .set_x_request_id(MyMakeRequestId::default()) //! // log requests and responses //! .layer( //! TraceLayer::new_for_http() //! .make_span_with(DefaultMakeSpan::new().include_headers(true)) //! .on_response(DefaultOnResponse::new().include_headers(true)) //! ) //! // propagate the header to the response before the response reaches `TraceLayer` //! .propagate_x_request_id() //! .service(handler); //! # //! # Ok(()) //! # } //! ``` //! //! # Doesn't override existing headers //! //! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on //! requests or responses. Among other things, this allows other middleware to conditionally set //! request ids and use the middleware in this module as a fallback. //! //! [`ServiceBuilderExt`]: crate::ServiceBuilderExt //! [`Uuid`]: https://crates.io/crates/uuid //! [`Trace`]: crate::trace::Trace use http::{ header::{HeaderName, HeaderValue}, Request, Response, }; use pin_project_lite::pin_project; use std::task::{Context, Poll}; use std::{future::Future, pin::Pin}; use tower_layer::Layer; use tower_service::Service; use uuid::Uuid; pub(crate) const X_REQUEST_ID: &str = "x-request-id"; /// Trait for producing [`RequestId`]s. /// /// Used by [`SetRequestId`]. pub trait MakeRequestId { /// Try and produce a [`RequestId`] from the request. fn make_request_id(&mut self, request: &Request) -> Option; } /// An identifier for a request. #[derive(Debug, Clone)] pub struct RequestId(HeaderValue); impl RequestId { /// Create a new `RequestId` from a [`HeaderValue`]. pub fn new(header_value: HeaderValue) -> Self { Self(header_value) } /// Gets a reference to the underlying [`HeaderValue`]. pub fn header_value(&self) -> &HeaderValue { &self.0 } /// Consumes `self`, returning the underlying [`HeaderValue`]. pub fn into_header_value(self) -> HeaderValue { self.0 } } impl From for RequestId { fn from(value: HeaderValue) -> Self { Self::new(value) } } /// Set request id headers and extensions on requests. /// /// This layer applies the [`SetRequestId`] middleware. /// /// See the [module docs](self) and [`SetRequestId`] for more details. #[derive(Debug, Clone)] pub struct SetRequestIdLayer { header_name: HeaderName, make_request_id: M, } impl SetRequestIdLayer { /// Create a new `SetRequestIdLayer`. pub fn new(header_name: HeaderName, make_request_id: M) -> Self where M: MakeRequestId, { SetRequestIdLayer { header_name, make_request_id, } } /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name. pub fn x_request_id(make_request_id: M) -> Self where M: MakeRequestId, { SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id) } } impl Layer for SetRequestIdLayer where M: Clone + MakeRequestId, { type Service = SetRequestId; fn layer(&self, inner: S) -> Self::Service { SetRequestId::new( inner, self.header_name.clone(), self.make_request_id.clone(), ) } } /// Set request id headers and extensions on requests. /// /// See the [module docs](self) for an example. /// /// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a /// header with the same name, then the header will be inserted. /// /// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other /// services can access it. #[derive(Debug, Clone)] pub struct SetRequestId { inner: S, header_name: HeaderName, make_request_id: M, } impl SetRequestId { /// Create a new `SetRequestId`. pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self where M: MakeRequestId, { Self { inner, header_name, make_request_id, } } /// Create a new `SetRequestId` that uses `x-request-id` as the header name. pub fn x_request_id(inner: S, make_request_id: M) -> Self where M: MakeRequestId, { Self::new( inner, HeaderName::from_static(X_REQUEST_ID), make_request_id, ) } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware. pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer where M: MakeRequestId, { SetRequestIdLayer::new(header_name, make_request_id) } } impl Service> for SetRequestId where S: Service, Response = Response>, M: MakeRequestId, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { if let Some(request_id) = req.headers().get(&self.header_name) { if req.extensions().get::().is_none() { let request_id = request_id.clone(); req.extensions_mut().insert(RequestId::new(request_id)); } } else if let Some(request_id) = self.make_request_id.make_request_id(&req) { req.extensions_mut().insert(request_id.clone()); req.headers_mut() .insert(self.header_name.clone(), request_id.0); } self.inner.call(req) } } /// Propagate request ids from requests to responses. /// /// This layer applies the [`PropagateRequestId`] middleware. /// /// See the [module docs](self) and [`PropagateRequestId`] for more details. #[derive(Debug, Clone)] pub struct PropagateRequestIdLayer { header_name: HeaderName, } impl PropagateRequestIdLayer { /// Create a new `PropagateRequestIdLayer`. pub fn new(header_name: HeaderName) -> Self { PropagateRequestIdLayer { header_name } } /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name. pub fn x_request_id() -> Self { Self::new(HeaderName::from_static(X_REQUEST_ID)) } } impl Layer for PropagateRequestIdLayer { type Service = PropagateRequestId; fn layer(&self, inner: S) -> Self::Service { PropagateRequestId::new(inner, self.header_name.clone()) } } /// Propagate request ids from requests to responses. /// /// See the [module docs](self) for an example. /// /// If the request contains a matching header that header will be applied to responses. If a /// [`RequestId`] extension is also present it will be propagated as well. #[derive(Debug, Clone)] pub struct PropagateRequestId { inner: S, header_name: HeaderName, } impl PropagateRequestId { /// Create a new `PropagateRequestId`. pub fn new(inner: S, header_name: HeaderName) -> Self { Self { inner, header_name } } /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name. pub fn x_request_id(inner: S) -> Self { Self::new(inner, HeaderName::from_static(X_REQUEST_ID)) } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware. pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer { PropagateRequestIdLayer::new(header_name) } } impl Service> for PropagateRequestId where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = PropagateRequestIdResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let request_id = req .headers() .get(&self.header_name) .cloned() .map(RequestId::new); PropagateRequestIdResponseFuture { inner: self.inner.call(req), header_name: self.header_name.clone(), request_id, } } } pin_project! { /// Response future for [`PropagateRequestId`]. pub struct PropagateRequestIdResponseFuture { #[pin] inner: F, header_name: HeaderName, request_id: Option, } } impl Future for PropagateRequestIdResponseFuture where F: Future, E>>, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut response = futures_core::ready!(this.inner.poll(cx))?; if let Some(current_id) = response.headers().get(&*this.header_name) { if response.extensions().get::().is_none() { let current_id = current_id.clone(); response.extensions_mut().insert(RequestId::new(current_id)); } } else if let Some(request_id) = this.request_id.take() { response .headers_mut() .insert(this.header_name.clone(), request_id.0.clone()); response.extensions_mut().insert(request_id); } Poll::Ready(Ok(response)) } } /// A [`MakeRequestId`] that generates `UUID`s. #[derive(Clone, Copy, Default)] pub struct MakeRequestUuid; impl MakeRequestId for MakeRequestUuid { fn make_request_id(&mut self, _request: &Request) -> Option { let request_id = Uuid::new_v4().to_string().parse().unwrap(); Some(RequestId::new(request_id)) } } #[cfg(test)] mod tests { use crate::ServiceBuilderExt as _; use hyper::{Body, Response}; use std::{ convert::Infallible, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, }; use tower::{ServiceBuilder, ServiceExt}; #[allow(unused_imports)] use super::*; #[tokio::test] async fn basic() { let svc = ServiceBuilder::new() .set_x_request_id(Counter::default()) .propagate_x_request_id() .service_fn(handler); // header on response let req = Request::builder().body(Body::empty()).unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.headers()["x-request-id"], "0"); let req = Request::builder().body(Body::empty()).unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.headers()["x-request-id"], "1"); // doesn't override if header is already there let req = Request::builder() .header("x-request-id", "foo") .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.headers()["x-request-id"], "foo"); // extension propagated let req = Request::builder().body(Body::empty()).unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.extensions().get::().unwrap().0, "2"); } #[tokio::test] async fn other_middleware_setting_request_id() { let svc = ServiceBuilder::new() .override_request_header( HeaderName::from_static("x-request-id"), HeaderValue::from_str("foo").unwrap(), ) .set_x_request_id(Counter::default()) .map_request(|request: Request<_>| { // `set_x_request_id` should set the extension if its missing assert_eq!(request.extensions().get::().unwrap().0, "foo"); request }) .propagate_x_request_id() .service_fn(handler); let req = Request::builder() .header( "x-request-id", "this-will-be-overriden-by-override_request_header-middleware", ) .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.headers()["x-request-id"], "foo"); assert_eq!(res.extensions().get::().unwrap().0, "foo"); } #[tokio::test] async fn other_middleware_setting_request_id_on_response() { let svc = ServiceBuilder::new() .set_x_request_id(Counter::default()) .propagate_x_request_id() .override_response_header( HeaderName::from_static("x-request-id"), HeaderValue::from_str("foo").unwrap(), ) .service_fn(handler); let req = Request::builder() .header("x-request-id", "foo") .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(req).await.unwrap(); assert_eq!(res.headers()["x-request-id"], "foo"); assert_eq!(res.extensions().get::().unwrap().0, "foo"); } #[derive(Clone, Default)] struct Counter(Arc); impl MakeRequestId for Counter { fn make_request_id(&mut self, _request: &Request) -> Option { let id = HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap(); Some(RequestId::new(id)) } } async fn handler(_: Request) -> Result, Infallible> { Ok(Response::new(Body::empty())) } #[tokio::test] async fn uuid() { let svc = ServiceBuilder::new() .set_x_request_id(MakeRequestUuid) .propagate_x_request_id() .service_fn(handler); // header on response let req = Request::builder().body(Body::empty()).unwrap(); let mut res = svc.clone().oneshot(req).await.unwrap(); let id = res.headers_mut().remove("x-request-id").unwrap(); id.to_str().unwrap().parse::().unwrap(); } } tower-http-0.4.4/src/sensitive_headers.rs000064400000000000000000000331541046102023000166150ustar 00000000000000//! Middlewares that mark headers as [sensitive]. //! //! [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive //! //! # Example //! //! ``` //! use tower_http::sensitive_headers::SetSensitiveHeadersLayer; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use http::{Request, Response, header::AUTHORIZATION}; //! use hyper::Body; //! use std::{iter::once, convert::Infallible}; //! //! async fn handle(req: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut service = ServiceBuilder::new() //! // Mark the `Authorization` header as sensitive so it doesn't show in logs //! // //! // `SetSensitiveHeadersLayer` will mark the header as sensitive on both the //! // request and response. //! // //! // The middleware is constructed from an iterator of headers to easily mark //! // multiple headers at once. //! .layer(SetSensitiveHeadersLayer::new(once(AUTHORIZATION))) //! .service(service_fn(handle)); //! //! // Call the service. //! let response = service //! .ready() //! .await? //! .call(Request::new(Body::empty())) //! .await?; //! # Ok(()) //! # } //! ``` //! //! Its important to think about the order in which requests and responses arrive at your //! middleware. For example to hide headers both on requests and responses when using //! [`TraceLayer`] you have to apply [`SetSensitiveRequestHeadersLayer`] before [`TraceLayer`] //! and [`SetSensitiveResponseHeadersLayer`] afterwards. //! //! ``` //! use tower_http::{ //! trace::TraceLayer, //! sensitive_headers::{ //! SetSensitiveRequestHeadersLayer, //! SetSensitiveResponseHeadersLayer, //! }, //! }; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use http::header; //! use std::sync::Arc; //! # use http::{Request, Response}; //! # use hyper::Body; //! # use std::convert::Infallible; //! # async fn handle(req: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::empty())) //! # } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let headers: Arc<[_]> = Arc::new([ //! header::AUTHORIZATION, //! header::PROXY_AUTHORIZATION, //! header::COOKIE, //! header::SET_COOKIE, //! ]); //! //! let service = ServiceBuilder::new() //! .layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone(&headers))) //! .layer(TraceLayer::new_for_http()) //! .layer(SetSensitiveResponseHeadersLayer::from_shared(headers)) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! [`TraceLayer`]: crate::trace::TraceLayer use futures_util::ready; use http::{header::HeaderName, Request, Response}; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Mark headers as [sensitive] on both requests and responses. /// /// Produces [`SetSensitiveHeaders`] services. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[derive(Clone, Debug)] pub struct SetSensitiveHeadersLayer { headers: Arc<[HeaderName]>, } impl SetSensitiveHeadersLayer { /// Create a new [`SetSensitiveHeadersLayer`]. pub fn new(headers: I) -> Self where I: IntoIterator, { let headers = headers.into_iter().collect::>(); Self::from_shared(headers.into()) } /// Create a new [`SetSensitiveHeadersLayer`] from a shared slice of headers. pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { Self { headers } } } impl Layer for SetSensitiveHeadersLayer { type Service = SetSensitiveHeaders; fn layer(&self, inner: S) -> Self::Service { SetSensitiveRequestHeaders::from_shared( SetSensitiveResponseHeaders::from_shared(inner, self.headers.clone()), self.headers.clone(), ) } } /// Mark headers as [sensitive] on both requests and responses. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive pub type SetSensitiveHeaders = SetSensitiveRequestHeaders>; /// Mark request headers as [sensitive]. /// /// Produces [`SetSensitiveRequestHeaders`] services. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[derive(Clone, Debug)] pub struct SetSensitiveRequestHeadersLayer { headers: Arc<[HeaderName]>, } impl SetSensitiveRequestHeadersLayer { /// Create a new [`SetSensitiveRequestHeadersLayer`]. pub fn new(headers: I) -> Self where I: IntoIterator, { let headers = headers.into_iter().collect::>(); Self::from_shared(headers.into()) } /// Create a new [`SetSensitiveRequestHeadersLayer`] from a shared slice of headers. pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { Self { headers } } } impl Layer for SetSensitiveRequestHeadersLayer { type Service = SetSensitiveRequestHeaders; fn layer(&self, inner: S) -> Self::Service { SetSensitiveRequestHeaders { inner, headers: self.headers.clone(), } } } /// Mark request headers as [sensitive]. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[derive(Clone, Debug)] pub struct SetSensitiveRequestHeaders { inner: S, headers: Arc<[HeaderName]>, } impl SetSensitiveRequestHeaders { /// Create a new [`SetSensitiveRequestHeaders`]. pub fn new(inner: S, headers: I) -> Self where I: IntoIterator, { let headers = headers.into_iter().collect::>(); Self::from_shared(inner, headers.into()) } /// Create a new [`SetSensitiveRequestHeaders`] from a shared slice of headers. pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self { Self { inner, headers } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `SetSensitiveRequestHeaders` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(headers: I) -> SetSensitiveRequestHeadersLayer where I: IntoIterator, { SetSensitiveRequestHeadersLayer::new(headers) } } impl Service> for SetSensitiveRequestHeaders where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let headers = req.headers_mut(); for header in &*self.headers { if let http::header::Entry::Occupied(mut entry) = headers.entry(header) { for value in entry.iter_mut() { value.set_sensitive(true); } } } self.inner.call(req) } } /// Mark response headers as [sensitive]. /// /// Produces [`SetSensitiveResponseHeaders`] services. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[derive(Clone, Debug)] pub struct SetSensitiveResponseHeadersLayer { headers: Arc<[HeaderName]>, } impl SetSensitiveResponseHeadersLayer { /// Create a new [`SetSensitiveResponseHeadersLayer`]. pub fn new(headers: I) -> Self where I: IntoIterator, { let headers = headers.into_iter().collect::>(); Self::from_shared(headers.into()) } /// Create a new [`SetSensitiveResponseHeadersLayer`] from a shared slice of headers. pub fn from_shared(headers: Arc<[HeaderName]>) -> Self { Self { headers } } } impl Layer for SetSensitiveResponseHeadersLayer { type Service = SetSensitiveResponseHeaders; fn layer(&self, inner: S) -> Self::Service { SetSensitiveResponseHeaders { inner, headers: self.headers.clone(), } } } /// Mark response headers as [sensitive]. /// /// See the [module docs](crate::sensitive_headers) for more details. /// /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[derive(Clone, Debug)] pub struct SetSensitiveResponseHeaders { inner: S, headers: Arc<[HeaderName]>, } impl SetSensitiveResponseHeaders { /// Create a new [`SetSensitiveResponseHeaders`]. pub fn new(inner: S, headers: I) -> Self where I: IntoIterator, { let headers = headers.into_iter().collect::>(); Self::from_shared(inner, headers.into()) } /// Create a new [`SetSensitiveResponseHeaders`] from a shared slice of headers. pub fn from_shared(inner: S, headers: Arc<[HeaderName]>) -> Self { Self { inner, headers } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `SetSensitiveResponseHeaders` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(headers: I) -> SetSensitiveResponseHeadersLayer where I: IntoIterator, { SetSensitiveResponseHeadersLayer::new(headers) } } impl Service> for SetSensitiveResponseHeaders where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = SetSensitiveResponseHeadersResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { SetSensitiveResponseHeadersResponseFuture { future: self.inner.call(req), headers: self.headers.clone(), } } } pin_project! { /// Response future for [`SetSensitiveResponseHeaders`]. #[derive(Debug)] pub struct SetSensitiveResponseHeadersResponseFuture { #[pin] future: F, headers: Arc<[HeaderName]>, } } impl Future for SetSensitiveResponseHeadersResponseFuture where F: Future, E>>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.future.poll(cx)?); let headers = res.headers_mut(); for header in &**this.headers { if let http::header::Entry::Occupied(mut entry) = headers.entry(header) { for value in entry.iter_mut() { value.set_sensitive(true); } } } Poll::Ready(Ok(res)) } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::header; use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn multiple_value_header() { async fn response_set_cookie(req: http::Request<()>) -> Result, ()> { let mut iter = req.headers().get_all(header::COOKIE).iter().peekable(); assert!(iter.peek().is_some()); for value in iter { assert!(value.is_sensitive()) } let mut resp = http::Response::new(()); resp.headers_mut().append( header::CONTENT_TYPE, http::HeaderValue::from_static("text/html"), ); resp.headers_mut().append( header::SET_COOKIE, http::HeaderValue::from_static("cookie-1"), ); resp.headers_mut().append( header::SET_COOKIE, http::HeaderValue::from_static("cookie-2"), ); resp.headers_mut().append( header::SET_COOKIE, http::HeaderValue::from_static("cookie-3"), ); Ok(resp) } let mut service = ServiceBuilder::new() .layer(SetSensitiveRequestHeadersLayer::new(vec![header::COOKIE])) .layer(SetSensitiveResponseHeadersLayer::new(vec![ header::SET_COOKIE, ])) .service_fn(response_set_cookie); let mut req = http::Request::new(()); req.headers_mut() .append(header::COOKIE, http::HeaderValue::from_static("cookie+1")); req.headers_mut() .append(header::COOKIE, http::HeaderValue::from_static("cookie+2")); let resp = service.ready().await.unwrap().call(req).await.unwrap(); assert!(!resp .headers() .get(header::CONTENT_TYPE) .unwrap() .is_sensitive()); let mut iter = resp.headers().get_all(header::SET_COOKIE).iter().peekable(); assert!(iter.peek().is_some()); for value in iter { assert!(value.is_sensitive()) } } } tower-http-0.4.4/src/services/fs/mod.rs000064400000000000000000000037341046102023000161240ustar 00000000000000//! File system related services. use bytes::Bytes; use futures_util::Stream; use http::HeaderMap; use http_body::Body; use pin_project_lite::pin_project; use std::{ io, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncReadExt, Take}; use tokio_util::io::ReaderStream; mod serve_dir; mod serve_file; pub use self::{ serve_dir::{ future::ResponseFuture as ServeFileSystemResponseFuture, DefaultServeDirFallback, // The response body and future are used for both ServeDir and ServeFile ResponseBody as ServeFileSystemResponseBody, ServeDir, }, serve_file::ServeFile, }; pin_project! { // NOTE: This could potentially be upstreamed to `http-body`. /// Adapter that turns an [`impl AsyncRead`][tokio::io::AsyncRead] to an [`impl Body`][http_body::Body]. #[derive(Debug)] pub struct AsyncReadBody { #[pin] reader: ReaderStream, } } impl AsyncReadBody where T: AsyncRead, { /// Create a new [`AsyncReadBody`] wrapping the given reader, /// with a specific read buffer capacity fn with_capacity(read: T, capacity: usize) -> Self { Self { reader: ReaderStream::with_capacity(read, capacity), } } fn with_capacity_limited( read: T, capacity: usize, max_read_bytes: u64, ) -> AsyncReadBody> { AsyncReadBody { reader: ReaderStream::with_capacity(read.take(max_read_bytes), capacity), } } } impl Body for AsyncReadBody where T: AsyncRead, { type Data = Bytes; type Error = io::Error; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { self.project().reader.poll_next(cx) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } } tower-http-0.4.4/src/services/fs/serve_dir/future.rs000064400000000000000000000261611046102023000206400ustar 00000000000000use super::{ open_file::{FileOpened, FileRequestExtent, OpenFileOutput}, DefaultServeDirFallback, ResponseBody, }; use crate::{content_encoding::Encoding, services::fs::AsyncReadBody, BoxError}; use bytes::Bytes; use futures_util::{ future::{BoxFuture, FutureExt, TryFutureExt}, ready, }; use http::{ header::{self, ALLOW}, HeaderValue, Request, Response, StatusCode, }; use http_body::{Body, Empty, Full}; use pin_project_lite::pin_project; use std::{ convert::Infallible, future::Future, io, pin::Pin, task::{Context, Poll}, }; use tower_service::Service; pin_project! { /// Response future of [`ServeDir::try_call()`][`super::ServeDir::try_call()`]. pub struct ResponseFuture { #[pin] pub(super) inner: ResponseFutureInner, } } impl ResponseFuture { pub(super) fn open_file_future( future: BoxFuture<'static, io::Result>, fallback_and_request: Option<(F, Request)>, ) -> Self { Self { inner: ResponseFutureInner::OpenFileFuture { future, fallback_and_request, }, } } pub(super) fn invalid_path(fallback_and_request: Option<(F, Request)>) -> Self { Self { inner: ResponseFutureInner::InvalidPath { fallback_and_request, }, } } pub(super) fn method_not_allowed() -> Self { Self { inner: ResponseFutureInner::MethodNotAllowed, } } } pin_project! { #[project = ResponseFutureInnerProj] pub(super) enum ResponseFutureInner { OpenFileFuture { #[pin] future: BoxFuture<'static, io::Result>, fallback_and_request: Option<(F, Request)>, }, FallbackFuture { future: BoxFuture<'static, Result, Infallible>>, }, InvalidPath { fallback_and_request: Option<(F, Request)>, }, MethodNotAllowed, } } impl Future for ResponseFuture where F: Service, Response = Response, Error = Infallible> + Clone, F::Future: Send + 'static, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into>, { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let mut this = self.as_mut().project(); let new_state = match this.inner.as_mut().project() { ResponseFutureInnerProj::OpenFileFuture { future: open_file_future, fallback_and_request, } => match ready!(open_file_future.poll(cx)) { Ok(OpenFileOutput::FileOpened(file_output)) => { break Poll::Ready(Ok(build_response(*file_output))); } Ok(OpenFileOutput::Redirect { location }) => { let mut res = response_with_status(StatusCode::TEMPORARY_REDIRECT); res.headers_mut().insert(http::header::LOCATION, location); break Poll::Ready(Ok(res)); } Ok(OpenFileOutput::FileNotFound) => { if let Some((mut fallback, request)) = fallback_and_request.take() { call_fallback(&mut fallback, request) } else { break Poll::Ready(Ok(not_found())); } } Ok(OpenFileOutput::PreconditionFailed) => { break Poll::Ready(Ok(response_with_status( StatusCode::PRECONDITION_FAILED, ))); } Ok(OpenFileOutput::NotModified) => { break Poll::Ready(Ok(response_with_status(StatusCode::NOT_MODIFIED))); } Err(err) => { #[cfg(unix)] // 20 = libc::ENOTDIR => "not a directory // when `io_error_more` landed, this can be changed // to checking for `io::ErrorKind::NotADirectory`. // https://github.com/rust-lang/rust/issues/86442 let error_is_not_a_directory = err.raw_os_error() == Some(20); #[cfg(not(unix))] let error_is_not_a_directory = false; if matches!( err.kind(), io::ErrorKind::NotFound | io::ErrorKind::PermissionDenied ) || error_is_not_a_directory { if let Some((mut fallback, request)) = fallback_and_request.take() { call_fallback(&mut fallback, request) } else { break Poll::Ready(Ok(not_found())); } } else { break Poll::Ready(Err(err)); } } }, ResponseFutureInnerProj::FallbackFuture { future } => { break Pin::new(future).poll(cx).map_err(|err| match err {}) } ResponseFutureInnerProj::InvalidPath { fallback_and_request, } => { if let Some((mut fallback, request)) = fallback_and_request.take() { call_fallback(&mut fallback, request) } else { break Poll::Ready(Ok(not_found())); } } ResponseFutureInnerProj::MethodNotAllowed => { let mut res = response_with_status(StatusCode::METHOD_NOT_ALLOWED); res.headers_mut() .insert(ALLOW, HeaderValue::from_static("GET,HEAD")); break Poll::Ready(Ok(res)); } }; this.inner.set(new_state); } } } fn response_with_status(status: StatusCode) -> Response { Response::builder() .status(status) .body(empty_body()) .unwrap() } fn not_found() -> Response { response_with_status(StatusCode::NOT_FOUND) } pub(super) fn call_fallback( fallback: &mut F, req: Request, ) -> ResponseFutureInner where F: Service, Response = Response, Error = Infallible> + Clone, F::Future: Send + 'static, FResBody: http_body::Body + Send + 'static, FResBody::Error: Into, { let future = fallback .call(req) .map_ok(|response| { response .map(|body| { body.map_err(|err| match err.into().downcast::() { Ok(err) => *err, Err(err) => io::Error::new(io::ErrorKind::Other, err), }) .boxed_unsync() }) .map(ResponseBody::new) }) .boxed(); ResponseFutureInner::FallbackFuture { future } } fn build_response(output: FileOpened) -> Response { let (maybe_file, size) = match output.extent { FileRequestExtent::Full(file, meta) => (Some(file), meta.len()), FileRequestExtent::Head(meta) => (None, meta.len()), }; let mut builder = Response::builder() .header(header::CONTENT_TYPE, output.mime_header_value) .header(header::ACCEPT_RANGES, "bytes"); if let Some(encoding) = output .maybe_encoding .filter(|encoding| *encoding != Encoding::Identity) { builder = builder.header(header::CONTENT_ENCODING, encoding.into_header_value()); } if let Some(last_modified) = output.last_modified { builder = builder.header(header::LAST_MODIFIED, last_modified.0.to_string()); } match output.maybe_range { Some(Ok(ranges)) => { if let Some(range) = ranges.first() { if ranges.len() > 1 { builder .header(header::CONTENT_RANGE, format!("bytes */{}", size)) .status(StatusCode::RANGE_NOT_SATISFIABLE) .body(body_from_bytes(Bytes::from( "Cannot serve multipart range requests", ))) .unwrap() } else { let body = if let Some(file) = maybe_file { let range_size = range.end() - range.start() + 1; ResponseBody::new( AsyncReadBody::with_capacity_limited( file, output.chunk_size, range_size, ) .boxed_unsync(), ) } else { empty_body() }; builder .header( header::CONTENT_RANGE, format!("bytes {}-{}/{}", range.start(), range.end(), size), ) .header(header::CONTENT_LENGTH, range.end() - range.start() + 1) .status(StatusCode::PARTIAL_CONTENT) .body(body) .unwrap() } } else { builder .header(header::CONTENT_RANGE, format!("bytes */{}", size)) .status(StatusCode::RANGE_NOT_SATISFIABLE) .body(body_from_bytes(Bytes::from( "No range found after parsing range header, please file an issue", ))) .unwrap() } } Some(Err(_)) => builder .header(header::CONTENT_RANGE, format!("bytes */{}", size)) .status(StatusCode::RANGE_NOT_SATISFIABLE) .body(empty_body()) .unwrap(), // Not a range request None => { let body = if let Some(file) = maybe_file { ResponseBody::new( AsyncReadBody::with_capacity(file, output.chunk_size).boxed_unsync(), ) } else { empty_body() }; builder .header(header::CONTENT_LENGTH, size.to_string()) .body(body) .unwrap() } } } fn body_from_bytes(bytes: Bytes) -> ResponseBody { let body = Full::from(bytes).map_err(|err| match err {}).boxed_unsync(); ResponseBody::new(body) } fn empty_body() -> ResponseBody { let body = Empty::new().map_err(|err| match err {}).boxed_unsync(); ResponseBody::new(body) } tower-http-0.4.4/src/services/fs/serve_dir/headers.rs000064400000000000000000000027711046102023000207420ustar 00000000000000use http::header::HeaderValue; use httpdate::HttpDate; use std::time::SystemTime; pub(super) struct LastModified(pub(super) HttpDate); impl From for LastModified { fn from(time: SystemTime) -> Self { LastModified(time.into()) } } pub(super) struct IfModifiedSince(HttpDate); impl IfModifiedSince { /// Check if the supplied time means the resource has been modified. pub(super) fn is_modified(&self, last_modified: &LastModified) -> bool { self.0 < last_modified.0 } /// convert a header value into a IfModifiedSince, invalid values are silentely ignored pub(super) fn from_header_value(value: &HeaderValue) -> Option { std::str::from_utf8(value.as_bytes()) .ok() .and_then(|value| httpdate::parse_http_date(value).ok()) .map(|time| IfModifiedSince(time.into())) } } pub(super) struct IfUnmodifiedSince(HttpDate); impl IfUnmodifiedSince { /// Check if the supplied time passes the precondtion. pub(super) fn precondition_passes(&self, last_modified: &LastModified) -> bool { self.0 >= last_modified.0 } /// Convert a header value into a IfModifiedSince, invalid values are silentely ignored pub(super) fn from_header_value(value: &HeaderValue) -> Option { std::str::from_utf8(value.as_bytes()) .ok() .and_then(|value| httpdate::parse_http_date(value).ok()) .map(|time| IfUnmodifiedSince(time.into())) } } tower-http-0.4.4/src/services/fs/serve_dir/mod.rs000064400000000000000000000502061046102023000201020ustar 00000000000000use self::future::ResponseFuture; use crate::{ content_encoding::{encodings, SupportedEncodings}, set_status::SetStatus, }; use bytes::Bytes; use futures_util::FutureExt; use http::{header, HeaderValue, Method, Request, Response, StatusCode}; use http_body::{combinators::UnsyncBoxBody, Body, Empty}; use percent_encoding::percent_decode; use std::{ convert::Infallible, io, path::{Component, Path, PathBuf}, task::{Context, Poll}, }; use tower_service::Service; pub(crate) mod future; mod headers; mod open_file; #[cfg(test)] mod tests; // default capacity 64KiB const DEFAULT_CAPACITY: usize = 65536; /// Service that serves files from a given directory and all its sub directories. /// /// The `Content-Type` will be guessed from the file extension. /// /// An empty response with status `404 Not Found` will be returned if: /// /// - The file doesn't exist /// - Any segment of the path contains `..` /// - Any segment of the path contains a backslash /// - On unix, any segment of the path referenced as directory is actually an /// existing file (`/file.html/something`) /// - We don't have necessary permissions to read the file /// /// # Example /// /// ``` /// use tower_http::services::ServeDir; /// /// // This will serve files in the "assets" directory and /// // its subdirectories /// let service = ServeDir::new("assets"); /// /// # async { /// // Run our service using `hyper` /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); /// hyper::Server::bind(&addr) /// .serve(tower::make::Shared::new(service)) /// .await /// .expect("server error"); /// # }; /// ``` #[derive(Clone, Debug)] pub struct ServeDir { base: PathBuf, buf_chunk_size: usize, precompressed_variants: Option, // This is used to specialise implementation for // single files variant: ServeVariant, fallback: Option, call_fallback_on_method_not_allowed: bool, } impl ServeDir { /// Create a new [`ServeDir`]. pub fn new

(path: P) -> Self where P: AsRef, { let mut base = PathBuf::from("."); base.push(path.as_ref()); Self { base, buf_chunk_size: DEFAULT_CAPACITY, precompressed_variants: None, variant: ServeVariant::Directory { append_index_html_on_directories: true, }, fallback: None, call_fallback_on_method_not_allowed: false, } } pub(crate) fn new_single_file

(path: P, mime: HeaderValue) -> Self where P: AsRef, { Self { base: path.as_ref().to_owned(), buf_chunk_size: DEFAULT_CAPACITY, precompressed_variants: None, variant: ServeVariant::SingleFile { mime }, fallback: None, call_fallback_on_method_not_allowed: false, } } } impl ServeDir { /// If the requested path is a directory append `index.html`. /// /// This is useful for static sites. /// /// Defaults to `true`. pub fn append_index_html_on_directories(mut self, append: bool) -> Self { match &mut self.variant { ServeVariant::Directory { append_index_html_on_directories, } => { *append_index_html_on_directories = append; self } ServeVariant::SingleFile { mime: _ } => self, } } /// Set a specific read buffer chunk size. /// /// The default capacity is 64kb. pub fn with_buf_chunk_size(mut self, chunk_size: usize) -> Self { self.buf_chunk_size = chunk_size; self } /// Informs the service that it should also look for a precompressed gzip /// version of _any_ file in the directory. /// /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, /// a client with an `Accept-Encoding` header that allows the gzip encoding /// will receive the file `dir/foo.txt.gz` instead of `dir/foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the directory. Different precompressed variants can be combined. pub fn precompressed_gzip(mut self) -> Self { self.precompressed_variants .get_or_insert(Default::default()) .gzip = true; self } /// Informs the service that it should also look for a precompressed brotli /// version of _any_ file in the directory. /// /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, /// a client with an `Accept-Encoding` header that allows the brotli encoding /// will receive the file `dir/foo.txt.br` instead of `dir/foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the directory. Different precompressed variants can be combined. pub fn precompressed_br(mut self) -> Self { self.precompressed_variants .get_or_insert(Default::default()) .br = true; self } /// Informs the service that it should also look for a precompressed deflate /// version of _any_ file in the directory. /// /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, /// a client with an `Accept-Encoding` header that allows the deflate encoding /// will receive the file `dir/foo.txt.zz` instead of `dir/foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the directory. Different precompressed variants can be combined. pub fn precompressed_deflate(mut self) -> Self { self.precompressed_variants .get_or_insert(Default::default()) .deflate = true; self } /// Informs the service that it should also look for a precompressed zstd /// version of _any_ file in the directory. /// /// Assuming the `dir` directory is being served and `dir/foo.txt` is requested, /// a client with an `Accept-Encoding` header that allows the zstd encoding /// will receive the file `dir/foo.txt.zst` instead of `dir/foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the directory. Different precompressed variants can be combined. pub fn precompressed_zstd(mut self) -> Self { self.precompressed_variants .get_or_insert(Default::default()) .zstd = true; self } /// Set the fallback service. /// /// This service will be called if there is no file at the path of the request. /// /// The status code returned by the fallback will not be altered. Use /// [`ServeDir::not_found_service`] to set a fallback and always respond with `404 Not Found`. /// /// # Example /// /// This can be used to respond with a different file: /// /// ```rust /// use tower_http::services::{ServeDir, ServeFile}; /// /// let service = ServeDir::new("assets") /// // respond with `not_found.html` for missing files /// .fallback(ServeFile::new("assets/not_found.html")); /// /// # async { /// // Run our service using `hyper` /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); /// hyper::Server::bind(&addr) /// .serve(tower::make::Shared::new(service)) /// .await /// .expect("server error"); /// # }; /// ``` pub fn fallback(self, new_fallback: F2) -> ServeDir { ServeDir { base: self.base, buf_chunk_size: self.buf_chunk_size, precompressed_variants: self.precompressed_variants, variant: self.variant, fallback: Some(new_fallback), call_fallback_on_method_not_allowed: self.call_fallback_on_method_not_allowed, } } /// Set the fallback service and override the fallback's status code to `404 Not Found`. /// /// This service will be called if there is no file at the path of the request. /// /// # Example /// /// This can be used to respond with a different file: /// /// ```rust /// use tower_http::services::{ServeDir, ServeFile}; /// /// let service = ServeDir::new("assets") /// // respond with `404 Not Found` and the contents of `not_found.html` for missing files /// .not_found_service(ServeFile::new("assets/not_found.html")); /// /// # async { /// // Run our service using `hyper` /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); /// hyper::Server::bind(&addr) /// .serve(tower::make::Shared::new(service)) /// .await /// .expect("server error"); /// # }; /// ``` /// /// Setups like this are often found in single page applications. pub fn not_found_service(self, new_fallback: F2) -> ServeDir> { self.fallback(SetStatus::new(new_fallback, StatusCode::NOT_FOUND)) } /// Customize whether or not to call the fallback for requests that aren't `GET` or `HEAD`. /// /// Defaults to not calling the fallback and instead returning `405 Method Not Allowed`. pub fn call_fallback_on_method_not_allowed(mut self, call_fallback: bool) -> Self { self.call_fallback_on_method_not_allowed = call_fallback; self } /// Call the service and get a future that contains any `std::io::Error` that might have /// happened. /// /// By default `>::call` will handle IO errors and convert them into /// responses. It does that by converting [`std::io::ErrorKind::NotFound`] and /// [`std::io::ErrorKind::PermissionDenied`] to `404 Not Found` and any other error to `500 /// Internal Server Error`. The error will also be logged with `tracing`. /// /// If you want to manually control how the error response is generated you can make a new /// service that wraps a `ServeDir` and calls `try_call` instead of `call`. /// /// # Example /// /// ``` /// use tower_http::services::ServeDir; /// use std::{io, convert::Infallible}; /// use http::{Request, Response, StatusCode}; /// use http_body::{combinators::UnsyncBoxBody, Body as _}; /// use hyper::Body; /// use bytes::Bytes; /// use tower::{service_fn, ServiceExt, BoxError}; /// /// async fn serve_dir( /// request: Request /// ) -> Result>, Infallible> { /// let mut service = ServeDir::new("assets"); /// /// // You only need to worry about backpressure, and thus call `ServiceExt::ready`, if /// // your adding a fallback to `ServeDir` that cares about backpressure. /// // /// // Its shown here for demonstration but you can do `service.try_call(request)` /// // otherwise /// let ready_service = match ServiceExt::>::ready(&mut service).await { /// Ok(ready_service) => ready_service, /// Err(infallible) => match infallible {}, /// }; /// /// match ready_service.try_call(request).await { /// Ok(response) => { /// Ok(response.map(|body| body.map_err(Into::into).boxed_unsync())) /// } /// Err(err) => { /// let body = Body::from("Something went wrong...") /// .map_err(Into::into) /// .boxed_unsync(); /// let response = Response::builder() /// .status(StatusCode::INTERNAL_SERVER_ERROR) /// .body(body) /// .unwrap(); /// Ok(response) /// } /// } /// } /// /// # async { /// // Run our service using `hyper` /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); /// hyper::Server::bind(&addr) /// .serve(tower::make::Shared::new(service_fn(serve_dir))) /// .await /// .expect("server error"); /// # }; /// ``` pub fn try_call( &mut self, req: Request, ) -> ResponseFuture where F: Service, Response = Response, Error = Infallible> + Clone, F::Future: Send + 'static, FResBody: http_body::Body + Send + 'static, FResBody::Error: Into>, { if req.method() != Method::GET && req.method() != Method::HEAD { if self.call_fallback_on_method_not_allowed { if let Some(fallback) = &mut self.fallback { return ResponseFuture { inner: future::call_fallback(fallback, req), }; } } else { return ResponseFuture::method_not_allowed(); } } // `ServeDir` doesn't care about the request body but the fallback might. So move out the // body and pass it to the fallback, leaving an empty body in its place // // this is necessary because we cannot clone bodies let (mut parts, body) = req.into_parts(); // same goes for extensions let extensions = std::mem::take(&mut parts.extensions); let req = Request::from_parts(parts, Empty::::new()); let fallback_and_request = self.fallback.as_mut().map(|fallback| { let mut fallback_req = Request::new(body); *fallback_req.method_mut() = req.method().clone(); *fallback_req.uri_mut() = req.uri().clone(); *fallback_req.headers_mut() = req.headers().clone(); *fallback_req.extensions_mut() = extensions; // get the ready fallback and leave a non-ready clone in its place let clone = fallback.clone(); let fallback = std::mem::replace(fallback, clone); (fallback, fallback_req) }); let path_to_file = match self .variant .build_and_validate_path(&self.base, req.uri().path()) { Some(path_to_file) => path_to_file, None => { return ResponseFuture::invalid_path(fallback_and_request); } }; let buf_chunk_size = self.buf_chunk_size; let range_header = req .headers() .get(header::RANGE) .and_then(|value| value.to_str().ok()) .map(|s| s.to_owned()); let negotiated_encodings = encodings( req.headers(), self.precompressed_variants.unwrap_or_default(), ); let variant = self.variant.clone(); let open_file_future = Box::pin(open_file::open_file( variant, path_to_file, req, negotiated_encodings, range_header, buf_chunk_size, )); ResponseFuture::open_file_future(open_file_future, fallback_and_request) } } impl Service> for ServeDir where F: Service, Response = Response, Error = Infallible> + Clone, F::Future: Send + 'static, FResBody: http_body::Body + Send + 'static, FResBody::Error: Into>, { type Response = Response; type Error = Infallible; type Future = InfallibleResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if let Some(fallback) = &mut self.fallback { fallback.poll_ready(cx) } else { Poll::Ready(Ok(())) } } fn call(&mut self, req: Request) -> Self::Future { let future = self .try_call(req) .map(|result: Result<_, _>| -> Result<_, Infallible> { let response = result.unwrap_or_else(|err| { tracing::error!(error = %err, "Failed to read file"); let body = ResponseBody::new(Empty::new().map_err(|err| match err {}).boxed_unsync()); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(body) .unwrap() }); Ok(response) } as _); InfallibleResponseFuture::new(future) } } opaque_future! { /// Response future of [`ServeDir`]. pub type InfallibleResponseFuture = futures_util::future::Map< ResponseFuture, fn(Result, io::Error>) -> Result, Infallible>, >; } // Allow the ServeDir service to be used in the ServeFile service // with almost no overhead #[derive(Clone, Debug)] enum ServeVariant { Directory { append_index_html_on_directories: bool, }, SingleFile { mime: HeaderValue, }, } impl ServeVariant { fn build_and_validate_path(&self, base_path: &Path, requested_path: &str) -> Option { match self { ServeVariant::Directory { append_index_html_on_directories: _, } => { let path = requested_path.trim_start_matches('/'); let path_decoded = percent_decode(path.as_ref()).decode_utf8().ok()?; let path_decoded = Path::new(&*path_decoded); let mut path_to_file = base_path.to_path_buf(); for component in path_decoded.components() { match component { Component::Normal(comp) => { // protect against paths like `/foo/c:/bar/baz` (#204) if Path::new(&comp) .components() .all(|c| matches!(c, Component::Normal(_))) { path_to_file.push(comp) } else { return None; } } Component::CurDir => {} Component::Prefix(_) | Component::RootDir | Component::ParentDir => { return None; } } } Some(path_to_file) } ServeVariant::SingleFile { mime: _ } => Some(base_path.to_path_buf()), } } } opaque_body! { /// Response body for [`ServeDir`] and [`ServeFile`][super::ServeFile]. #[derive(Default)] pub type ResponseBody = UnsyncBoxBody; } /// The default fallback service used with [`ServeDir`]. #[derive(Debug, Clone, Copy)] pub struct DefaultServeDirFallback(Infallible); impl Service> for DefaultServeDirFallback where ReqBody: Send + 'static, { type Response = Response; type Error = Infallible; type Future = InfallibleResponseFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { match self.0 {} } fn call(&mut self, _req: Request) -> Self::Future { match self.0 {} } } #[derive(Clone, Copy, Debug, Default)] struct PrecompressedVariants { gzip: bool, deflate: bool, br: bool, zstd: bool, } impl SupportedEncodings for PrecompressedVariants { fn gzip(&self) -> bool { self.gzip } fn deflate(&self) -> bool { self.deflate } fn br(&self) -> bool { self.br } fn zstd(&self) -> bool { self.zstd } } tower-http-0.4.4/src/services/fs/serve_dir/open_file.rs000064400000000000000000000253541046102023000212710ustar 00000000000000use super::{ headers::{IfModifiedSince, IfUnmodifiedSince, LastModified}, ServeVariant, }; use crate::content_encoding::{Encoding, QValue}; use bytes::Bytes; use http::{header, HeaderValue, Method, Request, Uri}; use http_body::Empty; use http_range_header::RangeUnsatisfiableError; use std::{ ffi::OsStr, fs::Metadata, io::{self, SeekFrom}, ops::RangeInclusive, path::{Path, PathBuf}, }; use tokio::{fs::File, io::AsyncSeekExt}; pub(super) enum OpenFileOutput { FileOpened(Box), Redirect { location: HeaderValue }, FileNotFound, PreconditionFailed, NotModified, } pub(super) struct FileOpened { pub(super) extent: FileRequestExtent, pub(super) chunk_size: usize, pub(super) mime_header_value: HeaderValue, pub(super) maybe_encoding: Option, pub(super) maybe_range: Option>, RangeUnsatisfiableError>>, pub(super) last_modified: Option, } pub(super) enum FileRequestExtent { Full(File, Metadata), Head(Metadata), } pub(super) async fn open_file( variant: ServeVariant, mut path_to_file: PathBuf, req: Request>, negotiated_encodings: Vec<(Encoding, QValue)>, range_header: Option, buf_chunk_size: usize, ) -> io::Result { let if_unmodified_since = req .headers() .get(header::IF_UNMODIFIED_SINCE) .and_then(IfUnmodifiedSince::from_header_value); let if_modified_since = req .headers() .get(header::IF_MODIFIED_SINCE) .and_then(IfModifiedSince::from_header_value); let mime = match variant { ServeVariant::Directory { append_index_html_on_directories, } => { // Might already at this point know a redirect or not found result should be // returned which corresponds to a Some(output). Otherwise the path might be // modified and proceed to the open file/metadata future. if let Some(output) = maybe_redirect_or_append_path( &mut path_to_file, req.uri(), append_index_html_on_directories, ) .await { return Ok(output); } mime_guess::from_path(&path_to_file) .first_raw() .map(HeaderValue::from_static) .unwrap_or_else(|| { HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() }) } ServeVariant::SingleFile { mime } => mime, }; if req.method() == Method::HEAD { let (meta, maybe_encoding) = file_metadata_with_fallback(path_to_file, negotiated_encodings).await?; let last_modified = meta.modified().ok().map(LastModified::from); if let Some(output) = check_modified_headers( last_modified.as_ref(), if_unmodified_since, if_modified_since, ) { return Ok(output); } let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { extent: FileRequestExtent::Head(meta), chunk_size: buf_chunk_size, mime_header_value: mime, maybe_encoding, maybe_range, last_modified, }))) } else { let (mut file, maybe_encoding) = open_file_with_fallback(path_to_file, negotiated_encodings).await?; let meta = file.metadata().await?; let last_modified = meta.modified().ok().map(LastModified::from); if let Some(output) = check_modified_headers( last_modified.as_ref(), if_unmodified_since, if_modified_since, ) { return Ok(output); } let maybe_range = try_parse_range(range_header.as_deref(), meta.len()); if let Some(Ok(ranges)) = maybe_range.as_ref() { // if there is any other amount of ranges than 1 we'll return an // unsatisfiable later as there isn't yet support for multipart ranges if ranges.len() == 1 { file.seek(SeekFrom::Start(*ranges[0].start())).await?; } } Ok(OpenFileOutput::FileOpened(Box::new(FileOpened { extent: FileRequestExtent::Full(file, meta), chunk_size: buf_chunk_size, mime_header_value: mime, maybe_encoding, maybe_range, last_modified, }))) } } fn check_modified_headers( modified: Option<&LastModified>, if_unmodified_since: Option, if_modified_since: Option, ) -> Option { if let Some(since) = if_unmodified_since { let precondition = modified .as_ref() .map(|time| since.precondition_passes(time)) .unwrap_or(false); if !precondition { return Some(OpenFileOutput::PreconditionFailed); } } if let Some(since) = if_modified_since { let unmodified = modified .as_ref() .map(|time| !since.is_modified(time)) // no last_modified means its always modified .unwrap_or(false); if unmodified { return Some(OpenFileOutput::NotModified); } } None } // Returns the preferred_encoding encoding and modifies the path extension // to the corresponding file extension for the encoding. fn preferred_encoding( path: &mut PathBuf, negotiated_encoding: &[(Encoding, QValue)], ) -> Option { let preferred_encoding = Encoding::preferred_encoding(negotiated_encoding); if let Some(file_extension) = preferred_encoding.and_then(|encoding| encoding.to_file_extension()) { let new_extension = path .extension() .map(|extension| { let mut os_string = extension.to_os_string(); os_string.push(file_extension); os_string }) .unwrap_or_else(|| file_extension.to_os_string()); path.set_extension(new_extension); } preferred_encoding } // Attempts to open the file with any of the possible negotiated_encodings in the // preferred order. If none of the negotiated_encodings have a corresponding precompressed // file the uncompressed file is used as a fallback. async fn open_file_with_fallback( mut path: PathBuf, mut negotiated_encoding: Vec<(Encoding, QValue)>, ) -> io::Result<(File, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. let encoding = preferred_encoding(&mut path, &negotiated_encoding); match (File::open(&path).await, encoding) { (Ok(file), maybe_encoding) => break (file, maybe_encoding), (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist negotiated_encoding .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); continue; } (Err(err), _) => return Err(err), }; }; Ok((file, encoding)) } // Attempts to get the file metadata with any of the possible negotiated_encodings in the // preferred order. If none of the negotiated_encodings have a corresponding precompressed // file the uncompressed file is used as a fallback. async fn file_metadata_with_fallback( mut path: PathBuf, mut negotiated_encoding: Vec<(Encoding, QValue)>, ) -> io::Result<(Metadata, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. let encoding = preferred_encoding(&mut path, &negotiated_encoding); match (tokio::fs::metadata(&path).await, encoding) { (Ok(file), maybe_encoding) => break (file, maybe_encoding), (Err(err), Some(encoding)) if err.kind() == io::ErrorKind::NotFound => { // Remove the extension corresponding to a precompressed file (.gz, .br, .zz) // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist negotiated_encoding .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); continue; } (Err(err), _) => return Err(err), }; }; Ok((file, encoding)) } async fn maybe_redirect_or_append_path( path_to_file: &mut PathBuf, uri: &Uri, append_index_html_on_directories: bool, ) -> Option { if !uri.path().ends_with('/') { if is_dir(path_to_file).await { let location = HeaderValue::from_str(&append_slash_on_path(uri.clone()).to_string()).unwrap(); Some(OpenFileOutput::Redirect { location }) } else { None } } else if is_dir(path_to_file).await { if append_index_html_on_directories { path_to_file.push("index.html"); None } else { Some(OpenFileOutput::FileNotFound) } } else { None } } fn try_parse_range( maybe_range_ref: Option<&str>, file_size: u64, ) -> Option>, RangeUnsatisfiableError>> { maybe_range_ref.map(|header_value| { http_range_header::parse_range_header(header_value) .and_then(|first_pass| first_pass.validate(file_size)) }) } async fn is_dir(path_to_file: &Path) -> bool { tokio::fs::metadata(path_to_file) .await .map_or(false, |meta_data| meta_data.is_dir()) } fn append_slash_on_path(uri: Uri) -> Uri { let http::uri::Parts { scheme, authority, path_and_query, .. } = uri.into_parts(); let mut uri_builder = Uri::builder(); if let Some(scheme) = scheme { uri_builder = uri_builder.scheme(scheme); } if let Some(authority) = authority { uri_builder = uri_builder.authority(authority); } let uri_builder = if let Some(path_and_query) = path_and_query { if let Some(query) = path_and_query.query() { uri_builder.path_and_query(format!("{}/?{}", path_and_query.path(), query)) } else { uri_builder.path_and_query(format!("{}/", path_and_query.path())) } } else { uri_builder.path_and_query("/") }; uri_builder.build().unwrap() } tower-http-0.4.4/src/services/fs/serve_dir/tests.rs000064400000000000000000000555611046102023000204760ustar 00000000000000use crate::services::{ServeDir, ServeFile}; use brotli::BrotliDecompress; use bytes::Bytes; use flate2::bufread::{DeflateDecoder, GzDecoder}; use http::header::ALLOW; use http::{header, Method, Response}; use http::{Request, StatusCode}; use http_body::Body as HttpBody; use hyper::Body; use std::convert::Infallible; use std::io::{self, Read}; use tower::{service_fn, ServiceExt}; #[tokio::test] async fn basic() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = body_into_text(res.into_body()).await; let contents = std::fs::read_to_string("../README.md").unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn basic_with_index() { let svc = ServeDir::new("../test-files"); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()[header::CONTENT_TYPE], "text/html"); let body = body_into_text(res.into_body()).await; assert_eq!(body, "HTML!\n"); } #[tokio::test] async fn head_request() { let svc = ServeDir::new("../test-files"); let req = Request::builder() .uri("/precompressed.txt") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "23"); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn precompresed_head_request() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let req = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "gzip") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); assert_eq!(res.headers()["content-length"], "59"); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn with_custom_chunk_size() { let svc = ServeDir::new("..").with_buf_chunk_size(1024 * 32); let req = Request::builder() .uri("/README.md") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = body_into_text(res.into_body()).await; let contents = std::fs::read_to_string("../README.md").unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn precompressed_gzip() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let req = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn precompressed_br() { let svc = ServeDir::new("../test-files").precompressed_br(); let req = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn precompressed_deflate() { let svc = ServeDir::new("../test-files").precompressed_deflate(); let request = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "deflate,br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "deflate"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let request = Request::builder() .uri("/precompressed.txt") .header("Accept-Encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert!(res.headers().get("content-encoding").is_none()); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn only_precompressed_variant_existing() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let request = Request::builder() .uri("/only_gzipped.txt") .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); // Should reply with gzipped file if client supports it let request = Request::builder() .uri("/only_gzipped.txt") .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file\"")); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let request = Request::builder() .uri("/missing_precompressed.txt") .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("Test file!")); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_request() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let request = Request::builder() .uri("/missing_precompressed.txt") .header("Accept-Encoding", "gzip") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "11"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); assert!(res.into_body().data().await.is_none()); } #[tokio::test] async fn access_to_sub_dirs() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/tower-http/Cargo.toml") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/x-toml"); let body = body_into_text(res.into_body()).await; let contents = std::fs::read_to_string("Cargo.toml").unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn not_found() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/not-found") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); let body = body_into_text(res.into_body()).await; assert!(body.is_empty()); } #[cfg(unix)] #[tokio::test] async fn not_found_when_not_a_directory() { let svc = ServeDir::new("../test-files"); // `index.html` is a file, and we are trying to request // it as a directory. let req = Request::builder() .uri("/index.html/some_file") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); // This should lead to a 404 assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); let body = body_into_text(res.into_body()).await; assert!(body.is_empty()); } #[tokio::test] async fn not_found_precompressed() { let svc = ServeDir::new("../test-files").precompressed_gzip(); let req = Request::builder() .uri("/not-found") .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); let body = body_into_text(res.into_body()).await; assert!(body.is_empty()); } #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found_for_head_request() { let svc = ServeDir::new("../test-files") .precompressed_gzip() .precompressed_br(); let req = Request::builder() .uri("/precompressed_br.txt") .header("Accept-Encoding", "gzip,br,deflate") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); assert_eq!(res.headers()["content-length"], "15"); assert!(res.into_body().data().await.is_none()); } #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found() { let svc = ServeDir::new("../test-files") .precompressed_gzip() .precompressed_br(); let req = Request::builder() .uri("/precompressed_br.txt") .header("Accept-Encoding", "gzip,br,deflate") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); assert!(decompressed.starts_with("Test file")); } #[tokio::test] async fn redirect_to_trailing_slash_on_dir() { let svc = ServeDir::new("."); let req = Request::builder().uri("/src").body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); let location = &res.headers()[http::header::LOCATION]; assert_eq!(location, "/src/"); } #[tokio::test] async fn empty_directory_without_index() { let svc = ServeDir::new(".").append_index_html_on_directories(false); let req = Request::new(Body::empty()); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); let body = body_into_text(res.into_body()).await; assert!(body.is_empty()); } async fn body_into_text(body: B) -> String where B: HttpBody + Unpin, B::Error: std::fmt::Debug, { let bytes = hyper::body::to_bytes(body).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } #[tokio::test] async fn access_cjk_percent_encoded_uri_path() { // percent encoding present of 你好世界.txt let cjk_filename_encoded = "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C.txt"; let svc = ServeDir::new("../test-files"); let req = Request::builder() .uri(format!("/{}", cjk_filename_encoded)) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/plain"); } #[tokio::test] async fn access_space_percent_encoded_uri_path() { let encoded_filename = "filename%20with%20space.txt"; let svc = ServeDir::new("../test-files"); let req = Request::builder() .uri(format!("/{}", encoded_filename)) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/plain"); } #[tokio::test] async fn read_partial_in_bounds() { let svc = ServeDir::new(".."); let bytes_start_incl = 9; let bytes_end_incl = 1023; let req = Request::builder() .uri("/README.md") .header( "Range", format!("bytes={}-{}", bytes_start_incl, bytes_end_incl), ) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); let file_contents = std::fs::read("../README.md").unwrap(); assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); assert_eq!( res.headers()["content-length"], (bytes_end_incl - bytes_start_incl + 1).to_string() ); assert!(res.headers()["content-range"] .to_str() .unwrap() .starts_with(&format!( "bytes {}-{}/{}", bytes_start_incl, bytes_end_incl, file_contents.len() ))); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = hyper::body::to_bytes(res.into_body()).await.ok().unwrap(); let source = Bytes::from(file_contents[bytes_start_incl..=bytes_end_incl].to_vec()); assert_eq!(body, source); } #[tokio::test] #[ignore] // https://github.com/tower-rs/tower-http/commit/0c50afe28a3c9bec7aa4e1f620ce5a0a805b6103 // This commit on master fixes the issue so lets ignore it for now async fn read_partial_rejects_out_of_bounds_range() { let svc = ServeDir::new(".."); let bytes_start_incl = 0; let bytes_end_excl = 9999999; let requested_len = bytes_end_excl - bytes_start_incl; let req = Request::builder() .uri("/README.md") .header( "Range", format!("bytes={}-{}", bytes_start_incl, requested_len - 1), ) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); let file_contents = std::fs::read("../README.md").unwrap(); assert_eq!( res.headers()["content-range"], &format!("bytes */{}", file_contents.len()) ) } #[tokio::test] async fn read_partial_errs_on_garbage_header() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header("Range", "bad_format") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); let file_contents = std::fs::read("../README.md").unwrap(); assert_eq!( res.headers()["content-range"], &format!("bytes */{}", file_contents.len()) ) } #[tokio::test] async fn read_partial_errs_on_bad_range() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header("Range", "bytes=-1-15") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); let file_contents = std::fs::read("../README.md").unwrap(); assert_eq!( res.headers()["content-range"], &format!("bytes */{}", file_contents.len()) ) } #[tokio::test] async fn accept_encoding_identity() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header("Accept-Encoding", "identity") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); // Identity encoding should not be included in the response headers assert!(res.headers().get("content-encoding").is_none()); } #[tokio::test] async fn last_modified() { let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let last_modified = res .headers() .get(header::LAST_MODIFIED) .expect("Missing last modified header!"); // -- If-Modified-Since let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header(header::IF_MODIFIED_SINCE, last_modified) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_MODIFIED); let body = res.into_body().data().await; assert!(body.is_none()); let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let readme_bytes = include_bytes!("../../../../../README.md"); let body = res.into_body().data().await.unwrap().unwrap(); assert_eq!(body.as_ref(), readme_bytes); // -- If-Unmodified-Since let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header(header::IF_UNMODIFIED_SINCE, last_modified) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = res.into_body().data().await.unwrap().unwrap(); assert_eq!(body.as_ref(), readme_bytes); let svc = ServeDir::new(".."); let req = Request::builder() .uri("/README.md") .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn with_fallback_svc() { async fn fallback(req: Request) -> Result, Infallible> { Ok(Response::new(Body::from(format!( "from fallback {}", req.uri().path() )))) } let svc = ServeDir::new("..").fallback(tower::service_fn(fallback)); let req = Request::builder() .uri("/doesnt-exist") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = body_into_text(res.into_body()).await; assert_eq!(body, "from fallback /doesnt-exist"); } #[tokio::test] async fn with_fallback_serve_file() { let svc = ServeDir::new("..").fallback(ServeFile::new("../README.md")); let req = Request::builder() .uri("/doesnt-exist") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = body_into_text(res.into_body()).await; let contents = std::fs::read_to_string("../README.md").unwrap(); assert_eq!(body, contents); } #[tokio::test] async fn method_not_allowed() { let svc = ServeDir::new(".."); let req = Request::builder() .method(Method::POST) .uri("/README.md") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD"); } #[tokio::test] async fn calling_fallback_on_not_allowed() { async fn fallback(req: Request) -> Result, Infallible> { Ok(Response::new(Body::from(format!( "from fallback {}", req.uri().path() )))) } let svc = ServeDir::new("..") .call_fallback_on_method_not_allowed(true) .fallback(tower::service_fn(fallback)); let req = Request::builder() .method(Method::POST) .uri("/doesnt-exist") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = body_into_text(res.into_body()).await; assert_eq!(body, "from fallback /doesnt-exist"); } #[tokio::test] async fn with_fallback_svc_and_not_append_index_html_on_directories() { async fn fallback(req: Request) -> Result, Infallible> { Ok(Response::new(Body::from(format!( "from fallback {}", req.uri().path() )))) } let svc = ServeDir::new("..") .append_index_html_on_directories(false) .fallback(tower::service_fn(fallback)); let req = Request::builder().uri("/").body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = body_into_text(res.into_body()).await; assert_eq!(body, "from fallback /"); } // https://github.com/tower-rs/tower-http/issues/308 #[tokio::test] async fn calls_fallback_on_invalid_paths() { async fn fallback(_: T) -> Result, Infallible> { let mut res = Response::new(Body::empty()); res.headers_mut() .insert("from-fallback", "1".parse().unwrap()); Ok(res) } let svc = ServeDir::new("..").fallback(service_fn(fallback)); let req = Request::builder() .uri("/weird_%c3%28_path") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.headers()["from-fallback"], "1"); } tower-http-0.4.4/src/services/fs/serve_file.rs000064400000000000000000000462221046102023000174670ustar 00000000000000//! Service that serves a file. use super::ServeDir; use http::{HeaderValue, Request}; use mime::Mime; use std::{ path::Path, task::{Context, Poll}, }; use tower_service::Service; /// Service that serves a file. #[derive(Clone, Debug)] pub struct ServeFile(ServeDir); // Note that this is just a special case of ServeDir impl ServeFile { /// Create a new [`ServeFile`]. /// /// The `Content-Type` will be guessed from the file extension. pub fn new>(path: P) -> Self { let guess = mime_guess::from_path(path.as_ref()); let mime = guess .first_raw() .map(HeaderValue::from_static) .unwrap_or_else(|| { HeaderValue::from_str(mime::APPLICATION_OCTET_STREAM.as_ref()).unwrap() }); Self(ServeDir::new_single_file(path, mime)) } /// Create a new [`ServeFile`] with a specific mime type. /// /// # Panics /// /// Will panic if the mime type isn't a valid [header value]. /// /// [header value]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html pub fn new_with_mime>(path: P, mime: &Mime) -> Self { let mime = HeaderValue::from_str(mime.as_ref()).expect("mime isn't a valid header value"); Self(ServeDir::new_single_file(path, mime)) } /// Informs the service that it should also look for a precompressed gzip /// version of the file. /// /// If the client has an `Accept-Encoding` header that allows the gzip encoding, /// the file `foo.txt.gz` will be served instead of `foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the same directory. Different precompressed /// variants can be combined. pub fn precompressed_gzip(self) -> Self { Self(self.0.precompressed_gzip()) } /// Informs the service that it should also look for a precompressed brotli /// version of the file. /// /// If the client has an `Accept-Encoding` header that allows the brotli encoding, /// the file `foo.txt.br` will be served instead of `foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the same directory. Different precompressed /// variants can be combined. pub fn precompressed_br(self) -> Self { Self(self.0.precompressed_br()) } /// Informs the service that it should also look for a precompressed deflate /// version of the file. /// /// If the client has an `Accept-Encoding` header that allows the deflate encoding, /// the file `foo.txt.zz` will be served instead of `foo.txt`. /// If the precompressed file is not available, or the client doesn't support it, /// the uncompressed version will be served instead. /// Both the precompressed version and the uncompressed version are expected /// to be present in the same directory. Different precompressed /// variants can be combined. pub fn precompressed_deflate(self) -> Self { Self(self.0.precompressed_deflate()) } /// Set a specific read buffer chunk size. /// /// The default capacity is 64kb. pub fn with_buf_chunk_size(self, chunk_size: usize) -> Self { Self(self.0.with_buf_chunk_size(chunk_size)) } /// Call the service and get a future that contains any `std::io::Error` that might have /// happened. /// /// See [`ServeDir::try_call`] for more details. pub fn try_call( &mut self, req: Request, ) -> super::serve_dir::future::ResponseFuture where ReqBody: Send + 'static, { self.0.try_call(req) } } impl Service> for ServeFile where ReqBody: Send + 'static, { type Error = >>::Error; type Response = >>::Response; type Future = >>::Future; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[inline] fn call(&mut self, req: Request) -> Self::Future { self.0.call(req) } } #[cfg(test)] mod tests { use crate::services::ServeFile; use brotli::BrotliDecompress; use flate2::bufread::DeflateDecoder; use flate2::bufread::GzDecoder; use http::header; use http::Method; use http::{Request, StatusCode}; use http_body::Body as _; use hyper::Body; use mime::Mime; use std::io::Read; use std::str::FromStr; use tower::ServiceExt; #[tokio::test] async fn basic() { let svc = ServeFile::new("../README.md"); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); } #[tokio::test] async fn basic_with_mime() { let svc = ServeFile::new_with_mime("../README.md", &Mime::from_str("image/jpg").unwrap()); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); assert_eq!(res.headers()["content-type"], "image/jpg"); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); } #[tokio::test] async fn head_request() { let svc = ServeFile::new("../test-files/precompressed.txt"); let mut request = Request::new(Body::empty()); *request.method_mut() = Method::HEAD; let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "23"); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn precompresed_head_request() { let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); assert_eq!(res.headers()["content-length"], "59"); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn precompressed_gzip() { let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert!(res.headers().get("content-encoding").is_none()); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed() { let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("Test file!")); } #[tokio::test] async fn missing_precompressed_variant_fallbacks_to_uncompressed_head_request() { let svc = ServeFile::new("../test-files/missing_precompressed.txt").precompressed_gzip(); let request = Request::builder() .header("Accept-Encoding", "gzip") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "11"); // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn only_precompressed_variant_existing() { let svc = ServeFile::new("../test-files/only_gzipped.txt").precompressed_gzip(); let request = Request::builder().body(Body::empty()).unwrap(); let res = svc.clone().oneshot(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); // Should reply with gzipped file if client supports it let request = Request::builder() .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file\"")); } #[tokio::test] async fn precompressed_br() { let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_br(); let request = Request::builder() .header("Accept-Encoding", "gzip,br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn precompressed_deflate() { let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_deflate(); let request = Request::builder() .header("Accept-Encoding", "deflate,br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "deflate"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn multi_precompressed() { let svc = ServeFile::new("../test-files/precompressed.txt") .precompressed_gzip() .precompressed_br(); let request = Request::builder() .header("Accept-Encoding", "gzip") .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); let request = Request::builder() .header("Accept-Encoding", "br") .body(Body::empty()) .unwrap(); let res = svc.clone().oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); assert!(decompressed.starts_with("\"This is a test file!\"")); } #[tokio::test] async fn with_custom_chunk_size() { let svc = ServeFile::new("../README.md").with_buf_chunk_size(1024 * 32); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/markdown"); let body = res.into_body().data().await.unwrap().unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); } #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found() { let svc = ServeFile::new("../test-files/precompressed_br.txt") .precompressed_gzip() .precompressed_deflate() .precompressed_br(); let request = Request::builder() .header("Accept-Encoding", "gzip,deflate,br") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await.unwrap().unwrap(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); assert!(decompressed.starts_with("Test file")); } #[tokio::test] async fn fallbacks_to_different_precompressed_variant_if_not_found_head_request() { let svc = ServeFile::new("../test-files/precompressed_br.txt") .precompressed_gzip() .precompressed_deflate() .precompressed_br(); let request = Request::builder() .header("Accept-Encoding", "gzip,deflate,br") .method(Method::HEAD) .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "15"); assert_eq!(res.headers()["content-encoding"], "br"); let body = res.into_body().data().await; assert!(body.is_none()); } #[tokio::test] async fn returns_404_if_file_doesnt_exist() { let svc = ServeFile::new("../this-doesnt-exist.md"); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); } #[tokio::test] async fn returns_404_if_file_doesnt_exist_when_precompression_is_used() { let svc = ServeFile::new("../this-doesnt-exist.md").precompressed_deflate(); let request = Request::builder() .header("Accept-Encoding", "deflate") .body(Body::empty()) .unwrap(); let res = svc.oneshot(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); assert!(res.headers().get(header::CONTENT_TYPE).is_none()); } #[tokio::test] async fn last_modified() { let svc = ServeFile::new("../README.md"); let req = Request::builder().body(Body::empty()).unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let last_modified = res .headers() .get(header::LAST_MODIFIED) .expect("Missing last modified header!"); // -- If-Modified-Since let svc = ServeFile::new("../README.md"); let req = Request::builder() .header(header::IF_MODIFIED_SINCE, last_modified) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_MODIFIED); let body = res.into_body().data().await; assert!(body.is_none()); let svc = ServeFile::new("../README.md"); let req = Request::builder() .header(header::IF_MODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let readme_bytes = include_bytes!("../../../../README.md"); let body = res.into_body().data().await.unwrap().unwrap(); assert_eq!(body.as_ref(), readme_bytes); // -- If-Unmodified-Since let svc = ServeFile::new("../README.md"); let req = Request::builder() .header(header::IF_UNMODIFIED_SINCE, last_modified) .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = res.into_body().data().await.unwrap().unwrap(); assert_eq!(body.as_ref(), readme_bytes); let svc = ServeFile::new("../README.md"); let req = Request::builder() .header(header::IF_UNMODIFIED_SINCE, "Fri, 09 Aug 1996 14:21:40 GMT") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); let body = res.into_body().data().await; assert!(body.is_none()); } } tower-http-0.4.4/src/services/mod.rs000064400000000000000000000011141046102023000155020ustar 00000000000000//! [`Service`]s that return responses without wrapping other [`Service`]s. //! //! These kinds of services are also referred to as "leaf services" since they sit at the leaves of //! a [tree] of services. //! //! [`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html //! [tree]: https://en.wikipedia.org/wiki/Tree_(data_structure) #[cfg(feature = "redirect")] pub mod redirect; #[cfg(feature = "redirect")] #[doc(inline)] pub use self::redirect::Redirect; #[cfg(feature = "fs")] pub mod fs; #[cfg(feature = "fs")] #[doc(inline)] pub use self::fs::{ServeDir, ServeFile}; tower-http-0.4.4/src/services/redirect.rs000064400000000000000000000107171046102023000165350ustar 00000000000000//! Service that redirects all requests. //! //! # Example //! //! Imagine that we run `example.com` and want to redirect all requests using `HTTP` to `HTTPS`. //! That can be done like so: //! //! ```rust //! use http::{Request, Uri, StatusCode}; //! use hyper::Body; //! use tower::{Service, ServiceExt}; //! use tower_http::services::Redirect; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let uri: Uri = "https://example.com/".parse().unwrap(); //! let mut service: Redirect = Redirect::permanent(uri); //! //! let request = Request::builder() //! .uri("http://example.com") //! .body(Body::empty()) //! .unwrap(); //! //! let response = service.oneshot(request).await?; //! //! assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT); //! assert_eq!(response.headers()["location"], "https://example.com/"); //! # //! # Ok(()) //! # } //! ``` use http::{header, HeaderValue, Response, StatusCode, Uri}; use std::{ convert::{Infallible, TryFrom}, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower_service::Service; /// Service that redirects all requests. /// /// See the [module docs](crate::services::redirect) for more details. pub struct Redirect { status_code: StatusCode, location: HeaderValue, // Covariant over ResBody, no dropping of ResBody _marker: PhantomData ResBody>, } impl Redirect { /// Create a new [`Redirect`] that uses a [`307 Temporary Redirect`][mdn] status code. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/307 pub fn temporary(uri: Uri) -> Self { Self::with_status_code(StatusCode::TEMPORARY_REDIRECT, uri) } /// Create a new [`Redirect`] that uses a [`308 Permanent Redirect`][mdn] status code. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308 pub fn permanent(uri: Uri) -> Self { Self::with_status_code(StatusCode::PERMANENT_REDIRECT, uri) } /// Create a new [`Redirect`] that uses the given status code. /// /// # Panics /// /// - If `status_code` isn't a [redirection status code][mdn] (3xx). /// - If `uri` isn't a valid [`HeaderValue`]. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#redirection_messages pub fn with_status_code(status_code: StatusCode, uri: Uri) -> Self { assert!( status_code.is_redirection(), "not a redirection status code" ); Self { status_code, location: HeaderValue::try_from(uri.to_string()) .expect("URI isn't a valid header value"), _marker: PhantomData, } } } impl Service for Redirect where ResBody: Default, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: R) -> Self::Future { ResponseFuture { status_code: self.status_code, location: Some(self.location.clone()), _marker: PhantomData, } } } impl fmt::Debug for Redirect { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Redirect") .field("status_code", &self.status_code) .field("location", &self.location) .finish() } } impl Clone for Redirect { fn clone(&self) -> Self { Self { status_code: self.status_code, location: self.location.clone(), _marker: PhantomData, } } } /// Response future of [`Redirect`]. #[derive(Debug)] pub struct ResponseFuture { location: Option, status_code: StatusCode, // Covariant over ResBody, no dropping of ResBody _marker: PhantomData ResBody>, } impl Future for ResponseFuture where ResBody: Default, { type Output = Result, Infallible>; fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { let mut res = Response::default(); *res.status_mut() = self.status_code; res.headers_mut() .insert(header::LOCATION, self.location.take().unwrap()); Poll::Ready(Ok(res)) } } tower-http-0.4.4/src/set_header/mod.rs000064400000000000000000000060151046102023000157670ustar 00000000000000//! Middleware for setting headers on requests and responses. //! //! See [request] and [response] for more details. use http::{header::HeaderName, HeaderMap, HeaderValue, Request, Response}; pub mod request; pub mod response; #[doc(inline)] pub use self::{ request::{SetRequestHeader, SetRequestHeaderLayer}, response::{SetResponseHeader, SetResponseHeaderLayer}, }; /// Trait for producing header values. /// /// Used by [`SetRequestHeader`] and [`SetResponseHeader`]. /// /// This trait is implemented for closures with the correct type signature. Typically users will /// not have to implement this trait for their own types. /// /// It is also implemented directly for [`HeaderValue`]. When a fixed header value should be added /// to all responses, it can be supplied directly to the middleware. pub trait MakeHeaderValue { /// Try to create a header value from the request or response. fn make_header_value(&mut self, message: &T) -> Option; } impl MakeHeaderValue for F where F: FnMut(&T) -> Option, { fn make_header_value(&mut self, message: &T) -> Option { self(message) } } impl MakeHeaderValue for HeaderValue { fn make_header_value(&mut self, _message: &T) -> Option { Some(self.clone()) } } impl MakeHeaderValue for Option { fn make_header_value(&mut self, _message: &T) -> Option { self.clone() } } #[derive(Debug, Clone, Copy)] enum InsertHeaderMode { Override, Append, IfNotPresent, } impl InsertHeaderMode { fn apply(self, header_name: &HeaderName, target: &mut T, make: &mut M) where T: Headers, M: MakeHeaderValue, { match self { InsertHeaderMode::Override => { if let Some(value) = make.make_header_value(target) { target.headers_mut().insert(header_name.clone(), value); } } InsertHeaderMode::IfNotPresent => { if !target.headers().contains_key(header_name) { if let Some(value) = make.make_header_value(target) { target.headers_mut().insert(header_name.clone(), value); } } } InsertHeaderMode::Append => { if let Some(value) = make.make_header_value(target) { target.headers_mut().append(header_name.clone(), value); } } } } } trait Headers { fn headers(&self) -> &HeaderMap; fn headers_mut(&mut self) -> &mut HeaderMap; } impl Headers for Request { fn headers(&self) -> &HeaderMap { Request::headers(self) } fn headers_mut(&mut self) -> &mut HeaderMap { Request::headers_mut(self) } } impl Headers for Response { fn headers(&self) -> &HeaderMap { Response::headers(self) } fn headers_mut(&mut self) -> &mut HeaderMap { Response::headers_mut(self) } } tower-http-0.4.4/src/set_header/request.rs000064400000000000000000000167711046102023000167120ustar 00000000000000//! Set a header on the request. //! //! The header value to be set may be provided as a fixed value when the //! middleware is constructed, or determined dynamically based on the request //! by a closure. See the [`MakeHeaderValue`] trait for details. //! //! # Example //! //! Setting a header from a fixed value provided when the middleware is constructed: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetRequestHeaderLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let http_client = tower::service_fn(|_: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) //! # }); //! # //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `User-Agent: my very cool app` on requests. //! // //! // `if_not_present` will only insert the header if it does not already //! // have a value. //! SetRequestHeaderLayer::if_not_present( //! header::USER_AGENT, //! HeaderValue::from_static("my very cool app"), //! ) //! ) //! .service(http_client); //! //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! # //! # Ok(()) //! # } //! ``` //! //! Setting a header based on a value determined dynamically from the request: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetRequestHeaderLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let http_client = tower::service_fn(|_: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) //! # }); //! fn date_header_value() -> HeaderValue { //! // ... //! # HeaderValue::from_static("now") //! } //! //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `Date` to the current date and time. //! // //! // `overriding` will insert the header and override any previous values it //! // may have. //! SetRequestHeaderLayer::overriding( //! header::DATE, //! |request: &Request| { //! Some(date_header_value()) //! } //! ) //! ) //! .service(http_client); //! //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! # //! # Ok(()) //! # } //! ``` use super::{InsertHeaderMode, MakeHeaderValue}; use http::{header::HeaderName, Request, Response}; use std::{ fmt, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`SetRequestHeader`] which adds a request header. /// /// See [`SetRequestHeader`] for more details. pub struct SetRequestHeaderLayer { header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl fmt::Debug for SetRequestHeaderLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetRequestHeaderLayer") .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl SetRequestHeaderLayer { /// Create a new [`SetRequestHeaderLayer`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetRequestHeaderLayer`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetRequestHeaderLayer`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::IfNotPresent) } fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { make, header_name, mode, } } } impl Layer for SetRequestHeaderLayer where M: Clone, { type Service = SetRequestHeader; fn layer(&self, inner: S) -> Self::Service { SetRequestHeader { inner, header_name: self.header_name.clone(), make: self.make.clone(), mode: self.mode, } } } impl Clone for SetRequestHeaderLayer where M: Clone, { fn clone(&self) -> Self { Self { make: self.make.clone(), header_name: self.header_name.clone(), mode: self.mode, } } } /// Middleware that sets a header on the request. #[derive(Clone)] pub struct SetRequestHeader { inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl SetRequestHeader { /// Create a new [`SetRequestHeader`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetRequestHeader`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetRequestHeader`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) } fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { inner, header_name, make, mode, } } define_inner_service_accessors!(); } impl fmt::Debug for SetRequestHeader where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetRequestHeader") .field("inner", &self.inner) .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl Service> for SetRequestHeader where S: Service, Response = Response>, M: MakeHeaderValue>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { self.mode.apply(&self.header_name, &mut req, &mut self.make); self.inner.call(req) } } tower-http-0.4.4/src/set_header/response.rs000064400000000000000000000303571046102023000170540ustar 00000000000000//! Set a header on the response. //! //! The header value to be set may be provided as a fixed value when the //! middleware is constructed, or determined dynamically based on the response //! by a closure. See the [`MakeHeaderValue`] trait for details. //! //! # Example //! //! Setting a header from a fixed value provided when the middleware is constructed: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; //! use hyper::Body; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let render_html = tower::service_fn(|request: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `Content-Type: text/html` on responses. //! // //! // `if_not_present` will only insert the header if it does not already //! // have a value. //! SetResponseHeaderLayer::if_not_present( //! header::CONTENT_TYPE, //! HeaderValue::from_static("text/html"), //! ) //! ) //! .service(render_html); //! //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["content-type"], "text/html"); //! # //! # Ok(()) //! # } //! ``` //! //! Setting a header based on a value determined dynamically from the response: //! //! ``` //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; //! use hyper::Body; //! use http_body::Body as _; // for `Body::size_hint` //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # let render_html = tower::service_fn(|request: Request| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890"))) //! # }); //! # //! let mut svc = ServiceBuilder::new() //! .layer( //! // Layer that sets `Content-Length` if the body has a known size. //! // Bodies with streaming responses wont have a known size. //! // //! // `overriding` will insert the header and override any previous values it //! // may have. //! SetResponseHeaderLayer::overriding( //! header::CONTENT_LENGTH, //! |response: &Response| { //! if let Some(size) = response.body().size_hint().exact() { //! // If the response body has a known size, returning `Some` will //! // set the `Content-Length` header to that value. //! Some(HeaderValue::from_str(&size.to_string()).unwrap()) //! } else { //! // If the response body doesn't have a known size, return `None` //! // to skip setting the header on this response. //! None //! } //! } //! ) //! ) //! .service(render_html); //! //! let request = Request::new(Body::empty()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["content-length"], "10"); //! # //! # Ok(()) //! # } //! ``` use super::{InsertHeaderMode, MakeHeaderValue}; use futures_util::ready; use http::{header::HeaderName, Request, Response}; use pin_project_lite::pin_project; use std::{ fmt, future::Future, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`SetResponseHeader`] which adds a response header. /// /// See [`SetResponseHeader`] for more details. pub struct SetResponseHeaderLayer { header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl fmt::Debug for SetResponseHeaderLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetResponseHeaderLayer") .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl SetResponseHeaderLayer { /// Create a new [`SetResponseHeaderLayer`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetResponseHeaderLayer`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetResponseHeaderLayer`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(header_name: HeaderName, make: M) -> Self { Self::new(header_name, make, InsertHeaderMode::IfNotPresent) } fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { make, header_name, mode, } } } impl Layer for SetResponseHeaderLayer where M: Clone, { type Service = SetResponseHeader; fn layer(&self, inner: S) -> Self::Service { SetResponseHeader { inner, header_name: self.header_name.clone(), make: self.make.clone(), mode: self.mode, } } } impl Clone for SetResponseHeaderLayer where M: Clone, { fn clone(&self) -> Self { Self { make: self.make.clone(), header_name: self.header_name.clone(), mode: self.mode, } } } /// Middleware that sets a header on the response. #[derive(Clone)] pub struct SetResponseHeader { inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode, } impl SetResponseHeader { /// Create a new [`SetResponseHeader`]. /// /// If a previous value exists for the same header, it is removed and replaced with the new /// header value. pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Override) } /// Create a new [`SetResponseHeader`]. /// /// The new header is always added, preserving any existing values. If previous values exist, /// the header will have multiple values. pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::Append) } /// Create a new [`SetResponseHeader`]. /// /// If a previous value exists for the header, the new value is not inserted. pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) } fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { Self { inner, header_name, make, mode, } } define_inner_service_accessors!(); } impl fmt::Debug for SetResponseHeader where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SetResponseHeader") .field("inner", &self.inner) .field("header_name", &self.header_name) .field("mode", &self.mode) .field("make", &std::any::type_name::()) .finish() } } impl Service> for SetResponseHeader where S: Service, Response = Response>, M: MakeHeaderValue> + Clone, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseFuture { future: self.inner.call(req), header_name: self.header_name.clone(), make: self.make.clone(), mode: self.mode, } } } pin_project! { /// Response future for [`SetResponseHeader`]. #[derive(Debug)] pub struct ResponseFuture { #[pin] future: F, header_name: HeaderName, make: M, mode: InsertHeaderMode, } } impl Future for ResponseFuture where F: Future, E>>, M: MakeHeaderValue>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.future.poll(cx)?); this.mode.apply(this.header_name, &mut res, &mut *this.make); Poll::Ready(Ok(res)) } } #[cfg(test)] mod tests { use super::*; use http::{header, HeaderValue}; use hyper::Body; use std::convert::Infallible; use tower::{service_fn, ServiceExt}; #[tokio::test] async fn test_override_mode() { let svc = SetResponseHeader::overriding( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_append_mode() { let svc = SetResponseHeader::appending( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "good-content"); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_skip_if_present_mode() { let svc = SetResponseHeader::if_not_present( service_fn(|_req: Request| async { let res = Response::builder() .header(header::CONTENT_TYPE, "good-content") .body(Body::empty()) .unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "good-content"); assert_eq!(values.next(), None); } #[tokio::test] async fn test_skip_if_present_mode_when_not_present() { let svc = SetResponseHeader::if_not_present( service_fn(|_req: Request| async { let res = Response::builder().body(Body::empty()).unwrap(); Ok::<_, Infallible>(res) }), header::CONTENT_TYPE, HeaderValue::from_static("text/html"), ); let res = svc.oneshot(Request::new(Body::empty())).await.unwrap(); let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); assert_eq!(values.next().unwrap(), "text/html"); assert_eq!(values.next(), None); } } tower-http-0.4.4/src/set_status.rs000064400000000000000000000071611046102023000153060ustar 00000000000000//! Middleware to override status codes. //! //! # Example //! //! ``` //! use tower_http::set_status::SetStatusLayer; //! use http::{Request, Response, StatusCode}; //! use hyper::Body; //! use std::{iter::once, convert::Infallible}; //! use tower::{ServiceBuilder, Service, ServiceExt}; //! //! async fn handle(req: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut service = ServiceBuilder::new() //! // change the status to `404 Not Found` regardless what the inner service returns //! .layer(SetStatusLayer::new(StatusCode::NOT_FOUND)) //! .service_fn(handle); //! //! // Call the service. //! let request = Request::builder().body(Body::empty())?; //! //! let response = service.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::NOT_FOUND); //! # //! # Ok(()) //! # } //! ``` use http::{Request, Response, StatusCode}; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`SetStatus`] which overrides the status codes. #[derive(Debug, Clone, Copy)] pub struct SetStatusLayer { status: StatusCode, } impl SetStatusLayer { /// Create a new [`SetStatusLayer`]. /// /// The response status code will be `status` regardless of what the inner service returns. pub fn new(status: StatusCode) -> Self { SetStatusLayer { status } } } impl Layer for SetStatusLayer { type Service = SetStatus; fn layer(&self, inner: S) -> Self::Service { SetStatus::new(inner, self.status) } } /// Middleware to override status codes. /// /// See the [module docs](self) for more details. #[derive(Debug, Clone, Copy)] pub struct SetStatus { inner: S, status: StatusCode, } impl SetStatus { /// Create a new [`SetStatus`]. /// /// The response status code will be `status` regardless of what the inner service returns. pub fn new(inner: S, status: StatusCode) -> Self { Self { status, inner } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(status: StatusCode) -> SetStatusLayer { SetStatusLayer::new(status) } } impl Service> for SetStatus where S: Service, Response = Response>, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseFuture { inner: self.inner.call(req), status: Some(self.status), } } } pin_project! { /// Response future for [`SetStatus`]. pub struct ResponseFuture { #[pin] inner: F, status: Option, } } impl Future for ResponseFuture where F: Future, E>>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut response = futures_core::ready!(this.inner.poll(cx)?); *response.status_mut() = this.status.take().expect("future polled after completion"); Poll::Ready(Ok(response)) } } tower-http-0.4.4/src/timeout/body.rs000064400000000000000000000150361046102023000155330ustar 00000000000000use crate::BoxError; use futures_core::{ready, Future}; use http_body::Body; use pin_project_lite::pin_project; use std::{ pin::Pin, task::{Context, Poll}, time::Duration, }; use tokio::time::{sleep, Sleep}; pin_project! { /// Middleware that applies a timeout to request and response bodies. /// /// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration. /// /// Bodies must produce data at most within the specified timeout. /// If the body does not produce a requested data frame within the timeout period, it will return an error. /// /// # Differences from [`Timeout`][crate::timeout::Timeout] /// /// [`Timeout`][crate::timeout::Timeout] applies a timeout to the request future, not body. /// That timeout is not reset when bytes are handled, whether the request is active or not. /// Bodies are handled asynchronously outside of the tower stack's future and thus needs an additional timeout. /// /// This middleware will return a [`TimeoutError`]. /// /// # Example /// /// ``` /// use http::{Request, Response}; /// use hyper::Body; /// use std::time::Duration; /// use tower::ServiceBuilder; /// use tower_http::timeout::RequestBodyTimeoutLayer; /// /// async fn handle(_: Request) -> Result, std::convert::Infallible> { /// // ... /// # todo!() /// } /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// let svc = ServiceBuilder::new() /// // Timeout bodies after 30 seconds of inactivity /// .layer(RequestBodyTimeoutLayer::new(Duration::from_secs(30))) /// .service_fn(handle); /// # Ok(()) /// # } /// ``` pub struct TimeoutBody { timeout: Duration, // In http-body 1.0, `poll_*` will be merged into `poll_frame`. // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 #[pin] sleep_data: Option, #[pin] sleep_trailers: Option, #[pin] body: B, } } impl TimeoutBody { /// Creates a new [`TimeoutBody`]. pub fn new(timeout: Duration, body: B) -> Self { TimeoutBody { timeout, sleep_data: None, sleep_trailers: None, body, } } } impl Body for TimeoutBody where B: Body, B::Error: Into, { type Data = B::Data; type Error = Box; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let mut this = self.project(); // Start the `Sleep` if not active. let sleep_pinned = if let Some(some) = this.sleep_data.as_mut().as_pin_mut() { some } else { this.sleep_data.set(Some(sleep(*this.timeout))); this.sleep_data.as_mut().as_pin_mut().unwrap() }; // Error if the timeout has expired. if let Poll::Ready(()) = sleep_pinned.poll(cx) { return Poll::Ready(Some(Err(Box::new(TimeoutError(()))))); } // Check for body data. let data = ready!(this.body.poll_data(cx)); // Some data is ready. Reset the `Sleep`... this.sleep_data.set(None); Poll::Ready(data.transpose().map_err(Into::into).transpose()) } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { let mut this = self.project(); // In http-body 1.0, `poll_*` will be merged into `poll_frame`. // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 let sleep_pinned = if let Some(some) = this.sleep_trailers.as_mut().as_pin_mut() { some } else { this.sleep_trailers.set(Some(sleep(*this.timeout))); this.sleep_trailers.as_mut().as_pin_mut().unwrap() }; // Error if the timeout has expired. if let Poll::Ready(()) = sleep_pinned.poll(cx) { return Poll::Ready(Err(Box::new(TimeoutError(())))); } this.body.poll_trailers(cx).map_err(Into::into) } } /// Error for [`TimeoutBody`]. #[derive(Debug)] pub struct TimeoutError(()); impl std::error::Error for TimeoutError {} impl std::fmt::Display for TimeoutError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "data was not received within the designated timeout") } } #[cfg(test)] mod tests { use super::*; use bytes::Bytes; use pin_project_lite::pin_project; use std::{error::Error, fmt::Display}; #[derive(Debug)] struct MockError; impl Error for MockError {} impl Display for MockError { fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { todo!() } } pin_project! { struct MockBody { #[pin] sleep: Sleep } } impl Body for MockBody { type Data = Bytes; type Error = MockError; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let this = self.project(); this.sleep.poll(cx).map(|_| Some(Ok(vec![].into()))) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { todo!() } } #[tokio::test] async fn test_body_available_within_timeout() { let mock_sleep = Duration::from_secs(1); let timeout_sleep = Duration::from_secs(2); let mock_body = MockBody { sleep: sleep(mock_sleep), }; let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); assert!(timeout_body.boxed().data().await.unwrap().is_ok()); } #[tokio::test] async fn test_body_unavailable_within_timeout_error() { let mock_sleep = Duration::from_secs(2); let timeout_sleep = Duration::from_secs(1); let mock_body = MockBody { sleep: sleep(mock_sleep), }; let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); assert!(timeout_body.boxed().data().await.unwrap().is_err()); } } tower-http-0.4.4/src/timeout/mod.rs000064400000000000000000000032161046102023000153520ustar 00000000000000//! Middleware that applies a timeout to requests. //! //! If the request does not complete within the specified timeout it will be aborted and a `408 //! Request Timeout` response will be sent. //! //! # Differences from `tower::timeout` //! //! tower's [`Timeout`](tower::timeout::Timeout) middleware uses an error to signal timeout, i.e. //! it changes the error type to [`BoxError`](tower::BoxError). For HTTP services that is rarely //! what you want as returning errors will terminate the connection without sending a response. //! //! This middleware won't change the error type and instead return a `408 Request Timeout` //! response. That means if your service's error type is [`Infallible`] it will still be //! [`Infallible`] after applying this middleware. //! //! # Example //! //! ``` //! use http::{Request, Response}; //! use hyper::Body; //! use std::{convert::Infallible, time::Duration}; //! use tower::ServiceBuilder; //! use tower_http::timeout::TimeoutLayer; //! //! async fn handle(_: Request) -> Result, Infallible> { //! // ... //! # Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let svc = ServiceBuilder::new() //! // Timeout requests after 30 seconds //! .layer(TimeoutLayer::new(Duration::from_secs(30))) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! [`Infallible`]: std::convert::Infallible mod body; mod service; pub use body::{TimeoutBody, TimeoutError}; pub use service::{ RequestBodyTimeout, RequestBodyTimeoutLayer, ResponseBodyTimeout, ResponseBodyTimeoutLayer, Timeout, TimeoutLayer, }; tower-http-0.4.4/src/timeout/service.rs000064400000000000000000000157451046102023000162450ustar 00000000000000use crate::timeout::body::TimeoutBody; use futures_core::ready; use http::{Request, Response, StatusCode}; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, time::Duration, }; use tokio::time::Sleep; use tower_layer::Layer; use tower_service::Service; /// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. /// /// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct TimeoutLayer { timeout: Duration, } impl TimeoutLayer { /// Creates a new [`TimeoutLayer`]. pub fn new(timeout: Duration) -> Self { TimeoutLayer { timeout } } } impl Layer for TimeoutLayer { type Service = Timeout; fn layer(&self, inner: S) -> Self::Service { Timeout::new(inner, self.timeout) } } /// Middleware which apply a timeout to requests. /// /// If the request does not complete within the specified timeout it will be aborted and a `408 /// Request Timeout` response will be sent. /// /// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct Timeout { inner: S, timeout: Duration, } impl Timeout { /// Creates a new [`Timeout`]. pub fn new(inner: S, timeout: Duration) -> Self { Self { inner, timeout } } define_inner_service_accessors!(); /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> TimeoutLayer { TimeoutLayer::new(timeout) } } impl Service> for Timeout where S: Service, Response = Response>, ResBody: Default, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let sleep = tokio::time::sleep(self.timeout); ResponseFuture { inner: self.inner.call(req), sleep, } } } pin_project! { /// Response future for [`Timeout`]. pub struct ResponseFuture { #[pin] inner: F, #[pin] sleep: Sleep, } } impl Future for ResponseFuture where F: Future, E>>, B: Default, { type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if this.sleep.poll(cx).is_ready() { let mut res = Response::new(B::default()); *res.status_mut() = StatusCode::REQUEST_TIMEOUT; return Poll::Ready(Ok(res)); } this.inner.poll(cx) } } /// Applies a [`TimeoutBody`] to the request body. #[derive(Clone, Debug)] pub struct RequestBodyTimeoutLayer { timeout: Duration, } impl RequestBodyTimeoutLayer { /// Creates a new [`RequestBodyTimeoutLayer`]. pub fn new(timeout: Duration) -> Self { Self { timeout } } } impl Layer for RequestBodyTimeoutLayer { type Service = RequestBodyTimeout; fn layer(&self, inner: S) -> Self::Service { RequestBodyTimeout::new(inner, self.timeout) } } /// Applies a [`TimeoutBody`] to the request body. #[derive(Clone, Debug)] pub struct RequestBodyTimeout { inner: S, timeout: Duration, } impl RequestBodyTimeout { /// Creates a new [`RequestBodyTimeout`]. pub fn new(service: S, timeout: Duration) -> Self { Self { inner: service, timeout, } } /// Returns a new [`Layer`] that wraps services with a [`RequestBodyTimeoutLayer`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> RequestBodyTimeoutLayer { RequestBodyTimeoutLayer::new(timeout) } define_inner_service_accessors!(); } impl Service> for RequestBodyTimeout where S: Service>>, S::Error: Into>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let req = req.map(|body| TimeoutBody::new(self.timeout, body)); self.inner.call(req) } } /// Applies a [`TimeoutBody`] to the response body. #[derive(Clone)] pub struct ResponseBodyTimeoutLayer { timeout: Duration, } impl ResponseBodyTimeoutLayer { /// Creates a new [`ResponseBodyTimeoutLayer`]. pub fn new(timeout: Duration) -> Self { Self { timeout } } } impl Layer for ResponseBodyTimeoutLayer { type Service = ResponseBodyTimeout; fn layer(&self, inner: S) -> Self::Service { ResponseBodyTimeout::new(inner, self.timeout) } } /// Applies a [`TimeoutBody`] to the response body. #[derive(Clone)] pub struct ResponseBodyTimeout { inner: S, timeout: Duration, } impl Service> for ResponseBodyTimeout where S: Service, Response = Response>, S::Error: Into>, { type Response = Response>; type Error = S::Error; type Future = ResponseBodyTimeoutFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { ResponseBodyTimeoutFuture { inner: self.inner.call(req), timeout: self.timeout, } } } impl ResponseBodyTimeout { /// Creates a new [`ResponseBodyTimeout`]. pub fn new(service: S, timeout: Duration) -> Self { Self { inner: service, timeout, } } /// Returns a new [`Layer`] that wraps services with a [`ResponseBodyTimeoutLayer`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(timeout: Duration) -> ResponseBodyTimeoutLayer { ResponseBodyTimeoutLayer::new(timeout) } define_inner_service_accessors!(); } pin_project! { /// Response future for [`ResponseBodyTimeout`]. pub struct ResponseBodyTimeoutFuture { #[pin] inner: Fut, timeout: Duration, } } impl Future for ResponseBodyTimeoutFuture where Fut: Future, E>>, { type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let timeout = self.timeout; let this = self.project(); let res = ready!(this.inner.poll(cx))?; Poll::Ready(Ok(res.map(|body| TimeoutBody::new(timeout, body)))) } } tower-http-0.4.4/src/trace/body.rs000064400000000000000000000075771046102023000151560ustar 00000000000000use super::{OnBodyChunk, OnEos, OnFailure}; use crate::classify::ClassifyEos; use futures_core::ready; use http::HeaderMap; use http_body::Body; use pin_project_lite::pin_project; use std::{ fmt, pin::Pin, task::{Context, Poll}, time::Instant, }; use tracing::Span; pin_project! { /// Response body for [`Trace`]. /// /// [`Trace`]: super::Trace pub struct ResponseBody { #[pin] pub(crate) inner: B, pub(crate) classify_eos: Option, pub(crate) on_eos: Option<(OnEos, Instant)>, pub(crate) on_body_chunk: OnBodyChunk, pub(crate) on_failure: Option, pub(crate) start: Instant, pub(crate) span: Span, } } impl Body for ResponseBody where B: Body, B::Error: fmt::Display + 'static, C: ClassifyEos, OnEosT: OnEos, OnBodyChunkT: OnBodyChunk, OnFailureT: OnFailure, { type Data = B::Data; type Error = B::Error; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let this = self.project(); let _guard = this.span.enter(); let result = if let Some(result) = ready!(this.inner.poll_data(cx)) { result } else { return Poll::Ready(None); }; let latency = this.start.elapsed(); *this.start = Instant::now(); match &result { Ok(chunk) => { this.on_body_chunk.on_body_chunk(chunk, latency, this.span); } Err(err) => { if let Some((classify_eos, mut on_failure)) = this.classify_eos.take().zip(this.on_failure.take()) { let failure_class = classify_eos.classify_error(err); on_failure.on_failure(failure_class, latency, this.span); } } } Poll::Ready(Some(result)) } fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { let this = self.project(); let _guard = this.span.enter(); let result = ready!(this.inner.poll_trailers(cx)); let latency = this.start.elapsed(); if let Some((classify_eos, mut on_failure)) = this.classify_eos.take().zip(this.on_failure.take()) { match &result { Ok(trailers) => { if let Err(failure_class) = classify_eos.classify_eos(trailers.as_ref()) { on_failure.on_failure(failure_class, latency, this.span); } if let Some((on_eos, stream_start)) = this.on_eos.take() { on_eos.on_eos(trailers.as_ref(), stream_start.elapsed(), this.span); } } Err(err) => { let failure_class = classify_eos.classify_error(err); on_failure.on_failure(failure_class, latency, this.span); } } } Poll::Ready(result) } fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } fn size_hint(&self) -> http_body::SizeHint { self.inner.size_hint() } } impl Default for ResponseBody { fn default() -> Self { Self { inner: Default::default(), classify_eos: Default::default(), on_eos: Default::default(), on_body_chunk: Default::default(), on_failure: Default::default(), start: Instant::now(), span: Span::current(), } } } tower-http-0.4.4/src/trace/future.rs000064400000000000000000000076351046102023000155260ustar 00000000000000use super::{OnBodyChunk, OnEos, OnFailure, OnResponse, ResponseBody}; use crate::classify::{ClassifiedResponse, ClassifyResponse}; use http::Response; use http_body::Body; use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, time::Instant, }; use tracing::Span; pin_project! { /// Response future for [`Trace`]. /// /// [`Trace`]: super::Trace pub struct ResponseFuture { #[pin] pub(crate) inner: F, pub(crate) span: Span, pub(crate) classifier: Option, pub(crate) on_response: Option, pub(crate) on_body_chunk: Option, pub(crate) on_eos: Option, pub(crate) on_failure: Option, pub(crate) start: Instant, } } impl Future for ResponseFuture where Fut: Future, E>>, ResBody: Body, ResBody::Error: std::fmt::Display + 'static, E: std::fmt::Display + 'static, C: ClassifyResponse, OnResponseT: OnResponse, OnFailureT: OnFailure, OnBodyChunkT: OnBodyChunk, OnEosT: OnEos, { type Output = Result< Response>, E, >; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let _guard = this.span.enter(); let result = futures_util::ready!(this.inner.poll(cx)); let latency = this.start.elapsed(); let classifier = this.classifier.take().unwrap(); let on_eos = this.on_eos.take(); let on_body_chunk = this.on_body_chunk.take().unwrap(); let mut on_failure = this.on_failure.take().unwrap(); match result { Ok(res) => { let classification = classifier.classify_response(&res); let start = *this.start; this.on_response .take() .unwrap() .on_response(&res, latency, this.span); match classification { ClassifiedResponse::Ready(classification) => { if let Err(failure_class) = classification { on_failure.on_failure(failure_class, latency, this.span); } let span = this.span.clone(); let res = res.map(|body| ResponseBody { inner: body, classify_eos: None, on_eos: None, on_body_chunk, on_failure: Some(on_failure), start, span, }); Poll::Ready(Ok(res)) } ClassifiedResponse::RequiresEos(classify_eos) => { let span = this.span.clone(); let res = res.map(|body| ResponseBody { inner: body, classify_eos: Some(classify_eos), on_eos: on_eos.zip(Some(Instant::now())), on_body_chunk, on_failure: Some(on_failure), start, span, }); Poll::Ready(Ok(res)) } } } Err(err) => { let failure_class = classifier.classify_error(&err); on_failure.on_failure(failure_class, latency, this.span); Poll::Ready(Err(err)) } } } } tower-http-0.4.4/src/trace/layer.rs000064400000000000000000000200531046102023000153150ustar 00000000000000use super::{ DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, Trace, }; use crate::classify::{ GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier, }; use tower_layer::Layer; /// [`Layer`] that adds high level [tracing] to a [`Service`]. /// /// See the [module docs](crate::trace) for more details. /// /// [`Layer`]: tower_layer::Layer /// [tracing]: https://crates.io/crates/tracing /// [`Service`]: tower_service::Service #[derive(Debug, Copy, Clone)] pub struct TraceLayer< M, MakeSpan = DefaultMakeSpan, OnRequest = DefaultOnRequest, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure, > { pub(crate) make_classifier: M, pub(crate) make_span: MakeSpan, pub(crate) on_request: OnRequest, pub(crate) on_response: OnResponse, pub(crate) on_body_chunk: OnBodyChunk, pub(crate) on_eos: OnEos, pub(crate) on_failure: OnFailure, } impl TraceLayer { /// Create a new [`TraceLayer`] using the given [`MakeClassifier`]. pub fn new(make_classifier: M) -> Self where M: MakeClassifier, { Self { make_classifier, make_span: DefaultMakeSpan::new(), on_failure: DefaultOnFailure::default(), on_request: DefaultOnRequest::default(), on_eos: DefaultOnEos::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_response: DefaultOnResponse::default(), } } } impl TraceLayer { /// Customize what to do when a request is received. /// /// `NewOnRequest` is expected to implement [`OnRequest`]. /// /// [`OnRequest`]: super::OnRequest pub fn on_request( self, new_on_request: NewOnRequest, ) -> TraceLayer { TraceLayer { on_request: new_on_request, on_failure: self.on_failure, on_eos: self.on_eos, on_body_chunk: self.on_body_chunk, make_span: self.make_span, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a response has been produced. /// /// `NewOnResponse` is expected to implement [`OnResponse`]. /// /// [`OnResponse`]: super::OnResponse pub fn on_response( self, new_on_response: NewOnResponse, ) -> TraceLayer { TraceLayer { on_response: new_on_response, on_request: self.on_request, on_eos: self.on_eos, on_body_chunk: self.on_body_chunk, on_failure: self.on_failure, make_span: self.make_span, make_classifier: self.make_classifier, } } /// Customize what to do when a body chunk has been sent. /// /// `NewOnBodyChunk` is expected to implement [`OnBodyChunk`]. /// /// [`OnBodyChunk`]: super::OnBodyChunk pub fn on_body_chunk( self, new_on_body_chunk: NewOnBodyChunk, ) -> TraceLayer { TraceLayer { on_body_chunk: new_on_body_chunk, on_eos: self.on_eos, on_failure: self.on_failure, on_request: self.on_request, make_span: self.make_span, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a streaming response has closed. /// /// `NewOnEos` is expected to implement [`OnEos`]. /// /// [`OnEos`]: super::OnEos pub fn on_eos( self, new_on_eos: NewOnEos, ) -> TraceLayer { TraceLayer { on_eos: new_on_eos, on_body_chunk: self.on_body_chunk, on_failure: self.on_failure, on_request: self.on_request, make_span: self.make_span, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a response has been classified as a failure. /// /// `NewOnFailure` is expected to implement [`OnFailure`]. /// /// [`OnFailure`]: super::OnFailure pub fn on_failure( self, new_on_failure: NewOnFailure, ) -> TraceLayer { TraceLayer { on_failure: new_on_failure, on_request: self.on_request, on_eos: self.on_eos, on_body_chunk: self.on_body_chunk, make_span: self.make_span, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize how to make [`Span`]s that all request handling will be wrapped in. /// /// `NewMakeSpan` is expected to implement [`MakeSpan`]. /// /// [`MakeSpan`]: super::MakeSpan /// [`Span`]: tracing::Span pub fn make_span_with( self, new_make_span: NewMakeSpan, ) -> TraceLayer { TraceLayer { make_span: new_make_span, on_request: self.on_request, on_failure: self.on_failure, on_body_chunk: self.on_body_chunk, on_eos: self.on_eos, on_response: self.on_response, make_classifier: self.make_classifier, } } } impl TraceLayer> { /// Create a new [`TraceLayer`] using [`ServerErrorsAsFailures`] which supports classifying /// regular HTTP responses based on the status code. pub fn new_for_http() -> Self { Self { make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()), make_span: DefaultMakeSpan::new(), on_response: DefaultOnResponse::default(), on_request: DefaultOnRequest::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_eos: DefaultOnEos::default(), on_failure: DefaultOnFailure::default(), } } } impl TraceLayer> { /// Create a new [`TraceLayer`] using [`GrpcErrorsAsFailures`] which supports classifying /// gRPC responses and streams based on the `grpc-status` header. pub fn new_for_grpc() -> Self { Self { make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()), make_span: DefaultMakeSpan::new(), on_response: DefaultOnResponse::default(), on_request: DefaultOnRequest::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_eos: DefaultOnEos::default(), on_failure: DefaultOnFailure::default(), } } } impl Layer for TraceLayer where M: Clone, MakeSpan: Clone, OnRequest: Clone, OnResponse: Clone, OnEos: Clone, OnBodyChunk: Clone, OnFailure: Clone, { type Service = Trace; fn layer(&self, inner: S) -> Self::Service { Trace { inner, make_classifier: self.make_classifier.clone(), make_span: self.make_span.clone(), on_request: self.on_request.clone(), on_eos: self.on_eos.clone(), on_body_chunk: self.on_body_chunk.clone(), on_response: self.on_response.clone(), on_failure: self.on_failure.clone(), } } } tower-http-0.4.4/src/trace/make_span.rs000064400000000000000000000060441046102023000161430ustar 00000000000000use http::Request; use tracing::{Level, Span}; use super::DEFAULT_MESSAGE_LEVEL; /// Trait used to generate [`Span`]s from requests. [`Trace`] wraps all request handling in this /// span. /// /// [`Span`]: tracing::Span /// [`Trace`]: super::Trace pub trait MakeSpan { /// Make a span from a request. fn make_span(&mut self, request: &Request) -> Span; } impl MakeSpan for Span { fn make_span(&mut self, _request: &Request) -> Span { self.clone() } } impl MakeSpan for F where F: FnMut(&Request) -> Span, { fn make_span(&mut self, request: &Request) -> Span { self(request) } } /// The default way [`Span`]s will be created for [`Trace`]. /// /// [`Span`]: tracing::Span /// [`Trace`]: super::Trace #[derive(Debug, Clone)] pub struct DefaultMakeSpan { level: Level, include_headers: bool, } impl DefaultMakeSpan { /// Create a new `DefaultMakeSpan`. pub fn new() -> Self { Self { level: DEFAULT_MESSAGE_LEVEL, include_headers: false, } } /// Set the [`Level`] used for the [tracing span]. /// /// Defaults to [`Level::DEBUG`]. /// /// [tracing span]: https://docs.rs/tracing/latest/tracing/#spans pub fn level(mut self, level: Level) -> Self { self.level = level; self } /// Include request headers on the [`Span`]. /// /// By default headers are not included. /// /// [`Span`]: tracing::Span pub fn include_headers(mut self, include_headers: bool) -> Self { self.include_headers = include_headers; self } } impl Default for DefaultMakeSpan { fn default() -> Self { Self::new() } } impl MakeSpan for DefaultMakeSpan { fn make_span(&mut self, request: &Request) -> Span { // This ugly macro is needed, unfortunately, because `tracing::span!` // required the level argument to be static. Meaning we can't just pass // `self.level`. macro_rules! make_span { ($level:expr) => { if self.include_headers { tracing::span!( $level, "request", method = %request.method(), uri = %request.uri(), version = ?request.version(), headers = ?request.headers(), ) } else { tracing::span!( $level, "request", method = %request.method(), uri = %request.uri(), version = ?request.version(), ) } } } match self.level { Level::ERROR => make_span!(Level::ERROR), Level::WARN => make_span!(Level::WARN), Level::INFO => make_span!(Level::INFO), Level::DEBUG => make_span!(Level::DEBUG), Level::TRACE => make_span!(Level::TRACE), } } } tower-http-0.4.4/src/trace/mod.rs000064400000000000000000000525501046102023000147670ustar 00000000000000//! Middleware that adds high level [tracing] to a [`Service`]. //! //! # Example //! //! Adding tracing to your service can be as simple as: //! //! ```rust //! use http::{Request, Response}; //! use hyper::Body; //! use tower::{ServiceBuilder, ServiceExt, Service}; //! use tower_http::trace::TraceLayer; //! use std::convert::Infallible; //! //! async fn handle(request: Request) -> Result, Infallible> { //! Ok(Response::new(Body::from("foo"))) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! // Setup tracing //! tracing_subscriber::fmt::init(); //! //! let mut service = ServiceBuilder::new() //! .layer(TraceLayer::new_for_http()) //! .service_fn(handle); //! //! let request = Request::new(Body::from("foo")); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! # Ok(()) //! # } //! ``` //! //! If you run this application with `RUST_LOG=tower_http=trace cargo run` you should see logs like: //! //! ```text //! Mar 05 20:50:28.523 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_request: started processing request //! Mar 05 20:50:28.524 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_response: finished processing request latency=1 ms status=200 //! ``` //! //! # Customization //! //! [`Trace`] comes with good defaults but also supports customizing many aspects of the output. //! //! The default behaviour supports some customization: //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; //! use hyper::Body; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tracing::Level; //! use tower_http::{ //! LatencyUnit, //! trace::{TraceLayer, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse}, //! }; //! use std::time::Duration; //! # use tower::{ServiceExt, Service}; //! # use std::convert::Infallible; //! //! # async fn handle(request: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # tracing_subscriber::fmt::init(); //! # //! let service = ServiceBuilder::new() //! .layer( //! TraceLayer::new_for_http() //! .make_span_with( //! DefaultMakeSpan::new().include_headers(true) //! ) //! .on_request( //! DefaultOnRequest::new().level(Level::INFO) //! ) //! .on_response( //! DefaultOnResponse::new() //! .level(Level::INFO) //! .latency_unit(LatencyUnit::Micros) //! ) //! // on so on for `on_eos`, `on_body_chunk`, and `on_failure` //! ) //! .service_fn(handle); //! # let mut service = service; //! # let response = service //! # .ready() //! # .await? //! # .call(Request::new(Body::from("foo"))) //! # .await?; //! # Ok(()) //! # } //! ``` //! //! However for maximum control you can provide callbacks: //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; //! use hyper::Body; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; //! use std::time::Duration; //! use tracing::Span; //! # use tower::{ServiceExt, Service}; //! # use std::convert::Infallible; //! //! # async fn handle(request: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # tracing_subscriber::fmt::init(); //! # //! let service = ServiceBuilder::new() //! .layer( //! TraceLayer::new_for_http() //! .make_span_with(|request: &Request| { //! tracing::debug_span!("http-request") //! }) //! .on_request(|request: &Request, _span: &Span| { //! tracing::debug!("started {} {}", request.method(), request.uri().path()) //! }) //! .on_response(|response: &Response, latency: Duration, _span: &Span| { //! tracing::debug!("response generated in {:?}", latency) //! }) //! .on_body_chunk(|chunk: &Bytes, latency: Duration, _span: &Span| { //! tracing::debug!("sending {} bytes", chunk.len()) //! }) //! .on_eos(|trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span| { //! tracing::debug!("stream closed after {:?}", stream_duration) //! }) //! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| { //! tracing::debug!("something went wrong") //! }) //! ) //! .service_fn(handle); //! # let mut service = service; //! # let response = service //! # .ready() //! # .await? //! # .call(Request::new(Body::from("foo"))) //! # .await?; //! # Ok(()) //! # } //! ``` //! //! ## Disabling something //! //! Setting the behaviour to `()` will be disable that particular step: //! //! ```rust //! use http::StatusCode; //! use tower::ServiceBuilder; //! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; //! use std::time::Duration; //! use tracing::Span; //! # use tower::{ServiceExt, Service}; //! # use hyper::Body; //! # use http::{Response, Request}; //! # use std::convert::Infallible; //! //! # async fn handle(request: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # tracing_subscriber::fmt::init(); //! # //! let service = ServiceBuilder::new() //! .layer( //! // This configuration will only emit events on failures //! TraceLayer::new_for_http() //! .on_request(()) //! .on_response(()) //! .on_body_chunk(()) //! .on_eos(()) //! .on_failure(|error: ServerErrorsFailureClass, latency: Duration, _span: &Span| { //! tracing::debug!("something went wrong") //! }) //! ) //! .service_fn(handle); //! # let mut service = service; //! # let response = service //! # .ready() //! # .await? //! # .call(Request::new(Body::from("foo"))) //! # .await?; //! # Ok(()) //! # } //! ``` //! //! # When the callbacks are called //! //! ### `on_request` //! //! The `on_request` callback is called when the request arrives at the //! middleware in [`Service::call`] just prior to passing the request to the //! inner service. //! //! ### `on_response` //! //! The `on_response` callback is called when the inner service's response //! future completes with `Ok(response)` regardless if the response is //! classified as a success or a failure. //! //! For example if you're using [`ServerErrorsAsFailures`] as your classifier //! and the inner service responds with `500 Internal Server Error` then the //! `on_response` callback is still called. `on_failure` would _also_ be called //! in this case since the response was classified as a failure. //! //! ### `on_body_chunk` //! //! The `on_body_chunk` callback is called when the response body produces a new //! chunk, that is when [`Body::poll_data`] returns `Poll::Ready(Some(Ok(chunk)))`. //! //! `on_body_chunk` is called even if the chunk is empty. //! //! ### `on_eos` //! //! The `on_eos` callback is called when a streaming response body ends, that is //! when [`Body::poll_trailers`] returns `Poll::Ready(Ok(trailers))`. //! //! `on_eos` is called even if the trailers produced are `None`. //! //! ### `on_failure` //! //! The `on_failure` callback is called when: //! //! - The inner [`Service`]'s response future resolves to an error. //! - A response is classified as a failure. //! - [`Body::poll_data`] returns an error. //! - [`Body::poll_trailers`] returns an error. //! - An end-of-stream is classified as a failure. //! //! # Recording fields on the span //! //! All callbacks receive a reference to the [tracing] [`Span`], corresponding to this request, //! produced by the closure passed to [`TraceLayer::make_span_with`]. It can be used to [record //! field values][record] that weren't known when the span was created. //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; //! use hyper::Body; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tower_http::trace::TraceLayer; //! use tracing::Span; //! use std::time::Duration; //! # use std::convert::Infallible; //! //! # async fn handle(request: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # tracing_subscriber::fmt::init(); //! # //! let service = ServiceBuilder::new() //! .layer( //! TraceLayer::new_for_http() //! .make_span_with(|request: &Request| { //! tracing::debug_span!( //! "http-request", //! status_code = tracing::field::Empty, //! ) //! }) //! .on_response(|response: &Response, _latency: Duration, span: &Span| { //! span.record("status_code", &tracing::field::display(response.status())); //! //! tracing::debug!("response generated") //! }) //! ) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! # Providing classifiers //! //! Tracing requires determining if a response is a success or failure. [`MakeClassifier`] is used //! to create a classifier for the incoming request. See the docs for [`MakeClassifier`] and //! [`ClassifyResponse`] for more details on classification. //! //! A [`MakeClassifier`] can be provided when creating a [`TraceLayer`]: //! //! ```rust //! use http::{Request, Response}; //! use hyper::Body; //! use tower::ServiceBuilder; //! use tower_http::{ //! trace::TraceLayer, //! classify::{ //! MakeClassifier, ClassifyResponse, ClassifiedResponse, NeverClassifyEos, //! SharedClassifier, //! }, //! }; //! use std::convert::Infallible; //! //! # async fn handle(request: Request) -> Result, Infallible> { //! # Ok(Response::new(Body::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! # tracing_subscriber::fmt::init(); //! # //! // Our `MakeClassifier` that always crates `MyClassifier` classifiers. //! #[derive(Copy, Clone)] //! struct MyMakeClassify; //! //! impl MakeClassifier for MyMakeClassify { //! type Classifier = MyClassifier; //! type FailureClass = &'static str; //! type ClassifyEos = NeverClassifyEos<&'static str>; //! //! fn make_classifier(&self, req: &Request) -> Self::Classifier { //! MyClassifier //! } //! } //! //! // A classifier that classifies failures as `"something went wrong..."`. //! #[derive(Copy, Clone)] //! struct MyClassifier; //! //! impl ClassifyResponse for MyClassifier { //! type FailureClass = &'static str; //! type ClassifyEos = NeverClassifyEos<&'static str>; //! //! fn classify_response( //! self, //! res: &Response //! ) -> ClassifiedResponse { //! // Classify based on the status code. //! if res.status().is_server_error() { //! ClassifiedResponse::Ready(Err("something went wrong...")) //! } else { //! ClassifiedResponse::Ready(Ok(())) //! } //! } //! //! fn classify_error(self, error: &E) -> Self::FailureClass //! where //! E: std::fmt::Display + 'static, //! { //! "something went wrong..." //! } //! } //! //! let service = ServiceBuilder::new() //! // Create a trace layer that uses our classifier. //! .layer(TraceLayer::new(MyMakeClassify)) //! .service_fn(handle); //! //! // Since `MyClassifier` is `Clone` we can also use `SharedClassifier` //! // to avoid having to define a separate `MakeClassifier`. //! let service = ServiceBuilder::new() //! .layer(TraceLayer::new(SharedClassifier::new(MyClassifier))) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! [`TraceLayer`] comes with convenience methods for using common classifiers: //! //! - [`TraceLayer::new_for_http`] classifies based on the status code. It doesn't consider //! streaming responses. //! - [`TraceLayer::new_for_grpc`] classifies based on the gRPC protocol and supports streaming //! responses. //! //! [tracing]: https://crates.io/crates/tracing //! [`Service`]: tower_service::Service //! [`Service::call`]: tower_service::Service::call //! [`MakeClassifier`]: crate::classify::MakeClassifier //! [`ClassifyResponse`]: crate::classify::ClassifyResponse //! [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record //! [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with //! [`Span`]: tracing::Span //! [`ServerErrorsAsFailures`]: crate::classify::ServerErrorsAsFailures //! [`Body::poll_trailers`]: http_body::Body::poll_trailers //! [`Body::poll_data`]: http_body::Body::poll_data use std::{fmt, time::Duration}; use tracing::Level; pub use self::{ body::ResponseBody, future::ResponseFuture, layer::TraceLayer, make_span::{DefaultMakeSpan, MakeSpan}, on_body_chunk::{DefaultOnBodyChunk, OnBodyChunk}, on_eos::{DefaultOnEos, OnEos}, on_failure::{DefaultOnFailure, OnFailure}, on_request::{DefaultOnRequest, OnRequest}, on_response::{DefaultOnResponse, OnResponse}, service::Trace, }; use crate::LatencyUnit; macro_rules! event_dynamic_lvl { ( $(target: $target:expr,)? $(parent: $parent:expr,)? $lvl:expr, $($tt:tt)* ) => { match $lvl { tracing::Level::ERROR => { tracing::event!( $(target: $target,)? $(parent: $parent,)? tracing::Level::ERROR, $($tt)* ); } tracing::Level::WARN => { tracing::event!( $(target: $target,)? $(parent: $parent,)? tracing::Level::WARN, $($tt)* ); } tracing::Level::INFO => { tracing::event!( $(target: $target,)? $(parent: $parent,)? tracing::Level::INFO, $($tt)* ); } tracing::Level::DEBUG => { tracing::event!( $(target: $target,)? $(parent: $parent,)? tracing::Level::DEBUG, $($tt)* ); } tracing::Level::TRACE => { tracing::event!( $(target: $target,)? $(parent: $parent,)? tracing::Level::TRACE, $($tt)* ); } } }; } mod body; mod future; mod layer; mod make_span; mod on_body_chunk; mod on_eos; mod on_failure; mod on_request; mod on_response; mod service; const DEFAULT_MESSAGE_LEVEL: Level = Level::DEBUG; const DEFAULT_ERROR_LEVEL: Level = Level::ERROR; struct Latency { unit: LatencyUnit, duration: Duration, } impl fmt::Display for Latency { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.unit { LatencyUnit::Seconds => write!(f, "{} s", self.duration.as_secs_f64()), LatencyUnit::Millis => write!(f, "{} ms", self.duration.as_millis()), LatencyUnit::Micros => write!(f, "{} μs", self.duration.as_micros()), LatencyUnit::Nanos => write!(f, "{} ns", self.duration.as_nanos()), } } } #[cfg(test)] mod tests { use super::*; use crate::classify::ServerErrorsFailureClass; use bytes::Bytes; use http::{HeaderMap, Request, Response}; use hyper::Body; use once_cell::sync::Lazy; use std::{ sync::atomic::{AtomicU32, Ordering}, time::Duration, }; use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; use tracing::Span; #[tokio::test] async fn unary_request() { static ON_REQUEST_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_RESPONSE_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_BODY_CHUNK_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_EOS: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_FAILURE: Lazy = Lazy::new(|| AtomicU32::new(0)); let trace_layer = TraceLayer::new_for_http() .make_span_with(|_req: &Request| { tracing::info_span!("test-span", foo = tracing::field::Empty) }) .on_request(|_req: &Request, span: &Span| { span.record("foo", &42); ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_response(|_res: &Response, _latency: Duration, _span: &Span| { ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_eos( |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { ON_EOS.fetch_add(1, Ordering::SeqCst); }, ) .on_failure( |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { ON_FAILURE.fetch_add(1, Ordering::SeqCst); }, ); let mut svc = ServiceBuilder::new().layer(trace_layer).service_fn(echo); let res = svc .ready() .await .unwrap() .call(Request::new(Body::from("foobar"))) .await .unwrap(); assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request"); assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request"); assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); hyper::body::to_bytes(res.into_body()).await.unwrap(); assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); } #[tokio::test] async fn streaming_response() { static ON_REQUEST_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_RESPONSE_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_BODY_CHUNK_COUNT: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_EOS: Lazy = Lazy::new(|| AtomicU32::new(0)); static ON_FAILURE: Lazy = Lazy::new(|| AtomicU32::new(0)); let trace_layer = TraceLayer::new_for_http() .on_request(|_req: &Request, _span: &Span| { ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_response(|_res: &Response, _latency: Duration, _span: &Span| { ON_RESPONSE_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { ON_BODY_CHUNK_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_eos( |_trailers: Option<&HeaderMap>, _latency: Duration, _span: &Span| { ON_EOS.fetch_add(1, Ordering::SeqCst); }, ) .on_failure( |_class: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { ON_FAILURE.fetch_add(1, Ordering::SeqCst); }, ); let mut svc = ServiceBuilder::new() .layer(trace_layer) .service_fn(streaming_body); let res = svc .ready() .await .unwrap() .call(Request::new(Body::empty())) .await .unwrap(); assert_eq!(1, ON_REQUEST_COUNT.load(Ordering::SeqCst), "request"); assert_eq!(1, ON_RESPONSE_COUNT.load(Ordering::SeqCst), "request"); assert_eq!(0, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); hyper::body::to_bytes(res.into_body()).await.unwrap(); assert_eq!(3, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } async fn streaming_body(_req: Request) -> Result, BoxError> { use futures::stream::iter; let stream = iter(vec![ Ok::<_, BoxError>(Bytes::from("one")), Ok::<_, BoxError>(Bytes::from("two")), Ok::<_, BoxError>(Bytes::from("three")), ]); let body = Body::wrap_stream(stream); Ok(Response::new(body)) } } tower-http-0.4.4/src/trace/on_body_chunk.rs000064400000000000000000000037501046102023000170270ustar 00000000000000use std::time::Duration; use tracing::Span; /// Trait used to tell [`Trace`] what to do when a body chunk has been sent. /// /// See the [module docs](../trace/index.html#on_body_chunk) for details on exactly when the /// `on_body_chunk` callback is called. /// /// [`Trace`]: super::Trace pub trait OnBodyChunk { /// Do the thing. /// /// `latency` is the duration since the response was sent or since the last body chunk as sent. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] /// that weren't known when the span was created. /// /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// /// If you're using [hyper] as your server `B` will most likely be [`Bytes`]. /// /// [hyper]: https://hyper.rs /// [`Bytes`]: https://docs.rs/bytes/latest/bytes/struct.Bytes.html /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with fn on_body_chunk(&mut self, chunk: &B, latency: Duration, span: &Span); } impl OnBodyChunk for F where F: FnMut(&B, Duration, &Span), { fn on_body_chunk(&mut self, chunk: &B, latency: Duration, span: &Span) { self(chunk, latency, span) } } impl OnBodyChunk for () { #[inline] fn on_body_chunk(&mut self, _: &B, _: Duration, _: &Span) {} } /// The default [`OnBodyChunk`] implementation used by [`Trace`]. /// /// Simply does nothing. /// /// [`Trace`]: super::Trace #[derive(Debug, Default, Clone)] pub struct DefaultOnBodyChunk { _priv: (), } impl DefaultOnBodyChunk { /// Create a new `DefaultOnBodyChunk`. pub fn new() -> Self { Self { _priv: () } } } impl OnBodyChunk for DefaultOnBodyChunk { #[inline] fn on_body_chunk(&mut self, _: &B, _: Duration, _: &Span) {} } tower-http-0.4.4/src/trace/on_eos.rs000064400000000000000000000067031046102023000154710ustar 00000000000000use super::{Latency, DEFAULT_MESSAGE_LEVEL}; use crate::{classify::grpc_errors_as_failures::ParsedGrpcStatus, LatencyUnit}; use http::header::HeaderMap; use std::time::Duration; use tracing::{Level, Span}; /// Trait used to tell [`Trace`] what to do when a stream closes. /// /// See the [module docs](../trace/index.html#on_eos) for details on exactly when the `on_eos` /// callback is called. /// /// [`Trace`]: super::Trace pub trait OnEos { /// Do the thing. /// /// `stream_duration` is the duration since the response was sent. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] /// that weren't known when the span was created. /// /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span); } impl OnEos for () { #[inline] fn on_eos(self, _: Option<&HeaderMap>, _: Duration, _: &Span) {} } impl OnEos for F where F: FnOnce(Option<&HeaderMap>, Duration, &Span), { fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, span: &Span) { self(trailers, stream_duration, span) } } /// The default [`OnEos`] implementation used by [`Trace`]. /// /// [`Trace`]: super::Trace #[derive(Clone, Debug)] pub struct DefaultOnEos { level: Level, latency_unit: LatencyUnit, } impl Default for DefaultOnEos { fn default() -> Self { Self { level: DEFAULT_MESSAGE_LEVEL, latency_unit: LatencyUnit::Millis, } } } impl DefaultOnEos { /// Create a new [`DefaultOnEos`]. pub fn new() -> Self { Self::default() } /// Set the [`Level`] used for [tracing events]. /// /// Defaults to [`Level::DEBUG`]. /// /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events /// [`Level::DEBUG`]: https://docs.rs/tracing/latest/tracing/struct.Level.html#associatedconstant.DEBUG pub fn level(mut self, level: Level) -> Self { self.level = level; self } /// Set the [`LatencyUnit`] latencies will be reported in. /// /// Defaults to [`LatencyUnit::Millis`]. pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { self.latency_unit = latency_unit; self } } impl OnEos for DefaultOnEos { fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span) { let stream_duration = Latency { unit: self.latency_unit, duration: stream_duration, }; let status = trailers.and_then(|trailers| { match crate::classify::grpc_errors_as_failures::classify_grpc_metadata( trailers, crate::classify::GrpcCode::Ok.into_bitmask(), ) { ParsedGrpcStatus::Success | ParsedGrpcStatus::HeaderNotString | ParsedGrpcStatus::HeaderNotInt => Some(0), ParsedGrpcStatus::NonSuccess(status) => Some(status.get()), ParsedGrpcStatus::GrpcStatusHeaderMissing => None, } }); event_dynamic_lvl!(self.level, %stream_duration, status, "end of stream"); } } tower-http-0.4.4/src/trace/on_failure.rs000064400000000000000000000057411046102023000163330ustar 00000000000000use super::{Latency, DEFAULT_ERROR_LEVEL}; use crate::LatencyUnit; use std::{fmt, time::Duration}; use tracing::{Level, Span}; /// Trait used to tell [`Trace`] what to do when a request fails. /// /// See the [module docs](../trace/index.html#on_failure) for details on exactly when the /// `on_failure` callback is called. /// /// [`Trace`]: super::Trace pub trait OnFailure { /// Do the thing. /// /// `latency` is the duration since the request was received. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] /// that weren't known when the span was created. /// /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, span: &Span); } impl OnFailure for () { #[inline] fn on_failure(&mut self, _: FailureClass, _: Duration, _: &Span) {} } impl OnFailure for F where F: FnMut(FailureClass, Duration, &Span), { fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, span: &Span) { self(failure_classification, latency, span) } } /// The default [`OnFailure`] implementation used by [`Trace`]. /// /// [`Trace`]: super::Trace #[derive(Clone, Debug)] pub struct DefaultOnFailure { level: Level, latency_unit: LatencyUnit, } impl Default for DefaultOnFailure { fn default() -> Self { Self { level: DEFAULT_ERROR_LEVEL, latency_unit: LatencyUnit::Millis, } } } impl DefaultOnFailure { /// Create a new `DefaultOnFailure`. pub fn new() -> Self { Self::default() } /// Set the [`Level`] used for [tracing events]. /// /// Defaults to [`Level::ERROR`]. /// /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events pub fn level(mut self, level: Level) -> Self { self.level = level; self } /// Set the [`LatencyUnit`] latencies will be reported in. /// /// Defaults to [`LatencyUnit::Millis`]. pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { self.latency_unit = latency_unit; self } } impl OnFailure for DefaultOnFailure where FailureClass: fmt::Display, { fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, _: &Span) { let latency = Latency { unit: self.latency_unit, duration: latency, }; event_dynamic_lvl!( self.level, classification = %failure_classification, %latency, "response failed" ); } } tower-http-0.4.4/src/trace/on_request.rs000064400000000000000000000046311046102023000163710ustar 00000000000000use super::DEFAULT_MESSAGE_LEVEL; use http::Request; use tracing::Level; use tracing::Span; /// Trait used to tell [`Trace`] what to do when a request is received. /// /// See the [module docs](../trace/index.html#on_request) for details on exactly when the /// `on_request` callback is called. /// /// [`Trace`]: super::Trace pub trait OnRequest { /// Do the thing. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] /// that weren't known when the span was created. /// /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with fn on_request(&mut self, request: &Request, span: &Span); } impl OnRequest for () { #[inline] fn on_request(&mut self, _: &Request, _: &Span) {} } impl OnRequest for F where F: FnMut(&Request, &Span), { fn on_request(&mut self, request: &Request, span: &Span) { self(request, span) } } /// The default [`OnRequest`] implementation used by [`Trace`]. /// /// [`Trace`]: super::Trace #[derive(Clone, Debug)] pub struct DefaultOnRequest { level: Level, } impl Default for DefaultOnRequest { fn default() -> Self { Self { level: DEFAULT_MESSAGE_LEVEL, } } } impl DefaultOnRequest { /// Create a new `DefaultOnRequest`. pub fn new() -> Self { Self::default() } /// Set the [`Level`] used for [tracing events]. /// /// Please note that while this will set the level for the tracing events /// themselves, it might cause them to lack expected information, like /// request method or path. You can address this using /// [`DefaultMakeSpan::level`]. /// /// Defaults to [`Level::DEBUG`]. /// /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events /// [`DefaultMakeSpan::level`]: crate::trace::DefaultMakeSpan::level pub fn level(mut self, level: Level) -> Self { self.level = level; self } } impl OnRequest for DefaultOnRequest { fn on_request(&mut self, _: &Request, _: &Span) { event_dynamic_lvl!(self.level, "started processing request"); } } tower-http-0.4.4/src/trace/on_response.rs000064400000000000000000000122771046102023000165440ustar 00000000000000use super::{Latency, DEFAULT_MESSAGE_LEVEL}; use crate::LatencyUnit; use http::Response; use std::time::Duration; use tracing::Level; use tracing::Span; /// Trait used to tell [`Trace`] what to do when a response has been produced. /// /// See the [module docs](../trace/index.html#on_response) for details on exactly when the /// `on_response` callback is called. /// /// [`Trace`]: super::Trace pub trait OnResponse { /// Do the thing. /// /// `latency` is the duration since the request was received. /// /// `span` is the `tracing` [`Span`], corresponding to this request, produced by the closure /// passed to [`TraceLayer::make_span_with`]. It can be used to [record field values][record] /// that weren't known when the span was created. /// /// [`Span`]: https://docs.rs/tracing/latest/tracing/span/index.html /// [record]: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#method.record /// [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with fn on_response(self, response: &Response, latency: Duration, span: &Span); } impl OnResponse for () { #[inline] fn on_response(self, _: &Response, _: Duration, _: &Span) {} } impl OnResponse for F where F: FnOnce(&Response, Duration, &Span), { fn on_response(self, response: &Response, latency: Duration, span: &Span) { self(response, latency, span) } } /// The default [`OnResponse`] implementation used by [`Trace`]. /// /// [`Trace`]: super::Trace #[derive(Clone, Debug)] pub struct DefaultOnResponse { level: Level, latency_unit: LatencyUnit, include_headers: bool, } impl Default for DefaultOnResponse { fn default() -> Self { Self { level: DEFAULT_MESSAGE_LEVEL, latency_unit: LatencyUnit::Millis, include_headers: false, } } } impl DefaultOnResponse { /// Create a new `DefaultOnResponse`. pub fn new() -> Self { Self::default() } /// Set the [`Level`] used for [tracing events]. /// /// Please note that while this will set the level for the tracing events /// themselves, it might cause them to lack expected information, like /// request method or path. You can address this using /// [`DefaultMakeSpan::level`]. /// /// Defaults to [`Level::DEBUG`]. /// /// [tracing events]: https://docs.rs/tracing/latest/tracing/#events /// [`DefaultMakeSpan::level`]: crate::trace::DefaultMakeSpan::level pub fn level(mut self, level: Level) -> Self { self.level = level; self } /// Set the [`LatencyUnit`] latencies will be reported in. /// /// Defaults to [`LatencyUnit::Millis`]. pub fn latency_unit(mut self, latency_unit: LatencyUnit) -> Self { self.latency_unit = latency_unit; self } /// Include response headers on the [`Event`]. /// /// By default headers are not included. /// /// [`Event`]: tracing::Event pub fn include_headers(mut self, include_headers: bool) -> Self { self.include_headers = include_headers; self } } impl OnResponse for DefaultOnResponse { fn on_response(self, response: &Response, latency: Duration, _: &Span) { let latency = Latency { unit: self.latency_unit, duration: latency, }; let response_headers = self .include_headers .then(|| tracing::field::debug(response.headers())); event_dynamic_lvl!( self.level, %latency, status = status(response), response_headers, "finished processing request" ); } } fn status(res: &Response) -> Option { use crate::classify::grpc_errors_as_failures::ParsedGrpcStatus; // gRPC-over-HTTP2 uses the "application/grpc[+format]" content type, and gRPC-Web uses // "application/grpc-web[+format]" or "application/grpc-web-text[+format]", where "format" is // the message format, e.g. +proto, +json. // // So, valid grpc content types include (but are not limited to): // - application/grpc // - application/grpc+proto // - application/grpc-web+proto // - application/grpc-web-text+proto // // For simplicity, we simply check that the content type starts with "application/grpc". let is_grpc = res .headers() .get(http::header::CONTENT_TYPE) .map_or(false, |value| { value.as_bytes().starts_with("application/grpc".as_bytes()) }); if is_grpc { match crate::classify::grpc_errors_as_failures::classify_grpc_metadata( res.headers(), crate::classify::GrpcCode::Ok.into_bitmask(), ) { ParsedGrpcStatus::Success | ParsedGrpcStatus::HeaderNotString | ParsedGrpcStatus::HeaderNotInt => Some(0), ParsedGrpcStatus::NonSuccess(status) => Some(status.get()), // if `grpc-status` is missing then its a streaming response and there is no status // _yet_, so its neither success nor error ParsedGrpcStatus::GrpcStatusHeaderMissing => None, } } else { Some(res.status().as_u16().into()) } } tower-http-0.4.4/src/trace/service.rs000064400000000000000000000242431046102023000156460ustar 00000000000000use super::{ DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, MakeSpan, OnBodyChunk, OnEos, OnFailure, OnRequest, OnResponse, ResponseBody, ResponseFuture, TraceLayer, }; use crate::classify::{ GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier, }; use http::{Request, Response}; use http_body::Body; use std::{ fmt, task::{Context, Poll}, time::Instant, }; use tower_service::Service; /// Middleware that adds high level [tracing] to a [`Service`]. /// /// See the [module docs](crate::trace) for an example. /// /// [tracing]: https://crates.io/crates/tracing /// [`Service`]: tower_service::Service #[derive(Debug, Clone, Copy)] pub struct Trace< S, M, MakeSpan = DefaultMakeSpan, OnRequest = DefaultOnRequest, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure, > { pub(crate) inner: S, pub(crate) make_classifier: M, pub(crate) make_span: MakeSpan, pub(crate) on_request: OnRequest, pub(crate) on_response: OnResponse, pub(crate) on_body_chunk: OnBodyChunk, pub(crate) on_eos: OnEos, pub(crate) on_failure: OnFailure, } impl Trace { /// Create a new [`Trace`] using the given [`MakeClassifier`]. pub fn new(inner: S, make_classifier: M) -> Self where M: MakeClassifier, { Self { inner, make_classifier, make_span: DefaultMakeSpan::new(), on_request: DefaultOnRequest::default(), on_response: DefaultOnResponse::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_eos: DefaultOnEos::default(), on_failure: DefaultOnFailure::default(), } } /// Returns a new [`Layer`] that wraps services with a [`TraceLayer`] middleware. /// /// [`Layer`]: tower_layer::Layer pub fn layer(make_classifier: M) -> TraceLayer where M: MakeClassifier, { TraceLayer::new(make_classifier) } } impl Trace { define_inner_service_accessors!(); /// Customize what to do when a request is received. /// /// `NewOnRequest` is expected to implement [`OnRequest`]. /// /// [`OnRequest`]: super::OnRequest pub fn on_request( self, new_on_request: NewOnRequest, ) -> Trace { Trace { on_request: new_on_request, inner: self.inner, on_failure: self.on_failure, on_eos: self.on_eos, on_body_chunk: self.on_body_chunk, make_span: self.make_span, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a response has been produced. /// /// `NewOnResponse` is expected to implement [`OnResponse`]. /// /// [`OnResponse`]: super::OnResponse pub fn on_response( self, new_on_response: NewOnResponse, ) -> Trace { Trace { on_response: new_on_response, inner: self.inner, on_request: self.on_request, on_failure: self.on_failure, on_body_chunk: self.on_body_chunk, on_eos: self.on_eos, make_span: self.make_span, make_classifier: self.make_classifier, } } /// Customize what to do when a body chunk has been sent. /// /// `NewOnBodyChunk` is expected to implement [`OnBodyChunk`]. /// /// [`OnBodyChunk`]: super::OnBodyChunk pub fn on_body_chunk( self, new_on_body_chunk: NewOnBodyChunk, ) -> Trace { Trace { on_body_chunk: new_on_body_chunk, on_eos: self.on_eos, make_span: self.make_span, inner: self.inner, on_failure: self.on_failure, on_request: self.on_request, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a streaming response has closed. /// /// `NewOnEos` is expected to implement [`OnEos`]. /// /// [`OnEos`]: super::OnEos pub fn on_eos( self, new_on_eos: NewOnEos, ) -> Trace { Trace { on_eos: new_on_eos, make_span: self.make_span, inner: self.inner, on_failure: self.on_failure, on_request: self.on_request, on_body_chunk: self.on_body_chunk, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize what to do when a response has been classified as a failure. /// /// `NewOnFailure` is expected to implement [`OnFailure`]. /// /// [`OnFailure`]: super::OnFailure pub fn on_failure( self, new_on_failure: NewOnFailure, ) -> Trace { Trace { on_failure: new_on_failure, inner: self.inner, make_span: self.make_span, on_body_chunk: self.on_body_chunk, on_request: self.on_request, on_eos: self.on_eos, on_response: self.on_response, make_classifier: self.make_classifier, } } /// Customize how to make [`Span`]s that all request handling will be wrapped in. /// /// `NewMakeSpan` is expected to implement [`MakeSpan`]. /// /// [`MakeSpan`]: super::MakeSpan /// [`Span`]: tracing::Span pub fn make_span_with( self, new_make_span: NewMakeSpan, ) -> Trace { Trace { make_span: new_make_span, inner: self.inner, on_failure: self.on_failure, on_request: self.on_request, on_body_chunk: self.on_body_chunk, on_response: self.on_response, on_eos: self.on_eos, make_classifier: self.make_classifier, } } } impl Trace< S, SharedClassifier, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, > { /// Create a new [`Trace`] using [`ServerErrorsAsFailures`] which supports classifying /// regular HTTP responses based on the status code. pub fn new_for_http(inner: S) -> Self { Self { inner, make_classifier: SharedClassifier::new(ServerErrorsAsFailures::default()), make_span: DefaultMakeSpan::new(), on_request: DefaultOnRequest::default(), on_response: DefaultOnResponse::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_eos: DefaultOnEos::default(), on_failure: DefaultOnFailure::default(), } } } impl Trace< S, SharedClassifier, DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, > { /// Create a new [`Trace`] using [`GrpcErrorsAsFailures`] which supports classifying /// gRPC responses and streams based on the `grpc-status` header. pub fn new_for_grpc(inner: S) -> Self { Self { inner, make_classifier: SharedClassifier::new(GrpcErrorsAsFailures::default()), make_span: DefaultMakeSpan::new(), on_request: DefaultOnRequest::default(), on_response: DefaultOnResponse::default(), on_body_chunk: DefaultOnBodyChunk::default(), on_eos: DefaultOnEos::default(), on_failure: DefaultOnFailure::default(), } } } impl< S, ReqBody, ResBody, M, OnRequestT, OnResponseT, OnFailureT, OnBodyChunkT, OnEosT, MakeSpanT, > Service> for Trace where S: Service, Response = Response>, ReqBody: Body, ResBody: Body, ResBody::Error: fmt::Display + 'static, S::Error: fmt::Display + 'static, M: MakeClassifier, M::Classifier: Clone, MakeSpanT: MakeSpan, OnRequestT: OnRequest, OnResponseT: OnResponse + Clone, OnBodyChunkT: OnBodyChunk + Clone, OnEosT: OnEos + Clone, OnFailureT: OnFailure + Clone, { type Response = Response>; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let start = Instant::now(); let span = self.make_span.make_span(&req); let classifier = self.make_classifier.make_classifier(&req); let future = { let _guard = span.enter(); self.on_request.on_request(&req, &span); self.inner.call(req) }; ResponseFuture { inner: future, span, classifier: Some(classifier), on_response: Some(self.on_response.clone()), on_body_chunk: Some(self.on_body_chunk.clone()), on_eos: Some(self.on_eos.clone()), on_failure: Some(self.on_failure.clone()), start, } } } tower-http-0.4.4/src/validate_request.rs000064400000000000000000000417221046102023000164520ustar 00000000000000//! Middleware that validates requests. //! //! # Example //! //! ``` //! use tower_http::validate_request::ValidateRequestHeaderLayer; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::ACCEPT}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! //! async fn handle(request: Request) -> Result, Error> { //! Ok(Response::new(Body::empty())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut service = ServiceBuilder::new() //! // Require the `Accept` header to be `application/json`, `*/*` or `application/*` //! .layer(ValidateRequestHeaderLayer::accept("application/json")) //! .service_fn(handle); //! //! // Requests with the correct value are allowed through //! let request = Request::builder() //! .header(ACCEPT, "application/json") //! .body(Body::empty()) //! .unwrap(); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!(StatusCode::OK, response.status()); //! //! // Requests with an invalid value get a `406 Not Acceptable` response //! let request = Request::builder() //! .header(ACCEPT, "text/strings") //! .body(Body::empty()) //! .unwrap(); //! //! let response = service //! .ready() //! .await? //! .call(request) //! .await?; //! //! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status()); //! # Ok(()) //! # } //! ``` //! //! Custom validation can be made by implementing [`ValidateRequest`]: //! //! ``` //! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::ACCEPT}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! //! #[derive(Clone, Copy)] //! pub struct MyHeader { /* ... */ } //! //! impl ValidateRequest for MyHeader { //! type ResponseBody = Body; //! //! fn validate( //! &mut self, //! request: &mut Request, //! ) -> Result<(), Response> { //! // validate the request... //! # unimplemented!() //! } //! } //! //! async fn handle(request: Request) -> Result, Error> { //! Ok(Response::new(Body::empty())) //! } //! //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() //! // Validate requests using `MyHeader` //! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ })) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` //! //! Or using a closure: //! //! ``` //! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; //! use hyper::{Request, Response, Body, Error}; //! use http::{StatusCode, header::ACCEPT}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! //! async fn handle(request: Request) -> Result, Error> { //! # todo!(); //! // ... //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let service = ServiceBuilder::new() //! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| { //! // Validate the request //! # Ok::<_, Response>(()) //! })) //! .service_fn(handle); //! # Ok(()) //! # } //! ``` use http::{header, Request, Response, StatusCode}; use http_body::Body; use mime::{Mime, MimeIter}; use pin_project_lite::pin_project; use std::{ fmt, future::Future, marker::PhantomData, pin::Pin, sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Layer that applies [`ValidateRequestHeader`] which validates all requests. /// /// See the [module docs](crate::validate_request) for an example. #[derive(Debug, Clone)] pub struct ValidateRequestHeaderLayer { validate: T, } impl ValidateRequestHeaderLayer> { /// Validate requests have the required Accept header. /// /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, /// as configured. /// /// # Panics /// /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` /// See `AcceptHeader::new` for when this method panics. /// /// # Example /// /// ``` /// use hyper::Body; /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer}; /// /// let layer = ValidateRequestHeaderLayer::>::accept("application/json"); /// ``` /// /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept pub fn accept(value: &str) -> Self where ResBody: Body + Default, { Self::custom(AcceptHeader::new(value)) } } impl ValidateRequestHeaderLayer { /// Validate requests using a custom method. pub fn custom(validate: T) -> ValidateRequestHeaderLayer { Self { validate } } } impl Layer for ValidateRequestHeaderLayer where T: Clone, { type Service = ValidateRequestHeader; fn layer(&self, inner: S) -> Self::Service { ValidateRequestHeader::new(inner, self.validate.clone()) } } /// Middleware that validates requests. /// /// See the [module docs](crate::validate_request) for an example. #[derive(Clone, Debug)] pub struct ValidateRequestHeader { inner: S, validate: T, } impl ValidateRequestHeader { fn new(inner: S, validate: T) -> Self { Self::custom(inner, validate) } define_inner_service_accessors!(); } impl ValidateRequestHeader> { /// Validate requests have the required Accept header. /// /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`, /// as configured. /// /// # Panics /// /// See `AcceptHeader::new` for when this method panics. pub fn accept(inner: S, value: &str) -> Self where ResBody: Body + Default, { Self::custom(inner, AcceptHeader::new(value)) } } impl ValidateRequestHeader { /// Validate requests using a custom method. pub fn custom(inner: S, validate: T) -> ValidateRequestHeader { Self { inner, validate } } } impl Service> for ValidateRequestHeader where V: ValidateRequest, S: Service, Response = Response>, { type Response = Response; type Error = S::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { match self.validate.validate(&mut req) { Ok(_) => ResponseFuture::future(self.inner.call(req)), Err(res) => ResponseFuture::invalid_header_value(res), } } } pin_project! { /// Response future for [`ValidateRequestHeader`]. pub struct ResponseFuture { #[pin] kind: Kind, } } impl ResponseFuture { fn future(future: F) -> Self { Self { kind: Kind::Future { future }, } } fn invalid_header_value(res: Response) -> Self { Self { kind: Kind::Error { response: Some(res), }, } } } pin_project! { #[project = KindProj] enum Kind { Future { #[pin] future: F, }, Error { response: Option>, }, } } impl Future for ResponseFuture where F: Future, E>>, { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { KindProj::Future { future } => future.poll(cx), KindProj::Error { response } => { let response = response.take().expect("future polled after completion"); Poll::Ready(Ok(response)) } } } } /// Trait for validating requests. pub trait ValidateRequest { /// The body type used for responses to unvalidated requests. type ResponseBody; /// Validate the request. /// /// If `Ok(())` is returned then the request is allowed through, otherwise not. fn validate(&mut self, request: &mut Request) -> Result<(), Response>; } impl ValidateRequest for F where F: FnMut(&mut Request) -> Result<(), Response>, { type ResponseBody = ResBody; fn validate(&mut self, request: &mut Request) -> Result<(), Response> { self(request) } } /// Type that performs validation of the Accept header. pub struct AcceptHeader { header_value: Arc, _ty: PhantomData ResBody>, } impl AcceptHeader { /// Create a new `AcceptHeader`. /// /// # Panics /// /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json` fn new(header_value: &str) -> Self where ResBody: Body + Default, { Self { header_value: Arc::new( header_value .parse::() .expect("value is not a valid header value"), ), _ty: PhantomData, } } } impl Clone for AcceptHeader { fn clone(&self) -> Self { Self { header_value: self.header_value.clone(), _ty: PhantomData, } } } impl fmt::Debug for AcceptHeader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AcceptHeader") .field("header_value", &self.header_value) .finish() } } impl ValidateRequest for AcceptHeader where ResBody: Body + Default, { type ResponseBody = ResBody; fn validate(&mut self, req: &mut Request) -> Result<(), Response> { if !req.headers().contains_key(header::ACCEPT) { return Ok(()); } if req .headers() .get_all(header::ACCEPT) .into_iter() .filter_map(|header| header.to_str().ok()) .any(|h| { MimeIter::new(&h) .map(|mim| { if let Ok(mim) = mim { let typ = self.header_value.type_(); let subtype = self.header_value.subtype(); match (mim.type_(), mim.subtype()) { (t, s) if t == typ && s == subtype => true, (t, mime::STAR) if t == typ => true, (mime::STAR, mime::STAR) => true, _ => false, } } else { false } }) .reduce(|acc, mim| acc || mim) .unwrap_or(false) }) { return Ok(()); } let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::NOT_ACCEPTABLE; Err(res) } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use http::{header, StatusCode}; use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; #[tokio::test] async fn valid_accept_header() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "application/json") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn valid_accept_header_accept_all_json() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "application/*") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn valid_accept_header_accept_all() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "*/*") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn invalid_accept_header() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "invalid") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } #[tokio::test] async fn not_accepted_accept_header_subtype() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "application/strings") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } #[tokio::test] async fn not_accepted_accept_header() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "text/strings") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } #[tokio::test] async fn accepted_multiple_header_value() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "text/strings") .header(header::ACCEPT, "invalid, application/json") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn accepted_inner_header_value() { let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/json")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, "text/strings, invalid, application/json") .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn accepted_header_with_quotes_valid() { let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*"; let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("application/xml")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, value) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn accepted_header_with_quotes_invalid() { let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\""; let mut service = ServiceBuilder::new() .layer(ValidateRequestHeaderLayer::accept("text/html")) .service_fn(echo); let request = Request::get("/") .header(header::ACCEPT, value) .body(Body::empty()) .unwrap(); let res = service.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); } async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } }