pax_global_header00006660000000000000000000000064144627653360014532gustar00rootroot0000000000000052 comment=3b92cd7593a900d3c79c2aeb411f90be052a9a5c axum-axum-v0.6.20/000077500000000000000000000000001446276533600137275ustar00rootroot00000000000000axum-axum-v0.6.20/.github/000077500000000000000000000000001446276533600152675ustar00rootroot00000000000000axum-axum-v0.6.20/.github/ISSUE_TEMPLATE/000077500000000000000000000000001446276533600174525ustar00rootroot00000000000000axum-axum-v0.6.20/.github/ISSUE_TEMPLATE/bug_report.md000066400000000000000000000016241446276533600221470ustar00rootroot00000000000000--- name: ๐Ÿ› Bug Report about: If something isn't working as expected ๐Ÿค”. --- - [ ] I have looked for existing issues (including closed) about this ## Bug Report ### Version ### Platform ### Crates ### Description axum-axum-v0.6.20/.github/ISSUE_TEMPLATE/config.yml000066400000000000000000000004411446276533600214410ustar00rootroot00000000000000contact_links: - name: ๐Ÿ™ Q&A (GitHub Discussions) url: https://github.com/tokio-rs/axum/discussions/categories/q-a about: Q&A all around axum usage - name: ๐Ÿ’ฌ Tokio Discord url: https://discord.gg/tokio about: Community chat for Tokio (axum channel is under libs) axum-axum-v0.6.20/.github/ISSUE_TEMPLATE/feature_request.md000066400000000000000000000011141446276533600231740ustar00rootroot00000000000000--- name: ๐Ÿ’ก Feature Request about: I have a suggestion (and may want to implement it ๐Ÿ™‚)! --- - [ ] I have looked for existing issues (including closed) about this ## Feature Request ### Motivation ### Proposal ### Alternatives axum-axum-v0.6.20/.github/PULL_REQUEST_TEMPLATE.md000066400000000000000000000010751446276533600210730ustar00rootroot00000000000000 ## Motivation ## Solution axum-axum-v0.6.20/.github/workflows/000077500000000000000000000000001446276533600173245ustar00rootroot00000000000000axum-axum-v0.6.20/.github/workflows/CI.yml000066400000000000000000000137371446276533600203550ustar00rootroot00000000000000name: CI env: CARGO_TERM_COLOR: always MSRV: '1.63' on: push: branches: - main pull_request: {} jobs: check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@beta with: components: clippy, rustfmt - uses: Swatinem/rust-cache@v2 - name: Check run: cargo clippy --workspace --all-targets --all-features -- -D warnings - name: rustfmt run: cargo fmt --all --check check-docs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - name: cargo doc env: RUSTDOCFLAGS: "-D rustdoc::all -A rustdoc::private-doc-tests" run: cargo doc --all-features --no-deps cargo-hack: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - name: Install cargo-hack run: | curl -LsSf https://github.com/taiki-e/cargo-hack/releases/latest/download/cargo-hack-x86_64-unknown-linux-gnu.tar.gz | tar xzf - -C ~/.cargo/bin - name: cargo hack check run: cargo hack check --each-feature --no-dev-deps --all cargo-public-api-crates: runs-on: ubuntu-latest strategy: matrix: crate: [axum, axum-core, axum-extra, axum-macros] steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@nightly - uses: Swatinem/rust-cache@v2 - name: Install cargo-public-api-crates run: | cargo install --git https://github.com/davidpdrsn/cargo-public-api-crates - name: cargo public-api-crates check run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml check test-versions: needs: check runs-on: ubuntu-latest strategy: matrix: rust: [stable, beta] steps: - uses: actions/checkout@v3 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} - uses: Swatinem/rust-cache@v2 - name: Run tests run: cargo test --workspace --all-features --all-targets # some examples doesn't support our MSRV so we only test axum itself on our MSRV test-nightly: needs: check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Get rust-toolchain version id: rust-toolchain run: echo "version=$(cat axum-macros/rust-toolchain)" >> $GITHUB_OUTPUT - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ steps.rust-toolchain.outputs.version }} - uses: Swatinem/rust-cache@v2 - name: Run nightly tests working-directory: axum-macros run: cargo test # some examples doesn't support our MSRV (such as async-graphql) # so we only test axum itself on our MSRV test-msrv: needs: check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ env.MSRV }} - name: "install Rust nightly" uses: dtolnay/rust-toolchain@nightly - uses: Swatinem/rust-cache@v2 - name: Select minimal version run: cargo +nightly update -Z minimal-versions - name: Fix up Cargo.lock run: cargo +nightly update -p crc32fast --precise 1.1.1 - name: Run tests run: > cargo +${{ env.MSRV }} test -p axum -p axum-extra -p axum-core --all-features --all-targets --locked # the compiler errors are different on our MSRV which makes # the trybuild tests in axum-macros fail, so just run the doc # tests - name: Run axum-macros doc tests run: > cargo +${{ env.MSRV }} test -p axum-macros --doc --all-features --locked test-docs: needs: check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - name: Run doc tests run: cargo test --all-features --doc deny-check: name: cargo-deny check runs-on: ubuntu-latest continue-on-error: ${{ matrix.checks == 'advisories' }} strategy: matrix: checks: - advisories - bans licenses sources steps: - uses: actions/checkout@v3 - uses: EmbarkStudios/cargo-deny-action@v1 with: command: check ${{ matrix.checks }} arguments: --all-features --manifest-path axum/Cargo.toml armv5te-unknown-linux-musleabi: needs: check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable with: target: armv5te-unknown-linux-musleabi - uses: Swatinem/rust-cache@v2 - name: Check env: # Clang has native cross-compilation support CC: clang run: > cargo check --all-targets --all-features -p axum -p axum-core -p axum-extra -p axum-macros --target armv5te-unknown-linux-musleabi wasm32-unknown-unknown: needs: check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable with: target: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 - name: Check run: > cargo check --manifest-path ./examples/simple-router-wasm/Cargo.toml --target wasm32-unknown-unknown dependencies-are-sorted: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@beta - uses: Swatinem/rust-cache@v2 - name: Install cargo-sort run: | cargo install cargo-sort # Work around cargo-sort not honoring workspace.exclude - name: Remove non-crate folder run: rm -rf examples/async-graphql - name: Check dependency tables run: | cargo sort --workspace --grouped --check axum-axum-v0.6.20/.github/workflows/patch.toml000066400000000000000000000002041446276533600213140ustar00rootroot00000000000000# Patch dependencies to run all tests against versions of the crate in the # repository. [patch.crates-io] axum = { path = "axum" } axum-axum-v0.6.20/.gitignore000066400000000000000000000000341446276533600157140ustar00rootroot00000000000000target Cargo.lock .DS_Store axum-axum-v0.6.20/CHANGELOG.md000066400000000000000000000001571446276533600155430ustar00rootroot00000000000000axum's changelog has moved and now lives [here](https://github.com/tokio-rs/axum/blob/main/axum/CHANGELOG.md). axum-axum-v0.6.20/CONTRIBUTING.md000066400000000000000000000303651446276533600161670ustar00rootroot00000000000000# Contributing to axum ๐ŸŽˆ Thanks for your help improving the project! We are so happy to have you! There are opportunities to contribute to `axum` at any level. It doesn't matter if you are just getting started with Rust or are the most weathered expert, we can use your help. **No contribution is too small and all contributions are valued.** This guide will help you get started. **Do not let this guide intimidate you**. It should be considered a map to help you navigate the process. Don't know where to start? Check [issues labeled with "E-help-wanted"](https://github.com/tokio-rs/axum/issues?q=is%3Aopen+is%3Aissue+label%3AE-help-wanted) or ["E-easy"](https://github.com/tokio-rs/axum/issues?q=is%3Aopen+is%3Aissue+label%3AE-easy). You may also get help with contributing in the [`axum` Discord channel][discord], please join us! [discord]: https://discord.gg/tokio ## Conduct The `axum` project adheres to the [Rust Code of Conduct][coc]. This describes the _minimum_ behavior expected from all contributors. [coc]: https://github.com/rust-lang/rust/blob/master/CODE_OF_CONDUCT.md ## Contributing in Issues For any issue, there are fundamentally three ways an individual can contribute: 1. By opening the issue for discussion: For instance, if you believe that you have uncovered a bug in a `axum` crate, creating a new issue in the tokio-rs/axum [issue tracker][issues] is the way to report it. 2. By helping to triage the issue: This can be done by providing supporting details (a test case that demonstrates a bug), providing suggestions on how to address the issue, or ensuring that the issue is tagged correctly. 3. By helping to resolve the issue: Typically this is done either in the form of demonstrating that the issue reported is not a problem after all, or more often, by opening a Pull Request that changes some bit of something in axum in a concrete and reviewable manner. **Anybody can participate in any stage of contribution**. We urge you to participate in the discussion around bugs and participate in reviewing PRs. [issues]: https://github.com/tokio-rs/axum/issues ### Asking for General Help If you have reviewed existing documentation and still have questions or are having problems, you can open an issue asking for help. In exchange for receiving help, we ask that you contribute back a documentation PR that helps others avoid the problems that you encountered. ### Submitting a Bug Report When opening a new issue in the `axum` issue tracker, users will be presented with a [basic template][template] that should be filled in. If you believe that you have uncovered a bug, please fill out this form, following the template to the best of your ability. Do not worry if you cannot answer every detail, just fill in what you can. The two most important pieces of information we need in order to properly evaluate the report is a description of the behavior you are seeing and a simple test case we can use to recreate the problem on our own. If we cannot recreate the issue, it becomes harder for us to fix. See [How to create a Minimal, Complete, and Verifiable example][mcve]. [mcve]: https://stackoverflow.com/help/mcve [template]: .github/ISSUE_TEMPLATE/bug_report.md ### Triaging a Bug Report Once an issue has been opened, it is not uncommon for there to be discussion around it. Some contributors may have differing opinions about the issue, including whether the behavior being seen is a bug or a feature. This discussion is part of the process and should be kept focused, helpful, and professional. Short, clipped responsesโ€”that provide neither additional context nor supporting detailโ€”are not helpful or professional. To many, such responses are simply annoying and unfriendly. Contributors are encouraged to help one another make forward progress as much as possible, empowering one another to solve issues collaboratively. If you choose to comment on an issue that you feel either is not a problem that needs to be fixed, or if you encounter information in an issue that you feel is incorrect, explain why you feel that way with additional supporting context, and be willing to be convinced that you may be wrong. By doing so, we can often reach the correct outcome much faster. ### Resolving a Bug Report In the majority of cases, issues are resolved by opening a Pull Request. The process for opening and reviewing a Pull Request is similar to that of opening and triaging issues, but carries with it a necessary review and approval workflow that ensures that the proposed changes meet the minimal quality. ## Pull Requests Pull Requests are the way concrete changes are made to the code, documentation, and dependencies in the `axum` repository. Even tiny pull requests (e.g., one character pull request fixing a typo in API documentation) are greatly appreciated. Before making a large change, it is usually a good idea to first open an issue describing the change to solicit feedback and guidance. This will increase the likelihood of the PR getting merged. ### Tests If the change being proposed alters code (as opposed to only documentation for example), it is either adding new functionality to a crate or it is fixing existing, broken functionality. In both of these cases, the pull request should include one or more tests to ensure that the crate does not regress in the future. #### Documentation tests Ideally, every API has at least one [documentation test] that demonstrates how to use the API. Documentation tests are run with `cargo test --doc`. This ensures that the example is correct and provides additional test coverage. The trick to documentation tests is striking a balance between being succinct for a reader to understand and actually testing the API. In Rust documentation, lines that start with `/// #` are removed when the documentation is generated. They are only there to get the test to run. ### Commits It is a recommended best practice to keep your changes as logically grouped as possible within individual commits. There is no limit to the number of commits any single Pull Request may have, and many contributors find it easier to review changes that are split across multiple commits. Note that multiple commits often get squashed when they are landed (see the notes about [commit squashing]). #### Commit message guidelines A good commit message should describe what changed and why. 1. The first line should: * Contain a short description of the change (preferably 50 characters or less, and no more than 72 characters) 2. Keep the second line blank. 3. Wrap all other lines at 72 columns (except for long URLs). 4. If your patch fixes an open issue, you can add a reference to it at the end of the log. Use the `Fixes: #` prefix and the issue number. For other references use `Refs: #`. `Refs` may include multiple issues, separated by a comma. Examples: - `Fixes: #1337` - `Refs: #1234, #42` ### Opening the Pull Request From within GitHub, opening a new Pull Request will present you with a [template] that should be filled out. Please try to do your best at filling out the details, but feel free to skip parts if you're not sure what to put. [template]: .github/PULL_REQUEST_TEMPLATE.md ### Discuss and update You will probably get feedback or requests for changes to your Pull Request. This is a big part of the submission process so don't be discouraged! Some contributors may sign off on the Pull Request right away, others may have more detailed comments or feedback. This is a necessary part of the process in order to evaluate whether the changes are correct and necessary. **Any community member can review a PR and you might get conflicting feedback**. Keep an eye out for comments from code owners to provide guidance on conflicting feedback. **Once the PR is open, do not rebase the commits**. See [Commit Squashing] for more details. ### Commit Squashing In most cases, **do not squash commits that you add to your Pull Request during the review process**. When the commits in your Pull Request land, they may be squashed into one commit per logical change. Metadata will be added to the commit message (including links to the Pull Request, links to relevant issues, and the names of the reviewers). The commit history of your Pull Request, however, will stay intact on the Pull Request page. ## Reviewing Pull Requests **Any Tokio, Hyperium, and Tower, Axum community member is welcome to review any pull request**. All contributors who choose to review and provide feedback on Pull Requests have a responsibility to both the project and the individual making the contribution. Reviews and feedback must be helpful, insightful, and geared towards improving the contribution as opposed to simply blocking it. If there are reasons why you feel the PR should not land, explain what those are. Do not expect to be able to block a Pull Request from advancing simply because you say "No" without giving an explanation. Be open to having your mind changed. Be open to working with the contributor to make the Pull Request better. Reviews that are dismissive or disrespectful of the contributor or any other reviewers are strictly counter to the Code of Conduct. When reviewing a Pull Request, the primary goals are for the codebase to improve and for the person submitting the request to succeed. **Even if a Pull Request does not land, the submitters should come away from the experience feeling like their effort was not wasted or unappreciated**. Every Pull Request from a new contributor is an opportunity to grow the community. ### Review a bit at a time. Do not overwhelm new contributors. It is tempting to micro-optimize and make everything about relative performance, perfect grammar, or exact style matches. Do not succumb to that temptation. Focus first on the most significant aspects of the change: 1. Does this change make sense for axum? 2. Does this change make axum better, even if only incrementally? 3. Are there clear bugs or larger scale issues that need attending to? 4. Is the commit message readable and correct? If it contains a breaking change is it clear enough? Note that only **incremental** improvement is needed to land a PR. This means that the PR does not need to be perfect, only better than the status quo. Follow up PRs may be opened to continue iterating. When changes are necessary, *request* them, do not *demand* them, and **do not assume that the submitter already knows how to add a test or run a benchmark**. Specific performance optimization techniques, coding styles and conventions change over time. The first impression you give to a new contributor never does. Nits (requests for small changes that are not essential) are fine, but try to avoid stalling the Pull Request. Most nits can typically be fixed by the Axum Collaborator landing the Pull Request but they can also be an opportunity for the contributor to learn a bit more about the project. It is always good to clearly indicate nits when you comment: e.g. `Nit: change foo() to bar(). But this is not blocking.` If your comments were addressed but were not folded automatically after new commits or if they proved to be mistaken, please, [hide them][hiding-a-comment] with the appropriate reason to keep the conversation flow concise and relevant. ### Be aware of the person behind the code Be aware that *how* you communicate requests and reviews in your feedback can have a significant impact on the success of the Pull Request. Yes, we may land a particular change that makes `axum` better, but the individual might just not want to have anything to do with `axum` ever again. The goal is not just having good code. ### Abandoned or Stalled Pull Requests If a Pull Request appears to be abandoned or stalled, it is polite to first check with the contributor to see if they intend to continue the work before checking if they would mind if you took it over (especially if it just has nits left). When doing so, it is courteous to give the original contributor credit for the work they started (either by preserving their name and email address in the commit log, or by using an `Author: ` meta-data tag in the commit. [hiding-a-comment]: https://help.github.com/articles/managing-disruptive-comments/#hiding-a-comment [documentation test]: https://doc.rust-lang.org/rustdoc/documentation-tests.html [keep-a-changelog]: https://github.com/olivierlacan/keep-a-changelog/blob/master/CHANGELOG.md axum-axum-v0.6.20/Cargo.toml000066400000000000000000000004151446276533600156570ustar00rootroot00000000000000[workspace] members = ["axum", "axum-*", "examples/*"] # Only check / build main crates by default (check all with `--workspace`) default-members = ["axum", "axum-*"] # Example has been deleted, but README.md remains exclude = ["examples/async-graphql"] resolver = "2" axum-axum-v0.6.20/ECOSYSTEM.md000066400000000000000000000223271446276533600156320ustar00rootroot00000000000000# Community Projects If your project isn't listed here and you would like it to be, please feel free to create a PR. ## Community maintained axum ecosystem - [axum-server](https://crates.io/crates/axum-server): axum-server is a hyper server implementation designed to be used with axum. - [axum-typed-websockets](https://crates.io/crates/axum-typed-websockets): `axum::extract::ws` with type safe messages. - [tower-cookies](https://crates.io/crates/tower-cookies): Cookie manager middleware - [axum-flash](https://crates.io/crates/axum-flash): One-time notifications (aka flash messages) for axum. - [axum-msgpack](https://crates.io/crates/axum-msgpack): MessagePack Extractors for axum. - [axum-sqlx-tx](https://crates.io/crates/axum-sqlx-tx): Request-bound [SQLx](https://github.com/launchbadge/sqlx#readme) transactions with automatic commit/rollback based on response. - [aliri_axum](https://docs.rs/aliri_axum) and [aliri_tower](https://docs.rs/aliri_tower): JWT validation middleware and OAuth2 scopes enforcing extractors. - [ezsockets](https://github.com/gbaranski/ezsockets): Easy to use WebSocket library that integrates with Axum. - [axum_session](https://github.com/AscendingCreations/AxumSessions): Database persistent sessions like pythons flask_sessionstore for Axum. - [axum_session_auth](https://github.com/AscendingCreations/AxumSessionsAuth): Persistant session based user login with rights management for Axum. - [axum-auth](https://crates.io/crates/axum-auth): High-level http auth extractors for axum. - [axum-keycloak-auth](https://github.com/lpotthast/axum-keycloak-auth): Protect axum routes with a JWT emitted by Keycloak. - [shuttle](https://github.com/getsynth/shuttle): A serverless platform built for Rust. Now with axum support. - [axum-tungstenite](https://github.com/davidpdrsn/axum-tungstenite): WebSocket connections for axum directly using tungstenite - [axum-jrpc](https://github.com/0xdeafbeef/axum-jrpc): Json-rpc extractor for axum - [axum-tracing-opentelemetry](https://crates.io/crates/axum-tracing-opentelemetry): Middlewares and tools to integrate axum + tracing + opentelemetry - [svelte-axum-project](https://github.com/jbertovic/svelte-axum-project): Template and example for Svelte frontend app with Axum as backend - [axum-streams](https://github.com/abdolence/axum-streams-rs): Streaming HTTP body with different formats: JSON, CSV, Protobuf. - [axum-template](https://github.com/Altair-Bueno/axum-template): Layers, extractors and template engine wrappers for axum based Web MVC applications - [axum-guard-logic](https://github.com/sjud/axum_guard_logic): Use AND/OR logic to extract types and check their values against `Service` inputs. - [axum-casbin-auth](https://github.com/casbin-rs/axum-casbin-auth): Casbin access control middleware for axum framework - [aide](https://docs.rs/aide): Code-first Open API documentation generator with [axum integration](https://docs.rs/aide/latest/aide/axum/index.html). - [axum-jsonschema](https://docs.rs/axum-jsonschema/): A `Json` extractor that does JSON schema validation of requests. - [axum-sessions](https://docs.rs/axum-sessions): Cookie-based sessions for axum via async-session. - [axum-login](https://docs.rs/axum-login): Session-based user authentication for axum. - [axum-csrf-sync-pattern](https://crates.io/crates/axum-csrf-sync-pattern): A middleware implementing CSRF STP for AJAX backends and API endpoints. - [axum-otel-metrics](https://github.com/ttys3/axum-otel-metrics/): A axum OpenTelemetry Metrics middleware with prometheus exporter supported. - [jwt-authorizer](https://crates.io/crates/jwt-authorizer): JWT authorization layer for axum (oidc discovery, validation options, claims extraction, etc.) - [axum-typed-multipart](https://crates.io/crates/axum_typed_multipart): Type safe wrapper for `axum::extract::Multipart`. - [tower-governor](https://crates.io/crates/tower_governor): A Tower service and layer that provides a rate-limiting backend by [governor](https://crates.io/crates/governor) - [axum-restful](https://github.com/gongzhengyang/axum-restful): A restful framework based on axum and sea-orm, inspired by django-rest-framework. - [springtime-web-axum](https://crates.io/crates/springtime-web-axum): A web framework built on Springtime and axum, leveraging dependency injection for easy app development. ## Project showcase - [HomeDisk](https://github.com/MedzikUser/HomeDisk): โ˜๏ธ Fast, lightweight and Open Source local cloud for your data. - [Houseflow](https://github.com/gbaranski/houseflow): House automation platform written in Rust. - [JWT Auth](https://github.com/Z4RX/axum_jwt_example): JWT auth service for educational purposes. - [ROAPI](https://github.com/roapi/roapi): Create full-fledged APIs for static datasets without writing a single line of code. - [notify.run](https://github.com/notify-run/notify-run-rs): HTTP-to-WebPush relay for sending desktop/mobile notifications to yourself, written in Rust. - [turbo.fish](https://turbo.fish/) ([repository](https://github.com/jplatte/turbo.fish)): Find out for yourself ๐Ÿ˜‰ - [Book Management](https://github.com/lz1998/axum-book-management): CRUD system of book-management with ORM and JWT for educational purposes. - [realworld-axum-sqlx](https://github.com/launchbadge/realworld-axum-sqlx): A Rust implementation of the [Realworld] demo app spec using Axum and [SQLx]. See https://github.com/davidpdrsn/realworld-axum-sqlx for a fork with up to date dependencies. - [Rustapi](https://github.com/ndelvalle/rustapi): RESTful API template using MongoDB - [Jotsy](https://github.com/ohsayan/jotsy): Self-hosted notes app powered by Skytable, Axum and Tokio - [Svix](https://www.svix.com) ([repository](https://github.com/svix/svix-webhooks)): Enterprise-ready webhook service - [emojied](https://emojied.net) ([repository](https://github.com/sekunho/emojied)): Shorten URLs to emojis! - [CLOMonitor](https://clomonitor.io) ([repository](https://github.com/cncf/clomonitor)): Checks open source projects repositories to verify they meet certain best practices. - [Pinging.net](https://www.pinging.net) ([repository](https://github.com/benhansenslc/pinging)): A new way to check and monitor your internet connection. - [wastebin](https://github.com/matze/wastebin): A minimalist pastebin service. - [sandbox_axum_observability](https://github.com/davidB/sandbox_axum_observability) A Sandbox/showcase project to experiment axum and observability (tracing, opentelemetry, jaeger, grafana tempo,...) - [axum_admin](https://github.com/lingdu1234/axum_admin): An admin panel built with **axum**, Sea-orm and Vue 3. - [rgit](https://git.inept.dev/~doyle/rgit.git/about): A blazingly fast Git repository browser, compatible with- and heavily inspired by cgit. - [Petclinic](https://github.com/danipardo/petclinic): A port of Spring Framework's Petclinic showcase project to Axum - [axum-middleware-example](https://github.com/casbin-rs/axum-middleware-example): A authorization application using Axum-web, Casbin and Diesel, with JWT support. - [circleci-hook](https://github.com/DavidS/circleci-hook): Translate CircleCI WebHooks to OpenTelemetry traces to improve your test insights. Add detail with otel-cli to capture individual commands. Use the TRACEPARENT integration to add details from your tests. - [lishuuro.org](https://github.com/uros-5/backend-lishuuro): Small chess variant server that uses Rust as backend(Axum framework). - [freedit](https://github.com/freedit-org/freedit): A forum powered by rust. - [axum-http-auth-example](https://github.com/i0n/axum-http-auth-example): Axum http auth example using postgres and redis. - [Deaftone](https://github.com/Deaftone/Deaftone): Lightweight music server. With a clean and simple API - [dropit](https://github.com/scotow/dropit): Temporary file hosting. - [cobrust](https://github.com/scotow/cobrust): Multiplayer web based snake game. - [meta-cross](https://github.com/scotow/meta-cross): Tweaked version of Tic-Tac-Toe. - [httq](https://github.com/scotow/httq) HTTP to MQTT trivial proxy. [Realworld]: https://github.com/gothinkster/realworld [SQLx]: https://github.com/launchbadge/sqlx ## Tutorials - [Rust on Nails](https://rust-on-nails.com/): A full stack architecture for Rust web applications (uses Axum) - [axum-tutorial] ([website][axum-tutorial-website]): Axum web framework tutorial for beginners. - [demo-rust-axum]: Demo of Rust and axum web framework - [Introduction to axum (talk)]: Talk about axum from the Copenhagen Rust Meetup. - [Getting Started with Axum]: Axum tutorial, GET, POST endpoints and serving files. - [Using Rust, Axum, PostgreSQL, and Tokio to build a Blog] - [Introduction to axum]: YouTube playlist - [Rust Axum Full Course]: YouTube video [axum-tutorial]: https://github.com/programatik29/axum-tutorial [axum-tutorial-website]: https://programatik29.github.io/axum-tutorial/ [demo-rust-axum]: https://github.com/joelparkerhenderson/demo-rust-axum [Introduction to axum (talk)]: https://www.youtube.com/watch?v=ETdmhh7OQpA [Getting Started with Axum]: https://carlosmv.hashnode.dev/getting-started-with-axum-rust [Using Rust, Axum, PostgreSQL, and Tokio to build a Blog]: https://spacedimp.com/blog/using-rust-axum-postgresql-and-tokio-to-build-a-blog/ [Introduction to axum]: https://www.youtube.com/playlist?list=PLrmY5pVcnuE-_CP7XZ_44HN-mDrLQV4nS [Rust Axum Full Course]: https://www.youtube.com/watch?v=XZtlD_m59sM axum-axum-v0.6.20/README.md000077700000000000000000000000001446276533600174322axum/README.mdustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/000077500000000000000000000000001446276533600156275ustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/CHANGELOG.md000066400000000000000000000172761446276533600174550ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). # Unreleased - None. # 0.3.4 (11. April, 2023) - Changes to private APIs. # 0.3.3 (03. March, 2023) - **fixed:** Add `#[must_use]` attributes to types that do nothing unless used ([#1809]) [#1809]: https://github.com/tokio-rs/axum/pull/1809 # 0.3.2 (20. January, 2023) - **added:** Implement `IntoResponse` for `&'static [u8; N]` and `[u8; N]` ([#1690]) [#1690]: https://github.com/tokio-rs/axum/pull/1690 # 0.3.1 (9. January, 2023) - **added:** Add `body_text` and `status` methods to built-in rejections ([#1612]) [#1612]: https://github.com/tokio-rs/axum/pull/1612 # 0.3.0 (25. November, 2022) - **added:** Added new `FromRequestParts` trait. See axum's changelog for more details ([#1272]) - **breaking:** `FromRequest` has been reworked and `RequestParts` has been removed. See axum's changelog for more details ([#1272]) - **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) - **breaking:** `AppendHeaders` now works on any `impl IntoIterator` ([#1495]) [#1272]: https://github.com/tokio-rs/axum/pull/1272 [#1495]: https://github.com/tokio-rs/axum/pull/1495
0.3.0 Pre-Releases # 0.3.0-rc.3 (8. November, 2022) - **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) - **added:** Add `Error::into_inner` for converting `Error` to `BoxError` without allocating ([#1476]) - **breaking:** `AppendHeaders` now works on any `impl IntoIterator` ([#1495]) [#1397]: https://github.com/tokio-rs/axum/pull/1397 [#1476]: https://github.com/tokio-rs/axum/pull/1476 [#1495]: https://github.com/tokio-rs/axum/pull/1495 # 0.3.0-rc.2 (10. September, 2022) - **breaking:** Added default limit to how much data `Bytes::from_request` will consume. Previously it would attempt to consume the entire request body without checking its length. This meant if a malicious peer sent an large (or infinite) request body your server might run out of memory and crash. The default limit is at 2 MB and can be disabled by adding the new `DefaultBodyLimit::disable()` middleware. See its documentation for more details. This also applies to `String` which used `Bytes::from_request` internally. ([#1346]) [#1346]: https://github.com/tokio-rs/axum/pull/1346 # 0.3.0-rc.1 (23. August, 2022) - **breaking:** `FromRequest` has been reworked and `RequestParts` has been removed. See axum's changelog for more details ([#1272]) - **added:** Added new `FromRequestParts` trait. See axum's changelog for more details ([#1272]) - **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1272]: https://github.com/tokio-rs/axum/pull/1272
# 0.2.8 (10. September, 2022) - **breaking:** Added default limit to how much data `Bytes::from_request` will consume. Previously it would attempt to consume the entire request body without checking its length. This meant if a malicious peer sent an large (or infinite) request body your server might run out of memory and crash. The default limit is at 2 MB and can be disabled by adding the new `DefaultBodyLimit::disable()` middleware. See its documentation for more details. This also applies to `String` which used `Bytes::from_request` internally. ([#1346]) [#1346]: https://github.com/tokio-rs/axum/pull/1346 # 0.2.7 (10. July, 2022) - **fix:** Fix typos in `RequestParts` docs ([#1147]) [#1147]: https://github.com/tokio-rs/axum/pull/1147 # 0.2.6 (18. June, 2022) - **change:** axum-core's MSRV is now 1.56 ([#1098]) [#1098]: https://github.com/tokio-rs/axum/pull/1098 # 0.2.5 (08. June, 2022) - **added:** Automatically handle `http_body::LengthLimitError` in `FailedToBufferBody` and map such errors to `413 Payload Too Large` ([#1048]) - **fixed:** Use `impl IntoResponse` less in docs ([#1049]) [#1048]: https://github.com/tokio-rs/axum/pull/1048 [#1049]: https://github.com/tokio-rs/axum/pull/1049 # 0.2.4 (02. May, 2022) - **added:** Implement `IntoResponse` and `IntoResponseParts` for `http::Extensions` ([#975]) - **added:** Implement `IntoResponse` for `(http::response::Parts, impl IntoResponse)` ([#950]) - **added:** Implement `IntoResponse` for `(http::response::Response<()>, impl IntoResponse)` ([#950]) - **added:** Implement `IntoResponse for (Parts | Request<()>, $(impl IntoResponseParts)+, impl IntoResponse)` ([#980]) [#950]: https://github.com/tokio-rs/axum/pull/950 [#975]: https://github.com/tokio-rs/axum/pull/975 [#980]: https://github.com/tokio-rs/axum/pull/980 # 0.2.3 (25. April, 2022) - **added:** Add `response::ErrorResponse` and `response::Result` for `IntoResponse`-based error handling ([#921]) [#921]: https://github.com/tokio-rs/axum/pull/921 # 0.2.2 (19. April, 2022) - **added:** Add `AppendHeaders` for appending headers to a response rather than overriding them ([#927]) [#927]: https://github.com/tokio-rs/axum/pull/927 # 0.2.1 (03. April, 2022) - **added:** Add `RequestParts::extract` which allows applying an extractor as a method call ([#897]) [#897]: https://github.com/tokio-rs/axum/pull/897 # 0.2.0 (31. March, 2022) - **added:** Add `IntoResponseParts` trait which allows defining custom response types for adding headers or extensions to responses ([#797]) - **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead `HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the headers you need ([#698]) This includes these breaking changes: - `RequestParts::take_headers` has been removed. - `RequestParts::headers` returns `&HeaderMap`. - `RequestParts::headers_mut` returns `&mut HeaderMap`. - `HeadersAlreadyExtracted` has been removed. - The `HeadersAlreadyExtracted` variant has been removed from these rejections: - `RequestAlreadyExtracted` - `RequestPartsAlreadyExtracted` - `>::Rejection` has been changed to `std::convert::Infallible`. - **breaking:** `axum::http::Extensions` is no longer an extractor (ie it doesn't implement `FromRequest`). The `axum::extract::Extension` extractor is _not_ impacted by this and works the same. This change makes it harder to accidentally remove all extensions which would result in confusing errors elsewhere ([#699]) This includes these breaking changes: - `RequestParts::take_extensions` has been removed. - `RequestParts::extensions` returns `&Extensions`. - `RequestParts::extensions_mut` returns `&mut Extensions`. - `RequestAlreadyExtracted` has been removed. - `::Rejection` is now `BodyAlreadyExtracted`. - `::Rejection` is now `Infallible`. - `ExtensionsAlreadyExtracted` has been removed. - **breaking:** `RequestParts::body_mut` now returns `&mut Option` so the body can be swapped ([#869]) [#698]: https://github.com/tokio-rs/axum/pull/698 [#699]: https://github.com/tokio-rs/axum/pull/699 [#797]: https://github.com/tokio-rs/axum/pull/797 [#869]: https://github.com/tokio-rs/axum/pull/869 # 0.1.2 (22. February, 2022) - **added:** Implement `IntoResponse` for `bytes::BytesMut` and `bytes::Chain` ([#767]) [#767]: https://github.com/tokio-rs/axum/pull/767 # 0.1.1 (06. December, 2021) - **added:** `axum_core::response::Response` now exists as a shorthand for writing `Response` ([#590]) [#590]: https://github.com/tokio-rs/axum/pull/590 # 0.1.0 (02. December, 2021) - Initial release. axum-axum-v0.6.20/axum-core/Cargo.toml000066400000000000000000000026641446276533600175670ustar00rootroot00000000000000[package] categories = ["asynchronous", "network-programming", "web-programming"] description = "Core types and traits for axum" edition = "2021" rust-version = "1.56" homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-core" readme = "README.md" repository = "https://github.com/tokio-rs/axum" version = "0.3.4" # remember to also bump the version that axum and axum-extra depend on [features] tracing = ["dep:tracing"] # Required for intra-doc links to resolve correctly __private_docs = ["dep:tower-http"] [dependencies] async-trait = "0.1.67" bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2.7" http-body = "0.4.5" mime = "0.3.16" tower-layer = "0.3" tower-service = "0.3" # optional dependencies tower-http = { version = "0.4", optional = true, features = ["limit"] } tracing = { version = "0.1.37", default-features = false, optional = true } [build-dependencies] rustversion = "1.0.9" [dev-dependencies] axum = { path = "../axum", version = "0.6.0", features = ["headers"] } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = "0.14.24" tokio = { version = "1.25.0", features = ["macros"] } tower-http = { version = "0.4", features = ["limit"] } [package.metadata.cargo-public-api-crates] allowed = [ "futures_core", "http", "bytes", "http_body", "tower_layer", ] axum-axum-v0.6.20/axum-core/LICENSE000066400000000000000000000020411446276533600166310ustar00rootroot00000000000000Copyright 2021 Axum Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. axum-axum-v0.6.20/axum-core/README.md000066400000000000000000000026421446276533600171120ustar00rootroot00000000000000# axum-core [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-core/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum-core)](https://crates.io/crates/axum-core) [![Documentation](https://docs.rs/axum-core/badge.svg)](https://docs.rs/axum-core) Core types and traits for axum. More information about this crate can be found in the [crate documentation][docs]. ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. ## Minimum supported Rust version axum-core's MSRV is 1.56. ## Getting Help You're also welcome to ask in the [Discord channel][chat] or open an [issue] with your question. ## Contributing ๐ŸŽˆ Thanks for your help improving the project! We are so happy to have you! We have a [contributing guide][contributing] to help you get involved in the `axum` project. ## License This project is licensed under the [MIT license][license]. ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in `axum` by you, shall be licensed as MIT, without any additional terms or conditions. [`axum`]: https://crates.io/crates/axum [chat]: https://discord.gg/tokio [contributing]: /CONTRIBUTING.md [docs]: https://docs.rs/axum-core [license]: /axum-core/LICENSE [issue]: https://github.com/tokio-rs/axum/issues/new axum-axum-v0.6.20/axum-core/build.rs000066400000000000000000000002111446276533600172660ustar00rootroot00000000000000#[rustversion::nightly] fn main() { println!("cargo:rustc-cfg=nightly_error_messages"); } #[rustversion::not(nightly)] fn main() {} axum-axum-v0.6.20/axum-core/src/000077500000000000000000000000001446276533600164165ustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/src/body.rs000066400000000000000000000057401446276533600177270ustar00rootroot00000000000000//! HTTP body utilities. use crate::{BoxError, Error}; use bytes::Bytes; use bytes::{Buf, BufMut}; use http_body::Body; /// A boxed [`Body`] trait object. /// /// This is used in axum as the response body type for applications. It's /// necessary to unify multiple response bodies types into one. pub type BoxBody = http_body::combinators::UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub fn boxed(body: B) -> BoxBody where B: http_body::Body + Send + 'static, B::Error: Into, { try_downcast(body).unwrap_or_else(|body| body.map_err(Error::new).boxed_unsync()) } pub(crate) fn try_downcast(k: K) -> Result where T: 'static, K: Send + 'static, { let mut k = Some(k); if let Some(k) = ::downcast_mut::>(&mut k) { Ok(k.take().unwrap()) } else { Err(k.unwrap()) } } // copied from hyper under the following license: // Copyright (c) 2014-2021 Sean McArthur // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. pub(crate) async fn to_bytes(body: T) -> Result where T: Body, { futures_util::pin_mut!(body); // If there's only 1 chunk, we can just return Buf::to_bytes() let mut first = if let Some(buf) = body.data().await { buf? } else { return Ok(Bytes::new()); }; let second = if let Some(buf) = body.data().await { buf? } else { return Ok(first.copy_to_bytes(first.remaining())); }; // With more than 1 buf, we gotta flatten into a Vec first. let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; let mut vec = Vec::with_capacity(cap); vec.put(first); vec.put(second); while let Some(buf) = body.data().await { vec.put(buf?); } Ok(vec.into()) } #[test] fn test_try_downcast() { assert_eq!(try_downcast::(5_u32), Err(5_u32)); assert_eq!(try_downcast::(5_i32), Ok(5_i32)); } axum-axum-v0.6.20/axum-core/src/error.rs000066400000000000000000000013451446276533600201200ustar00rootroot00000000000000use crate::BoxError; use std::{error::Error as StdError, fmt}; /// Errors that can happen when using axum. #[derive(Debug)] pub struct Error { inner: BoxError, } impl Error { /// Create a new `Error` from a boxable error. pub fn new(error: impl Into) -> Self { Self { inner: error.into(), } } /// Convert an `Error` back into the underlying boxed trait object. pub fn into_inner(self) -> BoxError { self.inner } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.inner.fmt(f) } } impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { Some(&*self.inner) } } axum-axum-v0.6.20/axum-core/src/ext_traits/000077500000000000000000000000001446276533600206045ustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/src/ext_traits/mod.rs000066400000000000000000000024441446276533600217350ustar00rootroot00000000000000pub(crate) mod request; pub(crate) mod request_parts; #[cfg(test)] mod tests { use std::convert::Infallible; use crate::extract::{FromRef, FromRequestParts}; use async_trait::async_trait; use http::request::Parts; #[derive(Debug, Default, Clone, Copy)] pub(crate) struct State(pub(crate) S); #[async_trait] impl FromRequestParts for State where InnerState: FromRef, OuterState: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( _parts: &mut Parts, state: &OuterState, ) -> Result { let inner_state = InnerState::from_ref(state); Ok(Self(inner_state)) } } // some extractor that requires the state, such as `SignedCookieJar` pub(crate) struct RequiresState(pub(crate) String); #[async_trait] impl FromRequestParts for RequiresState where S: Send + Sync, String: FromRef, { type Rejection = Infallible; async fn from_request_parts( _parts: &mut Parts, state: &S, ) -> Result { Ok(Self(String::from_ref(state))) } } } axum-axum-v0.6.20/axum-core/src/ext_traits/request.rs000066400000000000000000000327331446276533600226520ustar00rootroot00000000000000use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use http_body::Limited; mod sealed { pub trait Sealed {} impl Sealed for http::Request {} } /// Extension trait that adds additional methods to [`Request`]. pub trait RequestExt: sealed::Sealed + Sized { /// Apply an extractor to this `Request`. /// /// This is just a convenience for `E::from_request(req, &())`. /// /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting /// the body and don't want to consume the request. /// /// # Example /// /// ``` /// use axum::{ /// async_trait, /// extract::FromRequest, /// http::{header::CONTENT_TYPE, Request, StatusCode}, /// response::{IntoResponse, Response}, /// Form, Json, RequestExt, /// }; /// /// struct FormOrJson(T); /// /// #[async_trait] /// impl FromRequest for FormOrJson /// where /// Json: FromRequest<(), B>, /// Form: FromRequest<(), B>, /// T: 'static, /// B: Send + 'static, /// S: Send + Sync, /// { /// type Rejection = Response; /// /// async fn from_request(req: Request, _state: &S) -> Result { /// let content_type = req /// .headers() /// .get(CONTENT_TYPE) /// .and_then(|value| value.to_str().ok()) /// .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?; /// /// if content_type.starts_with("application/json") { /// let Json(payload) = req /// .extract::, _>() /// .await /// .map_err(|err| err.into_response())?; /// /// Ok(Self(payload)) /// } else if content_type.starts_with("application/x-www-form-urlencoded") { /// let Form(payload) = req /// .extract::, _>() /// .await /// .map_err(|err| err.into_response())?; /// /// Ok(Self(payload)) /// } else { /// Err(StatusCode::BAD_REQUEST.into_response()) /// } /// } /// } /// ``` fn extract(self) -> BoxFuture<'static, Result> where E: FromRequest<(), B, M> + 'static, M: 'static; /// Apply an extractor that requires some state to this `Request`. /// /// This is just a convenience for `E::from_request(req, state)`. /// /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not /// extracting the body and don't want to consume the request. /// /// # Example /// /// ``` /// use axum::{ /// async_trait, /// extract::{FromRef, FromRequest}, /// http::Request, /// RequestExt, /// }; /// /// struct MyExtractor { /// requires_state: RequiresState, /// } /// /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, /// S: Send + Sync, /// B: Send + 'static, /// { /// type Rejection = std::convert::Infallible; /// /// async fn from_request(req: Request, state: &S) -> Result { /// let requires_state = req.extract_with_state::(state).await?; /// /// Ok(Self { requires_state }) /// } /// } /// /// // some extractor that consumes the request body and requires state /// struct RequiresState { /* ... */ } /// /// #[async_trait] /// impl FromRequest for RequiresState /// where /// String: FromRef, /// S: Send + Sync, /// B: Send + 'static, /// { /// // ... /// # type Rejection = std::convert::Infallible; /// # async fn from_request(req: Request, _state: &S) -> Result { /// # todo!() /// # } /// } /// ``` fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> where E: FromRequest + 'static, S: Send + Sync; /// Apply a parts extractor to this `Request`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. /// /// # Example /// /// ``` /// use axum::{ /// async_trait, /// extract::FromRequest, /// headers::{authorization::Bearer, Authorization}, /// http::Request, /// response::{IntoResponse, Response}, /// Json, RequestExt, TypedHeader, /// }; /// /// struct MyExtractor { /// bearer_token: String, /// payload: T, /// } /// /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// B: Send + 'static, /// S: Send + Sync, /// Json: FromRequest<(), B>, /// T: 'static, /// { /// type Rejection = Response; /// /// async fn from_request(mut req: Request, _state: &S) -> Result { /// let TypedHeader(auth_header) = req /// .extract_parts::>>() /// .await /// .map_err(|err| err.into_response())?; /// /// let Json(payload) = req /// .extract::, _>() /// .await /// .map_err(|err| err.into_response())?; /// /// Ok(Self { /// bearer_token: auth_header.token().to_owned(), /// payload, /// }) /// } /// } /// ``` fn extract_parts(&mut self) -> BoxFuture<'_, Result> where E: FromRequestParts<()> + 'static; /// Apply a parts extractor that requires some state to this `Request`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. /// /// # Example /// /// ``` /// use axum::{ /// async_trait, /// extract::{FromRef, FromRequest, FromRequestParts}, /// http::{request::Parts, Request}, /// response::{IntoResponse, Response}, /// Json, RequestExt, /// }; /// /// struct MyExtractor { /// requires_state: RequiresState, /// payload: T, /// } /// /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, /// Json: FromRequest<(), B>, /// T: 'static, /// S: Send + Sync, /// B: Send + 'static, /// { /// type Rejection = Response; /// /// async fn from_request(mut req: Request, state: &S) -> Result { /// let requires_state = req /// .extract_parts_with_state::(state) /// .await /// .map_err(|err| err.into_response())?; /// /// let Json(payload) = req /// .extract::, _>() /// .await /// .map_err(|err| err.into_response())?; /// /// Ok(Self { /// requires_state, /// payload, /// }) /// } /// } /// /// struct RequiresState {} /// /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, /// S: Send + Sync, /// { /// // ... /// # type Rejection = std::convert::Infallible; /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// # todo!() /// # } /// } /// ``` fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result> where E: FromRequestParts + 'static, S: Send + Sync; /// Apply the [default body limit](crate::extract::DefaultBodyLimit). /// /// If it is disabled, return the request as-is in `Err`. fn with_limited_body(self) -> Result>, Request>; /// Consumes the request, returning the body wrapped in [`Limited`] if a /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the /// default limit is disabled. fn into_limited_body(self) -> Result, B>; } impl RequestExt for Request where B: Send + 'static, { fn extract(self) -> BoxFuture<'static, Result> where E: FromRequest<(), B, M> + 'static, M: 'static, { self.extract_with_state(&()) } fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> where E: FromRequest + 'static, S: Send + Sync, { E::from_request(self, state) } fn extract_parts(&mut self) -> BoxFuture<'_, Result> where E: FromRequestParts<()> + 'static, { self.extract_parts_with_state(&()) } fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result> where E: FromRequestParts + 'static, S: Send + Sync, { let mut req = Request::new(()); *req.version_mut() = self.version(); *req.method_mut() = self.method().clone(); *req.uri_mut() = self.uri().clone(); *req.headers_mut() = std::mem::take(self.headers_mut()); *req.extensions_mut() = std::mem::take(self.extensions_mut()); let (mut parts, _) = req.into_parts(); Box::pin(async move { let result = E::from_request_parts(&mut parts, state).await; *self.version_mut() = parts.version; *self.method_mut() = parts.method.clone(); *self.uri_mut() = parts.uri.clone(); *self.headers_mut() = std::mem::take(&mut parts.headers); *self.extensions_mut() = std::mem::take(&mut parts.extensions); result }) } fn with_limited_body(self) -> Result>, Request> { // update docs in `axum-core/src/extract/default_body_limit.rs` and // `axum/src/docs/extract.md` if this changes const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb match self.extensions().get::().copied() { Some(DefaultBodyLimitKind::Disable) => Err(self), Some(DefaultBodyLimitKind::Limit(limit)) => { Ok(self.map(|b| http_body::Limited::new(b, limit))) } None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))), } } fn into_limited_body(self) -> Result, B> { self.with_limited_body() .map(Request::into_body) .map_err(Request::into_body) } } #[cfg(test)] mod tests { use super::*; use crate::{ ext_traits::tests::{RequiresState, State}, extract::FromRef, }; use async_trait::async_trait; use http::Method; use hyper::Body; #[tokio::test] async fn extract_without_state() { let req = Request::new(()); let method: Method = req.extract().await.unwrap(); assert_eq!(method, Method::GET); } #[tokio::test] async fn extract_body_without_state() { let req = Request::new(Body::from("foobar")); let body: String = req.extract().await.unwrap(); assert_eq!(body, "foobar"); } #[tokio::test] async fn extract_with_state() { let req = Request::new(()); let state = "state".to_owned(); let State(extracted_state): State = req.extract_with_state(&state).await.unwrap(); assert_eq!(extracted_state, state); } #[tokio::test] async fn extract_parts_without_state() { let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); let method: Method = req.extract_parts().await.unwrap(); assert_eq!(method, Method::GET); assert_eq!(req.headers()["x-foo"], "foo"); } #[tokio::test] async fn extract_parts_with_state() { let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); let state = "state".to_owned(); let State(extracted_state): State = req.extract_parts_with_state(&state).await.unwrap(); assert_eq!(extracted_state, state); assert_eq!(req.headers()["x-foo"], "foo"); } // this stuff just needs to compile #[allow(dead_code)] struct WorksForCustomExtractor { method: Method, from_state: String, body: String, } #[async_trait] impl FromRequest for WorksForCustomExtractor where S: Send + Sync, B: Send + 'static, String: FromRef + FromRequest<(), B>, { type Rejection = >::Rejection; async fn from_request(mut req: Request, state: &S) -> Result { let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap(); let method = req.extract_parts().await.unwrap(); let body = req.extract().await?; Ok(Self { method, from_state, body, }) } } } axum-axum-v0.6.20/axum-core/src/ext_traits/request_parts.rs000066400000000000000000000130771446276533600240630ustar00rootroot00000000000000use crate::extract::FromRequestParts; use futures_util::future::BoxFuture; use http::request::Parts; mod sealed { pub trait Sealed {} impl Sealed for http::request::Parts {} } /// Extension trait that adds additional methods to [`Parts`]. pub trait RequestPartsExt: sealed::Sealed + Sized { /// Apply an extractor to this `Parts`. /// /// This is just a convenience for `E::from_request_parts(parts, &())`. /// /// # Example /// /// ``` /// use axum::{ /// extract::{Query, TypedHeader, FromRequestParts}, /// response::{Response, IntoResponse}, /// headers::UserAgent, /// http::request::Parts, /// RequestPartsExt, /// async_trait, /// }; /// use std::collections::HashMap; /// /// struct MyExtractor { /// user_agent: String, /// query_params: HashMap, /// } /// /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// S: Send + Sync, /// { /// type Rejection = Response; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// let user_agent = parts /// .extract::>() /// .await /// .map(|user_agent| user_agent.as_str().to_owned()) /// .map_err(|err| err.into_response())?; /// /// let query_params = parts /// .extract::>>() /// .await /// .map(|Query(params)| params) /// .map_err(|err| err.into_response())?; /// /// Ok(MyExtractor { user_agent, query_params }) /// } /// } /// ``` fn extract(&mut self) -> BoxFuture<'_, Result> where E: FromRequestParts<()> + 'static; /// Apply an extractor that requires some state to this `Parts`. /// /// This is just a convenience for `E::from_request_parts(parts, state)`. /// /// # Example /// /// ``` /// use axum::{ /// extract::{FromRef, FromRequestParts}, /// response::{Response, IntoResponse}, /// http::request::Parts, /// RequestPartsExt, /// async_trait, /// }; /// /// struct MyExtractor { /// requires_state: RequiresState, /// } /// /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// String: FromRef, /// S: Send + Sync, /// { /// type Rejection = std::convert::Infallible; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// let requires_state = parts /// .extract_with_state::(state) /// .await?; /// /// Ok(MyExtractor { requires_state }) /// } /// } /// /// struct RequiresState { /* ... */ } /// /// // some extractor that requires a `String` in the state /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, /// S: Send + Sync, /// { /// // ... /// # type Rejection = std::convert::Infallible; /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// # unimplemented!() /// # } /// } /// ``` fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result> where E: FromRequestParts + 'static, S: Send + Sync; } impl RequestPartsExt for Parts { fn extract(&mut self) -> BoxFuture<'_, Result> where E: FromRequestParts<()> + 'static, { self.extract_with_state(&()) } fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result> where E: FromRequestParts + 'static, S: Send + Sync, { E::from_request_parts(self, state) } } #[cfg(test)] mod tests { use std::convert::Infallible; use super::*; use crate::{ ext_traits::tests::{RequiresState, State}, extract::FromRef, }; use async_trait::async_trait; use http::{Method, Request}; #[tokio::test] async fn extract_without_state() { let (mut parts, _) = Request::new(()).into_parts(); let method: Method = parts.extract().await.unwrap(); assert_eq!(method, Method::GET); } #[tokio::test] async fn extract_with_state() { let (mut parts, _) = Request::new(()).into_parts(); let state = "state".to_owned(); let State(extracted_state): State = parts .extract_with_state::, String>(&state) .await .unwrap(); assert_eq!(extracted_state, state); } // this stuff just needs to compile #[allow(dead_code)] struct WorksForCustomExtractor { method: Method, from_state: String, } #[async_trait] impl FromRequestParts for WorksForCustomExtractor where S: Send + Sync, String: FromRef, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let RequiresState(from_state) = parts.extract_with_state(state).await?; let method = parts.extract().await?; Ok(Self { method, from_state }) } } } axum-axum-v0.6.20/axum-core/src/extract/000077500000000000000000000000001446276533600200705ustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/src/extract/default_body_limit.rs000066400000000000000000000147401446276533600243030ustar00rootroot00000000000000use self::private::DefaultBodyLimitService; use tower_layer::Layer; /// Layer for configuring the default request body limit. /// /// For security reasons, [`Bytes`] will, by default, not accept bodies larger than 2MB. This also /// applies to extractors that uses [`Bytes`] internally such as `String`, [`Json`], and [`Form`]. /// /// This middleware provides ways to configure that. /// /// Note that if an extractor consumes the body directly with [`Body::data`], or similar, the /// default limit is _not_ applied. /// /// # Difference between `DefaultBodyLimit` and [`RequestBodyLimit`] /// /// `DefaultBodyLimit` and [`RequestBodyLimit`] serve similar functions but in different ways. /// /// `DefaultBodyLimit` is local in that it only applies to [`FromRequest`] implementations that /// explicitly apply it (or call another extractor that does). You can apply the limit with /// [`RequestExt::with_limited_body`] or [`RequestExt::into_limited_body`] /// /// [`RequestBodyLimit`] is applied globally to all requests, regardless of which extractors are /// used or how the body is consumed. /// /// `DefaultBodyLimit` is also easier to integrate into an existing setup since it doesn't change /// the request body type: /// /// ``` /// use axum::{ /// Router, /// routing::post, /// body::Body, /// extract::{DefaultBodyLimit, RawBody}, /// http::Request, /// }; /// /// let app = Router::new() /// .route( /// "/", /// // even with `DefaultBodyLimit` the request body is still just `Body` /// post(|request: Request| async {}), /// ) /// .layer(DefaultBodyLimit::max(1024)); /// # let _: Router<(), _> = app; /// ``` /// /// ``` /// use axum::{Router, routing::post, body::Body, extract::RawBody, http::Request}; /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// /// let app = Router::new() /// .route( /// "/", /// // `RequestBodyLimitLayer` changes the request body type to `Limited` /// // extracting a different body type wont work /// post(|request: Request>| async {}), /// ) /// .layer(RequestBodyLimitLayer::new(1024)); /// # let _: Router<(), _> = app; /// ``` /// /// In general using `DefaultBodyLimit` is recommended but if you need to use third party /// extractors and want to sure a limit is also applied there then [`RequestBodyLimit`] should be /// used. /// /// [`Body::data`]: http_body::Body::data /// [`Bytes`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.6.0/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.6.0/axum/struct.Form.html /// [`FromRequest`]: crate::extract::FromRequest /// [`RequestBodyLimit`]: tower_http::limit::RequestBodyLimit /// [`RequestExt::with_limited_body`]: crate::RequestExt::with_limited_body /// [`RequestExt::into_limited_body`]: crate::RequestExt::into_limited_body #[derive(Debug, Clone)] #[must_use] pub struct DefaultBodyLimit { kind: DefaultBodyLimitKind, } #[derive(Debug, Clone, Copy)] pub(crate) enum DefaultBodyLimitKind { Disable, Limit(usize), } impl DefaultBodyLimit { /// Disable the default request body limit. /// /// This must be used to receive bodies larger than the default limit of 2MB using [`Bytes`] or /// an extractor built on it such as `String`, [`Json`], [`Form`]. /// /// Note that if you're accepting data from untrusted remotes it is recommend to add your own /// limit such as [`tower_http::limit`]. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// body::{Bytes, Body}, /// extract::DefaultBodyLimit, /// }; /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// /// let app: Router<(), Limited> = Router::new() /// .route("/", get(|body: Bytes| async {})) /// // Disable the default limit /// .layer(DefaultBodyLimit::disable()) /// // Set a different limit /// .layer(RequestBodyLimitLayer::new(10 * 1000 * 1000)); /// ``` /// /// [`Bytes`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.6.0/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.6.0/axum/struct.Form.html pub fn disable() -> Self { Self { kind: DefaultBodyLimitKind::Disable, } } /// Set the default request body limit. /// /// By default the limit of request body sizes that [`Bytes::from_request`] (and other /// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method /// can be used to change that limit. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// body::{Bytes, Body}, /// extract::DefaultBodyLimit, /// }; /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// /// let app: Router<(), Limited> = Router::new() /// .route("/", get(|body: Bytes| async {})) /// // Replace the default of 2MB with 1024 bytes. /// .layer(DefaultBodyLimit::max(1024)); /// ``` /// /// [`Bytes::from_request`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.6.0/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.6.0/axum/struct.Form.html pub fn max(limit: usize) -> Self { Self { kind: DefaultBodyLimitKind::Limit(limit), } } } impl Layer for DefaultBodyLimit { type Service = DefaultBodyLimitService; fn layer(&self, inner: S) -> Self::Service { DefaultBodyLimitService { inner, kind: self.kind, } } } mod private { use super::DefaultBodyLimitKind; use http::Request; use std::task::Context; use tower_service::Service; #[derive(Debug, Clone, Copy)] pub struct DefaultBodyLimitService { pub(super) inner: S, pub(super) kind: DefaultBodyLimitKind, } impl Service> for DefaultBodyLimitService where S: Service>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll> { self.inner.poll_ready(cx) } #[inline] fn call(&mut self, mut req: Request) -> Self::Future { req.extensions_mut().insert(self.kind); self.inner.call(req) } } } axum-axum-v0.6.20/axum-core/src/extract/from_ref.rs000066400000000000000000000015131446276533600222350ustar00rootroot00000000000000/// Used to do reference-to-value conversions thus not consuming the input value. /// /// This is mainly used with [`State`] to extract "substates" from a reference to main application /// state. /// /// See [`State`] for more details on how library authors should use this trait. /// /// This trait can be derived using `#[derive(FromRef)]`. /// /// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html // NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is // defined in axum. That allows crate authors to use it when implementing extractors. pub trait FromRef { /// Converts to this type from a reference to the input type. fn from_ref(input: &T) -> Self; } impl FromRef for T where T: Clone, { fn from_ref(input: &T) -> Self { input.clone() } } axum-axum-v0.6.20/axum-core/src/extract/mod.rs000066400000000000000000000133071446276533600212210ustar00rootroot00000000000000//! Types and traits for extracting data from requests. //! //! See [`axum::extract`] for more details. //! //! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html use crate::response::IntoResponse; use async_trait::async_trait; use http::{request::Parts, Request}; use std::convert::Infallible; pub mod rejection; mod default_body_limit; mod from_ref; mod request_parts; mod tuple; pub(crate) use self::default_body_limit::DefaultBodyLimitKind; pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef}; mod private { #[derive(Debug, Clone, Copy)] pub enum ViaParts {} #[derive(Debug, Clone, Copy)] pub enum ViaRequest {} } /// Types that can be created from request parts. /// /// Extractors that implement `FromRequestParts` cannot consume the request body and can thus be /// run in any order for handlers. /// /// If your extractor needs to consume the request body then you should implement [`FromRequest`] /// and not [`FromRequestParts`]. /// /// See [`axum::extract`] for more general docs about extractors. /// /// [`axum::extract`]: https://docs.rs/axum/0.6.0/axum/extract/index.html #[async_trait] #[cfg_attr( nightly_error_messages, rustc_on_unimplemented( note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/latest/axum/extract/index.html` for details", ) )] pub trait FromRequestParts: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. async fn from_request_parts(parts: &mut Parts, state: &S) -> Result; } /// Types that can be created from requests. /// /// Extractors that implement `FromRequest` can consume the request body and can thus only be run /// once for handlers. /// /// If your extractor doesn't need to consume the request body then you should implement /// [`FromRequestParts`] and not [`FromRequest`]. /// /// See [`axum::extract`] for more general docs about extractors. /// /// # What is the `B` type parameter? /// /// `FromRequest` is generic over the request body (the `B` in /// [`http::Request`]). This is to allow `FromRequest` to be usable with any /// type of request body. This is necessary because some middleware change the /// request body, for example to add timeouts. /// /// If you're writing your own `FromRequest` that wont be used outside your /// application, and not using any middleware that changes the request body, you /// can most likely use `axum::body::Body`. /// /// If you're writing a library that's intended for others to use, it's recommended /// to keep the generic type parameter: /// /// ```rust /// use axum::{ /// async_trait, /// extract::FromRequest, /// http::{self, Request}, /// }; /// /// struct MyExtractor; /// /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// // these bounds are required by `async_trait` /// B: Send + 'static, /// S: Send + Sync, /// { /// type Rejection = http::StatusCode; /// /// async fn from_request(req: Request, state: &S) -> Result { /// // ... /// # unimplemented!() /// } /// } /// ``` /// /// This ensures your extractor is as flexible as possible. /// /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/0.6.0/axum/extract/index.html #[async_trait] #[cfg_attr( nightly_error_messages, rustc_on_unimplemented( note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/latest/axum/extract/index.html` for details", ) )] pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. async fn from_request(req: Request, state: &S) -> Result; } #[async_trait] impl FromRequest for T where B: Send + 'static, S: Send + Sync, T: FromRequestParts, { type Rejection = >::Rejection; async fn from_request(req: Request, state: &S) -> Result { let (mut parts, _) = req.into_parts(); Self::from_request_parts(&mut parts, state).await } } #[async_trait] impl FromRequestParts for Option where T: FromRequestParts, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( parts: &mut Parts, state: &S, ) -> Result, Self::Rejection> { Ok(T::from_request_parts(parts, state).await.ok()) } } #[async_trait] impl FromRequest for Option where T: FromRequest, B: Send + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { Ok(T::from_request(req, state).await.ok()) } } #[async_trait] impl FromRequestParts for Result where T: FromRequestParts, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { Ok(T::from_request_parts(parts, state).await) } } #[async_trait] impl FromRequest for Result where T: FromRequest, B: Send + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, state: &S) -> Result { Ok(T::from_request(req, state).await) } } axum-axum-v0.6.20/axum-core/src/extract/rejection.rs000066400000000000000000000041031446276533600224160ustar00rootroot00000000000000//! Rejection response types. use crate::__composite_rejection as composite_rejection; use crate::__define_rejection as define_rejection; use crate::BoxError; composite_rejection! { /// Rejection type for extractors that buffer the request body. Used if the /// request body cannot be buffered due to an error. pub enum FailedToBufferBody { LengthLimitError, UnknownBodyError, } } impl FailedToBufferBody { pub(crate) fn from_err(err: E) -> Self where E: Into, { match err.into().downcast::() { Ok(err) => Self::LengthLimitError(LengthLimitError::from_err(err)), Err(err) => Self::UnknownBodyError(UnknownBodyError::from_err(err)), } } } define_rejection! { #[status = PAYLOAD_TOO_LARGE] #[body = "Failed to buffer the request body"] /// Encountered some other error when buffering the body. /// /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or /// otherwise wrapping request bodies in [`http_body::Limited`]. pub struct LengthLimitError(Error); } define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to buffer the request body"] /// Encountered an unknown error when buffering the body. pub struct UnknownBodyError(Error); } define_rejection! { #[status = BAD_REQUEST] #[body = "Request body didn't contain valid UTF-8"] /// Rejection type used when buffering the request into a [`String`] if the /// body doesn't contain valid UTF-8. pub struct InvalidUtf8(Error); } composite_rejection! { /// Rejection used for [`Bytes`](bytes::Bytes). /// /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor /// can fail. pub enum BytesRejection { FailedToBufferBody, } } composite_rejection! { /// Rejection used for [`String`]. /// /// Contains one variant for each way the [`String`] extractor can fail. pub enum StringRejection { FailedToBufferBody, InvalidUtf8, } } axum-axum-v0.6.20/axum-core/src/extract/request_parts.rs000066400000000000000000000064361446276533600233500ustar00rootroot00000000000000use super::{rejection::*, FromRequest, FromRequestParts}; use crate::{BoxError, RequestExt}; use async_trait::async_trait; use bytes::Bytes; use http::{request::Parts, HeaderMap, Method, Request, Uri, Version}; use std::convert::Infallible; #[async_trait] impl FromRequest for Request where B: Send, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, _: &S) -> Result { Ok(req) } } #[async_trait] impl FromRequestParts for Method where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { Ok(parts.method.clone()) } } #[async_trait] impl FromRequestParts for Uri where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { Ok(parts.uri.clone()) } } #[async_trait] impl FromRequestParts for Version where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { Ok(parts.version) } } /// Clone the headers from the request. /// /// Prefer using [`TypedHeader`] to extract only the headers you need. /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] impl FromRequestParts for HeaderMap where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { Ok(parts.headers.clone()) } } #[async_trait] impl FromRequest for Bytes where B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = BytesRejection; async fn from_request(req: Request, _: &S) -> Result { let bytes = match req.into_limited_body() { Ok(limited_body) => crate::body::to_bytes(limited_body) .await .map_err(FailedToBufferBody::from_err)?, Err(unlimited_body) => crate::body::to_bytes(unlimited_body) .await .map_err(FailedToBufferBody::from_err)?, }; Ok(bytes) } } #[async_trait] impl FromRequest for String where B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = StringRejection; async fn from_request(req: Request, state: &S) -> Result { let bytes = Bytes::from_request(req, state) .await .map_err(|err| match err { BytesRejection::FailedToBufferBody(inner) => { StringRejection::FailedToBufferBody(inner) } })?; let string = std::str::from_utf8(&bytes) .map_err(InvalidUtf8::from_err)? .to_owned(); Ok(string) } } #[async_trait] impl FromRequest for Parts where B: Send + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, _: &S) -> Result { Ok(req.into_parts().0) } } axum-axum-v0.6.20/axum-core/src/extract/tuple.rs000066400000000000000000000066151446276533600215770ustar00rootroot00000000000000use super::{FromRequest, FromRequestParts}; use crate::response::{IntoResponse, Response}; use async_trait::async_trait; use http::request::{Parts, Request}; use std::convert::Infallible; #[async_trait] impl FromRequestParts for () where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> { Ok(()) } } macro_rules! impl_from_request { ( [$($ty:ident),*], $last:ident ) => { #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequestParts for ($($ty,)* $last,) where $( $ty: FromRequestParts + Send, )* $last: FromRequestParts + Send, S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { $( let $ty = $ty::from_request_parts(parts, state) .await .map_err(|err| err.into_response())?; )* let $last = $last::from_request_parts(parts, state) .await .map_err(|err| err.into_response())?; Ok(($($ty,)* $last,)) } } // This impl must not be generic over M, otherwise it would conflict with the blanket // implementation of `FromRequest` for `T: FromRequestParts`. #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequest for ($($ty,)* $last,) where $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, B: Send + 'static, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let (mut parts, body) = req.into_parts(); $( let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?; )* let req = Request::from_parts(parts, body); let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?; Ok(($($ty,)* $last,)) } } }; } all_the_tuples!(impl_from_request); #[cfg(test)] mod tests { use bytes::Bytes; use http::Method; use crate::extract::{FromRequest, FromRequestParts}; fn assert_from_request() where T: FromRequest<(), http_body::Full, M>, { } fn assert_from_request_parts>() {} #[test] fn unit() { assert_from_request_parts::<()>(); assert_from_request::<_, ()>(); } #[test] fn tuple_of_one() { assert_from_request_parts::<(Method,)>(); assert_from_request::<_, (Method,)>(); assert_from_request::<_, (Bytes,)>(); } #[test] fn tuple_of_two() { assert_from_request_parts::<((), ())>(); assert_from_request::<_, ((), ())>(); assert_from_request::<_, (Method, Bytes)>(); } #[test] fn nested_tuple() { assert_from_request_parts::<(((Method,),),)>(); assert_from_request::<_, ((((Bytes,),),),)>(); } } axum-axum-v0.6.20/axum-core/src/lib.rs000066400000000000000000000037321446276533600175370ustar00rootroot00000000000000#![cfg_attr(nightly_error_messages, feature(rustc_attrs))] //! Core types and traits for [`axum`]. //! //! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations //! should depend on the [`axum-core`] crate, instead of `axum` if possible. //! //! [`FromRequest`]: crate::extract::FromRequest //! [`IntoResponse`]: crate::response::IntoResponse //! [`axum`]: https://crates.io/crates/axum //! [`axum-core`]: http://crates.io/crates/axum-core #![warn( clippy::all, clippy::dbg_macro, clippy::todo, clippy::empty_enum, clippy::enum_glob_use, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, clippy::needless_continue, clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, clippy::rest_pat_in_fully_bound_structs, clippy::fn_params_excessive_bools, clippy::exit, clippy::inefficient_to_string, clippy::linkedlist, clippy::macro_use_imports, clippy::option_option, clippy::verbose_file_reads, clippy::unnested_or_patterns, clippy::str_to_string, rust_2018_idioms, future_incompatible, nonstandard_style, missing_debug_implementations, missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] #![forbid(unsafe_code)] #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] #[macro_use] pub(crate) mod macros; mod error; mod ext_traits; pub use self::error::Error; pub mod body; pub mod extract; pub mod response; /// Alias for a type-erased error type. pub type BoxError = Box; pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; axum-axum-v0.6.20/axum-core/src/macros.rs000066400000000000000000000232341446276533600202540ustar00rootroot00000000000000/// Private API. #[doc(hidden)] #[macro_export] macro_rules! __log_rejection { ( rejection_type = $ty:ident, body_text = $body_text:expr, status = $status:expr, ) => { #[cfg(feature = "tracing")] { tracing::event!( target: "axum::rejection", tracing::Level::TRACE, status = $status.as_u16(), body = $body_text, rejection_type = std::any::type_name::<$ty>(), "rejecting request", ); } }; } /// Private API. #[doc(hidden)] #[macro_export] macro_rules! __define_rejection { ( #[status = $status:ident] #[body = $body:expr] $(#[$m:meta])* pub struct $name:ident; ) => { $(#[$m])* #[derive(Debug)] #[non_exhaustive] pub struct $name; impl $crate::response::IntoResponse for $name { fn into_response(self) -> $crate::response::Response { $crate::__log_rejection!( rejection_type = $name, body_text = $body, status = http::StatusCode::$status, ); (self.status(), $body).into_response() } } impl $name { /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { $body.into() } /// Get the status code used for this rejection. pub fn status(&self) -> http::StatusCode { http::StatusCode::$status } } impl std::fmt::Display for $name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", $body) } } impl std::error::Error for $name {} impl Default for $name { fn default() -> Self { Self } } }; ( #[status = $status:ident] #[body = $body:expr] $(#[$m:meta])* pub struct $name:ident (Error); ) => { $(#[$m])* #[derive(Debug)] pub struct $name(pub(crate) $crate::Error); impl $name { pub(crate) fn from_err(err: E) -> Self where E: Into<$crate::BoxError>, { Self($crate::Error::new(err)) } } impl $crate::response::IntoResponse for $name { fn into_response(self) -> $crate::response::Response { $crate::__log_rejection!( rejection_type = $name, body_text = self.body_text(), status = http::StatusCode::$status, ); (self.status(), self.body_text()).into_response() } } impl $name { /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { format!(concat!($body, ": {}"), self.0).into() } /// Get the status code used for this rejection. pub fn status(&self) -> http::StatusCode { http::StatusCode::$status } } impl std::fmt::Display for $name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", $body) } } impl std::error::Error for $name { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.0) } } }; } /// Private API. #[doc(hidden)] #[macro_export] macro_rules! __composite_rejection { ( $(#[$m:meta])* pub enum $name:ident { $($variant:ident),+ $(,)? } ) => { $(#[$m])* #[derive(Debug)] #[non_exhaustive] pub enum $name { $( #[allow(missing_docs)] $variant($variant) ),+ } impl $crate::response::IntoResponse for $name { fn into_response(self) -> $crate::response::Response { match self { $( Self::$variant(inner) => inner.into_response(), )+ } } } impl $name { /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { match self { $( Self::$variant(inner) => inner.body_text(), )+ } } /// Get the status code used for this rejection. pub fn status(&self) -> http::StatusCode { match self { $( Self::$variant(inner) => inner.status(), )+ } } } $( impl From<$variant> for $name { fn from(inner: $variant) -> Self { Self::$variant(inner) } } )+ impl std::fmt::Display for $name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { $( Self::$variant(inner) => write!(f, "{}", inner), )+ } } } impl std::error::Error for $name { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { $( Self::$variant(inner) => inner.source(), )+ } } } }; } #[cfg(test)] mod composite_rejection_tests { use self::defs::*; use crate::Error; use std::error::Error as _; #[allow(dead_code, unreachable_pub)] mod defs { use crate::{__composite_rejection, __define_rejection}; __define_rejection! { #[status = BAD_REQUEST] #[body = "error message 1"] pub struct Inner1; } __define_rejection! { #[status = BAD_REQUEST] #[body = "error message 2"] pub struct Inner2(Error); } __composite_rejection! { pub enum Outer { Inner1, Inner2 } } } /// The implementation of `.source()` on `Outer` should defer straight to the implementation /// on its inner type instead of returning the inner type itself, because the `Display` /// implementation on `Outer` already forwards to the inner type and so it would result in two /// errors in the chain `Display`ing the same thing. #[test] fn source_gives_inner_source() { let rejection = Outer::Inner1(Inner1); assert!(rejection.source().is_none()); let msg = "hello world"; let rejection = Outer::Inner2(Inner2(Error::new(msg))); assert_eq!(rejection.source().unwrap().to_string(), msg); } } #[rustfmt::skip] macro_rules! all_the_tuples { ($name:ident) => { $name!([], T1); $name!([T1], T2); $name!([T1, T2], T3); $name!([T1, T2, T3], T4); $name!([T1, T2, T3, T4], T5); $name!([T1, T2, T3, T4, T5], T6); $name!([T1, T2, T3, T4, T5, T6], T7); $name!([T1, T2, T3, T4, T5, T6, T7], T8); $name!([T1, T2, T3, T4, T5, T6, T7, T8], T9); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], T14); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], T15); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], T16); }; } macro_rules! all_the_tuples_no_last_special_case { ($name:ident) => { $name!(T1); $name!(T1, T2); $name!(T1, T2, T3); $name!(T1, T2, T3, T4); $name!(T1, T2, T3, T4, T5); $name!(T1, T2, T3, T4, T5, T6); $name!(T1, T2, T3, T4, T5, T6, T7); $name!(T1, T2, T3, T4, T5, T6, T7, T8); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); }; } /// Private API. #[doc(hidden)] #[macro_export] macro_rules! __impl_deref { ($ident:ident) => { impl std::ops::Deref for $ident { type Target = T; #[inline] fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for $ident { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } }; ($ident:ident: $ty:ty) => { impl std::ops::Deref for $ident { type Target = $ty; #[inline] fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for $ident { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } }; } axum-axum-v0.6.20/axum-core/src/response/000077500000000000000000000000001446276533600202545ustar00rootroot00000000000000axum-axum-v0.6.20/axum-core/src/response/append_headers.rs000066400000000000000000000037301446276533600235670ustar00rootroot00000000000000use super::{IntoResponse, IntoResponseParts, Response, ResponseParts, TryIntoHeaderError}; use http::header::{HeaderName, HeaderValue}; use std::fmt; /// Append headers to a response. /// /// Returning something like `[("content-type", "foo=bar")]` from a handler will override any /// existing `content-type` headers. If instead you want to append headers, use `AppendHeaders`: /// /// ```rust /// use axum::{ /// response::{AppendHeaders, IntoResponse}, /// http::header::SET_COOKIE, /// }; /// /// async fn handler() -> impl IntoResponse { /// // something that sets the `set-cookie` header /// let set_some_cookies = /* ... */ /// # axum::http::HeaderMap::new(); /// /// ( /// set_some_cookies, /// // append two `set-cookie` headers to the response /// // without overriding the ones added by `set_some_cookies` /// AppendHeaders([ /// (SET_COOKIE, "foo=bar"), /// (SET_COOKIE, "baz=qux"), /// ]) /// ) /// } /// ``` #[derive(Debug)] #[must_use] pub struct AppendHeaders(pub I); impl IntoResponse for AppendHeaders where I: IntoIterator, K: TryInto, K::Error: fmt::Display, V: TryInto, V::Error: fmt::Display, { fn into_response(self) -> Response { (self, ()).into_response() } } impl IntoResponseParts for AppendHeaders where I: IntoIterator, K: TryInto, K::Error: fmt::Display, V: TryInto, V::Error: fmt::Display, { type Error = TryIntoHeaderError; fn into_response_parts(self, mut res: ResponseParts) -> Result { for (key, value) in self.0 { let key = key.try_into().map_err(TryIntoHeaderError::key)?; let value = value.try_into().map_err(TryIntoHeaderError::value)?; res.headers_mut().append(key, value); } Ok(res) } } axum-axum-v0.6.20/axum-core/src/response/into_response.rs000066400000000000000000000335161446276533600235210ustar00rootroot00000000000000use super::{IntoResponseParts, Response, ResponseParts}; use crate::{body, BoxError}; use bytes::{buf::Chain, Buf, Bytes, BytesMut}; use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, Extensions, StatusCode, }; use http_body::{ combinators::{MapData, MapErr}, Empty, Full, SizeHint, }; use std::{ borrow::Cow, convert::Infallible, fmt, pin::Pin, task::{Context, Poll}, }; /// Trait for generating responses. /// /// Types that implement `IntoResponse` can be returned from handlers. /// /// # Implementing `IntoResponse` /// /// You generally shouldn't have to implement `IntoResponse` manually, as axum /// provides implementations for many common types. /// /// However it might be necessary if you have a custom error type that you want /// to return from handlers: /// /// ```rust /// use axum::{ /// Router, /// body::{self, Bytes}, /// routing::get, /// http::StatusCode, /// response::{IntoResponse, Response}, /// }; /// /// enum MyError { /// SomethingWentWrong, /// SomethingElseWentWrong, /// } /// /// impl IntoResponse for MyError { /// fn into_response(self) -> Response { /// let body = match self { /// MyError::SomethingWentWrong => "something went wrong", /// MyError::SomethingElseWentWrong => "something else went wrong", /// }; /// /// // its often easiest to implement `IntoResponse` by calling other implementations /// (StatusCode::INTERNAL_SERVER_ERROR, body).into_response() /// } /// } /// /// // `Result` can now be returned from handlers /// let app = Router::new().route("/", get(handler)); /// /// async fn handler() -> Result<(), MyError> { /// Err(MyError::SomethingWentWrong) /// } /// # async { /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Or if you have a custom body type you'll also need to implement /// `IntoResponse` for it: /// /// ```rust /// use axum::{ /// body, /// routing::get, /// response::{IntoResponse, Response}, /// Router, /// }; /// use http_body::Body; /// use http::HeaderMap; /// use bytes::Bytes; /// use std::{ /// convert::Infallible, /// task::{Poll, Context}, /// pin::Pin, /// }; /// /// struct MyBody; /// /// // First implement `Body` for `MyBody`. This could for example use /// // some custom streaming protocol. /// impl Body for MyBody { /// type Data = Bytes; /// type Error = Infallible; /// /// fn poll_data( /// self: Pin<&mut Self>, /// cx: &mut Context<'_> /// ) -> Poll>> { /// # unimplemented!() /// // ... /// } /// /// fn poll_trailers( /// self: Pin<&mut Self>, /// cx: &mut Context<'_> /// ) -> Poll, Self::Error>> { /// # unimplemented!() /// // ... /// } /// } /// /// // Now we can implement `IntoResponse` directly for `MyBody` /// impl IntoResponse for MyBody { /// fn into_response(self) -> Response { /// Response::new(body::boxed(self)) /// } /// } /// /// // `MyBody` can now be returned from handlers. /// let app = Router::new().route("/", get(|| async { MyBody })); /// # async { /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub trait IntoResponse { /// Create a response. fn into_response(self) -> Response; } impl IntoResponse for StatusCode { fn into_response(self) -> Response { let mut res = ().into_response(); *res.status_mut() = self; res } } impl IntoResponse for () { fn into_response(self) -> Response { Empty::new().into_response() } } impl IntoResponse for Infallible { fn into_response(self) -> Response { match self {} } } impl IntoResponse for Result where T: IntoResponse, E: IntoResponse, { fn into_response(self) -> Response { match self { Ok(value) => value.into_response(), Err(err) => err.into_response(), } } } impl IntoResponse for Response where B: http_body::Body + Send + 'static, B::Error: Into, { fn into_response(self) -> Response { self.map(body::boxed) } } impl IntoResponse for http::response::Parts { fn into_response(self) -> Response { Response::from_parts(self, body::boxed(Empty::new())) } } impl IntoResponse for Full { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for Empty { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for http_body::combinators::BoxBody where E: Into + 'static, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for http_body::combinators::UnsyncBoxBody where E: Into + 'static, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for MapData where B: http_body::Body + Send + 'static, F: FnMut(B::Data) -> Bytes + Send + 'static, B::Error: Into, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for MapErr where B: http_body::Body + Send + 'static, F: FnMut(B::Error) -> E + Send + 'static, E: Into, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl IntoResponse for &'static str { fn into_response(self) -> Response { Cow::Borrowed(self).into_response() } } impl IntoResponse for String { fn into_response(self) -> Response { Cow::<'static, str>::Owned(self).into_response() } } impl IntoResponse for Box { fn into_response(self) -> Response { String::from(self).into_response() } } impl IntoResponse for Cow<'static, str> { fn into_response(self) -> Response { let mut res = Full::from(self).into_response(); res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), ); res } } impl IntoResponse for Bytes { fn into_response(self) -> Response { let mut res = Full::from(self).into_response(); res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), ); res } } impl IntoResponse for BytesMut { fn into_response(self) -> Response { self.freeze().into_response() } } impl IntoResponse for Chain where T: Buf + Unpin + Send + 'static, U: Buf + Unpin + Send + 'static, { fn into_response(self) -> Response { let (first, second) = self.into_inner(); let mut res = Response::new(body::boxed(BytesChainBody { first: Some(first), second: Some(second), })); res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), ); res } } struct BytesChainBody { first: Option, second: Option, } impl http_body::Body for BytesChainBody where T: Buf + Unpin, U: Buf + Unpin, { type Data = Bytes; type Error = Infallible; fn poll_data( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll>> { if let Some(mut buf) = self.first.take() { let bytes = buf.copy_to_bytes(buf.remaining()); return Poll::Ready(Some(Ok(bytes))); } if let Some(mut buf) = self.second.take() { let bytes = buf.copy_to_bytes(buf.remaining()); return Poll::Ready(Some(Ok(bytes))); } Poll::Ready(None) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } fn is_end_stream(&self) -> bool { self.first.is_none() && self.second.is_none() } fn size_hint(&self) -> SizeHint { match (self.first.as_ref(), self.second.as_ref()) { (Some(first), Some(second)) => { let total_size = first.remaining() + second.remaining(); SizeHint::with_exact(total_size as u64) } (Some(buf), None) => SizeHint::with_exact(buf.remaining() as u64), (None, Some(buf)) => SizeHint::with_exact(buf.remaining() as u64), (None, None) => SizeHint::with_exact(0), } } } impl IntoResponse for &'static [u8] { fn into_response(self) -> Response { Cow::Borrowed(self).into_response() } } impl IntoResponse for &'static [u8; N] { fn into_response(self) -> Response { self.as_slice().into_response() } } impl IntoResponse for [u8; N] { fn into_response(self) -> Response { self.to_vec().into_response() } } impl IntoResponse for Vec { fn into_response(self) -> Response { Cow::<'static, [u8]>::Owned(self).into_response() } } impl IntoResponse for Box<[u8]> { fn into_response(self) -> Response { Vec::from(self).into_response() } } impl IntoResponse for Cow<'static, [u8]> { fn into_response(self) -> Response { let mut res = Full::from(self).into_response(); res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), ); res } } impl IntoResponse for (StatusCode, R) where R: IntoResponse, { fn into_response(self) -> Response { let mut res = self.1.into_response(); *res.status_mut() = self.0; res } } impl IntoResponse for HeaderMap { fn into_response(self) -> Response { let mut res = ().into_response(); *res.headers_mut() = self; res } } impl IntoResponse for Extensions { fn into_response(self) -> Response { let mut res = ().into_response(); *res.extensions_mut() = self; res } } impl IntoResponse for [(K, V); N] where K: TryInto, K::Error: fmt::Display, V: TryInto, V::Error: fmt::Display, { fn into_response(self) -> Response { (self, ()).into_response() } } impl IntoResponse for (http::response::Parts, R) where R: IntoResponse, { fn into_response(self) -> Response { let (parts, res) = self; (parts.status, parts.headers, parts.extensions, res).into_response() } } impl IntoResponse for (http::response::Response<()>, R) where R: IntoResponse, { fn into_response(self) -> Response { let (template, res) = self; let (parts, ()) = template.into_parts(); (parts, res).into_response() } } macro_rules! impl_into_response { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] impl IntoResponse for ($($ty),*, R) where $( $ty: IntoResponseParts, )* R: IntoResponse, { fn into_response(self) -> Response { let ($($ty),*, res) = self; let res = res.into_response(); let parts = ResponseParts { res }; $( let parts = match $ty.into_response_parts(parts) { Ok(parts) => parts, Err(err) => { return err.into_response(); } }; )* parts.res } } #[allow(non_snake_case)] impl IntoResponse for (StatusCode, $($ty),*, R) where $( $ty: IntoResponseParts, )* R: IntoResponse, { fn into_response(self) -> Response { let (status, $($ty),*, res) = self; let res = res.into_response(); let parts = ResponseParts { res }; $( let parts = match $ty.into_response_parts(parts) { Ok(parts) => parts, Err(err) => { return err.into_response(); } }; )* (status, parts.res).into_response() } } #[allow(non_snake_case)] impl IntoResponse for (http::response::Parts, $($ty),*, R) where $( $ty: IntoResponseParts, )* R: IntoResponse, { fn into_response(self) -> Response { let (outer_parts, $($ty),*, res) = self; let res = res.into_response(); let parts = ResponseParts { res }; $( let parts = match $ty.into_response_parts(parts) { Ok(parts) => parts, Err(err) => { return err.into_response(); } }; )* (outer_parts, parts.res).into_response() } } #[allow(non_snake_case)] impl IntoResponse for (http::response::Response<()>, $($ty),*, R) where $( $ty: IntoResponseParts, )* R: IntoResponse, { fn into_response(self) -> Response { let (template, $($ty),*, res) = self; let (parts, ()) = template.into_parts(); (parts, $($ty),*, res).into_response() } } } } all_the_tuples_no_last_special_case!(impl_into_response); axum-axum-v0.6.20/axum-core/src/response/into_response_parts.rs000066400000000000000000000164571446276533600247370ustar00rootroot00000000000000use super::{IntoResponse, Response}; use http::{ header::{HeaderMap, HeaderName, HeaderValue}, Extensions, StatusCode, }; use std::{convert::Infallible, fmt}; /// Trait for adding headers and extensions to a response. /// /// # Example /// /// ```rust /// use axum::{ /// response::{ResponseParts, IntoResponse, IntoResponseParts, Response}, /// http::{StatusCode, header::{HeaderName, HeaderValue}}, /// }; /// /// // Hypothetical helper type for setting a single header /// struct SetHeader<'a>(&'a str, &'a str); /// /// impl<'a> IntoResponseParts for SetHeader<'a> { /// type Error = (StatusCode, String); /// /// fn into_response_parts(self, mut res: ResponseParts) -> Result { /// match (self.0.parse::(), self.1.parse::()) { /// (Ok(name), Ok(value)) => { /// res.headers_mut().insert(name, value); /// }, /// (Err(_), _) => { /// return Err(( /// StatusCode::INTERNAL_SERVER_ERROR, /// format!("Invalid header name {}", self.0), /// )); /// }, /// (_, Err(_)) => { /// return Err(( /// StatusCode::INTERNAL_SERVER_ERROR, /// format!("Invalid header value {}", self.1), /// )); /// }, /// } /// /// Ok(res) /// } /// } /// /// // Its also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as /// // the response /// impl<'a> IntoResponse for SetHeader<'a> { /// fn into_response(self) -> Response { /// // This gives an empty response with the header /// (self, ()).into_response() /// } /// } /// /// // We can now return `SetHeader` in responses /// // /// // Note that returning `impl IntoResponse` might be easier if the response has many parts to /// // it. The return type is written out here for clarity. /// async fn handler() -> (SetHeader<'static>, SetHeader<'static>, &'static str) { /// ( /// SetHeader("server", "axum"), /// SetHeader("x-foo", "custom"), /// "body", /// ) /// } /// /// // Or on its own as the whole response /// async fn other_handler() -> SetHeader<'static> { /// SetHeader("x-foo", "custom") /// } /// ``` pub trait IntoResponseParts { /// The type returned in the event of an error. /// /// This can be used to fallibly convert types into headers or extensions. type Error: IntoResponse; /// Set parts of the response fn into_response_parts(self, res: ResponseParts) -> Result; } impl IntoResponseParts for Option where T: IntoResponseParts, { type Error = T::Error; fn into_response_parts(self, res: ResponseParts) -> Result { if let Some(inner) = self { inner.into_response_parts(res) } else { Ok(res) } } } /// Parts of a response. /// /// Used with [`IntoResponseParts`]. #[derive(Debug)] pub struct ResponseParts { pub(crate) res: Response, } impl ResponseParts { /// Gets a reference to the response headers. pub fn headers(&self) -> &HeaderMap { self.res.headers() } /// Gets a mutable reference to the response headers. pub fn headers_mut(&mut self) -> &mut HeaderMap { self.res.headers_mut() } /// Gets a reference to the response extensions. pub fn extensions(&self) -> &Extensions { self.res.extensions() } /// Gets a mutable reference to the response extensions. pub fn extensions_mut(&mut self) -> &mut Extensions { self.res.extensions_mut() } } impl IntoResponseParts for HeaderMap { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { res.headers_mut().extend(self); Ok(res) } } impl IntoResponseParts for [(K, V); N] where K: TryInto, K::Error: fmt::Display, V: TryInto, V::Error: fmt::Display, { type Error = TryIntoHeaderError; fn into_response_parts(self, mut res: ResponseParts) -> Result { for (key, value) in self { let key = key.try_into().map_err(TryIntoHeaderError::key)?; let value = value.try_into().map_err(TryIntoHeaderError::value)?; res.headers_mut().insert(key, value); } Ok(res) } } /// Error returned if converting a value to a header fails. #[derive(Debug)] pub struct TryIntoHeaderError { kind: TryIntoHeaderErrorKind, } impl TryIntoHeaderError { pub(super) fn key(err: K) -> Self { Self { kind: TryIntoHeaderErrorKind::Key(err), } } pub(super) fn value(err: V) -> Self { Self { kind: TryIntoHeaderErrorKind::Value(err), } } } #[derive(Debug)] enum TryIntoHeaderErrorKind { Key(K), Value(V), } impl IntoResponse for TryIntoHeaderError where K: fmt::Display, V: fmt::Display, { fn into_response(self) -> Response { match self.kind { TryIntoHeaderErrorKind::Key(inner) => { (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response() } TryIntoHeaderErrorKind::Value(inner) => { (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response() } } } } impl fmt::Display for TryIntoHeaderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.kind { TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"), TryIntoHeaderErrorKind::Value(_) => { write!(f, "failed to convert value to a header value") } } } } impl std::error::Error for TryIntoHeaderError where K: std::error::Error + 'static, V: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match &self.kind { TryIntoHeaderErrorKind::Key(inner) => Some(inner), TryIntoHeaderErrorKind::Value(inner) => Some(inner), } } } macro_rules! impl_into_response_parts { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] impl<$($ty,)*> IntoResponseParts for ($($ty,)*) where $( $ty: IntoResponseParts, )* { type Error = Response; fn into_response_parts(self, res: ResponseParts) -> Result { let ($($ty,)*) = self; $( let res = match $ty.into_response_parts(res) { Ok(res) => res, Err(err) => { return Err(err.into_response()); } }; )* Ok(res) } } } } all_the_tuples_no_last_special_case!(impl_into_response_parts); impl IntoResponseParts for Extensions { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { res.extensions_mut().extend(self); Ok(res) } } axum-axum-v0.6.20/axum-core/src/response/mod.rs000066400000000000000000000063411446276533600214050ustar00rootroot00000000000000//! Types and traits for generating responses. //! //! See [`axum::response`] for more details. //! //! [`axum::response`]: https://docs.rs/axum/latest/axum/response/index.html use crate::body::BoxBody; mod append_headers; mod into_response; mod into_response_parts; pub use self::{ append_headers::AppendHeaders, into_response::IntoResponse, into_response_parts::{IntoResponseParts, ResponseParts, TryIntoHeaderError}, }; /// Type alias for [`http::Response`] whose body type defaults to [`BoxBody`], the most common body /// type used with axum. pub type Response = http::Response; /// An [`IntoResponse`]-based result type that uses [`ErrorResponse`] as the error type. /// /// All types which implement [`IntoResponse`] can be converted to an [`ErrorResponse`]. This makes /// it useful as a general purpose error type for functions which combine multiple distinct error /// types that all implement [`IntoResponse`]. /// /// # Example /// /// ``` /// use axum::{ /// response::{IntoResponse, Response}, /// http::StatusCode, /// }; /// /// // two fallible functions with different error types /// fn try_something() -> Result<(), ErrorA> { /// // ... /// # unimplemented!() /// } /// /// fn try_something_else() -> Result<(), ErrorB> { /// // ... /// # unimplemented!() /// } /// /// // each error type implements `IntoResponse` /// struct ErrorA; /// /// impl IntoResponse for ErrorA { /// fn into_response(self) -> Response { /// // ... /// # unimplemented!() /// } /// } /// /// enum ErrorB { /// SomethingWentWrong, /// } /// /// impl IntoResponse for ErrorB { /// fn into_response(self) -> Response { /// // ... /// # unimplemented!() /// } /// } /// /// // we can combine them using `axum::response::Result` and still use `?` /// async fn handler() -> axum::response::Result<&'static str> { /// // the errors are automatically converted to `ErrorResponse` /// try_something()?; /// try_something_else()?; /// /// Ok("it worked!") /// } /// ``` /// /// # As a replacement for `std::result::Result` /// /// Since `axum::response::Result` has a default error type you only have to specify the `Ok` type: /// /// ``` /// use axum::{ /// response::{IntoResponse, Response, Result}, /// http::StatusCode, /// }; /// /// // `Result` automatically uses `ErrorResponse` as the error type. /// async fn handler() -> Result<&'static str> { /// try_something()?; /// /// Ok("it worked!") /// } /// /// // You can still specify the error even if you've imported `axum::response::Result` /// fn try_something() -> Result<(), StatusCode> { /// // ... /// # unimplemented!() /// } /// ``` pub type Result = std::result::Result; impl IntoResponse for Result where T: IntoResponse, { fn into_response(self) -> Response { match self { Ok(ok) => ok.into_response(), Err(err) => err.0, } } } /// An [`IntoResponse`]-based error type /// /// See [`Result`] for more details. #[derive(Debug)] pub struct ErrorResponse(Response); impl From for ErrorResponse where T: IntoResponse, { fn from(value: T) -> Self { Self(value.into_response()) } } axum-axum-v0.6.20/axum-extra/000077500000000000000000000000001446276533600160225ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/CHANGELOG.md000066400000000000000000000321461446276533600176410ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog], and this project adheres to [Semantic Versioning]. # Unreleased - None. # 0.7.7 (03. August, 2023) - **added:** `Clone` implementation for `ErasedJson` ([#2142]) [#2142]: https://github.com/tokio-rs/axum/pull/2142 # 0.7.6 (02. August, 2023) - **fixed:** Remove unused dependency ([#2135]) [#2135]: https://github.com/tokio-rs/axum/pull/2135 # 0.7.5 (17. July, 2023) - **fixed:** Remove explicit auto deref from `PrivateCookieJar` example ([#2028]) [#2028]: https://github.com/tokio-rs/axum/pull/2028 [#2028]: https://github.com/tokio-rs/axum/pull/2028 # 0.7.4 (18. April, 2023) - **added:** Add `Html` response type ([#1921]) - **added:** Add `Css` response type ([#1921]) - **added:** Add `JavaScript` response type ([#1921]) - **added:** Add `Wasm` response type ([#1921]) [#1921]: https://github.com/tokio-rs/axum/pull/1921 # 0.7.3 (11. April, 2023) - **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922]) - **added:** Add `OptionalPath` extractor ([#1889]) [#1889]: https://github.com/tokio-rs/axum/pull/1889 [#1922]: https://github.com/tokio-rs/axum/pull/1922 # 0.7.2 (22. March, 2023) - **added:** Implement `IntoResponse` for `MultipartError` ([#1861]) [#1861]: https://github.com/tokio-rs/axum/pull/1861 # 0.7.1 (13. March, 2023) - Updated to latest `axum-macros` # 0.7.0 (03. March, 2023) - **breaking:** Remove the `spa` feature which should have been removed in 0.6.0 ([#1802]) - **added:** Add `Multipart`. This is similar to `axum::extract::Multipart` except that it enforces field exclusivity at runtime instead of compile time, as this improves usability ([#1692]) - **added:** Implement `Clone` for `CookieJar`, `PrivateCookieJar` and `SignedCookieJar` ([#1808]) - **fixed:** Add `#[must_use]` attributes to types that do nothing unless used ([#1809]) [#1692]: https://github.com/tokio-rs/axum/pull/1692 [#1802]: https://github.com/tokio-rs/axum/pull/1802 [#1808]: https://github.com/tokio-rs/axum/pull/1808 [#1809]: https://github.com/tokio-rs/axum/pull/1809 # 0.6.0 (24. February, 2022) - **breaking:** Change casing of `ProtoBuf` to `Protobuf` ([#1595]) - **breaking:** `SpaRouter` has been removed. Use `ServeDir` and `ServeFile` from `tower-http` instead: ```rust // before Router::new().merge(SpaRouter::new("/assets", "dist")); // with ServeDir Router::new().nest_service("/assets", ServeDir::new("dist")); // before with `index_file` Router::new().merge(SpaRouter::new("/assets", "dist").index_file("index.html")); // with ServeDir + ServeFile Router::new().nest_service( "/assets", ServeDir::new("dist").not_found_service(ServeFile::new("dist/index.html")), ); ``` See the [static-file-server-example] for more examples ([#1784]) [#1595]: https://github.com/tokio-rs/axum/pull/1595 [#1784]: https://github.com/tokio-rs/axum/pull/1784 [static-file-server-example]: https://github.com/tokio-rs/axum/blob/main/examples/static-file-server/src/main.rs # 0.5.0 (12. February, 2022) - **added:** Add `option_layer` for converting an `Option` into a `Layer` ([#1696]) - **added:** Implement `Layer` and `Service` for `Either` ([#1696]) - **added:** Add `TypedPath::with_query_params` ([#1744]) - **breaking:** Update to [`cookie`] 0.17 ([#1747]) [#1696]: https://github.com/tokio-rs/axum/pull/1696 [#1744]: https://github.com/tokio-rs/axum/pull/1744 [#1747]: https://github.com/tokio-rs/axum/pull/1747 [`cookie`]: https://crates.io/crates/cookie # 0.4.2 (02. December, 2022) - **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]): - Redirects to the correct URI if the route contains path parameters - Keeps query parameters when redirecting - Better improved error message if adding route for `/` [#1608]: https://github.com/tokio-rs/axum/pull/1608 # 0.4.1 (29. November, 2022) - **fixed:** Fix wrong `From` impl for `Resource` ([#1589]) [#1589]: https://github.com/tokio-rs/axum/pull/1589 # 0.4.0 (25. November, 2022) - **added:** Add `RouterExt::route_with_tsr` for adding routes with an additional "trailing slash redirect" route ([#1119]) - **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170]) - **added:** Add Protocol Buffer extractor and response ([#1239]) - **added:** Add `Either*` types for combining extractors and responses into a single type ([#1263]) - **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262]) - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and `SignedCookieJar` so they're easier to use in custom middleware - **changed:** For methods that accept some `S: Service`, the bounds have been relaxed so the return type can be any type that implements `IntoResponse` rather than being a literal `Response` - **change:** axum-extra's MSRV is now 1.60 ([#1239]) - **breaking:** `Form` has a new rejection type ([#1496]) - **breaking:** `Query` has a new rejection type ([#1496]) - **breaking:** `Resource::nest` and `Resource::nest_collection` have been removed. You can instead convert the `Resource` into a `Router` and add additional routes as necessary ([#1086]) - **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys from the router's state, rather than extensions - **breaking:** `Resource` has a new `S` type param which represents the state ([#1155]) - **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155]) - **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155]) [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1170]: https://github.com/tokio-rs/axum/pull/1170 [#1214]: https://github.com/tokio-rs/axum/pull/1214 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1262]: https://github.com/tokio-rs/axum/pull/1262 [#1263]: https://github.com/tokio-rs/axum/pull/1263 [#1496]: https://github.com/tokio-rs/axum/pull/1496
0.4.0 Pre-Releases # 0.4.0-rc.3 (19. November, 2022) - **breaking:** Depend axum 0.6.0-rc.5 and axum-macros 0.3.0-rc.3 # 0.4.0-rc.2 (8. November, 2022) - **breaking:** `Form` has a new rejection type ([#1496]) - **breaking:** `Query` has a new rejection type ([#1496]) [#1496]: https://github.com/tokio-rs/axum/pull/1496 # 0.4.0-rc.1 (23. August, 2022) - **added:** Add `RouterExt::route_with_tsr` for adding routes with an additional "trailing slash redirect" route ([#1119]) - **breaking:** `Resource::nest` and `Resource::nest_collection` has been removed. You can instead convert the `Resource` into a `Router` and add additional routes as necessary ([#1086]) - **changed:** For methods that accept some `S: Service`, the bounds have been relaxed so the response type must implement `IntoResponse` rather than being a literal `Response` - **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170]) - **change:** axum-extra's MSRV is now 1.60 ([#1239]) - **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys from the router's state, rather than extensions - **added:** Add Protocol Buffer extractor and response ([#1239]) - **added:** Add `Either*` types for combining extractors and responses into a single type ([#1263]) - **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262]) - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and `SignedCookieJar` so they're easier to use in custom middleware - **breaking:** `Resource` has a new `S` type param which represents the state ([#1155]) - **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155]) - **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155]) [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1170]: https://github.com/tokio-rs/axum/pull/1170 [#1214]: https://github.com/tokio-rs/axum/pull/1214 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1262]: https://github.com/tokio-rs/axum/pull/1262 [#1263]: https://github.com/tokio-rs/axum/pull/1263
# 0.3.7 (09. August, 2022) - **fixed:** Depend on axum 0.5.15 which contains a fix for an accidental breaking change. # 0.3.6 (02. July, 2022) - **fixed:** Fix feature labels missing in generated docs ([#1137]) [#1137]: https://github.com/tokio-rs/axum/pull/1137 # 0.3.5 (27. June, 2022) - **added:** Add `JsonLines` for streaming newline delimited JSON ([#1093]) - **change:** axum-extra's MSRV is now 1.56 ([#1098]) [#1093]: https://github.com/tokio-rs/axum/pull/1093 [#1098]: https://github.com/tokio-rs/axum/pull/1098 # 0.3.4 (08. June, 2022) - **fixed:** Use `impl IntoResponse` less in docs ([#1049]) - **added:** Add `AsyncReadBody` for creating a body from a `tokio::io::AsyncRead` ([#1072]) [#1049]: https://github.com/tokio-rs/axum/pull/1049 [#1072]: https://github.com/tokio-rs/axum/pull/1072 # 0.3.3 (18. May, 2022) - **added:** Add `extract::Query` which supports multi-value items ([#1041]) - **added:** Support customizing rejections for `#[derive(TypedPath)]` ([#1012]) [#1041]: https://github.com/tokio-rs/axum/pull/1041 [#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.3.2 (15. May, 2022) - **added:** Add `extract::Form` which supports multi-value items ([#1031]) [#1031]: https://github.com/tokio-rs/axum/pull/1031 # 0.3.1 (10. May, 2022) - **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001]) - **fixed:** Support wildcards in typed paths ([#1003]) - **added:** Support using a custom rejection type for `#[derive(TypedPath)]` instead of `PathRejection` ([#1012]) [#1001]: https://github.com/tokio-rs/axum/pull/1001 [#1003]: https://github.com/tokio-rs/axum/pull/1003 [#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.3.0 (27. April, 2022) - **fixed:** Don't depend on axum with default features enabled ([#913]) - **breaking:** Private and signed cookies now requires enabling the `cookie-private` and `cookie-signed` features respectively ([#949]) - **changed:** Update to tower-http 0.3 ([#965]) [#913]: https://github.com/tokio-rs/axum/pull/913 [#949]: https://github.com/tokio-rs/axum/pull/949 [#965]: https://github.com/tokio-rs/axum/pull/965 # 0.2.1 (03. April, 2022) - **added:** Re-export `SameSite` and `Expiration` from the `cookie` crate ([#898]) - **added:** Add `PrivateCookieJar` for managing private cookies ([#900]) - **added:** Add `SpaRouter` for routing setups commonly used for single page applications ([#904]) - **fixed:** Fix `SignedCookieJar` when using custom key types ([#899]) [#898]: https://github.com/tokio-rs/axum/pull/898 [#899]: https://github.com/tokio-rs/axum/pull/899 [#900]: https://github.com/tokio-rs/axum/pull/900 [#904]: https://github.com/tokio-rs/axum/pull/904 # 0.2.0 (31. March, 2022) - **added:** Add `TypedPath::to_uri` for converting the path into a `Uri` ([#790]) - **added:** Extractors and responses for dealing with cookies. See `extract::cookies` for more details ([#816]) - **breaking:** `CachedRejection` has been removed ([#699]) - **breaking:** ` as FromRequest>::Rejection` is now `T::Rejection`. ([#699]) - **breaking:** `middleware::from_fn` has been remove from axum-extra and moved into the main axum crate ([#719]) - **breaking:** `HasRoutes` has been removed. `Router::merge` now accepts `Into` ([#819]) - **breaking:** `RouterExt::with` method has been removed. Use `Router::merge` instead. It works identically ([#819]) [#699]: https://github.com/tokio-rs/axum/pull/699 [#719]: https://github.com/tokio-rs/axum/pull/719 [#790]: https://github.com/tokio-rs/axum/pull/790 [#816]: https://github.com/tokio-rs/axum/pull/816 [#819]: https://github.com/tokio-rs/axum/pull/819 # 0.1.5 (1. March, 2022) - **added:** Add `TypedPath::to_uri` for converting the path into a `Uri` ([#790]) [#790]: https://github.com/tokio-rs/axum/pull/790 # 0.1.4 (22. February, 2022) - **fix:** Depend on the right versions of axum and axum-macros ([#782]) [#782]: https://github.com/tokio-rs/axum/pull/782 # 0.1.3 (22. February, 2022) - **added:** Add type safe routing. See `axum_extra::routing::typed` for more details ([#756]) - **fix:** Depend on tower with `default_features = false` ([#666]) - **change:** `middleware::from_fn` has been deprecated and moved into the main axum crate ([#719]) [#666]: https://github.com/tokio-rs/axum/pull/666 [#719]: https://github.com/tokio-rs/axum/pull/719 [#756]: https://github.com/tokio-rs/axum/pull/756 # 0.1.2 (13. January, 2022) - **fix:** Depend on tower with `default_features = false` ([#666]) # 0.1.1 (27. December, 2021) - Add `middleware::from_fn` for creating middleware from async functions ([#656]) - Add support for returning pretty JSON response in `response::ErasedJson` ([#662]) [#656]: https://github.com/tokio-rs/axum/pull/656 [#662]: https://github.com/tokio-rs/axum/pull/662 # 0.1.0 (02. December, 2021) - Initial release. [Keep a Changelog]: https://keepachangelog.com/en/1.0.0/ [Semantic Versioning]: https://semver.org/spec/v2.0.0.html axum-axum-v0.6.20/axum-extra/Cargo.toml000066400000000000000000000057421446276533600177620ustar00rootroot00000000000000[package] categories = ["asynchronous", "network-programming", "web-programming"] description = "Extra utilities for axum" edition = "2021" rust-version = "1.63" homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-extra" readme = "README.md" repository = "https://github.com/tokio-rs/axum" version = "0.7.7" [features] default = [] async-read-body = ["dep:tokio-util", "tokio-util?/io"] cookie = ["dep:cookie"] cookie-private = ["cookie", "cookie?/private"] cookie-signed = ["cookie", "cookie?/signed"] cookie-key-expansion = ["cookie", "cookie?/key-expansion"] erased-json = ["dep:serde_json"] form = ["dep:serde_html_form"] json-lines = [ "dep:serde_json", "dep:tokio-util", "dep:tokio-stream", "tokio-util?/io", "tokio-stream?/io-util" ] multipart = ["dep:multer"] protobuf = ["dep:prost"] query = ["dep:serde_html_form"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] [dependencies] axum = { path = "../axum", version = "0.6.13", default-features = false } axum-core = { path = "../axum-core", version = "0.3.4" } bytes = "1.1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2" http-body = "0.4.4" mime = "0.3" pin-project-lite = "0.2" serde = "1.0" tokio = "1.19" tower = { version = "0.4", default_features = false, features = ["util"] } tower-layer = "0.3" tower-service = "0.3" # optional dependencies axum-macros = { path = "../axum-macros", version = "0.3.8", optional = true } cookie = { package = "cookie", version = "0.17", features = ["percent-encode"], optional = true } form_urlencoded = { version = "1.1.0", optional = true } multer = { version = "2.0.0", optional = true } percent-encoding = { version = "2.1", optional = true } prost = { version = "0.11", optional = true } serde_html_form = { version = "0.2.0", optional = true } serde_json = { version = "1.0.71", optional = true } tokio-stream = { version = "0.1.9", optional = true } tokio-util = { version = "0.7", optional = true } [dev-dependencies] axum = { path = "../axum", version = "0.6.0", features = ["headers"] } axum-macros = { path = "../axum-macros", version = "0.3.8", features = ["__private"] } http-body = "0.4.4" hyper = "0.14" reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.71" tokio = { version = "1.14", features = ["full"] } tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.4", features = ["map-response-body", "timeout"] } [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [package.metadata.cargo-public-api-crates] allowed = [ "axum", "axum_core", "axum_macros", "bytes", "cookie", "futures_core", "futures_util", "http", "http_body", "hyper", "prost", "serde", "tokio", "tower_layer", "tower_service", ] axum-axum-v0.6.20/axum-extra/LICENSE000066400000000000000000000020411446276533600170240ustar00rootroot00000000000000Copyright 2021 Axum Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. axum-axum-v0.6.20/axum-extra/README.md000066400000000000000000000026511446276533600173050ustar00rootroot00000000000000# axum-extra [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-extra/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum-extra)](https://crates.io/crates/axum-extra) [![Documentation](https://docs.rs/axum-extra/badge.svg)](https://docs.rs/axum-extra) Extra utilities for [`axum`]. More information about this crate can be found in the [crate documentation][docs]. ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. ## Minimum supported Rust version axum-extra's MSRV is 1.63. ## Getting Help You're also welcome to ask in the [Discord channel][chat] or open an [issue] with your question. ## Contributing ๐ŸŽˆ Thanks for your help improving the project! We are so happy to have you! We have a [contributing guide][contributing] to help you get involved in the `axum` project. ## License This project is licensed under the [MIT license][license]. ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in `axum` by you, shall be licensed as MIT, without any additional terms or conditions. [`axum`]: https://crates.io/crates/axum [chat]: https://discord.gg/tokio [contributing]: /CONTRIBUTING.md [docs]: https://docs.rs/axum-extra [license]: /axum-extra/LICENSE [issue]: https://github.com/tokio-rs/axum/issues/new axum-axum-v0.6.20/axum-extra/src/000077500000000000000000000000001446276533600166115ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/body/000077500000000000000000000000001446276533600175465ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/body/async_read_body.rs000066400000000000000000000045511446276533600232460ustar00rootroot00000000000000use axum::{ body::{self, Bytes, HttpBody, StreamBody}, http::HeaderMap, response::{IntoResponse, Response}, Error, }; use pin_project_lite::pin_project; use std::{ pin::Pin, task::{Context, Poll}, }; use tokio::io::AsyncRead; use tokio_util::io::ReaderStream; pin_project! { /// An [`HttpBody`] created from an [`AsyncRead`]. /// /// # Example /// /// `AsyncReadBody` can be used to stream the contents of a file: /// /// ```rust /// use axum::{ /// Router, /// routing::get, /// http::{StatusCode, header::CONTENT_TYPE}, /// response::{Response, IntoResponse}, /// }; /// use axum_extra::body::AsyncReadBody; /// use tokio::fs::File; /// /// async fn cargo_toml() -> Result { /// let file = File::open("Cargo.toml") /// .await /// .map_err(|err| { /// (StatusCode::NOT_FOUND, format!("File not found: {}", err)) /// })?; /// /// let headers = [(CONTENT_TYPE, "text/x-toml")]; /// let body = AsyncReadBody::new(file); /// Ok((headers, body).into_response()) /// } /// /// let app = Router::new().route("/Cargo.toml", get(cargo_toml)); /// # let _: Router = app; /// ``` #[cfg(feature = "async-read-body")] #[derive(Debug)] #[must_use] pub struct AsyncReadBody { #[pin] read: StreamBody>, } } impl AsyncReadBody { /// Create a new `AsyncReadBody`. pub fn new(read: R) -> Self where R: AsyncRead + Send + 'static, { Self { read: StreamBody::new(ReaderStream::new(read)), } } } impl HttpBody for AsyncReadBody where R: AsyncRead + Send + 'static, { type Data = Bytes; type Error = Error; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { self.project().read.poll_data(cx) } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } } impl IntoResponse for AsyncReadBody where R: AsyncRead + Send + 'static, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } axum-axum-v0.6.20/axum-extra/src/body/mod.rs000066400000000000000000000002441446276533600206730ustar00rootroot00000000000000//! Additional bodies. #[cfg(feature = "async-read-body")] mod async_read_body; #[cfg(feature = "async-read-body")] pub use self::async_read_body::AsyncReadBody; axum-axum-v0.6.20/axum-extra/src/either.rs000077500000000000000000000204311446276533600204420ustar00rootroot00000000000000//! `Either*` types for combining extractors or responses into a single type. //! //! # As an extractor //! //! ``` //! use axum_extra::either::Either3; //! use axum::{ //! body::Bytes, //! Router, //! async_trait, //! routing::get, //! extract::FromRequestParts, //! }; //! //! // extractors for checking permissions //! struct AdminPermissions {} //! //! #[async_trait] //! impl FromRequestParts for AdminPermissions //! where //! S: Send + Sync, //! { //! // check for admin permissions... //! # type Rejection = (); //! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { //! # todo!() //! # } //! } //! //! struct User {} //! //! #[async_trait] //! impl FromRequestParts for User //! where //! S: Send + Sync, //! { //! // check for a logged in user... //! # type Rejection = (); //! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { //! # todo!() //! # } //! } //! //! async fn handler( //! body: Either3, //! ) { //! match body { //! Either3::E1(admin) => { /* ... */ } //! Either3::E2(user) => { /* ... */ } //! Either3::E3(guest) => { /* ... */ } //! } //! } //! # //! # let _: axum::routing::MethodRouter = axum::routing::get(handler); //! ``` //! //! Note that if all the inner extractors reject the request, the rejection from the last //! extractor will be returned. For the example above that would be [`BytesRejection`]. //! //! # As a response //! //! ``` //! use axum_extra::either::Either3; //! use axum::{Json, http::StatusCode, response::IntoResponse}; //! use serde_json::{Value, json}; //! //! async fn handler() -> Either3, &'static str, StatusCode> { //! if something() { //! Either3::E1(Json(json!({ "data": "..." }))) //! } else if something_else() { //! Either3::E2("foobar") //! } else { //! Either3::E3(StatusCode::NOT_FOUND) //! } //! } //! //! fn something() -> bool { //! // ... //! # false //! } //! //! fn something_else() -> bool { //! // ... //! # false //! } //! # //! # let _: axum::routing::MethodRouter = axum::routing::get(handler); //! ``` //! //! The general recommendation is to use [`IntoResponse::into_response`] to return different response //! types, but if you need to preserve the exact type then `Either*` works as well. //! //! [`BytesRejection`]: axum::extract::rejection::BytesRejection //! [`IntoResponse::into_response`]: https://docs.rs/axum/0.5/axum/response/index.html#returning-different-response-types use std::task::{Context, Poll}; use axum::{ async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, }; use http::request::Parts; use tower_layer::Layer; use tower_service::Service; /// Combines two extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), } /// Combines three extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either3 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), } /// Combines four extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either4 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), #[allow(missing_docs)] E4(E4), } /// Combines five extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either5 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), #[allow(missing_docs)] E4(E4), #[allow(missing_docs)] E5(E5), } /// Combines six extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either6 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), #[allow(missing_docs)] E4(E4), #[allow(missing_docs)] E5(E5), #[allow(missing_docs)] E6(E6), } /// Combines seven extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either7 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), #[allow(missing_docs)] E4(E4), #[allow(missing_docs)] E5(E5), #[allow(missing_docs)] E6(E6), #[allow(missing_docs)] E7(E7), } /// Combines eight extractors or responses into a single type. /// /// See the [module docs](self) for examples. #[derive(Debug, Clone)] #[must_use] pub enum Either8 { #[allow(missing_docs)] E1(E1), #[allow(missing_docs)] E2(E2), #[allow(missing_docs)] E3(E3), #[allow(missing_docs)] E4(E4), #[allow(missing_docs)] E5(E5), #[allow(missing_docs)] E6(E6), #[allow(missing_docs)] E7(E7), #[allow(missing_docs)] E8(E8), } macro_rules! impl_traits_for_either { ( $either:ident => [$($ident:ident),* $(,)?], $last:ident $(,)? ) => { #[async_trait] impl FromRequestParts for $either<$($ident),*, $last> where $($ident: FromRequestParts),*, $last: FromRequestParts, S: Send + Sync, { type Rejection = $last::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { $( if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await { return Ok(Self::$ident(value)); } )* FromRequestParts::from_request_parts(parts, state).await.map(Self::$last) } } impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last> where $($ident: IntoResponse),*, $last: IntoResponse, { fn into_response(self) -> Response { match self { $( Self::$ident(value) => value.into_response(), )* Self::$last(value) => value.into_response(), } } } }; } impl_traits_for_either!(Either => [E1], E2); impl_traits_for_either!(Either3 => [E1, E2], E3); impl_traits_for_either!(Either4 => [E1, E2, E3], E4); impl_traits_for_either!(Either5 => [E1, E2, E3, E4], E5); impl_traits_for_either!(Either6 => [E1, E2, E3, E4, E5], E6); impl_traits_for_either!(Either7 => [E1, E2, E3, E4, E5, E6], E7); impl_traits_for_either!(Either8 => [E1, E2, E3, E4, E5, E6, E7], E8); impl Layer for Either where E1: Layer, E2: Layer, { type Service = Either; fn layer(&self, inner: S) -> Self::Service { match self { Either::E1(layer) => Either::E1(layer.layer(inner)), Either::E2(layer) => Either::E2(layer.layer(inner)), } } } impl Service for Either where E1: Service, E2: Service, { type Response = E1::Response; type Error = E1::Error; type Future = futures_util::future::Either; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self { Either::E1(inner) => inner.poll_ready(cx), Either::E2(inner) => inner.poll_ready(cx), } } fn call(&mut self, req: R) -> Self::Future { match self { Either::E1(inner) => futures_util::future::Either::Left(inner.call(req)), Either::E2(inner) => futures_util::future::Either::Right(inner.call(req)), } } } axum-axum-v0.6.20/axum-extra/src/extract/000077500000000000000000000000001446276533600202635ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/extract/cached.rs000066400000000000000000000114751446276533600220500ustar00rootroot00000000000000use axum::{ async_trait, extract::{Extension, FromRequestParts}, }; use http::request::Parts; /// Cache results of other extractors. /// /// `Cached` wraps another extractor and caches its result in [request extensions]. /// /// This is useful if you have a tree of extractors that share common sub-extractors that /// you only want to run once, perhaps because they're expensive. /// /// The cache purely type based so you can only cache one value of each type. The cache is also /// local to the current request and not reused across requests. /// /// # Example /// /// ```rust /// use axum_extra::extract::Cached; /// use axum::{ /// async_trait, /// extract::FromRequestParts, /// body::BoxBody, /// response::{IntoResponse, Response}, /// http::{StatusCode, request::Parts}, /// }; /// /// #[derive(Clone)] /// struct Session { /* ... */ } /// /// #[async_trait] /// impl FromRequestParts for Session /// where /// S: Send + Sync, /// { /// type Rejection = (StatusCode, String); /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // load session... /// # unimplemented!() /// } /// } /// /// struct CurrentUser { /* ... */ } /// /// #[async_trait] /// impl FromRequestParts for CurrentUser /// where /// S: Send + Sync, /// { /// type Rejection = Response; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // loading a `CurrentUser` requires first loading the `Session` /// // /// // by using `Cached` we avoid extracting the session more than /// // once, in case other extractors for the same request also loads the session /// let session: Session = Cached::::from_request_parts(parts, state) /// .await /// .map_err(|err| err.into_response())? /// .0; /// /// // load user from session... /// # unimplemented!() /// } /// } /// /// // handler that extracts the current user and the session /// // /// // the session will only be loaded once, even though `CurrentUser` /// // also loads it /// async fn handler( /// current_user: CurrentUser, /// // we have to use `Cached` here otherwise the /// // cached session would not be used /// Cached(session): Cached, /// ) { /// // ... /// } /// ``` /// /// [request extensions]: http::Extensions #[derive(Debug, Clone, Default)] pub struct Cached(pub T); #[derive(Clone)] struct CachedEntry(T); #[async_trait] impl FromRequestParts for Cached where S: Send + Sync, T: FromRequestParts + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { match Extension::>::from_request_parts(parts, state).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { let value = T::from_request_parts(parts, state).await?; parts.extensions.insert(CachedEntry(value.clone())); Ok(Self(value)) } } } } axum_core::__impl_deref!(Cached); #[cfg(test)] mod tests { use super::*; use axum::{extract::FromRequestParts, http::Request, routing::get, Router}; use http::request::Parts; use std::{ convert::Infallible, sync::atomic::{AtomicU32, Ordering}, time::Instant, }; #[tokio::test] async fn works() { static COUNTER: AtomicU32 = AtomicU32::new(0); #[derive(Clone, Debug, PartialEq, Eq)] struct Extractor(Instant); #[async_trait] impl FromRequestParts for Extractor where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( _parts: &mut Parts, _state: &S, ) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } } let (mut parts, _) = Request::new(()).into_parts(); let first = Cached::::from_request_parts(&mut parts, &()) .await .unwrap() .0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); let second = Cached::::from_request_parts(&mut parts, &()) .await .unwrap() .0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); assert_eq!(first, second); } // Not a #[test], we just want to know this compiles async fn _last_handler_argument() { async fn handler(_: http::Method, _: Cached) {} let _r: Router = Router::new().route("/", get(handler)); } } axum-axum-v0.6.20/axum-extra/src/extract/cookie/000077500000000000000000000000001446276533600215345ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/extract/cookie/mod.rs000066400000000000000000000262441446276533600226710ustar00rootroot00000000000000//! Cookie parsing and cookie jar management. //! //! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details. use axum::{ async_trait, extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use http::{ header::{COOKIE, SET_COOKIE}, request::Parts, HeaderMap, }; use std::convert::Infallible; #[cfg(feature = "cookie-private")] mod private; #[cfg(feature = "cookie-signed")] mod signed; #[cfg(feature = "cookie-private")] pub use self::private::PrivateCookieJar; #[cfg(feature = "cookie-signed")] pub use self::signed::SignedCookieJar; pub use cookie::{Cookie, Expiration, SameSite}; #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] pub use cookie::Key; /// Extractor that grabs cookies from the request and manages the jar. /// /// Note that methods like [`CookieJar::add`], [`CookieJar::remove`], etc updates the [`CookieJar`] /// and returns it. This value _must_ be returned from the handler as part of the response for the /// changes to be propagated. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// routing::{post, get}, /// extract::TypedHeader, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, /// }; /// use axum_extra::extract::cookie::{CookieJar, Cookie}; /// /// async fn create_session( /// TypedHeader(auth): TypedHeader>, /// jar: CookieJar, /// ) -> Result<(CookieJar, Redirect), StatusCode> { /// if let Some(session_id) = authorize_and_create_session(auth.token()).await { /// Ok(( /// // the updated jar must be returned for the changes /// // to be included in the response /// jar.add(Cookie::new("session_id", session_id)), /// Redirect::to("/me"), /// )) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// async fn me(jar: CookieJar) -> Result<(), StatusCode> { /// if let Some(session_id) = jar.get("session_id") { /// // fetch and render user... /// # Ok(()) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// async fn authorize_and_create_session(token: &str) -> Option { /// // authorize the user and create a session... /// # todo!() /// } /// /// let app = Router::new() /// .route("/sessions", post(create_session)) /// .route("/me", get(me)); /// # let app: Router = app; /// ``` #[derive(Debug, Default, Clone)] pub struct CookieJar { jar: cookie::CookieJar, } #[async_trait] impl FromRequestParts for CookieJar where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { Ok(Self::from_headers(&parts.headers)) } } fn cookies_from_request(headers: &HeaderMap) -> impl Iterator> + '_ { headers .get_all(COOKIE) .into_iter() .filter_map(|value| value.to_str().ok()) .flat_map(|value| value.split(';')) .filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok()) } impl CookieJar { /// Create a new `CookieJar` from a map of request headers. /// /// The cookies in `headers` will be added to the jar. /// /// This is intended to be used in middleware and other places where it might be difficult to /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap) -> Self { let mut jar = cookie::CookieJar::new(); for cookie in cookies_from_request(headers) { jar.add_original(cookie); } Self { jar } } /// Create a new empty `CookieJar`. /// /// This is inteded to be used in middleware and other places where it might be difficult to /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`]. /// /// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for /// CookieJar`. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new() -> Self { Self::default() } /// Get a cookie from the jar. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::CookieJar; /// use axum::response::IntoResponse; /// /// async fn handle(jar: CookieJar) { /// let value: Option = jar /// .get("foo") /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` pub fn get(&self, name: &str) -> Option<&Cookie<'static>> { self.jar.get(name) } /// Remove a cookie from the jar. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{CookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: CookieJar) -> CookieJar { /// jar.remove(Cookie::named("foo")) /// } /// ``` #[must_use] pub fn remove(mut self, cookie: Cookie<'static>) -> Self { self.jar.remove(cookie); self } /// Add a cookie to the jar. /// /// The value will automatically be percent-encoded. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{CookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: CookieJar) -> CookieJar { /// jar.add(Cookie::new("foo", "bar")) /// } /// ``` #[must_use] #[allow(clippy::should_implement_trait)] pub fn add(mut self, cookie: Cookie<'static>) -> Self { self.jar.add(cookie); self } /// Get an iterator over all cookies in the jar. pub fn iter(&self) -> impl Iterator> { self.jar.iter() } } impl IntoResponseParts for CookieJar { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { set_cookies(self.jar, res.headers_mut()); Ok(res) } } impl IntoResponse for CookieJar { fn into_response(self) -> Response { (self, ()).into_response() } } fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) { for cookie in jar.delta() { if let Ok(header_value) = cookie.encoded().to_string().parse() { headers.append(SET_COOKIE, header_value); } } // we don't need to call `jar.reset_delta()` because `into_response_parts` consumes the cookie // jar so it cannot be called multiple times. } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router}; use tower::ServiceExt; macro_rules! cookie_test { ($name:ident, $jar:ty) => { #[tokio::test] async fn $name() { async fn set_cookie(jar: $jar) -> impl IntoResponse { jar.add(Cookie::new("key", "value")) } async fn get_cookie(jar: $jar) -> impl IntoResponse { jar.get("key").unwrap().value().to_owned() } async fn remove_cookie(jar: $jar) -> impl IntoResponse { jar.remove(Cookie::named("key")) } let state = AppState { key: Key::generate(), custom_key: CustomKey(Key::generate()), }; let app = Router::<_, Body>::new() .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) .route("/remove", get(remove_cookie)) .with_state(state); let res = app .clone() .oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap()) .await .unwrap(); let cookie_value = res.headers()["set-cookie"].to_str().unwrap(); let res = app .clone() .oneshot( Request::builder() .uri("/get") .header("cookie", cookie_value) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let body = body_text(res).await; assert_eq!(body, "value"); let res = app .clone() .oneshot( Request::builder() .uri("/remove") .header("cookie", cookie_value) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert!(res.headers()["set-cookie"] .to_str() .unwrap() .contains("key=;")); } }; } cookie_test!(plaintext_cookies, CookieJar); #[cfg(feature = "cookie-signed")] cookie_test!(signed_cookies, SignedCookieJar); #[cfg(feature = "cookie-signed")] cookie_test!(signed_cookies_with_custom_key, SignedCookieJar); #[cfg(feature = "cookie-private")] cookie_test!(private_cookies, PrivateCookieJar); #[cfg(feature = "cookie-private")] cookie_test!(private_cookies_with_custom_key, PrivateCookieJar); #[derive(Clone)] struct AppState { key: Key, custom_key: CustomKey, } impl FromRef for Key { fn from_ref(state: &AppState) -> Key { state.key.clone() } } impl FromRef for CustomKey { fn from_ref(state: &AppState) -> CustomKey { state.custom_key.clone() } } #[derive(Clone)] struct CustomKey(Key); impl From for Key { fn from(custom: CustomKey) -> Self { custom.0 } } #[cfg(feature = "cookie-signed")] #[tokio::test] async fn signed_cannot_access_invalid_cookies() { async fn get_cookie(jar: SignedCookieJar) -> impl IntoResponse { format!("{:?}", jar.get("key")) } let state = AppState { key: Key::generate(), custom_key: CustomKey(Key::generate()), }; let app = Router::<_, Body>::new() .route("/get", get(get_cookie)) .with_state(state); let res = app .clone() .oneshot( Request::builder() .uri("/get") .header("cookie", "key=value") .body(Body::empty()) .unwrap(), ) .await .unwrap(); let body = body_text(res).await; assert_eq!(body, "None"); } async fn body_text(body: B) -> String where B: axum::body::HttpBody, B::Error: std::fmt::Debug, { let bytes = hyper::body::to_bytes(body).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } } axum-axum-v0.6.20/axum-extra/src/extract/cookie/private.rs000066400000000000000000000215271446276533600235630ustar00rootroot00000000000000use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::PrivateJar; use http::{request::Parts, HeaderMap}; use std::{convert::Infallible, fmt, marker::PhantomData}; /// Extractor that grabs private cookies from the request and manages the jar. /// /// All cookies will be private and encrypted with a [`Key`]. This makes it suitable for storing /// private data. /// /// Note that methods like [`PrivateCookieJar::add`], [`PrivateCookieJar::remove`], etc updates the /// [`PrivateCookieJar`] and returns it. This value _must_ be returned from the handler as part of /// the response for the changes to be propagated. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// routing::{post, get}, /// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, /// }; /// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key}; /// /// async fn set_secret( /// jar: PrivateCookieJar, /// ) -> (PrivateCookieJar, Redirect) { /// let updated_jar = jar.add(Cookie::new("secret", "secret-data")); /// (updated_jar, Redirect::to("/get")) /// } /// /// async fn get_secret(jar: PrivateCookieJar) { /// if let Some(data) = jar.get("secret") { /// // ... /// } /// } /// /// // our application state /// #[derive(Clone)] /// struct AppState { /// // that holds the key used to sign cookies /// key: Key, /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state /// impl FromRef for Key { /// fn from_ref(state: &AppState) -> Self { /// state.key.clone() /// } /// } /// /// let state = AppState { /// // Generate a secure key /// // /// // You probably don't wanna generate a new one each time the app starts though /// key: Key::generate(), /// }; /// /// let app = Router::new() /// .route("/set", post(set_secret)) /// .route("/get", get(get_secret)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` /// /// If you have been using `Arc` you cannot implement `FromRef> for Key`. /// You can use a new type instead: /// /// ```rust /// # use axum::extract::FromRef; /// # use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key}; /// use std::sync::Arc; /// use std::ops::Deref; /// /// #[derive(Clone)] /// struct AppState(Arc); /// /// // deref so you can still access the inner fields easily /// impl Deref for AppState { /// type Target = InnerState; /// /// fn deref(&self) -> &Self::Target { /// &self.0 /// } /// } /// /// struct InnerState { /// key: Key /// } /// /// impl FromRef for Key { /// fn from_ref(state: &AppState) -> Self { /// state.0.key.clone() /// } /// } /// ``` pub struct PrivateCookieJar { jar: cookie::CookieJar, key: Key, // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } impl fmt::Debug for PrivateCookieJar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PrivateCookieJar") .field("jar", &self.jar) .field("key", &"REDACTED") .finish() } } #[async_trait] impl FromRequestParts for PrivateCookieJar where S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let k = K::from_ref(state); let key = k.into(); let PrivateCookieJar { jar, key, _marker: _, } = PrivateCookieJar::from_headers(&parts.headers, key); Ok(PrivateCookieJar { jar, key, _marker: PhantomData, }) } } impl PrivateCookieJar { /// Create a new `PrivateCookieJar` from a map of request headers. /// /// The valid cookies in `headers` will be added to the jar. /// /// This is intended to be used in middleware and other where places it might be difficult to /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut private_jar = jar.private_mut(&key); for cookie in cookies_from_request(headers) { if let Some(cookie) = private_jar.decrypt(cookie) { private_jar.add_original(cookie); } } Self { jar, key, _marker: PhantomData, } } /// Create a new empty `PrivateCookieJarIter`. /// /// This is inteded to be used in middleware and other places where it might be difficult to /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new(key: Key) -> Self { Self { jar: Default::default(), key, _marker: PhantomData, } } } impl PrivateCookieJar { /// Get a cookie from the jar. /// /// If the cookie exists and can be decrypted then it is returned in plaintext. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::PrivateCookieJar; /// use axum::response::IntoResponse; /// /// async fn handle(jar: PrivateCookieJar) { /// let value: Option = jar /// .get("foo") /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` pub fn get(&self, name: &str) -> Option> { self.private_jar().get(name) } /// Remove a cookie from the jar. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: PrivateCookieJar) -> PrivateCookieJar { /// jar.remove(Cookie::named("foo")) /// } /// ``` #[must_use] pub fn remove(mut self, cookie: Cookie<'static>) -> Self { self.private_jar_mut().remove(cookie); self } /// Add a cookie to the jar. /// /// The value will automatically be percent-encoded. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: PrivateCookieJar) -> PrivateCookieJar { /// jar.add(Cookie::new("foo", "bar")) /// } /// ``` #[must_use] #[allow(clippy::should_implement_trait)] pub fn add(mut self, cookie: Cookie<'static>) -> Self { self.private_jar_mut().add(cookie); self } /// Authenticates and decrypts `cookie`, returning the plaintext version if decryption succeeds /// or `None` otherwise. pub fn decrypt(&self, cookie: Cookie<'static>) -> Option> { self.private_jar().decrypt(cookie) } /// Get an iterator over all cookies in the jar. /// /// Only cookies with valid authenticity and integrity are yielded by the iterator. pub fn iter(&self) -> impl Iterator> + '_ { PrivateCookieJarIter { jar: self, iter: self.jar.iter(), } } fn private_jar(&self) -> PrivateJar<&'_ cookie::CookieJar> { self.jar.private(&self.key) } fn private_jar_mut(&mut self) -> PrivateJar<&'_ mut cookie::CookieJar> { self.jar.private_mut(&self.key) } } impl IntoResponseParts for PrivateCookieJar { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { set_cookies(self.jar, res.headers_mut()); Ok(res) } } impl IntoResponse for PrivateCookieJar { fn into_response(self) -> Response { (self, ()).into_response() } } struct PrivateCookieJarIter<'a, K> { jar: &'a PrivateCookieJar, iter: cookie::Iter<'a>, } impl<'a, K> Iterator for PrivateCookieJarIter<'a, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { loop { let cookie = self.iter.next()?; if let Some(cookie) = self.jar.get(cookie.name()) { return Some(cookie); } } } } impl Clone for PrivateCookieJar { fn clone(&self) -> Self { Self { jar: self.jar.clone(), key: self.key.clone(), _marker: self._marker, } } } axum-axum-v0.6.20/axum-extra/src/extract/cookie/signed.rs000066400000000000000000000230321446276533600233530ustar00rootroot00000000000000use super::{cookies_from_request, set_cookies}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::SignedJar; use cookie::{Cookie, Key}; use http::{request::Parts, HeaderMap}; use std::{convert::Infallible, fmt, marker::PhantomData}; /// Extractor that grabs signed cookies from the request and manages the jar. /// /// All cookies will be signed and verified with a [`Key`]. Do not use this to store private data /// as the values are still transmitted in plaintext. /// /// Note that methods like [`SignedCookieJar::add`], [`SignedCookieJar::remove`], etc updates the /// [`SignedCookieJar`] and returns it. This value _must_ be returned from the handler as part of /// the response for the changes to be propagated. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// routing::{post, get}, /// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, /// }; /// use axum_extra::extract::cookie::{SignedCookieJar, Cookie, Key}; /// /// async fn create_session( /// TypedHeader(auth): TypedHeader>, /// jar: SignedCookieJar, /// ) -> Result<(SignedCookieJar, Redirect), StatusCode> { /// if let Some(session_id) = authorize_and_create_session(auth.token()).await { /// Ok(( /// // the updated jar must be returned for the changes /// // to be included in the response /// jar.add(Cookie::new("session_id", session_id)), /// Redirect::to("/me"), /// )) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// async fn me(jar: SignedCookieJar) -> Result<(), StatusCode> { /// if let Some(session_id) = jar.get("session_id") { /// // fetch and render user... /// # Ok(()) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// async fn authorize_and_create_session(token: &str) -> Option { /// // authorize the user and create a session... /// # todo!() /// } /// /// // our application state /// #[derive(Clone)] /// struct AppState { /// // that holds the key used to sign cookies /// key: Key, /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state /// impl FromRef for Key { /// fn from_ref(state: &AppState) -> Self { /// state.key.clone() /// } /// } /// /// let state = AppState { /// // Generate a secure key /// // /// // You probably don't wanna generate a new one each time the app starts though /// key: Key::generate(), /// }; /// /// let app = Router::new() /// .route("/sessions", post(create_session)) /// .route("/me", get(me)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` /// If you have been using `Arc` you cannot implement `FromRef> for Key`. /// You can use a new type instead: /// /// ```rust /// # use axum::extract::FromRef; /// # use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key}; /// use std::sync::Arc; /// use std::ops::Deref; /// /// #[derive(Clone)] /// struct AppState(Arc); /// /// // deref so you can still access the inner fields easily /// impl Deref for AppState { /// type Target = InnerState; /// /// fn deref(&self) -> &Self::Target { /// &*self.0 /// } /// } /// /// struct InnerState { /// key: Key /// } /// /// impl FromRef for Key { /// fn from_ref(state: &AppState) -> Self { /// state.0.key.clone() /// } /// } /// ``` pub struct SignedCookieJar { jar: cookie::CookieJar, key: Key, // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } impl fmt::Debug for SignedCookieJar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SignedCookieJar") .field("jar", &self.jar) .field("key", &"REDACTED") .finish() } } #[async_trait] impl FromRequestParts for SignedCookieJar where S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let k = K::from_ref(state); let key = k.into(); let SignedCookieJar { jar, key, _marker: _, } = SignedCookieJar::from_headers(&parts.headers, key); Ok(SignedCookieJar { jar, key, _marker: PhantomData, }) } } impl SignedCookieJar { /// Create a new `SignedCookieJar` from a map of request headers. /// /// The valid cookies in `headers` will be added to the jar. /// /// This is intended to be used in middleware and other places where it might be difficult to /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { let mut jar = cookie::CookieJar::new(); let mut signed_jar = jar.signed_mut(&key); for cookie in cookies_from_request(headers) { if let Some(cookie) = signed_jar.verify(cookie) { signed_jar.add_original(cookie); } } Self { jar, key, _marker: PhantomData, } } /// Create a new empty `SignedCookieJar`. /// /// This is inteded to be used in middleware and other places where it might be difficult to /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`]. /// /// [`FromRequestParts`]: axum::extract::FromRequestParts pub fn new(key: Key) -> Self { Self { jar: Default::default(), key, _marker: PhantomData, } } } impl SignedCookieJar { /// Get a cookie from the jar. /// /// If the cookie exists and its authenticity and integrity can be verified then it is returned /// in plaintext. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::SignedCookieJar; /// use axum::response::IntoResponse; /// /// async fn handle(jar: SignedCookieJar) { /// let value: Option = jar /// .get("foo") /// .map(|cookie| cookie.value().to_owned()); /// } /// ``` pub fn get(&self, name: &str) -> Option> { self.signed_jar().get(name) } /// Remove a cookie from the jar. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{SignedCookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: SignedCookieJar) -> SignedCookieJar { /// jar.remove(Cookie::named("foo")) /// } /// ``` #[must_use] pub fn remove(mut self, cookie: Cookie<'static>) -> Self { self.signed_jar_mut().remove(cookie); self } /// Add a cookie to the jar. /// /// The value will automatically be percent-encoded. /// /// # Example /// /// ```rust /// use axum_extra::extract::cookie::{SignedCookieJar, Cookie}; /// use axum::response::IntoResponse; /// /// async fn handle(jar: SignedCookieJar) -> SignedCookieJar { /// jar.add(Cookie::new("foo", "bar")) /// } /// ``` #[must_use] #[allow(clippy::should_implement_trait)] pub fn add(mut self, cookie: Cookie<'static>) -> Self { self.signed_jar_mut().add(cookie); self } /// Verifies the authenticity and integrity of `cookie`, returning the plaintext version if /// verification succeeds or `None` otherwise. pub fn verify(&self, cookie: Cookie<'static>) -> Option> { self.signed_jar().verify(cookie) } /// Get an iterator over all cookies in the jar. /// /// Only cookies with valid authenticity and integrity are yielded by the iterator. pub fn iter(&self) -> impl Iterator> + '_ { SignedCookieJarIter { jar: self, iter: self.jar.iter(), } } fn signed_jar(&self) -> SignedJar<&'_ cookie::CookieJar> { self.jar.signed(&self.key) } fn signed_jar_mut(&mut self) -> SignedJar<&'_ mut cookie::CookieJar> { self.jar.signed_mut(&self.key) } } impl IntoResponseParts for SignedCookieJar { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { set_cookies(self.jar, res.headers_mut()); Ok(res) } } impl IntoResponse for SignedCookieJar { fn into_response(self) -> Response { (self, ()).into_response() } } struct SignedCookieJarIter<'a, K> { jar: &'a SignedCookieJar, iter: cookie::Iter<'a>, } impl<'a, K> Iterator for SignedCookieJarIter<'a, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { loop { let cookie = self.iter.next()?; if let Some(cookie) = self.jar.get(cookie.name()) { return Some(cookie); } } } } impl Clone for SignedCookieJar { fn clone(&self) -> Self { Self { jar: self.jar.clone(), key: self.key.clone(), _marker: self._marker, } } } axum-axum-v0.6.20/axum-extra/src/extract/form.rs000066400000000000000000000077541446276533600216110ustar00rootroot00000000000000use axum::{ async_trait, body::HttpBody, extract::{rejection::RawFormRejection, FromRequest, RawForm}, response::{IntoResponse, Response}, BoxError, Error, RequestExt, }; use http::{Request, StatusCode}; use serde::de::DeserializeOwned; use std::fmt; /// Extractor that deserializes `application/x-www-form-urlencoded` requests /// into some type. /// /// `T` is expected to implement [`serde::Deserialize`]. /// /// # Differences from `axum::extract::Form` /// /// This extractor uses [`serde_html_form`] under-the-hood which supports multi-value items. These /// are sent by multiple `` attributes of the same name (e.g. checkboxes) and `` attributes of the same name (e.g. checkboxes) and `| async move { input.foo.to_owned() }), ); let input = Input { foo: "bar".to_owned(), }; let client = TestClient::new(app); let res = client.post("/").body(input.encode_to_vec()).send().await; let body = res.text().await; assert_eq!(body, "bar"); } #[tokio::test] async fn prost_decode_error() { #[derive(prost::Message)] struct Input { #[prost(string, tag = "1")] foo: String, } #[derive(prost::Message)] struct Expected { #[prost(int32, tag = "1")] test: i32, } let app = Router::new().route("/", post(|_: Protobuf| async {})); let input = Input { foo: "bar".to_owned(), }; let client = TestClient::new(app); let res = client.post("/").body(input.encode_to_vec()).send().await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); } #[tokio::test] async fn encode_body() { #[derive(prost::Message)] struct Input { #[prost(string, tag = "1")] foo: String, } #[derive(prost::Message)] struct Output { #[prost(string, tag = "1")] result: String, } let app = Router::new().route( "/", post(|input: Protobuf| async move { let output = Output { result: input.foo.to_owned(), }; Protobuf(output) }), ); let input = Input { foo: "bar".to_owned(), }; let client = TestClient::new(app); let res = client.post("/").body(input.encode_to_vec()).send().await; assert_eq!( res.headers()["content-type"], mime::APPLICATION_OCTET_STREAM.as_ref() ); let body = res.bytes().await; let output = Output::decode(body).unwrap(); assert_eq!(output.result, "bar"); } } axum-axum-v0.6.20/axum-extra/src/response/000077500000000000000000000000001446276533600204475ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/response/erased_json.rs000066400000000000000000000042431446276533600233140ustar00rootroot00000000000000use std::sync::Arc; use axum::{ http::{header, HeaderValue, StatusCode}, response::{IntoResponse, Response}, }; use bytes::{BufMut, Bytes, BytesMut}; use serde::Serialize; /// A response type that holds a JSON in serialized form. /// /// This allows returning a borrowing type from a handler, or returning different response /// types as JSON from different branches inside a handler. /// /// # Example /// /// ```rust /// # use axum::{response::IntoResponse}; /// # use axum_extra::response::ErasedJson; /// async fn handler() -> ErasedJson { /// # let condition = true; /// # let foo = (); /// # let bar = vec![()]; /// // ... /// /// if condition { /// ErasedJson::new(&foo) /// } else { /// ErasedJson::new(&bar) /// } /// } /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "erased-json")))] #[derive(Clone, Debug)] #[must_use] pub struct ErasedJson(Result>); impl ErasedJson { /// Create an `ErasedJson` by serializing a value with the compact formatter. pub fn new(val: T) -> Self { let mut bytes = BytesMut::with_capacity(128); let result = match serde_json::to_writer((&mut bytes).writer(), &val) { Ok(()) => Ok(bytes.freeze()), Err(e) => Err(Arc::new(e)), }; Self(result) } /// Create an `ErasedJson` by serializing a value with the pretty formatter. pub fn pretty(val: T) -> Self { let mut bytes = BytesMut::with_capacity(128); let result = match serde_json::to_writer_pretty((&mut bytes).writer(), &val) { Ok(()) => Ok(bytes.freeze()), Err(e) => Err(Arc::new(e)), }; Self(result) } } impl IntoResponse for ErasedJson { fn into_response(self) -> Response { match self.0 { Ok(bytes) => ( [( header::CONTENT_TYPE, HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), )], bytes, ) .into_response(), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), } } } axum-axum-v0.6.20/axum-extra/src/response/mod.rs000066400000000000000000000036461446276533600216050ustar00rootroot00000000000000//! Additional types for generating responses. #[cfg(feature = "erased-json")] mod erased_json; #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; #[cfg(feature = "json-lines")] #[doc(no_inline)] pub use crate::json_lines::JsonLines; macro_rules! mime_response { ( $(#[$m:meta])* $ident:ident, $mime:ident, ) => { mime_response! { $(#[$m])* $ident, mime::$mime.as_ref(), } }; ( $(#[$m:meta])* $ident:ident, $mime:expr, ) => { $(#[$m])* #[derive(Clone, Copy, Debug)] #[must_use] pub struct $ident(pub T); impl axum::response::IntoResponse for $ident where T: axum::response::IntoResponse, { fn into_response(self) -> axum::response::Response { ( [( http::header::CONTENT_TYPE, http::HeaderValue::from_static($mime), )], self.0, ) .into_response() } } impl From for $ident { fn from(inner: T) -> Self { Self(inner) } } }; } mime_response! { /// A HTML response. /// /// Will automatically get `Content-Type: text/html; charset=utf-8`. Html, TEXT_HTML_UTF_8, } mime_response! { /// A JavaScript response. /// /// Will automatically get `Content-Type: application/javascript; charset=utf-8`. JavaScript, APPLICATION_JAVASCRIPT_UTF_8, } mime_response! { /// A CSS response. /// /// Will automatically get `Content-Type: text/css; charset=utf-8`. Css, TEXT_CSS_UTF_8, } mime_response! { /// A WASM response. /// /// Will automatically get `Content-Type: application/wasm`. Wasm, "application/wasm", } axum-axum-v0.6.20/axum-extra/src/routing/000077500000000000000000000000001446276533600203005ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/src/routing/mod.rs000066400000000000000000000321351446276533600214310ustar00rootroot00000000000000//! Additional types for defining routes. use axum::{ http::Request, response::{IntoResponse, Redirect, Response}, routing::{any, MethodRouter}, Router, }; use http::{uri::PathAndQuery, StatusCode, Uri}; use std::{borrow::Cow, convert::Infallible}; use tower_service::Service; mod resource; #[cfg(feature = "typed-routing")] mod typed; pub use self::resource::Resource; #[cfg(feature = "typed-routing")] pub use self::typed::WithQueryParams; #[cfg(feature = "typed-routing")] pub use axum_macros::TypedPath; #[cfg(feature = "typed-routing")] pub use self::typed::{SecondElementIs, TypedPath}; /// Extension trait that adds additional methods to [`Router`]. pub trait RouterExt: sealed::Sealed { /// Add a typed `GET` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `DELETE` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `HEAD` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `OPTIONS` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `PATCH` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `POST` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `PUT` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add a typed `TRACE` route to the router. /// /// The path will be inferred from the first argument to the handler function which must /// implement [`TypedPath`]. /// /// See [`TypedPath`] for more details and examples. #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath; /// Add another route to the router with an additional "trailing slash redirect" route. /// /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a /// route for `/foo/` that redirects to `/foo`. /// /// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a /// route for `/bar` that redirects to `/bar/`. /// /// This is similar to what axum 0.5.x did by default, except this explicitly adds another /// route, so trying to add a `/foo/` route after calling `.route_with_tsr("/foo", /* ... */)` /// will result in a panic due to route overlap. /// /// # Example /// /// ``` /// use axum::{Router, routing::get}; /// use axum_extra::routing::RouterExt; /// /// let app = Router::new() /// // `/foo/` will redirect to `/foo` /// .route_with_tsr("/foo", get(|| async {})) /// // `/bar` will redirect to `/bar/` /// .route_with_tsr("/bar/", get(|| async {})); /// # let _: Router = app; /// ``` fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where Self: Sized; /// Add another route to the router with an additional "trailing slash redirect" route. /// /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. fn route_service_with_tsr(self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized; } impl RouterExt for Router where B: axum::body::HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::get(handler)) } #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::delete(handler)) } #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::head(handler)) } #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::options(handler)) } #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::patch(handler)) } #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::post(handler)) } #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::put(handler)) } #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where H: axum::handler::Handler, T: SecondElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::trace(handler)) } #[track_caller] fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where Self: Sized, { validate_tsr_path(path); self = self.route(path, method_router); add_tsr_redirect_route(self, path) } #[track_caller] fn route_service_with_tsr(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized, { validate_tsr_path(path); self = self.route_service(path, service); add_tsr_redirect_route(self, path) } } #[track_caller] fn validate_tsr_path(path: &str) { if path == "/" { panic!("Cannot add a trailing slash redirect route for `/`") } } fn add_tsr_redirect_route(router: Router, path: &str) -> Router where B: axum::body::HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { async fn redirect_handler(uri: Uri) -> Response { let new_uri = map_path(uri, |path| { path.strip_suffix('/') .map(Cow::Borrowed) .unwrap_or_else(|| Cow::Owned(format!("{path}/"))) }); if let Some(new_uri) = new_uri { Redirect::permanent(&new_uri.to_string()).into_response() } else { StatusCode::BAD_REQUEST.into_response() } } if let Some(path_without_trailing_slash) = path.strip_suffix('/') { router.route(path_without_trailing_slash, any(redirect_handler)) } else { router.route(&format!("{path}/"), any(redirect_handler)) } } /// Map the path of a `Uri`. /// /// Returns `None` if the `Uri` cannot be put back together with the new path. fn map_path(original_uri: Uri, f: F) -> Option where F: FnOnce(&str) -> Cow<'_, str>, { let mut parts = original_uri.into_parts(); let path_and_query = parts.path_and_query.as_ref()?; let new_path = f(path_and_query.path()); let new_path_and_query = if let Some(query) = &path_and_query.query() { format!("{new_path}?{query}").parse::().ok()? } else { new_path.parse::().ok()? }; parts.path_and_query = Some(new_path_and_query); Uri::from_parts(parts).ok() } mod sealed { pub trait Sealed {} impl Sealed for axum::Router {} } #[cfg(test)] mod tests { use super::*; use crate::test_helpers::*; use axum::{extract::Path, http::StatusCode, routing::get}; #[tokio::test] async fn test_tsr() { let app = Router::new() .route_with_tsr("/foo", get(|| async {})) .route_with_tsr("/bar/", get(|| async {})); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/foo/").send().await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/foo"); let res = client.get("/bar/").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/bar/"); } #[tokio::test] async fn tsr_with_params() { let app = Router::new() .route_with_tsr( "/a/:a", get(|Path(param): Path| async move { param }), ) .route_with_tsr( "/b/:b/", get(|Path(param): Path| async move { param }), ); let client = TestClient::new(app); let res = client.get("/a/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo"); let res = client.get("/a/foo/").send().await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/a/foo"); let res = client.get("/b/foo/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo"); let res = client.get("/b/foo").send().await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/b/foo/"); } #[tokio::test] async fn tsr_maintains_query_params() { let app = Router::new().route_with_tsr("/foo", get(|| async {})); let client = TestClient::new(app); let res = client.get("/foo/?a=a").send().await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/foo?a=a"); } #[test] #[should_panic = "Cannot add a trailing slash redirect route for `/`"] fn tsr_at_root() { let _: Router = Router::new().route_with_tsr("/", get(|| async move {})); } } axum-axum-v0.6.20/axum-extra/src/routing/resource.rs000066400000000000000000000143341446276533600225020ustar00rootroot00000000000000use axum::{ body::Body, handler::Handler, routing::{delete, get, on, post, MethodFilter, MethodRouter}, Router, }; /// A resource which defines a set of conventional CRUD routes. /// /// # Example /// /// ```rust /// use axum::{Router, routing::get, extract::Path}; /// use axum_extra::routing::{RouterExt, Resource}; /// /// let users = Resource::named("users") /// // Define a route for `GET /users` /// .index(|| async {}) /// // `POST /users` /// .create(|| async {}) /// // `GET /users/new` /// .new(|| async {}) /// // `GET /users/:users_id` /// .show(|Path(user_id): Path| async {}) /// // `GET /users/:users_id/edit` /// .edit(|Path(user_id): Path| async {}) /// // `PUT or PATCH /users/:users_id` /// .update(|Path(user_id): Path| async {}) /// // `DELETE /users/:users_id` /// .destroy(|Path(user_id): Path| async {}); /// /// let app = Router::new().merge(users); /// # let _: Router = app; /// ``` #[derive(Debug)] #[must_use] pub struct Resource { pub(crate) name: String, pub(crate) router: Router, } impl Resource where B: axum::body::HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { /// Create a `Resource` with the given name. /// /// All routes will be nested at `/{resource_name}`. pub fn named(resource_name: &str) -> Self { Self { name: resource_name.to_owned(), router: Router::new(), } } /// Add a handler at `GET /{resource_name}`. pub fn index(self, handler: H) -> Self where H: Handler, T: 'static, { let path = self.index_create_path(); self.route(&path, get(handler)) } /// Add a handler at `POST /{resource_name}`. pub fn create(self, handler: H) -> Self where H: Handler, T: 'static, { let path = self.index_create_path(); self.route(&path, post(handler)) } /// Add a handler at `GET /{resource_name}/new`. pub fn new(self, handler: H) -> Self where H: Handler, T: 'static, { let path = format!("/{}/new", self.name); self.route(&path, get(handler)) } /// Add a handler at `GET /{resource_name}/:{resource_name}_id`. pub fn show(self, handler: H) -> Self where H: Handler, T: 'static, { let path = self.show_update_destroy_path(); self.route(&path, get(handler)) } /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. pub fn edit(self, handler: H) -> Self where H: Handler, T: 'static, { let path = format!("/{0}/:{0}_id/edit", self.name); self.route(&path, get(handler)) } /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. pub fn update(self, handler: H) -> Self where H: Handler, T: 'static, { let path = self.show_update_destroy_path(); self.route(&path, on(MethodFilter::PUT | MethodFilter::PATCH, handler)) } /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. pub fn destroy(self, handler: H) -> Self where H: Handler, T: 'static, { let path = self.show_update_destroy_path(); self.route(&path, delete(handler)) } fn index_create_path(&self) -> String { format!("/{}", self.name) } fn show_update_destroy_path(&self) -> String { format!("/{0}/:{0}_id", self.name) } fn route(mut self, path: &str, method_router: MethodRouter) -> Self { self.router = self.router.route(path, method_router); self } } impl From> for Router { fn from(resource: Resource) -> Self { resource.router } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use axum::{extract::Path, http::Method, Router}; use http::Request; use tower::{Service, ServiceExt}; #[tokio::test] async fn works() { let users = Resource::named("users") .index(|| async { "users#index" }) .create(|| async { "users#create" }) .new(|| async { "users#new" }) .show(|Path(id): Path| async move { format!("users#show id={id}") }) .edit(|Path(id): Path| async move { format!("users#edit id={id}") }) .update(|Path(id): Path| async move { format!("users#update id={id}") }) .destroy(|Path(id): Path| async move { format!("users#destroy id={id}") }); let mut app = Router::new().merge(users); assert_eq!( call_route(&mut app, Method::GET, "/users").await, "users#index" ); assert_eq!( call_route(&mut app, Method::POST, "/users").await, "users#create" ); assert_eq!( call_route(&mut app, Method::GET, "/users/new").await, "users#new" ); assert_eq!( call_route(&mut app, Method::GET, "/users/1").await, "users#show id=1" ); assert_eq!( call_route(&mut app, Method::GET, "/users/1/edit").await, "users#edit id=1" ); assert_eq!( call_route(&mut app, Method::PATCH, "/users/1").await, "users#update id=1" ); assert_eq!( call_route(&mut app, Method::PUT, "/users/1").await, "users#update id=1" ); assert_eq!( call_route(&mut app, Method::DELETE, "/users/1").await, "users#destroy id=1" ); } async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { let res = app .ready() .await .unwrap() .call( Request::builder() .method(method) .uri(uri) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let bytes = hyper::body::to_bytes(res).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } } axum-axum-v0.6.20/axum-extra/src/routing/typed.rs000066400000000000000000000302751446276533600220020ustar00rootroot00000000000000use std::{any::type_name, fmt}; use super::sealed::Sealed; use http::Uri; use serde::Serialize; /// A type safe path. /// /// This is used to statically connect a path to its corresponding handler using /// [`RouterExt::typed_get`], [`RouterExt::typed_post`], etc. /// /// # Example /// /// ```rust /// use serde::Deserialize; /// use axum::{Router, extract::Json}; /// use axum_extra::routing::{ /// TypedPath, /// RouterExt, // for `Router::typed_*` /// }; /// /// // A type safe route with `/users/:id` as its associated path. /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id")] /// struct UsersMember { /// id: u32, /// } /// /// // A regular handler function that takes `UsersMember` as the first argument /// // and thus creates a typed connection between this handler and the `/users/:id` path. /// // /// // The `TypedPath` must be the first argument to the function. /// async fn users_show( /// UsersMember { id }: UsersMember, /// ) { /// // ... /// } /// /// let app = Router::new() /// // Add our typed route to the router. /// // /// // The path will be inferred to `/users/:id` since `users_show`'s /// // first argument is `UsersMember` which implements `TypedPath` /// .typed_get(users_show) /// .typed_post(users_create) /// .typed_delete(users_destroy); /// /// #[derive(TypedPath)] /// #[typed_path("/users")] /// struct UsersCollection; /// /// #[derive(Deserialize)] /// struct UsersCreatePayload { /* ... */ } /// /// async fn users_create( /// _: UsersCollection, /// // Our handlers can accept other extractors. /// Json(payload): Json, /// ) { /// // ... /// } /// /// async fn users_destroy(_: UsersCollection) { /* ... */ } /// /// # /// # let app: Router = app; /// ``` /// /// # Using `#[derive(TypedPath)]` /// /// While `TypedPath` can be implemented manually, it's _highly_ recommended to derive it: /// /// ``` /// use serde::Deserialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id")] /// struct UsersMember { /// id: u32, /// } /// ``` /// /// The macro expands to: /// /// - A `TypedPath` implementation. /// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`], /// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must /// also implement [`serde::Deserialize`], unless it's a unit struct. /// - A [`Display`] implementation that interpolates the captures. This can be used to, among other /// things, create links to known paths and have them verified statically. Note that the /// [`Display`] implementation for each field must return something that's compatible with its /// [`Deserialize`] implementation. /// /// Additionally the macro will verify the captures in the path matches the fields of the struct. /// For example this fails to compile since the struct doesn't have a `team_id` field: /// /// ```compile_fail /// use serde::Deserialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id/teams/:team_id")] /// struct UsersMember { /// id: u32, /// } /// ``` /// /// Unit and tuple structs are also supported: /// /// ``` /// use serde::Deserialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath)] /// #[typed_path("/users")] /// struct UsersCollection; /// /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id")] /// struct UsersMember(u32); /// ``` /// /// ## Percent encoding /// /// The generated [`Display`] implementation will automatically percent-encode the arguments: /// /// ``` /// use serde::Deserialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id")] /// struct UsersMember { /// id: String, /// } /// /// assert_eq!( /// UsersMember { /// id: "foo bar".to_string(), /// }.to_string(), /// "/users/foo%20bar", /// ); /// ``` /// /// ## Customizing the rejection /// /// By default the rejection used in the [`FromRequest`] implemetation will be [`PathRejection`]. /// /// That can be customized using `#[typed_path("...", rejection(YourType))]`: /// /// ``` /// use serde::Deserialize; /// use axum_extra::routing::TypedPath; /// use axum::{ /// response::{IntoResponse, Response}, /// extract::rejection::PathRejection, /// }; /// /// #[derive(TypedPath, Deserialize)] /// #[typed_path("/users/:id", rejection(UsersMemberRejection))] /// struct UsersMember { /// id: String, /// } /// /// struct UsersMemberRejection; /// /// // Your rejection type must implement `From`. /// // /// // Here you can grab whatever details from the inner rejection /// // that you need. /// impl From for UsersMemberRejection { /// fn from(rejection: PathRejection) -> Self { /// # UsersMemberRejection /// // ... /// } /// } /// /// // Your rejection must implement `IntoResponse`, like all rejections. /// impl IntoResponse for UsersMemberRejection { /// fn into_response(self) -> Response { /// # ().into_response() /// // ... /// } /// } /// ``` /// /// The `From` requirement only applies if your typed path is a struct with named /// fields or a tuple struct. For unit structs your rejection type must implement `Default`: /// /// ``` /// use axum_extra::routing::TypedPath; /// use axum::response::{IntoResponse, Response}; /// /// #[derive(TypedPath)] /// #[typed_path("/users", rejection(UsersCollectionRejection))] /// struct UsersCollection; /// /// #[derive(Default)] /// struct UsersCollectionRejection; /// /// impl IntoResponse for UsersCollectionRejection { /// fn into_response(self) -> Response { /// # ().into_response() /// // ... /// } /// } /// ``` /// /// [`FromRequest`]: axum::extract::FromRequest /// [`RouterExt::typed_get`]: super::RouterExt::typed_get /// [`RouterExt::typed_post`]: super::RouterExt::typed_post /// [`Path`]: axum::extract::Path /// [`Display`]: std::fmt::Display /// [`Deserialize`]: serde::Deserialize /// [`PathRejection`]: axum::extract::rejection::PathRejection pub trait TypedPath: std::fmt::Display { /// The path with optional captures such as `/users/:id`. const PATH: &'static str; /// Convert the path into a `Uri`. /// /// # Panics /// /// The default implementation parses the required [`Display`] implementation. If that fails it /// will panic. /// /// Using `#[derive(TypedPath)]` will never result in a panic since it percent-encodes /// arguments. /// /// [`Display`]: std::fmt::Display fn to_uri(&self) -> Uri { self.to_string().parse().unwrap() } /// Add query parameters to a path. /// /// # Example /// /// ``` /// use axum_extra::routing::TypedPath; /// use serde::Serialize; /// /// #[derive(TypedPath)] /// #[typed_path("/users")] /// struct Users; /// /// #[derive(Serialize)] /// struct Pagination { /// page: u32, /// per_page: u32, /// } /// /// let path = Users.with_query_params(Pagination { /// page: 1, /// per_page: 10, /// }); /// /// assert_eq!(path.to_uri(), "/users?&page=1&per_page=10"); /// ``` /// /// # Panics /// /// If `params` doesn't support being serialized as query params [`WithQueryParams`]'s [`Display`] /// implementation will panic, and thus [`WithQueryParams::to_uri`] will also panic. /// /// [`WithQueryParams::to_uri`]: TypedPath::to_uri /// [`Display`]: std::fmt::Display fn with_query_params(self, params: T) -> WithQueryParams where T: Serialize, Self: Sized, { WithQueryParams { path: self, params } } } /// A [`TypedPath`] with query params. /// /// See [`TypedPath::with_query_params`] for more details. #[derive(Debug, Clone, Copy)] pub struct WithQueryParams { path: P, params: T, } impl fmt::Display for WithQueryParams where P: TypedPath, T: Serialize, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut out = self.path.to_string(); if !out.contains('?') { out.push('?'); } let mut urlencoder = form_urlencoded::Serializer::new(&mut out); self.params .serialize(serde_html_form::ser::Serializer::new(&mut urlencoder)) .unwrap_or_else(|err| { panic!( "failed to URL encode value of type `{}`: {}", type_name::(), err ) }); f.write_str(&out)?; Ok(()) } } impl TypedPath for WithQueryParams where P: TypedPath, T: Serialize, { const PATH: &'static str = P::PATH; } /// Utility trait used with [`RouterExt`] to ensure the second element of a tuple type is a /// given type. /// /// If you see it in type errors its most likely because the second argument to your handler doesn't /// implement [`TypedPath`]. /// /// You normally shouldn't have to use this trait directly. /// /// It is sealed such that it cannot be implemented outside this crate. /// /// [`RouterExt`]: super::RouterExt pub trait SecondElementIs

: Sealed {} macro_rules! impl_second_element_is { ( $($ty:ident),* $(,)? ) => { impl SecondElementIs

for (M, P, $($ty,)*) where P: TypedPath {} impl Sealed for (M, P, $($ty,)*) where P: TypedPath {} impl SecondElementIs

for (M, Option

, $($ty,)*) where P: TypedPath {} impl Sealed for (M, Option

, $($ty,)*) where P: TypedPath {} impl SecondElementIs

for (M, Result, $($ty,)*) where P: TypedPath {} impl Sealed for (M, Result, $($ty,)*) where P: TypedPath {} }; } impl_second_element_is!(); impl_second_element_is!(T1); impl_second_element_is!(T1, T2); impl_second_element_is!(T1, T2, T3); impl_second_element_is!(T1, T2, T3, T4); impl_second_element_is!(T1, T2, T3, T4, T5); impl_second_element_is!(T1, T2, T3, T4, T5, T6); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); #[cfg(test)] mod tests { use super::*; use crate::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:id")] struct UsersShow { id: i32, } #[derive(Serialize)] struct Params { foo: &'static str, bar: i32, baz: bool, } #[test] fn with_params() { let path = UsersShow { id: 1 }.with_query_params(Params { foo: "foo", bar: 123, baz: true, }); let uri = path.to_uri(); // according to [the spec] starting the params with `?&` is allowed specifically: // // > If bytes is the empty byte sequence, then continue. // // [the spec]: https://url.spec.whatwg.org/#urlencoded-parsing assert_eq!(uri, "/users/1?&foo=foo&bar=123&baz=true"); } #[test] fn with_params_called_multiple_times() { let path = UsersShow { id: 1 } .with_query_params(Params { foo: "foo", bar: 123, baz: true, }) .with_query_params([("qux", 1337)]); let uri = path.to_uri(); assert_eq!(uri, "/users/1?&foo=foo&bar=123&baz=true&qux=1337"); } } axum-axum-v0.6.20/axum-extra/test_files/000077500000000000000000000000001446276533600201635ustar00rootroot00000000000000axum-axum-v0.6.20/axum-extra/test_files/index.html000066400000000000000000000000271446276533600221570ustar00rootroot00000000000000

Hello, World!

axum-axum-v0.6.20/axum-extra/test_files/index_2.html000066400000000000000000000000371446276533600224010ustar00rootroot00000000000000Hello, World! axum-axum-v0.6.20/axum-extra/test_files/script.js000066400000000000000000000000221446276533600220170ustar00rootroot00000000000000console.log('hi') axum-axum-v0.6.20/axum-macros/000077500000000000000000000000001446276533600161635ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/CHANGELOG.md000066400000000000000000000137201446276533600177770ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). # Unreleased - None. # 0.3.8 (17. July, 2023) - **fixed:** Allow unreachable code in `#[debug_handler]` ([#2014]) [#2014]: https://github.com/tokio-rs/axum/pull/2014 # 0.3.7 (22. March, 2023) - **change:** Update to syn 2.0 ([#1862]) - **fixed:** Give better error if generics are used with `#[derive(FromRef)]` ([#1874]) [#1862]: https://github.com/tokio-rs/axum/pull/1862 [#1874]: https://github.com/tokio-rs/axum/pull/1874 # 0.3.6 (13. March, 2023) - **fixed:** Improve `#[debug_handler]` message for known generic request-consuming extractors ([#1826]) [#1826]: https://github.com/tokio-rs/axum/pull/1826 # 0.3.5 (03. March, 2023) - **fixed:** In `#[debug_handler]` provide specific errors about `FromRequest` extractors not being the last argument ([#1797]) [#1797]: https://github.com/tokio-rs/axum/pull/1797 # 0.3.4 (12. February, 2022) - **fixed:** Fix `#[derive(FromRef)]` with `Copy` fields generating clippy warnings ([#1749]) [#1749]: https://github.com/tokio-rs/axum/pull/1749 # 0.3.3 (11. February, 2022) - **fixed:** Fix `#[debug_handler]` sometimes giving wrong borrow related suggestions ([#1710]) [#1710]: https://github.com/tokio-rs/axum/pull/1710 # 0.3.2 (22. January, 2022) - No public API changes. # 0.3.1 (9. January, 2022) - **fixed:** Fix warnings for cloning references in generated code ([#1676]) [#1676]: https://github.com/tokio-rs/axum/pull/1676 # 0.3.0 (25. November, 2022) - **added:** Add `#[derive(FromRequestParts)]` for deriving an implementation of `FromRequestParts`, similarly to `#[derive(FromRequest)]` ([#1305]) - **added:** Add `#[derive(FromRef)]` ([#1430]) - **added:** Add `#[from_ref(skip)]` to skip implementing `FromRef` for individual fields ([#1537]) - **added:** Support using a different rejection for `#[derive(FromRequest)]` with `#[from_request(rejection(MyRejection))]` ([#1256]) - **change:** axum-macro's MSRV is now 1.60 ([#1239]) - **breaking:** `#[derive(FromRequest)]` will no longer generate a rejection enum but instead generate `type Rejection = axum::response::Response`. Use the new `#[from_request(rejection(MyRejection))]` attribute to change this. The `rejection_derive` attribute has also been removed ([#1272]) [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1256]: https://github.com/tokio-rs/axum/pull/1256 [#1272]: https://github.com/tokio-rs/axum/pull/1272 [#1305]: https://github.com/tokio-rs/axum/pull/1305 [#1430]: https://github.com/tokio-rs/axum/pull/1430 [#1537]: https://github.com/tokio-rs/axum/pull/1537
0.3.0 Pre-Releases # 0.3.0-rc.3 (18. November, 2022) - **added:** Add `#[from_ref(skip)]` to skip implementing `FromRef` for individual fields ([#1537]) [#1537]: https://github.com/tokio-rs/axum/pull/1537 # 0.3.0-rc.2 (8. November, 2022) - **added:** Add `#[derive(FromRef)]` ([#1430]) [#1430]: https://github.com/tokio-rs/axum/pull/1430 # 0.3.0-rc.1 (23. August, 2022) - **change:** axum-macro's MSRV is now 1.60 ([#1239]) - **added:** Support using a different rejection for `#[derive(FromRequest)]` with `#[from_request(rejection(MyRejection))]` ([#1256]) - **breaking:** `#[derive(FromRequest)]` will no longer generate a rejection enum but instead generate `type Rejection = axum::response::Response`. Use the new `#[from_request(rejection(MyRejection))]` attribute to change this. The `rejection_derive` attribute has also been removed ([#1272]) - **added:** Add `#[derive(FromRequestParts)]` for deriving an implementation of `FromRequestParts`, similarly to `#[derive(FromRequest)]` ([#1305]) [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1256]: https://github.com/tokio-rs/axum/pull/1256 [#1272]: https://github.com/tokio-rs/axum/pull/1272 [#1305]: https://github.com/tokio-rs/axum/pull/1305
# 0.2.3 (27. June, 2022) - **change:** axum-macros's MSRV is now 1.56 ([#1098]) - **fixed:** Silence "unnecessary use of `to_string`" lint for `#[derive(TypedPath)]` ([#1117]) [#1098]: https://github.com/tokio-rs/axum/pull/1098 [#1117]: https://github.com/tokio-rs/axum/pull/1117 # 0.2.2 (18. May, 2022) - **added:** In `debug_handler`, check if `Request` is used as non-final extractor ([#1035]) - **added:** In `debug_handler`, check if multiple `Path` extractors are used ([#1035]) - **added:** In `debug_handler`, check if multiple body extractors are used ([#1036]) - **added:** Support customizing rejections for `#[derive(TypedPath)]` ([#1012]) [#1035]: https://github.com/tokio-rs/axum/pull/1035 [#1036]: https://github.com/tokio-rs/axum/pull/1036 [#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.2.1 (10. May, 2022) - **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001]) - **fixed:** Support wildcards in typed paths ([#1003]) - **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009]) - **added:** Support using a custom rejection type for `#[derive(TypedPath)]` instead of `PathRejection` ([#1012]) [#1001]: https://github.com/tokio-rs/axum/pull/1001 [#1003]: https://github.com/tokio-rs/axum/pull/1003 [#1009]: https://github.com/tokio-rs/axum/pull/1009 [#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.2.0 (31. March, 2022) - **breaking:** Routes are now required to start with `/`. Previously empty routes or routes such as `:foo` would be accepted but most likely result in bugs ([#823]) [#823]: https://github.com/tokio-rs/axum/pull/823 # 0.1.2 (1. March 2022) - **fixed:** Use fully qualified `Result` type ([#796]) [#796]: https://github.com/tokio-rs/axum/pull/796 # 0.1.1 (22. February 2022) - Add `#[derive(TypedPath)]` for use with axum-extra's new "type safe" routing API ([#756]) [#756]: https://github.com/tokio-rs/axum/pull/756 # 0.1.0 (31. January, 2022) - Initial release. axum-axum-v0.6.20/axum-macros/Cargo.toml000066400000000000000000000022171446276533600201150ustar00rootroot00000000000000[package] categories = ["asynchronous", "network-programming", "web-programming"] description = "Macros for axum" edition = "2021" rust-version = "1.60" homepage = "https://github.com/tokio-rs/axum" keywords = ["axum"] license = "MIT" name = "axum-macros" readme = "README.md" repository = "https://github.com/tokio-rs/axum" version = "0.3.8" # remember to also bump the version that axum and axum-extra depends on [features] default = [] __private = ["syn/visit-mut"] [lib] proc-macro = true [dependencies] heck = "0.4" proc-macro2 = "1.0" quote = "1.0" syn = { version = "2.0", features = [ "full", "parsing", # needed for `Hash` impls "extra-traits", ] } [dev-dependencies] axum = { path = "../axum", version = "0.6.0", features = ["headers", "macros"] } axum-extra = { path = "../axum-extra", version = "0.7.0", features = ["typed-routing", "cookie-private"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" syn = { version = "2.0", features = ["full", "extra-traits"] } tokio = { version = "1.25.0", features = ["full"] } trybuild = "1.0.63" [package.metadata.cargo-public-api-crates] allowed = [] axum-axum-v0.6.20/axum-macros/LICENSE000066400000000000000000000020411446276533600171650ustar00rootroot00000000000000Copyright 2021 Axum Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. axum-axum-v0.6.20/axum-macros/README.md000066400000000000000000000026511446276533600174460ustar00rootroot00000000000000# axum-macros [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-macros/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum-macros)](https://crates.io/crates/axum-macros) [![Documentation](https://docs.rs/axum-macros/badge.svg)](https://docs.rs/axum-macros) Macros for [`axum`]. More information about this crate can be found in the [crate documentation][docs]. ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. ## Minimum supported Rust version axum-macros's MSRV is 1.63. ## Getting Help You're also welcome to ask in the [Discord channel][chat] or open an [issue] with your question. ## Contributing ๐ŸŽˆ Thanks for your help improving the project! We are so happy to have you! We have a [contributing guide][contributing] to help you get involved in the `axum` project. ## License This project is licensed under the [MIT license][license]. ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in `axum` by you, shall be licensed as MIT, without any additional terms or conditions. [`axum`]: https://crates.io/crates/axum [chat]: https://discord.gg/tokio [contributing]: /CONTRIBUTING.md [docs]: https://docs.rs/axum-macros [license]: /axum-macros/LICENSE [issue]: https://github.com/tokio-rs/axum/issues/new axum-axum-v0.6.20/axum-macros/rust-toolchain000066400000000000000000000000231446276533600210540ustar00rootroot00000000000000nightly-2022-11-18 axum-axum-v0.6.20/axum-macros/src/000077500000000000000000000000001446276533600167525ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/src/attr_parsing.rs000066400000000000000000000050001446276533600220100ustar00rootroot00000000000000use quote::ToTokens; use syn::{ parse::{Parse, ParseStream}, Token, }; pub(crate) fn parse_parenthesized_attribute( input: ParseStream, out: &mut Option<(K, T)>, ) -> syn::Result<()> where K: Parse + ToTokens, T: Parse, { let kw = input.parse()?; let content; syn::parenthesized!(content in input); let inner = content.parse()?; if out.is_some() { let kw_name = std::any::type_name::().split("::").last().unwrap(); let msg = format!("`{kw_name}` specified more than once"); return Err(syn::Error::new_spanned(kw, msg)); } *out = Some((kw, inner)); Ok(()) } pub(crate) fn parse_assignment_attribute( input: ParseStream, out: &mut Option<(K, T)>, ) -> syn::Result<()> where K: Parse + ToTokens, T: Parse, { let kw = input.parse()?; input.parse::()?; let inner = input.parse()?; if out.is_some() { let kw_name = std::any::type_name::().split("::").last().unwrap(); let msg = format!("`{kw_name}` specified more than once"); return Err(syn::Error::new_spanned(kw, msg)); } *out = Some((kw, inner)); Ok(()) } pub(crate) trait Combine: Sized { fn combine(self, other: Self) -> syn::Result; } pub(crate) fn parse_attrs(ident: &str, attrs: &[syn::Attribute]) -> syn::Result where T: Combine + Default + Parse, { attrs .iter() .filter(|attr| attr.meta.path().is_ident(ident)) .map(|attr| attr.parse_args::()) .try_fold(T::default(), |out, next| out.combine(next?)) } pub(crate) fn combine_attribute(a: &mut Option<(K, T)>, b: Option<(K, T)>) -> syn::Result<()> where K: ToTokens, { if let Some((kw, inner)) = b { if a.is_some() { let kw_name = std::any::type_name::().split("::").last().unwrap(); let msg = format!("`{kw_name}` specified more than once"); return Err(syn::Error::new_spanned(kw, msg)); } *a = Some((kw, inner)); } Ok(()) } pub(crate) fn combine_unary_attribute(a: &mut Option, b: Option) -> syn::Result<()> where K: ToTokens, { if let Some(kw) = b { if a.is_some() { let kw_name = std::any::type_name::().split("::").last().unwrap(); let msg = format!("`{kw_name}` specified more than once"); return Err(syn::Error::new_spanned(kw, msg)); } *a = Some(kw); } Ok(()) } pub(crate) fn second(tuple: (T, K)) -> K { tuple.1 } axum-axum-v0.6.20/axum-macros/src/axum_test.rs000066400000000000000000000022561446276533600213360ustar00rootroot00000000000000use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{parse::Parse, parse_quote, visit_mut::VisitMut, ItemFn}; pub(crate) fn expand(_attr: Attrs, mut item_fn: ItemFn) -> TokenStream { item_fn.attrs.push(parse_quote!(#[tokio::test])); let nest_service_fn = replace_nest_with_nest_service(item_fn.clone()); quote! { #item_fn #nest_service_fn } } pub(crate) struct Attrs; impl Parse for Attrs { fn parse(_input: syn::parse::ParseStream) -> syn::Result { Ok(Self) } } fn replace_nest_with_nest_service(mut item_fn: ItemFn) -> Option { item_fn.sig.ident = format_ident!("{}_with_nest_service", item_fn.sig.ident); let mut visitor = NestToNestService::default(); syn::visit_mut::visit_item_fn_mut(&mut visitor, &mut item_fn); (visitor.count > 0).then(|| item_fn) } #[derive(Default)] struct NestToNestService { count: usize, } impl VisitMut for NestToNestService { fn visit_expr_method_call_mut(&mut self, i: &mut syn::ExprMethodCall) { if i.method == "nest" && i.args.len() == 2 { i.method = parse_quote!(nest_service); self.count += 1; } } } axum-axum-v0.6.20/axum-macros/src/debug_handler.rs000066400000000000000000000436711446276533600221160ustar00rootroot00000000000000use std::collections::HashSet; use crate::{ attr_parsing::{parse_assignment_attribute, second}, with_position::{Position, WithPosition}, }; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let Attrs { body_ty, state_ty } = attr; let body_ty = body_ty .map(second) .unwrap_or_else(|| parse_quote!(axum::body::Body)); let mut state_ty = state_ty.map(second); let check_extractor_count = check_extractor_count(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); // If the function is generic, we can't reliably check its inputs or whether the future it // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { let mut err = None; if state_ty.is_none() { let state_types_from_args = state_types_from_args(&item_fn); #[allow(clippy::comparison_chain)] if state_types_from_args.len() == 1 { state_ty = state_types_from_args.into_iter().next(); } else if state_types_from_args.len() > 1 { err = Some( syn::Error::new( Span::call_site(), "can't infer state type, please add set it explicitly, as in \ `#[debug_handler(state = MyStateType)]`", ) .into_compile_error(), ); } } err.unwrap_or_else(|| { let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); let check_input_order = check_input_order(&item_fn); let check_future_send = check_future_send(&item_fn); if let Some(check_input_order) = check_input_order { quote! { #check_input_order #check_future_send } } else { let check_inputs_impls_from_request = check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); quote! { #check_inputs_impls_from_request #check_future_send } } }) } else { syn::Error::new_spanned( &item_fn.sig.generics, "`#[axum_macros::debug_handler]` doesn't support generic functions", ) .into_compile_error() }; quote! { #item_fn #check_extractor_count #check_path_extractor #check_output_impls_into_response #check_inputs_and_future_send } } mod kw { syn::custom_keyword!(body); syn::custom_keyword!(state); } pub(crate) struct Attrs { body_ty: Option<(kw::body, Type)>, state_ty: Option<(kw::state, Type)>, } impl Parse for Attrs { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut body_ty = None; let mut state_ty = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(kw::body) { parse_assignment_attribute(input, &mut body_ty)?; } else if lh.peek(kw::state) { parse_assignment_attribute(input, &mut state_ty)?; } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { body_ty, state_ty }) } } fn check_extractor_count(item_fn: &ItemFn) -> Option { let max_extractors = 16; if item_fn.sig.inputs.len() <= max_extractors { None } else { let error_message = format!( "Handlers cannot take more than {max_extractors} arguments. \ Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", ); let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); Some(error) } } fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { item_fn .sig .inputs .iter() .enumerate() .filter_map(|(idx, fn_arg)| match fn_arg { FnArg::Receiver(_) => None, FnArg::Typed(pat_type) => { if let Type::Path(type_path) = &*pat_type.ty { type_path .path .segments .last() .map(|segment| (idx, fn_arg, &segment.ident)) } else { None } } }) } fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { let path_extractors = extractor_idents(item_fn) .filter(|(_, _, ident)| *ident == "Path") .collect::>(); if path_extractors.len() > 1 { path_extractors .into_iter() .map(|(_, arg, _)| { syn::Error::new_spanned( arg, "Multiple parameters must be extracted with a tuple \ `Path<(_, _)>` or a struct `Path`, not by applying \ multiple `Path<_>` extractors", ) .to_compile_error() }) .collect() } else { quote! {} } } fn is_self_pat_type(typed: &syn::PatType) -> bool { let ident = if let syn::Pat::Ident(ident) = &*typed.pat { &ident.ident } else { return false; }; ident == "self" } fn check_inputs_impls_from_request( item_fn: &ItemFn, body_ty: &Type, state_ty: Type, ) -> TokenStream { let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); WithPosition::new(item_fn.sig.inputs.iter()) .enumerate() .map(|(idx, arg)| { let must_impl_from_request_parts = match &arg { Position::First(_) | Position::Middle(_) => true, Position::Last(_) | Position::Only(_) => false, }; let arg = arg.into_inner(); let (span, ty) = match arg { FnArg::Receiver(receiver) => { if receiver.reference.is_some() { return syn::Error::new_spanned( receiver, "Handlers must only take owned values", ) .into_compile_error(); } let span = receiver.span(); (span, syn::parse_quote!(Self)) } FnArg::Typed(typed) => { let ty = &typed.ty; let span = ty.span(); if is_self_pat_type(typed) { (span, syn::parse_quote!(Self)) } else { (span, ty.clone()) } } }; let consumes_request = request_consuming_type_name(&ty).is_some(); let check_fn = format_ident!( "__axum_macros_check_{}_{}_from_request_check", item_fn.sig.ident, idx, span = span, ); let call_check_fn = format_ident!( "__axum_macros_check_{}_{}_from_request_call_check", item_fn.sig.ident, idx, span = span, ); let call_check_fn_body = if takes_self { quote_spanned! {span=> Self::#check_fn(); } } else { quote_spanned! {span=> #check_fn(); } }; let check_fn_generics = if must_impl_from_request_parts || consumes_request { quote! {} } else { quote! { } }; let from_request_bound = if must_impl_from_request_parts { quote_spanned! {span=> #ty: ::axum::extract::FromRequestParts<#state_ty> + Send } } else if consumes_request { quote_spanned! {span=> #ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send } } else { quote_spanned! {span=> #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send } }; quote_spanned! {span=> #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] fn #check_fn #check_fn_generics() where #from_request_bound, {} // we have to call the function to actually trigger a compile error // since the function is generic, just defining it is not enough #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] fn #call_check_fn() { #call_check_fn_body } } }) .collect::() } fn check_input_order(item_fn: &ItemFn) -> Option { let types_that_consume_the_request = item_fn .sig .inputs .iter() .enumerate() .filter_map(|(idx, arg)| { let ty = match arg { FnArg::Typed(pat_type) => &*pat_type.ty, FnArg::Receiver(_) => return None, }; let type_name = request_consuming_type_name(ty)?; Some((idx, type_name, ty.span())) }) .collect::>(); if types_that_consume_the_request.is_empty() { return None; }; // exactly one type that consumes the request if types_that_consume_the_request.len() == 1 { // and that is not the last if types_that_consume_the_request[0].0 != item_fn.sig.inputs.len() - 1 { let (_idx, type_name, span) = &types_that_consume_the_request[0]; let error = format!( "`{type_name}` consumes the request body and thus must be \ the last argument to the handler function" ); return Some(quote_spanned! {*span=> compile_error!(#error); }); } else { return None; } } if types_that_consume_the_request.len() == 2 { let (_, first, _) = &types_that_consume_the_request[0]; let (_, second, _) = &types_that_consume_the_request[1]; let error = format!( "Can't have two extractors that consume the request body. \ `{first}` and `{second}` both do that.", ); let span = item_fn.sig.inputs.span(); Some(quote_spanned! {span=> compile_error!(#error); }) } else { let types = WithPosition::new(types_that_consume_the_request.into_iter()) .map(|pos| match pos { Position::First((_, type_name, _)) | Position::Middle((_, type_name, _)) => { format!("`{type_name}`, ") } Position::Last((_, type_name, _)) => format!("and `{type_name}`"), Position::Only(_) => unreachable!(), }) .collect::(); let error = format!( "Can't have more than one extractor that consume the request body. \ {types} all do that.", ); let span = item_fn.sig.inputs.span(); Some(quote_spanned! {span=> compile_error!(#error); }) } } fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { let path = match ty { Type::Path(type_path) => &type_path.path, _ => return None, }; let ident = match path.segments.last() { Some(path_segment) => &path_segment.ident, None => return None, }; let type_name = match &*ident.to_string() { "Json" => "Json<_>", "BodyStream" => "BodyStream", "RawBody" => "RawBody<_>", "RawForm" => "RawForm", "Multipart" => "Multipart", "Protobuf" => "Protobuf", "JsonLines" => "JsonLines<_>", "Form" => "Form<_>", "Request" => "Request<_>", "Bytes" => "Bytes", "String" => "String", "Parts" => "Parts", _ => return None, }; Some(type_name) } fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { let ty = match &item_fn.sig.output { syn::ReturnType::Default => return quote! {}, syn::ReturnType::Type(_, ty) => ty, }; let span = ty.span(); let declare_inputs = item_fn .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Receiver(_) => None, FnArg::Typed(pat_ty) => { let pat = &pat_ty.pat; let ty = &pat_ty.ty; Some(quote! { let #pat: #ty = panic!(); }) } }) .collect::(); let block = &item_fn.block; let make_value_name = format_ident!( "__axum_macros_check_{}_into_response_make_value", item_fn.sig.ident ); let make = if item_fn.sig.asyncness.is_some() { quote_spanned! {span=> #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] async fn #make_value_name() -> #ty { #declare_inputs #block } } } else { quote_spanned! {span=> #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] fn #make_value_name() -> #ty { #declare_inputs #block } } }; let name = format_ident!("__axum_macros_check_{}_into_response", item_fn.sig.ident); if let Some(receiver) = self_receiver(item_fn) { quote_spanned! {span=> #make #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] async fn #name() { let value = #receiver #make_value_name().await; fn check(_: T) where T: ::axum::response::IntoResponse {} check(value); } } } else { quote_spanned! {span=> #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] async fn #name() { #make let value = #make_value_name().await; fn check(_: T) where T: ::axum::response::IntoResponse {} check(value); } } } } fn check_future_send(item_fn: &ItemFn) -> TokenStream { if item_fn.sig.asyncness.is_none() { match &item_fn.sig.output { syn::ReturnType::Default => { return syn::Error::new_spanned( item_fn.sig.fn_token, "Handlers must be `async fn`s", ) .into_compile_error(); } syn::ReturnType::Type(_, ty) => ty, }; } let span = item_fn.sig.ident.span(); let handler_name = &item_fn.sig.ident; let args = item_fn.sig.inputs.iter().map(|_| { quote_spanned! {span=> panic!() } }); let name = format_ident!("__axum_macros_check_{}_future", item_fn.sig.ident); let do_check = quote! { fn check(_: T) where T: ::std::future::Future + Send {} check(future); }; if let Some(receiver) = self_receiver(item_fn) { quote! { #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] fn #name() { let future = #receiver #handler_name(#(#args),*); #do_check } } } else { quote! { #[allow(warnings)] #[allow(unreachable_code)] #[doc(hidden)] fn #name() { #item_fn let future = #handler_name(#(#args),*); #do_check } } } } fn self_receiver(item_fn: &ItemFn) -> Option { let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); if takes_self { return Some(quote! { Self:: }); } if let syn::ReturnType::Type(_, ty) = &item_fn.sig.output { if let syn::Type::Path(path) = &**ty { let segments = &path.path.segments; if segments.len() == 1 { if let Some(last) = segments.last() { match &last.arguments { syn::PathArguments::None if last.ident == "Self" => { return Some(quote! { Self:: }); } _ => {} } } } } } None } /// Given a signature like /// /// ```skip /// #[debug_handler] /// async fn handler( /// _: axum::extract::State, /// _: State, /// ) {} /// ``` /// /// This will extract `AppState`. /// /// Returns `None` if there are no `State` args or multiple of different types. fn state_types_from_args(item_fn: &ItemFn) -> HashSet { let types = item_fn .sig .inputs .iter() .filter_map(|input| match input { FnArg::Receiver(_) => None, FnArg::Typed(pat_type) => Some(pat_type), }) .map(|pat_type| &*pat_type.ty); crate::infer_state_types(types).collect() } #[test] fn ui() { crate::run_ui_tests("debug_handler"); } axum-axum-v0.6.20/axum-macros/src/from_ref.rs000066400000000000000000000050101446276533600211130ustar00rootroot00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote_spanned; use syn::{ parse::{Parse, ParseStream}, spanned::Spanned, Field, ItemStruct, Token, Type, }; use crate::attr_parsing::{combine_unary_attribute, parse_attrs, Combine}; pub(crate) fn expand(item: ItemStruct) -> syn::Result { if !item.generics.params.is_empty() { return Err(syn::Error::new_spanned( item.generics, "`#[derive(FromRef)]` doesn't support generics", )); } let tokens = item .fields .iter() .enumerate() .map(|(idx, field)| expand_field(&item.ident, idx, field)) .collect(); Ok(tokens) } fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream { let FieldAttrs { skip } = match parse_attrs("from_ref", &field.attrs) { Ok(attrs) => attrs, Err(err) => return err.into_compile_error(), }; if skip.is_some() { return TokenStream::default(); } let field_ty = &field.ty; let span = field.ty.span(); let body = if let Some(field_ident) = &field.ident { if matches!(field_ty, Type::Reference(_)) { quote_spanned! {span=> state.#field_ident } } else { quote_spanned! {span=> state.#field_ident.clone() } } } else { let idx = syn::Index { index: idx as _, span: field.span(), }; quote_spanned! {span=> state.#idx.clone() } }; quote_spanned! {span=> #[allow(clippy::clone_on_copy)] impl ::axum::extract::FromRef<#state> for #field_ty { fn from_ref(state: &#state) -> Self { #body } } } } mod kw { syn::custom_keyword!(skip); } #[derive(Default)] pub(super) struct FieldAttrs { pub(super) skip: Option, } impl Parse for FieldAttrs { fn parse(input: ParseStream) -> syn::Result { let mut skip = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(kw::skip) { skip = Some(input.parse()?); } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { skip }) } } impl Combine for FieldAttrs { fn combine(mut self, other: Self) -> syn::Result { let Self { skip } = other; combine_unary_attribute(&mut self.skip, skip)?; Ok(self) } } #[test] fn ui() { crate::run_ui_tests("from_ref"); } axum-axum-v0.6.20/axum-macros/src/from_request.rs000066400000000000000000001075301446276533600220410ustar00rootroot00000000000000use self::attr::FromRequestContainerAttrs; use crate::{ attr_parsing::{parse_attrs, second}, from_request::attr::FromRequestFieldAttrs, }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; use std::{collections::HashSet, fmt, iter}; use syn::{ parse_quote, punctuated::Punctuated, spanned::Spanned, Fields, Ident, Path, Token, Type, }; mod attr; #[derive(Clone, Copy)] pub(crate) enum Trait { FromRequest, FromRequestParts, } impl Trait { fn body_type(&self) -> impl Iterator { match self { Trait::FromRequest => Some(parse_quote!(B)).into_iter(), Trait::FromRequestParts => None.into_iter(), } } fn via_marker_type(&self) -> Option { match self { Trait::FromRequest => Some(parse_quote!(M)), Trait::FromRequestParts => None, } } } impl fmt::Display for Trait { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Trait::FromRequest => f.write_str("FromRequest"), Trait::FromRequestParts => f.write_str("FromRequestParts"), } } } #[derive(Debug)] enum State { Custom(syn::Type), Default(syn::Type), CannotInfer, } impl State { /// ```not_rust /// impl A for B {} /// ^ this type /// ``` fn impl_generics(&self) -> impl Iterator { match self { State::Default(inner) => Some(inner.clone()), State::Custom(_) => None, State::CannotInfer => Some(parse_quote!(S)), } .into_iter() } /// ```not_rust /// impl A for B {} /// ^ this type /// ``` fn trait_generics(&self) -> impl Iterator { match self { State::Default(inner) => iter::once(inner.clone()), State::Custom(inner) => iter::once(inner.clone()), State::CannotInfer => iter::once(parse_quote!(S)), } } fn bounds(&self) -> TokenStream { match self { State::Custom(_) => quote! {}, State::Default(inner) => quote! { #inner: ::std::marker::Send + ::std::marker::Sync, }, State::CannotInfer => quote! { S: ::std::marker::Send + ::std::marker::Sync, }, } } } impl ToTokens for State { fn to_tokens(&self, tokens: &mut TokenStream) { match self { State::Custom(inner) => inner.to_tokens(tokens), State::Default(inner) => inner.to_tokens(tokens), State::CannotInfer => quote! { S }.to_tokens(tokens), } } } pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { match item { syn::Item::Struct(item) => { let syn::ItemStruct { attrs, ident, generics, fields, semi_token: _, vis: _, struct_token: _, } = item; let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?; let FromRequestContainerAttrs { via, rejection, state, } = parse_attrs("from_request", &attrs)?; let state = match state { Some((_, state)) => State::Custom(state), None => { let mut inferred_state_types: HashSet<_> = infer_state_type_from_field_types(&fields) .chain(infer_state_type_from_field_attributes(&fields)) .collect(); if let Some((_, via)) = &via { inferred_state_types.extend(state_from_via(&ident, via)); } match inferred_state_types.len() { 0 => State::Default(syn::parse_quote!(S)), 1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()), _ => State::CannotInfer, } } }; let trait_impl = match (via.map(second), rejection.map(second)) { (Some(via), rejection) => impl_struct_by_extracting_all_at_once( ident, fields, via, rejection, generic_ident, &state, tr, )?, (None, rejection) => { error_on_generic_ident(generic_ident, tr)?; impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)? } }; if let State::CannotInfer = state { let attr_name = match tr { Trait::FromRequest => "from_request", Trait::FromRequestParts => "from_request_parts", }; let compile_error = syn::Error::new( Span::call_site(), format_args!( "can't infer state type, please add \ `#[{attr_name}(state = MyStateType)]` attribute", ), ) .into_compile_error(); Ok(quote! { #trait_impl #compile_error }) } else { Ok(trait_impl) } } syn::Item::Enum(item) => { let syn::ItemEnum { attrs, vis: _, enum_token: _, ident, generics, brace_token: _, variants, } = item; let generics_error = format!("`#[derive({tr})] on enums don't support generics"); if !generics.params.is_empty() { return Err(syn::Error::new_spanned(generics, generics_error)); } if let Some(where_clause) = generics.where_clause { return Err(syn::Error::new_spanned(where_clause, generics_error)); } let FromRequestContainerAttrs { via, rejection, state, } = parse_attrs("from_request", &attrs)?; let state = match state { Some((_, state)) => State::Custom(state), None => (|| { let via = via.as_ref().map(|(_, via)| via)?; state_from_via(&ident, via).map(State::Custom) })() .unwrap_or_else(|| State::Default(syn::parse_quote!(S))), }; match (via.map(second), rejection) { (Some(via), rejection) => impl_enum_by_extracting_all_at_once( ident, variants, via, rejection.map(second), state, tr, ), (None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned( rejection_kw, "cannot use `rejection` without `via`", )), (None, _) => Err(syn::Error::new( Span::call_site(), "missing `#[from_request(via(...))]`", )), } } _ => Err(syn::Error::new_spanned(item, "expected `struct` or `enum`")), } } fn parse_single_generic_type_on_struct( generics: syn::Generics, fields: &syn::Fields, tr: Trait, ) -> syn::Result> { if let Some(where_clause) = generics.where_clause { return Err(syn::Error::new_spanned( where_clause, format_args!("#[derive({tr})] doesn't support structs with `where` clauses"), )); } match generics.params.len() { 0 => Ok(None), 1 => { let param = generics.params.first().unwrap(); let ty_ident = match param { syn::GenericParam::Type(ty) => &ty.ident, syn::GenericParam::Lifetime(lifetime) => { return Err(syn::Error::new_spanned( lifetime, format_args!( "#[derive({tr})] doesn't support structs \ that are generic over lifetimes" ), )); } syn::GenericParam::Const(konst) => { return Err(syn::Error::new_spanned( konst, format_args!( "#[derive({tr})] doesn't support structs \ that have const generics" ), )); } }; match fields { syn::Fields::Named(fields_named) => { return Err(syn::Error::new_spanned( fields_named, format_args!( "#[derive({tr})] doesn't support named fields \ for generic structs. Use a tuple struct instead" ), )); } syn::Fields::Unnamed(fields_unnamed) => { if fields_unnamed.unnamed.len() != 1 { return Err(syn::Error::new_spanned( fields_unnamed, format_args!( "#[derive({tr})] only supports generics on \ tuple structs that have exactly one field" ), )); } let field = fields_unnamed.unnamed.first().unwrap(); if let syn::Type::Path(type_path) = &field.ty { if type_path .path .get_ident() .map_or(true, |field_type_ident| field_type_ident != ty_ident) { return Err(syn::Error::new_spanned( type_path, format_args!( "#[derive({tr})] only supports generics on \ tuple structs that have exactly one field of the generic type" ), )); } } else { return Err(syn::Error::new_spanned(&field.ty, "Expected type path")); } } syn::Fields::Unit => return Ok(None), } Ok(Some(ty_ident.clone())) } _ => Err(syn::Error::new_spanned( generics, format_args!("#[derive({tr})] only supports 0 or 1 generic type parameters"), )), } } fn error_on_generic_ident(generic_ident: Option, tr: Trait) -> syn::Result<()> { if let Some(generic_ident) = generic_ident { Err(syn::Error::new_spanned( generic_ident, format_args!( "#[derive({tr})] only supports generics when used with #[from_request(via)]" ), )) } else { Ok(()) } } fn impl_struct_by_extracting_each_field( ident: syn::Ident, fields: syn::Fields, rejection: Option, state: &State, tr: Trait, ) -> syn::Result { let trait_fn_body = match state { State::CannotInfer => quote! { ::std::unimplemented!() }, _ => { let extract_fields = extract_fields(&fields, &rejection, tr)?; quote! { ::std::result::Result::Ok(Self { #(#extract_fields)* }) } } }; let rejection_ident = if let Some(rejection) = rejection { quote!(#rejection) } else if has_no_fields(&fields) { quote!(::std::convert::Infallible) } else { quote!(::axum::response::Response) }; let impl_generics = tr .body_type() .chain(state.impl_generics()) .collect::>(); let trait_generics = state .trait_generics() .chain(tr.body_type()) .collect::>(); let state_bounds = state.bounds(); Ok(match tr { Trait::FromRequest => quote! { #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, #state_bounds { type Rejection = #rejection_ident; async fn from_request( mut req: ::axum::http::Request, state: &#state, ) -> ::std::result::Result { #trait_fn_body } } }, Trait::FromRequestParts => quote! { #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where #state_bounds { type Rejection = #rejection_ident; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { #trait_fn_body } } }, }) } fn has_no_fields(fields: &syn::Fields) -> bool { match fields { syn::Fields::Named(fields) => fields.named.is_empty(), syn::Fields::Unnamed(fields) => fields.unnamed.is_empty(), syn::Fields::Unit => true, } } fn extract_fields( fields: &syn::Fields, rejection: &Option, tr: Trait, ) -> syn::Result> { fn member(field: &syn::Field, index: usize) -> TokenStream { match &field.ident { Some(ident) => quote! { #ident }, _ => { let member = syn::Member::Unnamed(syn::Index { index: index as u32, span: field.span(), }); quote! { #member } } } } fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { if let Some((_, path)) = via { let span = path.span(); quote_spanned! {span=> |#path(inner)| inner } } else { quote_spanned! {ty_span=> ::std::convert::identity } } } let mut fields_iter = fields.iter(); let last = match tr { // Use FromRequestParts for all elements except the last Trait::FromRequest => fields_iter.next_back(), // Use FromRequestParts for all elements Trait::FromRequestParts => None, }; let mut res: Vec<_> = fields_iter .enumerate() .map(|(index, field)| { let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?; let member = member(field, index); let ty_span = field.ty.span(); let into_inner = into_inner(via, ty_span); if peel_option(&field.ty).is_some() { let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = ::axum::extract::FromRequestParts::from_request_parts( &mut parts, state, ) .await .ok() .map(#into_inner); req = ::axum::http::Request::from_parts(parts, body); value }, } } Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequestParts::from_request_parts( parts, state, ) .await .ok() .map(#into_inner) }, } } }; Ok(tokens) } else if peel_result_ok(&field.ty).is_some() { let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = ::axum::extract::FromRequestParts::from_request_parts( &mut parts, state, ) .await .map(#into_inner); req = ::axum::http::Request::from_parts(parts, body); value }, } } Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequestParts::from_request_parts( parts, state, ) .await .map(#into_inner) }, } } }; Ok(tokens) } else { let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { quote! { ::axum::response::IntoResponse::into_response } }; let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = ::axum::extract::FromRequestParts::from_request_parts( &mut parts, state, ) .await .map(#into_inner) .map_err(#map_err)?; req = ::axum::http::Request::from_parts(parts, body); value }, } } Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequestParts::from_request_parts( parts, state, ) .await .map(#into_inner) .map_err(#map_err)? }, } } }; Ok(tokens) } }) .collect::>()?; // Handle the last element, if deriving FromRequest if let Some(field) = last { let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?; let member = member(field, fields.len() - 1); let ty_span = field.ty.span(); let into_inner = into_inner(via, ty_span); let item = if peel_option(&field.ty).is_some() { quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequest::from_request(req, state) .await .ok() .map(#into_inner) }, } } else if peel_result_ok(&field.ty).is_some() { quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequest::from_request(req, state) .await .map(#into_inner) }, } } else { let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { quote! { ::axum::response::IntoResponse::into_response } }; quote_spanned! {ty_span=> #member: { ::axum::extract::FromRequest::from_request(req, state) .await .map(#into_inner) .map_err(#map_err)? }, } }; res.push(item); } Ok(res) } fn peel_option(ty: &syn::Type) -> Option<&syn::Type> { let type_path = if let syn::Type::Path(type_path) = ty { type_path } else { return None; }; let segment = type_path.path.segments.last()?; if segment.ident != "Option" { return None; } let args = match &segment.arguments { syn::PathArguments::AngleBracketed(args) => args, syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None, }; let ty = if args.args.len() == 1 { args.args.last().unwrap() } else { return None; }; if let syn::GenericArgument::Type(ty) = ty { Some(ty) } else { None } } fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> { let type_path = if let syn::Type::Path(type_path) = ty { type_path } else { return None; }; let segment = type_path.path.segments.last()?; if segment.ident != "Result" { return None; } let args = match &segment.arguments { syn::PathArguments::AngleBracketed(args) => args, syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None, }; let ty = if args.args.len() == 2 { args.args.first().unwrap() } else { return None; }; if let syn::GenericArgument::Type(ty) = ty { Some(ty) } else { None } } fn impl_struct_by_extracting_all_at_once( ident: syn::Ident, fields: syn::Fields, via_path: syn::Path, rejection: Option, generic_ident: Option, state: &State, tr: Trait, ) -> syn::Result { let fields = match fields { syn::Fields::Named(fields) => fields.named.into_iter(), syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(), syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(), }; for field in fields { let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?; if let Some((via, _)) = via { return Err(syn::Error::new_spanned( via, "`#[from_request(via(...))]` on a field cannot be used \ together with `#[from_request(...)]` on the container", )); } } let path_span = via_path.span(); let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { let rejection = quote! { #rejection }; let map_err = quote! { ::std::convert::From::from }; (rejection, map_err) } else { let rejection = quote! { ::axum::response::Response }; let map_err = quote! { ::axum::response::IntoResponse::into_response }; (rejection, map_err) }; // for something like // // ``` // #[derive(Clone, Default, FromRequest)] // #[from_request(via(State))] // struct AppState {} // ``` // // we need to implement `impl FromRequest` but only for // - `#[derive(FromRequest)]`, not `#[derive(FromRequestParts)]` // - `State`, not other extractors // // honestly not sure why but the tests all pass let via_marker_type = if path_ident_is_state(&via_path) { tr.via_marker_type() } else { None }; let impl_generics = tr .body_type() .chain(via_marker_type.clone()) .chain(state.impl_generics()) .chain(generic_ident.is_some().then(|| parse_quote!(T))) .collect::>(); let trait_generics = state .trait_generics() .chain(tr.body_type()) .chain(via_marker_type) .collect::>(); let ident_generics = generic_ident .is_some() .then(|| quote! { }) .unwrap_or_default(); let rejection_bound = rejection.as_ref().map(|rejection| { match (tr, generic_ident.is_some()) { (Trait::FromRequest, true) => { quote! { #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequest<#trait_generics>>::Rejection>, } }, (Trait::FromRequest, false) => { quote! { #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequest<#trait_generics>>::Rejection>, } }, (Trait::FromRequestParts, true) => { quote! { #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>, } }, (Trait::FromRequestParts, false) => { quote! { #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>, } } } }).unwrap_or_default(); let via_type_generics = if generic_ident.is_some() { quote! { T } } else { quote! { Self } }; let value_to_self = if generic_ident.is_some() { quote! { #ident(value) } } else { quote! { value } }; let state_bounds = state.bounds(); let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics where #via_path<#via_type_generics>: ::axum::extract::FromRequest<#trait_generics>, #rejection_bound B: ::std::marker::Send + 'static, #state_bounds { type Rejection = #associated_rejection_type; async fn from_request( req: ::axum::http::Request, state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequest::from_request(req, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) } } } } Trait::FromRequestParts => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics where #via_path<#via_type_generics>: ::axum::extract::FromRequestParts<#trait_generics>, #rejection_bound #state_bounds { type Rejection = #associated_rejection_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequestParts::from_request_parts(parts, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) } } } } }; Ok(tokens) } fn impl_enum_by_extracting_all_at_once( ident: syn::Ident, variants: Punctuated, path: syn::Path, rejection: Option, state: State, tr: Trait, ) -> syn::Result { for variant in variants { let FromRequestFieldAttrs { via } = parse_attrs("from_request", &variant.attrs)?; if let Some((via, _)) = via { return Err(syn::Error::new_spanned( via, "`#[from_request(via(...))]` cannot be used on variants", )); } let fields = match variant.fields { syn::Fields::Named(fields) => fields.named.into_iter(), syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(), syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(), }; for field in fields { let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?; if let Some((via, _)) = via { return Err(syn::Error::new_spanned( via, "`#[from_request(via(...))]` cannot be used inside variants", )); } } } let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { let rejection = quote! { #rejection }; let map_err = quote! { ::std::convert::From::from }; (rejection, map_err) } else { let rejection = quote! { ::axum::response::Response }; let map_err = quote! { ::axum::response::IntoResponse::into_response }; (rejection, map_err) }; let path_span = path.span(); let impl_generics = tr .body_type() .chain(state.impl_generics()) .collect::>(); let trait_generics = state .trait_generics() .chain(tr.body_type()) .collect::>(); let state_bounds = state.bounds(); let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, #state_bounds { type Rejection = #associated_rejection_type; async fn from_request( req: ::axum::http::Request, state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequest::from_request(req, state) .await .map(|#path(inner)| inner) .map_err(#map_err) } } } } Trait::FromRequestParts => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where #state_bounds { type Rejection = #associated_rejection_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequestParts::from_request_parts(parts, state) .await .map(|#path(inner)| inner) .map_err(#map_err) } } } } }; Ok(tokens) } /// For a struct like /// /// ```skip /// struct Extractor { /// state: State, /// } /// ``` /// /// We can infer the state type to be `AppState` because it appears inside a `State` fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator + '_ { match fields { Fields::Named(fields_named) => Box::new(crate::infer_state_types( fields_named.named.iter().map(|field| &field.ty), )) as Box>, Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types( fields_unnamed.unnamed.iter().map(|field| &field.ty), )), Fields::Unit => Box::new(iter::empty()), } } /// For a struct like /// /// ```skip /// struct Extractor { /// #[from_request(via(State))] /// state: AppState, /// } /// ``` /// /// We can infer the state type to be `AppState` because it has `via(State)` and thus can be /// extracted with `State` fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator + '_ { match fields { Fields::Named(fields_named) => { Box::new(fields_named.named.iter().filter_map(|field| { // TODO(david): its a little wasteful to parse the attributes again here // ideally we should parse things once and pass the data down let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs).ok()?; let (_, via_path) = via?; path_ident_is_state(&via_path).then(|| field.ty.clone()) })) as Box> } Fields::Unnamed(fields_unnamed) => { Box::new(fields_unnamed.unnamed.iter().filter_map(|field| { // TODO(david): its a little wasteful to parse the attributes again here // ideally we should parse things once and pass the data down let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs).ok()?; let (_, via_path) = via?; path_ident_is_state(&via_path).then(|| field.ty.clone()) })) } Fields::Unit => Box::new(iter::empty()), } } fn path_ident_is_state(path: &Path) -> bool { if let Some(last_segment) = path.segments.last() { last_segment.ident == "State" } else { false } } fn state_from_via(ident: &Ident, via: &Path) -> Option { path_ident_is_state(via).then(|| parse_quote!(#ident)) } #[test] fn ui() { crate::run_ui_tests("from_request"); } /// For some reason the compiler error for this is different locally and on CI. No idea why... So /// we don't use trybuild for this test. /// /// ```compile_fail /// #[derive(axum_macros::FromRequest)] /// struct Extractor { /// thing: bool, /// } /// ``` #[allow(dead_code)] fn test_field_doesnt_impl_from_request() {} axum-axum-v0.6.20/axum-macros/src/from_request/000077500000000000000000000000001446276533600214655ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/src/from_request/attr.rs000066400000000000000000000047141446276533600230130ustar00rootroot00000000000000use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, Combine}; use syn::{ parse::{Parse, ParseStream}, Token, }; pub(crate) mod kw { syn::custom_keyword!(via); syn::custom_keyword!(rejection); syn::custom_keyword!(state); } #[derive(Default)] pub(super) struct FromRequestContainerAttrs { pub(super) via: Option<(kw::via, syn::Path)>, pub(super) rejection: Option<(kw::rejection, syn::Path)>, pub(super) state: Option<(kw::state, syn::Type)>, } impl Parse for FromRequestContainerAttrs { fn parse(input: ParseStream) -> syn::Result { let mut via = None; let mut rejection = None; let mut state = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(kw::via) { parse_parenthesized_attribute(input, &mut via)?; } else if lh.peek(kw::rejection) { parse_parenthesized_attribute(input, &mut rejection)?; } else if lh.peek(kw::state) { parse_parenthesized_attribute(input, &mut state)?; } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { via, rejection, state, }) } } impl Combine for FromRequestContainerAttrs { fn combine(mut self, other: Self) -> syn::Result { let Self { via, rejection, state, } = other; combine_attribute(&mut self.via, via)?; combine_attribute(&mut self.rejection, rejection)?; combine_attribute(&mut self.state, state)?; Ok(self) } } #[derive(Default)] pub(super) struct FromRequestFieldAttrs { pub(super) via: Option<(kw::via, syn::Path)>, } impl Parse for FromRequestFieldAttrs { fn parse(input: ParseStream) -> syn::Result { let mut via = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(kw::via) { parse_parenthesized_attribute(input, &mut via)?; } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { via }) } } impl Combine for FromRequestFieldAttrs { fn combine(mut self, other: Self) -> syn::Result { let Self { via } = other; combine_attribute(&mut self.via, via)?; Ok(self) } } axum-axum-v0.6.20/axum-macros/src/lib.rs000066400000000000000000000523511446276533600200740ustar00rootroot00000000000000//! Macros for [`axum`]. //! //! [`axum`]: https://crates.io/crates/axum #![warn( clippy::all, clippy::dbg_macro, clippy::todo, clippy::empty_enum, clippy::enum_glob_use, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, clippy::needless_continue, clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, clippy::rest_pat_in_fully_bound_structs, clippy::fn_params_excessive_bools, clippy::exit, clippy::inefficient_to_string, clippy::linkedlist, clippy::macro_use_imports, clippy::option_option, clippy::verbose_file_reads, clippy::unnested_or_patterns, clippy::str_to_string, rust_2018_idioms, future_incompatible, nonstandard_style, missing_debug_implementations, missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] #![forbid(unsafe_code)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse::Parse, Type}; mod attr_parsing; #[cfg(feature = "__private")] mod axum_test; mod debug_handler; mod from_ref; mod from_request; mod typed_path; mod with_position; use from_request::Trait::{FromRequest, FromRequestParts}; /// Derive an implementation of [`FromRequest`]. /// /// Supports generating two kinds of implementations: /// 1. One that extracts each field individually. /// 2. Another that extracts the whole type at once via another extractor. /// /// # Each field individually /// /// By default `#[derive(FromRequest)]` will call `FromRequest::from_request` for each field: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{Extension, TypedHeader}, /// headers::ContentType, /// body::Bytes, /// }; /// /// #[derive(FromRequest)] /// struct MyExtractor { /// state: Extension, /// content_type: TypedHeader, /// request_body: Bytes, /// } /// /// #[derive(Clone)] /// struct State { /// // ... /// } /// /// async fn handler(extractor: MyExtractor) {} /// ``` /// /// This requires that each field is an extractor (i.e. implements [`FromRequest`]). /// /// Note that only the last field can consume the request body. Therefore this doesn't compile: /// /// ```compile_fail /// use axum_macros::FromRequest; /// use axum::body::Bytes; /// /// #[derive(FromRequest)] /// struct MyExtractor { /// // only the last field can implement `FromRequest` /// // other fields must only implement `FromRequestParts` /// bytes: Bytes, /// string: String, /// } /// ``` /// /// ## Extracting via another extractor /// /// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the /// field itself doesn't need to implement `FromRequest`: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{Extension, TypedHeader}, /// headers::ContentType, /// body::Bytes, /// }; /// /// #[derive(FromRequest)] /// struct MyExtractor { /// // This will extracted via `Extension::::from_request` /// #[from_request(via(Extension))] /// state: State, /// // and this via `TypedHeader::::from_request` /// #[from_request(via(TypedHeader))] /// content_type: ContentType, /// // Can still be combined with other extractors /// request_body: Bytes, /// } /// /// #[derive(Clone)] /// struct State { /// // ... /// } /// /// async fn handler(extractor: MyExtractor) {} /// ``` /// /// Note this requires the via extractor to be a generic newtype struct (a tuple struct with /// exactly one public field) that implements `FromRequest`: /// /// ``` /// pub struct ViaExtractor(pub T); /// /// // impl FromRequest for ViaExtractor { ... } /// ``` /// /// More complex via extractors are not supported and require writing a manual implementation. /// /// ## Optional fields /// /// `#[from_request(via(...))]` supports `Option<_>` and `Result<_, _>` to make fields optional: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{TypedHeader, rejection::TypedHeaderRejection}, /// headers::{ContentType, UserAgent}, /// }; /// /// #[derive(FromRequest)] /// struct MyExtractor { /// // This will extracted via `Option::>::from_request` /// #[from_request(via(TypedHeader))] /// content_type: Option, /// // This will extracted via /// // `Result::, TypedHeaderRejection>::from_request` /// #[from_request(via(TypedHeader))] /// user_agent: Result, /// } /// /// async fn handler(extractor: MyExtractor) {} /// ``` /// /// ## The rejection /// /// By default [`axum::response::Response`] will be used as the rejection. You can also use your own /// rejection type with `#[from_request(rejection(YourType))]`: /// /// ``` /// use axum::{ /// extract::{ /// rejection::{ExtensionRejection, StringRejection}, /// FromRequest, /// }, /// Extension, /// response::{Response, IntoResponse}, /// }; /// /// #[derive(FromRequest)] /// #[from_request(rejection(MyRejection))] /// struct MyExtractor { /// state: Extension, /// body: String, /// } /// /// struct MyRejection(Response); /// /// // This tells axum how to convert `Extension`'s rejections into `MyRejection` /// impl From for MyRejection { /// fn from(rejection: ExtensionRejection) -> Self { /// // ... /// # todo!() /// } /// } /// /// // This tells axum how to convert `String`'s rejections into `MyRejection` /// impl From for MyRejection { /// fn from(rejection: StringRejection) -> Self { /// // ... /// # todo!() /// } /// } /// /// // All rejections must implement `IntoResponse` /// impl IntoResponse for MyRejection { /// fn into_response(self) -> Response { /// self.0 /// } /// } /// ``` /// /// # The whole type at once /// /// By using `#[from_request(via(...))]` on the container you can extract the whole type at once, /// instead of each field individually: /// /// ``` /// use axum_macros::FromRequest; /// use axum::extract::Extension; /// /// // This will extracted via `Extension::::from_request` /// #[derive(Clone, FromRequest)] /// #[from_request(via(Extension))] /// struct State { /// // ... /// } /// /// async fn handler(state: State) {} /// ``` /// /// The rejection will be the "via extractors"'s rejection. For the previous example that would be /// [`axum::extract::rejection::ExtensionRejection`]. /// /// You can use a different rejection type with `#[from_request(rejection(YourType))]`: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{Extension, rejection::ExtensionRejection}, /// response::{IntoResponse, Response}, /// Json, /// http::StatusCode, /// }; /// use serde_json::json; /// /// // This will extracted via `Extension::::from_request` /// #[derive(Clone, FromRequest)] /// #[from_request( /// via(Extension), /// // Use your own rejection type /// rejection(MyRejection), /// )] /// struct State { /// // ... /// } /// /// struct MyRejection(Response); /// /// // This tells axum how to convert `Extension`'s rejections into `MyRejection` /// impl From for MyRejection { /// fn from(rejection: ExtensionRejection) -> Self { /// let response = ( /// StatusCode::INTERNAL_SERVER_ERROR, /// Json(json!({ "error": "Something went wrong..." })), /// ).into_response(); /// /// MyRejection(response) /// } /// } /// /// // All rejections must implement `IntoResponse` /// impl IntoResponse for MyRejection { /// fn into_response(self) -> Response { /// self.0 /// } /// } /// /// async fn handler(state: State) {} /// ``` /// /// This allows you to wrap other extractors and easily customize the rejection: /// /// ``` /// use axum_macros::FromRequest; /// use axum::{ /// extract::{Extension, rejection::JsonRejection}, /// response::{IntoResponse, Response}, /// http::StatusCode, /// }; /// use serde_json::json; /// use serde::Deserialize; /// /// // create an extractor that internally uses `axum::Json` but has a custom rejection /// #[derive(FromRequest)] /// #[from_request(via(axum::Json), rejection(MyRejection))] /// struct MyJson(T); /// /// struct MyRejection(Response); /// /// impl From for MyRejection { /// fn from(rejection: JsonRejection) -> Self { /// let response = ( /// StatusCode::INTERNAL_SERVER_ERROR, /// axum::Json(json!({ "error": rejection.to_string() })), /// ).into_response(); /// /// MyRejection(response) /// } /// } /// /// impl IntoResponse for MyRejection { /// fn into_response(self) -> Response { /// self.0 /// } /// } /// /// #[derive(Deserialize)] /// struct Payload {} /// /// async fn handler( /// // make sure to use `MyJson` and not `axum::Json` /// MyJson(payload): MyJson, /// ) {} /// ``` /// /// # Known limitations /// /// Generics are only supported on tuple structs with exactly on field. Thus this doesn't work /// /// ```compile_fail /// #[derive(axum_macros::FromRequest)] /// struct MyExtractor { /// thing: Option, /// } /// ``` /// /// [`FromRequest`]: https://docs.rs/axum/latest/axum/extract/trait.FromRequest.html /// [`axum::response::Response`]: https://docs.rs/axum/0.6/axum/response/type.Response.html /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html #[proc_macro_derive(FromRequest, attributes(from_request))] pub fn derive_from_request(item: TokenStream) -> TokenStream { expand_with(item, |item| from_request::expand(item, FromRequest)) } /// Derive an implementation of [`FromRequestParts`]. /// /// This works similarly to `#[derive(FromRequest)]` except it uses [`FromRequestParts`]. All the /// same options are supported. /// /// # Example /// /// ``` /// use axum_macros::FromRequestParts; /// use axum::{ /// extract::{Query, TypedHeader}, /// headers::ContentType, /// }; /// use std::collections::HashMap; /// /// #[derive(FromRequestParts)] /// struct MyExtractor { /// #[from_request(via(Query))] /// query_params: HashMap, /// content_type: TypedHeader, /// } /// /// async fn handler(extractor: MyExtractor) {} /// ``` /// /// # Cannot extract the body /// /// [`FromRequestParts`] cannot extract the request body: /// /// ```compile_fail /// use axum_macros::FromRequestParts; /// /// #[derive(FromRequestParts)] /// struct MyExtractor { /// body: String, /// } /// ``` /// /// Use `#[derive(FromRequest)]` for that. /// /// [`FromRequestParts`]: https://docs.rs/axum/0.6/axum/extract/trait.FromRequestParts.html #[proc_macro_derive(FromRequestParts, attributes(from_request))] pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { expand_with(item, |item| from_request::expand(item, FromRequestParts)) } /// Generates better error messages when applied handler functions. /// /// While using [`axum`], you can get long error messages for simple mistakes. For example: /// /// ```compile_fail /// use axum::{routing::get, Router}; /// /// #[tokio::main] /// async fn main() { /// let app = Router::new().route("/", get(handler)); /// /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(app.into_make_service()) /// .await /// .unwrap(); /// } /// /// fn handler() -> &'static str { /// "Hello, world" /// } /// ``` /// /// You will get a long error message about function not implementing [`Handler`] trait. But why /// does this function not implement it? To figure it out, the [`debug_handler`] macro can be used. /// /// ```compile_fail /// # use axum::{routing::get, Router}; /// # use axum_macros::debug_handler; /// # /// # #[tokio::main] /// # async fn main() { /// # let app = Router::new().route("/", get(handler)); /// # /// # axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// # .serve(app.into_make_service()) /// # .await /// # .unwrap(); /// # } /// # /// #[debug_handler] /// fn handler() -> &'static str { /// "Hello, world" /// } /// ``` /// /// ```text /// error: handlers must be async functions /// --> main.rs:xx:1 /// | /// xx | fn handler() -> &'static str { /// | ^^ /// ``` /// /// As the error message says, handler function needs to be async. /// /// ``` /// use axum::{routing::get, Router, debug_handler}; /// /// #[tokio::main] /// async fn main() { /// # async { /// let app = Router::new().route("/", get(handler)); /// /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(app.into_make_service()) /// .await /// .unwrap(); /// # }; /// } /// /// #[debug_handler] /// async fn handler() -> &'static str { /// "Hello, world" /// } /// ``` /// /// # Changing request body type /// /// By default `#[debug_handler]` assumes your request body type is `axum::body::Body`. This will /// work for most extractors but, for example, it wont work for `Request`, /// which only implements `FromRequest` and _not_ `FromRequest`. /// /// To work around that the request body type can be customized like so: /// /// ``` /// use axum::{body::BoxBody, http::Request, debug_handler}; /// /// #[debug_handler(body = BoxBody)] /// async fn handler(request: Request) {} /// ``` /// /// # Changing state type /// /// By default `#[debug_handler]` assumes your state type is `()` unless your handler has a /// [`axum::extract::State`] argument: /// /// ``` /// use axum::{debug_handler, extract::State}; /// /// #[debug_handler] /// async fn handler( /// // this makes `#[debug_handler]` use `AppState` /// State(state): State, /// ) {} /// /// #[derive(Clone)] /// struct AppState {} /// ``` /// /// If your handler takes multiple [`axum::extract::State`] arguments or you need to otherwise /// customize the state type you can set it with `#[debug_handler(state = ...)]`: /// /// ``` /// use axum::{debug_handler, extract::{State, FromRef}}; /// /// #[debug_handler(state = AppState)] /// async fn handler( /// State(app_state): State, /// State(inner_state): State, /// ) {} /// /// #[derive(Clone)] /// struct AppState { /// inner: InnerState, /// } /// /// #[derive(Clone)] /// struct InnerState {} /// /// impl FromRef for InnerState { /// fn from_ref(state: &AppState) -> Self { /// state.inner.clone() /// } /// } /// ``` /// /// # Limitations /// /// This macro does not work for functions in an `impl` block that don't have a `self` parameter: /// /// ```compile_fail /// use axum::{debug_handler, extract::Path}; /// /// struct App {} /// /// impl App { /// #[debug_handler] /// async fn handler(Path(_): Path) {} /// } /// ``` /// /// This will yield an error similar to this: /// /// ```text /// error[E0425]: cannot find function `__axum_macros_check_handler_0_from_request_check` in this scope // --> src/main.rs:xx:xx // | // xx | pub async fn handler(Path(_): Path) {} // | ^^^^ not found in this scope /// ``` /// /// # Performance /// /// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`) /// /// [`axum`]: https://docs.rs/axum/latest /// [`Handler`]: https://docs.rs/axum/latest/axum/handler/trait.Handler.html /// [`axum::extract::State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html /// [`debug_handler`]: macro@debug_handler #[proc_macro_attribute] pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { #[cfg(not(debug_assertions))] return input; #[cfg(debug_assertions)] return expand_attr_with(_attr, input, debug_handler::expand); } /// Private API: Do no use this! /// /// Attribute macro to be placed on test functions that'll generate two functions: /// /// 1. One identical to the function it was placed on. /// 2. One where calls to `Router::nest` has been replaced with `Router::nest_service` /// /// This makes it easy to that `nest` and `nest_service` behaves in the same way, without having to /// manually write identical tests for both methods. #[cfg(feature = "__private")] #[proc_macro_attribute] #[doc(hidden)] pub fn __private_axum_test(_attr: TokenStream, input: TokenStream) -> TokenStream { expand_attr_with(_attr, input, axum_test::expand) } /// Derive an implementation of [`axum_extra::routing::TypedPath`]. /// /// See that trait for more details. /// /// [`axum_extra::routing::TypedPath`]: https://docs.rs/axum-extra/latest/axum_extra/routing/trait.TypedPath.html #[proc_macro_derive(TypedPath, attributes(typed_path))] pub fn derive_typed_path(input: TokenStream) -> TokenStream { expand_with(input, typed_path::expand) } /// Derive an implementation of [`FromRef`] for each field in a struct. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// extract::{State, FromRef}, /// }; /// /// # /// # type AuthToken = String; /// # type DatabasePool = (); /// # /// // This will implement `FromRef` for each field in the struct. /// #[derive(FromRef, Clone)] /// struct AppState { /// auth_token: AuthToken, /// database_pool: DatabasePool, /// // fields can also be skipped /// #[from_ref(skip)] /// api_token: String, /// } /// /// // So those types can be extracted via `State` /// async fn handler(State(auth_token): State) {} /// /// async fn other_handler(State(database_pool): State) {} /// /// # let auth_token = Default::default(); /// # let database_pool = Default::default(); /// let state = AppState { /// auth_token, /// database_pool, /// api_token: "secret".to_owned(), /// }; /// /// let app = Router::new() /// .route("/", get(handler).post(other_handler)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` /// /// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html #[proc_macro_derive(FromRef, attributes(from_ref))] pub fn derive_from_ref(item: TokenStream) -> TokenStream { expand_with(item, from_ref::expand) } fn expand_with(input: TokenStream, f: F) -> TokenStream where F: FnOnce(I) -> syn::Result, I: Parse, K: ToTokens, { expand(syn::parse(input).and_then(f)) } fn expand_attr_with(attr: TokenStream, input: TokenStream, f: F) -> TokenStream where F: FnOnce(A, I) -> K, A: Parse, I: Parse, K: ToTokens, { let expand_result = (|| { let attr = syn::parse(attr)?; let input = syn::parse(input)?; Ok(f(attr, input)) })(); expand(expand_result) } fn expand(result: syn::Result) -> TokenStream where T: ToTokens, { match result { Ok(tokens) => { let tokens = (quote! { #tokens }).into(); if std::env::var_os("AXUM_MACROS_DEBUG").is_some() { eprintln!("{tokens}"); } tokens } Err(err) => err.into_compile_error().into(), } } fn infer_state_types<'a, I>(types: I) -> impl Iterator + 'a where I: Iterator + 'a, { types .filter_map(|ty| { if let Type::Path(path) = ty { Some(&path.path) } else { None } }) .filter_map(|path| { if let Some(last_segment) = path.segments.last() { if last_segment.ident != "State" { return None; } match &last_segment.arguments { syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => { Some(args.args.first().unwrap()) } _ => None, } } else { None } }) .filter_map(|generic_arg| { if let syn::GenericArgument::Type(ty) = generic_arg { Some(ty) } else { None } }) .cloned() } #[cfg(test)] fn run_ui_tests(directory: &str) { #[rustversion::nightly] fn go(directory: &str) { let t = trybuild::TestCases::new(); if let Ok(mut path) = std::env::var("AXUM_TEST_ONLY") { if let Some(path_without_prefix) = path.strip_prefix("axum-macros/") { path = path_without_prefix.to_owned(); } if !path.contains(&format!("/{directory}/")) { return; } if path.contains("/fail/") { t.compile_fail(path); } else if path.contains("/pass/") { t.pass(path); } else { panic!() } } else { t.compile_fail(format!("tests/{directory}/fail/*.rs")); t.pass(format!("tests/{directory}/pass/*.rs")); } } #[rustversion::not(nightly)] fn go(_directory: &str) {} go(directory); } axum-axum-v0.6.20/axum-macros/src/typed_path.rs000066400000000000000000000306361446276533600214710ustar00rootroot00000000000000use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, ItemStruct, LitStr, Token}; use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine}; pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { let ItemStruct { attrs, ident, generics, fields, .. } = &item_struct; if !generics.params.is_empty() || generics.where_clause.is_some() { return Err(syn::Error::new_spanned( generics, "`#[derive(TypedPath)]` doesn't support generics", )); } let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?; let path = path.ok_or_else(|| { syn::Error::new( Span::call_site(), "Missing path: `#[typed_path(\"/foo/bar\")]`", ) })?; let rejection = rejection.map(second); match fields { syn::Fields::Named(_) => { let segments = parse_path(&path)?; Ok(expand_named_fields(ident, path, &segments, rejection)) } syn::Fields::Unnamed(fields) => { let segments = parse_path(&path)?; expand_unnamed_fields(fields, ident, path, &segments, rejection) } syn::Fields::Unit => expand_unit_fields(ident, path, rejection), } } mod kw { syn::custom_keyword!(rejection); } #[derive(Default)] struct Attrs { path: Option, rejection: Option<(kw::rejection, syn::Path)>, } impl Parse for Attrs { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut path = None; let mut rejection = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(LitStr) { path = Some(input.parse()?); } else if lh.peek(kw::rejection) { parse_parenthesized_attribute(input, &mut rejection)?; } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { path, rejection }) } } impl Combine for Attrs { fn combine(mut self, other: Self) -> syn::Result { let Self { path, rejection } = other; if let Some(path) = path { if self.path.is_some() { return Err(syn::Error::new_spanned( path, "path specified more than once", )); } self.path = Some(path); } combine_attribute(&mut self.rejection, rejection)?; Ok(self) } } fn expand_named_fields( ident: &syn::Ident, path: LitStr, segments: &[Segment], rejection: Option, ) -> TokenStream { let format_str = format_str_from_path(segments); let captures = captures_from_path(segments); let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::axum_extra::routing::TypedPath for #ident { const PATH: &'static str = #path; } }; let display_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::std::fmt::Display for #ident { #[allow(clippy::unnecessary_to_owned)] fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { let Self { #(#captures,)* } = self; write!( f, #format_str, #( #captures = ::axum_extra::__private::utf8_percent_encode( &#captures.to_string(), ::axum_extra::__private::PATH_SEGMENT, ) ),* ) } } }; let rejection_assoc_type = rejection_assoc_type(&rejection); let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where S: Send + Sync, { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, state: &S, ) -> ::std::result::Result { ::axum::extract::Path::from_request_parts(parts, state) .await .map(|path| path.0) #map_err_rejection } } }; quote! { #typed_path_impl #display_impl #from_request_impl } } fn expand_unnamed_fields( fields: &syn::FieldsUnnamed, ident: &syn::Ident, path: LitStr, segments: &[Segment], rejection: Option, ) -> syn::Result { let num_captures = segments .iter() .filter(|segment| match segment { Segment::Capture(_, _) => true, Segment::Static(_) => false, }) .count(); let num_fields = fields.unnamed.len(); if num_fields != num_captures { return Err(syn::Error::new_spanned( fields, format!( "Mismatch in number of captures and fields. Path has {} but struct has {}", simple_pluralize(num_captures, "capture"), simple_pluralize(num_fields, "field"), ), )); } let destructure_self = segments .iter() .filter_map(|segment| match segment { Segment::Capture(capture, _) => Some(capture), Segment::Static(_) => None, }) .enumerate() .map(|(idx, capture)| { let idx = syn::Index { index: idx as _, span: Span::call_site(), }; let capture = format_ident!("{}", capture, span = path.span()); quote_spanned! {path.span()=> #idx: #capture, } }); let format_str = format_str_from_path(segments); let captures = captures_from_path(segments); let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::axum_extra::routing::TypedPath for #ident { const PATH: &'static str = #path; } }; let display_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::std::fmt::Display for #ident { #[allow(clippy::unnecessary_to_owned)] fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { let Self { #(#destructure_self)* } = self; write!( f, #format_str, #( #captures = ::axum_extra::__private::utf8_percent_encode( &#captures.to_string(), ::axum_extra::__private::PATH_SEGMENT, ) ),* ) } } }; let rejection_assoc_type = rejection_assoc_type(&rejection); let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where S: Send + Sync, { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, state: &S, ) -> ::std::result::Result { ::axum::extract::Path::from_request_parts(parts, state) .await .map(|path| path.0) #map_err_rejection } } }; Ok(quote! { #typed_path_impl #display_impl #from_request_impl }) } fn simple_pluralize(count: usize, word: &str) -> String { if count == 1 { format!("{count} {word}") } else { format!("{count} {word}s") } } fn expand_unit_fields( ident: &syn::Ident, path: LitStr, rejection: Option, ) -> syn::Result { for segment in parse_path(&path)? { match segment { Segment::Capture(_, span) => { return Err(syn::Error::new( span, "Typed paths for unit structs cannot contain captures", )); } Segment::Static(_) => {} } } let typed_path_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::axum_extra::routing::TypedPath for #ident { const PATH: &'static str = #path; } }; let display_impl = quote_spanned! {path.span()=> #[automatically_derived] impl ::std::fmt::Display for #ident { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { write!(f, #path) } } }; let rejection_assoc_type = if let Some(rejection) = &rejection { quote! { #rejection } } else { quote! { ::axum::http::StatusCode } }; let create_rejection = if let Some(rejection) = &rejection { quote! { Err(<#rejection as ::std::default::Default>::default()) } } else { quote! { Err(::axum::http::StatusCode::NOT_FOUND) } }; let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where S: Send + Sync, { type Rejection = #rejection_assoc_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, _state: &S, ) -> ::std::result::Result { if parts.uri.path() == ::PATH { Ok(Self) } else { #create_rejection } } } }; Ok(quote! { #typed_path_impl #display_impl #from_request_impl }) } fn format_str_from_path(segments: &[Segment]) -> String { segments .iter() .map(|segment| match segment { Segment::Capture(capture, _) => format!("{{{capture}}}"), Segment::Static(segment) => segment.to_owned(), }) .collect::>() .join("/") } fn captures_from_path(segments: &[Segment]) -> Vec { segments .iter() .filter_map(|segment| match segment { Segment::Capture(capture, span) => Some(format_ident!("{}", capture, span = *span)), Segment::Static(_) => None, }) .collect::>() } fn parse_path(path: &LitStr) -> syn::Result> { let value = path.value(); if value.is_empty() { return Err(syn::Error::new_spanned( path, "paths must start with a `/`. Use \"/\" for root routes", )); } else if !path.value().starts_with('/') { return Err(syn::Error::new_spanned(path, "paths must start with a `/`")); } path.value() .split('/') .map(|segment| { if let Some(capture) = segment .strip_prefix(':') .or_else(|| segment.strip_prefix('*')) { Ok(Segment::Capture(capture.to_owned(), path.span())) } else { Ok(Segment::Static(segment.to_owned())) } }) .collect() } enum Segment { Capture(String, Span), Static(String), } fn path_rejection() -> TokenStream { quote! { <::axum::extract::Path as ::axum::extract::FromRequestParts>::Rejection } } fn rejection_assoc_type(rejection: &Option) -> TokenStream { match rejection { Some(rejection) => quote! { #rejection }, None => path_rejection(), } } fn map_err_rejection(rejection: &Option) -> TokenStream { rejection .as_ref() .map(|rejection| { let path_rejection = path_rejection(); quote! { .map_err(|rejection| { <#rejection as ::std::convert::From<#path_rejection>>::from(rejection) }) } }) .unwrap_or_default() } #[test] fn ui() { crate::run_ui_tests("typed_path"); } axum-axum-v0.6.20/axum-macros/src/with_position.rs000066400000000000000000000067101446276533600222230ustar00rootroot00000000000000// this is copied from itertools under the following license // // Copyright (c) 2015 // // Permission is hereby granted, free of charge, to any // person obtaining a copy of this software and associated // documentation files (the "Software"), to deal in the // Software without restriction, including without // limitation the rights to use, copy, modify, merge, // publish, distribute, sublicense, and/or sell copies of // the Software, and to permit persons to whom the Software // is furnished to do so, subject to the following // conditions: // // The above copyright notice and this permission notice // shall be included in all copies or substantial portions // of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF // ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED // TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A // PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT // SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. use std::iter::{Fuse, FusedIterator, Peekable}; pub(crate) struct WithPosition where I: Iterator, { handled_first: bool, peekable: Peekable>, } impl WithPosition where I: Iterator, { pub(crate) fn new(iter: I) -> WithPosition { WithPosition { handled_first: false, peekable: iter.fuse().peekable(), } } } impl Clone for WithPosition where I: Clone + Iterator, I::Item: Clone, { fn clone(&self) -> Self { Self { handled_first: self.handled_first, peekable: self.peekable.clone(), } } } #[derive(Copy, Clone, Debug, PartialEq)] pub(crate) enum Position { First(T), Middle(T), Last(T), Only(T), } impl Position { pub(crate) fn into_inner(self) -> T { match self { Position::First(x) | Position::Middle(x) | Position::Last(x) | Position::Only(x) => x, } } } impl Iterator for WithPosition { type Item = Position; fn next(&mut self) -> Option { match self.peekable.next() { Some(item) => { if !self.handled_first { // Haven't seen the first item yet, and there is one to give. self.handled_first = true; // Peek to see if this is also the last item, // in which case tag it as `Only`. match self.peekable.peek() { Some(_) => Some(Position::First(item)), None => Some(Position::Only(item)), } } else { // Have seen the first item, and there's something left. // Peek to see if this is the last item. match self.peekable.peek() { Some(_) => Some(Position::Middle(item)), None => Some(Position::Last(item)), } } } // Iterator is finished. None => None, } } fn size_hint(&self) -> (usize, Option) { self.peekable.size_hint() } } impl ExactSizeIterator for WithPosition where I: ExactSizeIterator {} impl FusedIterator for WithPosition {} axum-axum-v0.6.20/axum-macros/tests/000077500000000000000000000000001446276533600173255ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/debug_handler/000077500000000000000000000000001446276533600221105ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/000077500000000000000000000000001446276533600230235ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/.gitkeep000066400000000000000000000000001446276533600244420ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/argument_not_extractor.rs000066400000000000000000000001371446276533600301670ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] async fn handler(foo: bool) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr000066400000000000000000000025101446276533600310430ustar00rootroot00000000000000error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} | ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool` | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/latest/axum/extract/index.html` for details = help: the following other types implement trait `FromRequestParts`: <() as FromRequestParts> <(T1, T2) as FromRequestParts> <(T1, T2, T3) as FromRequestParts> <(T1, T2, T3, T4) as FromRequestParts> <(T1, T2, T3, T4, T5) as FromRequestParts> <(T1, T2, T3, T4, T5, T6) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> and 26 others = note: required for `bool` to implement `FromRequest<(), Body, axum_core::extract::private::ViaParts>` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} | ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/duplicate_args.rs000066400000000000000000000002701446276533600263560ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler(body = BoxBody, body = BoxBody)] async fn handler() {} #[debug_handler(state = (), state = ())] async fn handler_2() {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/duplicate_args.stderr000066400000000000000000000005631446276533600272420ustar00rootroot00000000000000error: `body` specified more than once --> tests/debug_handler/fail/duplicate_args.rs:3:33 | 3 | #[debug_handler(body = BoxBody, body = BoxBody)] | ^^^^ error: `state` specified more than once --> tests/debug_handler/fail/duplicate_args.rs:6:29 | 6 | #[debug_handler(state = (), state = ())] | ^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/extract_self_mut.rs000066400000000000000000000007061446276533600267440ustar00rootroot00000000000000use axum::{ async_trait, extract::FromRequest, http::Request, }; use axum_macros::debug_handler; struct A; #[async_trait] impl FromRequest for A where B: Send + 'static, S: Send + Sync, { type Rejection = (); async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } impl A { #[debug_handler] async fn handler(&mut self) {} } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr000066400000000000000000000002661446276533600276240ustar00rootroot00000000000000error: Handlers must only take owned values --> tests/debug_handler/fail/extract_self_mut.rs:25:22 | 25 | async fn handler(&mut self) {} | ^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/extract_self_ref.rs000066400000000000000000000007021446276533600267070ustar00rootroot00000000000000use axum::{ async_trait, extract::FromRequest, http::Request, }; use axum_macros::debug_handler; struct A; #[async_trait] impl FromRequest for A where B: Send + 'static, S: Send + Sync, { type Rejection = (); async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } impl A { #[debug_handler] async fn handler(&self) {} } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr000066400000000000000000000002561446276533600275720ustar00rootroot00000000000000error: Handlers must only take owned values --> tests/debug_handler/fail/extract_self_ref.rs:25:22 | 25 | async fn handler(&self) {} | ^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/generics.rs000066400000000000000000000001431446276533600251660ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] async fn handler(extract: T) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/generics.stderr000066400000000000000000000002731446276533600260510ustar00rootroot00000000000000error: `#[axum_macros::debug_handler]` doesn't support generic functions --> tests/debug_handler/fail/generics.rs:4:17 | 4 | async fn handler(extract: T) {} | ^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/invalid_attrs.rs000066400000000000000000000001331446276533600262310ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler(foo)] async fn handler() {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/invalid_attrs.stderr000066400000000000000000000002141446276533600271100ustar00rootroot00000000000000error: expected `body` or `state` --> tests/debug_handler/fail/invalid_attrs.rs:3:17 | 3 | #[debug_handler(foo)] | ^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/json_not_deserialize.rs000066400000000000000000000002111446276533600275740ustar00rootroot00000000000000use axum::Json; use axum_macros::debug_handler; struct Struct {} #[debug_handler] async fn handler(foo: Json) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr000066400000000000000000000015731446276533600304670ustar00rootroot00000000000000error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is not satisfied --> tests/debug_handler/fail/json_not_deserialize.rs:7:23 | 7 | async fn handler(foo: Json) {} | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct` | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] &'a serde_json::raw::RawValue &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) and $N others = note: required for `Struct` to implement `serde::de::DeserializeOwned` = note: required for `Json` to implement `FromRequest<(), Body>` = help: see issue #48214 = help: add `#![feature(trivial_bounds)]` to the crate attributes to enable axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/multiple_paths.rs000066400000000000000000000002171446276533600264230ustar00rootroot00000000000000use axum::extract::Path; use axum_macros::debug_handler; #[debug_handler] async fn handler(_: Path, _: Path) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/multiple_paths.stderr000066400000000000000000000011601446276533600273000ustar00rootroot00000000000000error: Multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`, not by applying multiple `Path<_>` extractors --> tests/debug_handler/fail/multiple_paths.rs:5:18 | 5 | async fn handler(_: Path, _: Path) {} | ^^^^^^^^^^^^^^^ error: Multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`, not by applying multiple `Path<_>` extractors --> tests/debug_handler/fail/multiple_paths.rs:5:35 | 5 | async fn handler(_: Path, _: Path) {} | ^^^^^^^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/multiple_request_consumers.rs000066400000000000000000000003731446276533600310750ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::{Json, body::Bytes, http::{Method, Uri}}; #[debug_handler] async fn one(_: Json<()>, _: String, _: Uri) {} #[debug_handler] async fn two(_: Json<()>, _: Method, _: Bytes, _: Uri, _: String) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/multiple_request_consumers.stderr000066400000000000000000000011311446276533600317450ustar00rootroot00000000000000error: Can't have two extractors that consume the request body. `Json<_>` and `String` both do that. --> tests/debug_handler/fail/multiple_request_consumers.rs:5:14 | 5 | async fn one(_: Json<()>, _: String, _: Uri) {} | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Can't have more than one extractor that consume the request body. `Json<_>`, `Bytes`, and `String` all do that. --> tests/debug_handler/fail/multiple_request_consumers.rs:8:14 | 8 | async fn two(_: Json<()>, _: Method, _: Bytes, _: Uri, _: String) {} | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_a_function.rs000066400000000000000000000001121446276533600263700ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] struct A; fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_a_function.stderr000066400000000000000000000001461446276533600272560ustar00rootroot00000000000000error: expected `fn` --> tests/debug_handler/fail/not_a_function.rs:4:1 | 4 | struct A; | ^^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_async.rs000066400000000000000000000001201446276533600253570ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] fn handler() {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_async.stderr000066400000000000000000000001621446276533600262440ustar00rootroot00000000000000error: Handlers must be `async fn`s --> tests/debug_handler/fail/not_async.rs:4:1 | 4 | fn handler() {} | ^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_send.rs000066400000000000000000000002161446276533600252010ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] async fn handler() { let rc = std::rc::Rc::new(()); async {}.await; } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/not_send.stderr000066400000000000000000000016321446276533600260630ustar00rootroot00000000000000error: future cannot be sent between threads safely --> tests/debug_handler/fail/not_send.rs:3:1 | 3 | #[debug_handler] | ^^^^^^^^^^^^^^^^ future returned by `handler` is not `Send` | = help: within `impl Future`, the trait `Send` is not implemented for `Rc<()>` note: future is not `Send` as this value is used across an await --> tests/debug_handler/fail/not_send.rs:6:13 | 5 | let rc = std::rc::Rc::new(()); | -- has type `Rc<()>` which is not `Send` 6 | async {}.await; | ^^^^^^ await occurs here, with `rc` maybe used later 7 | } | - `rc` is later dropped here note: required by a bound in `check` --> tests/debug_handler/fail/not_send.rs:3:1 | 3 | #[debug_handler] | ^^^^^^^^^^^^^^^^ required by this bound in `check` = note: this error originates in the attribute macro `debug_handler` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/too_many_extractors.rs000066400000000000000000000005211446276533600274720ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::http::Uri; #[debug_handler] async fn handler( e1: Uri, e2: Uri, e3: Uri, e4: Uri, e5: Uri, e6: Uri, e7: Uri, e8: Uri, e9: Uri, e10: Uri, e11: Uri, e12: Uri, e13: Uri, e14: Uri, e15: Uri, e16: Uri, e17: Uri, ) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/too_many_extractors.stderr000066400000000000000000000005131446276533600303520ustar00rootroot00000000000000error: Handlers cannot take more than 16 arguments. Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors --> tests/debug_handler/fail/too_many_extractors.rs:6:5 | 6 | / e1: Uri, 7 | | e2: Uri, 8 | | e3: Uri, 9 | | e4: Uri, ... | 21 | | e16: Uri, 22 | | e17: Uri, | |_____________^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/wrong_order.rs000066400000000000000000000002671446276533600257250ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::{Json, http::Uri}; #[debug_handler] async fn one(_: Json<()>, _: Uri) {} #[debug_handler] async fn two(_: String, _: Uri) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/wrong_order.stderr000066400000000000000000000007001446276533600265740ustar00rootroot00000000000000error: `Json<_>` consumes the request body and thus must be the last argument to the handler function --> tests/debug_handler/fail/wrong_order.rs:5:17 | 5 | async fn one(_: Json<()>, _: Uri) {} | ^^^^^^^^ error: `String` consumes the request body and thus must be the last argument to the handler function --> tests/debug_handler/fail/wrong_order.rs:8:17 | 8 | async fn two(_: String, _: Uri) {} | ^^^^^^ axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/wrong_return_type.rs000066400000000000000000000001511446276533600271620ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] async fn handler() -> bool { false } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr000066400000000000000000000015711446276533600300500ustar00rootroot00000000000000error[E0277]: the trait bound `bool: IntoResponse` is not satisfied --> tests/debug_handler/fail/wrong_return_type.rs:4:23 | 4 | async fn handler() -> bool { | ^^^^ the trait `IntoResponse` is not implemented for `bool` | = help: the following other types implement trait `IntoResponse`: &'static [u8; N] &'static [u8] &'static str () (Response<()>, R) (Response<()>, T1, R) (Response<()>, T1, T2, R) (Response<()>, T1, T2, T3, R) and $N others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 | 4 | async fn handler() -> bool { | ^^^^ required by this bound in `__axum_macros_check_handler_into_response::{closure#0}::check` axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/000077500000000000000000000000001446276533600230565ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/associated_fn_without_self.rs000066400000000000000000000001641446276533600310230ustar00rootroot00000000000000use axum_macros::debug_handler; struct A; impl A { #[debug_handler] async fn handler() {} } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/deny_unreachable_code.rs000066400000000000000000000002151446276533600277040ustar00rootroot00000000000000#![deny(unreachable_code)] use axum::extract::Path; #[axum_macros::debug_handler] async fn handler(Path(_): Path) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/different_request_body_type.rs000066400000000000000000000004511446276533600312200ustar00rootroot00000000000000use axum::{body::BoxBody, http::Request}; use axum_macros::debug_handler; #[debug_handler(body = BoxBody)] async fn handler(_: Request) {} #[debug_handler(body = axum::body::BoxBody,)] async fn handler_with_trailing_comma_and_type_path(_: Request) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/impl_future.rs000066400000000000000000000002231446276533600257540ustar00rootroot00000000000000use axum_macros::debug_handler; use std::future::Future; #[debug_handler] fn handler() -> impl Future { async {} } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/impl_into_response.rs000066400000000000000000000002301446276533600273270ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::response::IntoResponse; #[debug_handler] async fn handler() -> impl IntoResponse { "hi!" } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/infer_state.rs000066400000000000000000000010561446276533600257310ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::extract::State; #[debug_handler] async fn handler(_: State) {} #[debug_handler] async fn handler_2(_: axum::extract::State) {} #[debug_handler] async fn handler_3( _: axum::extract::State, _: axum::extract::State, ) {} #[debug_handler] async fn handler_4( _: State, _: State, ) {} #[debug_handler] async fn handler_5( _: axum::extract::State, _: State, ) {} #[derive(Clone)] struct AppState; fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/multiple_extractors.rs000066400000000000000000000002341446276533600275340ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::http::{Method, Uri}; #[debug_handler] async fn handler(_one: Method, _two: Uri, _three: String) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/mut_extractor.rs000066400000000000000000000002121446276533600263170ustar00rootroot00000000000000use axum_macros::debug_handler; #[debug_handler] async fn handler(mut foo: String) -> String { foo += "bar"; foo } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/ready.rs000066400000000000000000000002151446276533600245260ustar00rootroot00000000000000use axum_macros::debug_handler; use std::future::{Ready, ready}; #[debug_handler] fn handler() -> Ready<()> { ready(()) } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/request_last.rs000066400000000000000000000002671446276533600261440ustar00rootroot00000000000000use axum::{body::Body, extract::Extension, http::Request}; use axum_macros::debug_handler; #[debug_handler] async fn handler(_: Extension, _: Request) {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs000066400000000000000000000054601446276533600307370ustar00rootroot00000000000000use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::IntoResponse}; use axum_macros::debug_handler; fn main() {} #[debug_handler] fn concrete_future() -> std::future::Ready> { std::future::ready(Ok(())) } #[debug_handler] fn impl_future() -> impl std::future::Future> { std::future::ready(Ok(())) } // === no args === #[debug_handler] async fn handler_no_arg_one() -> Result { Ok(()) } #[debug_handler] async fn handler_no_arg_two() -> Result<(), impl IntoResponse> { Err(()) } #[debug_handler] async fn handler_no_arg_three() -> Result { Ok::<_, ()>(()) } #[debug_handler] async fn handler_no_arg_four() -> Result { Err::<(), _>(()) } // === args === #[debug_handler] async fn handler_one(foo: String) -> Result { dbg!(foo); Ok(()) } #[debug_handler] async fn handler_two(foo: String) -> Result<(), impl IntoResponse> { dbg!(foo); Err(()) } #[debug_handler] async fn handler_three(foo: String) -> Result { dbg!(foo); Ok::<_, ()>(()) } #[debug_handler] async fn handler_four(foo: String) -> Result { dbg!(foo); Err::<(), _>(()) } // === no args with receiver === struct A; impl A { #[debug_handler] async fn handler_no_arg_one(self) -> Result { Ok(()) } #[debug_handler] async fn handler_no_arg_two(self) -> Result<(), impl IntoResponse> { Err(()) } #[debug_handler] async fn handler_no_arg_three(self) -> Result { Ok::<_, ()>(()) } #[debug_handler] async fn handler_no_arg_four(self) -> Result { Err::<(), _>(()) } } // === args with receiver === impl A { #[debug_handler] async fn handler_one(self, foo: String) -> Result { dbg!(foo); Ok(()) } #[debug_handler] async fn handler_two(self, foo: String) -> Result<(), impl IntoResponse> { dbg!(foo); Err(()) } #[debug_handler] async fn handler_three(self, foo: String) -> Result { dbg!(foo); Ok::<_, ()>(()) } #[debug_handler] async fn handler_four(self, foo: String) -> Result { dbg!(foo); Err::<(), _>(()) } } #[async_trait] impl FromRequestParts for A where S: Send + Sync, { type Rejection = (); async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result { unimplemented!() } } axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/returns_self.rs000066400000000000000000000004251446276533600261400ustar00rootroot00000000000000use axum::response::{IntoResponse, Response}; use axum_macros::debug_handler; struct A; impl A { #[debug_handler] async fn handler() -> Self { A } } impl IntoResponse for A { fn into_response(self) -> Response { todo!() } } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/self_receiver.rs000066400000000000000000000014251446276533600262430ustar00rootroot00000000000000use axum::{ async_trait, extract::FromRequest, http::Request, }; use axum_macros::debug_handler; struct A; #[async_trait] impl FromRequest for A where B: Send + 'static, S: Send + Sync, { type Rejection = (); async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } #[async_trait] impl FromRequest for Box where B: Send + 'static, S: Send + Sync, { type Rejection = (); async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } impl A { #[debug_handler] async fn handler(self) {} #[debug_handler] async fn handler_with_qualified_self(self: Box) {} } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/set_state.rs000066400000000000000000000010131446276533600254120ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::extract::{FromRef, FromRequest}; use axum::async_trait; use axum::http::Request; #[debug_handler(state = AppState)] async fn handler(_: A) {} #[derive(Clone)] struct AppState; struct A; #[async_trait] impl FromRequest for A where B: Send + 'static, S: Send + Sync, AppState: FromRef, { type Rejection = (); async fn from_request(_req: Request, _state: &S) -> Result { unimplemented!() } } fn main() {} axum-axum-v0.6.20/axum-macros/tests/debug_handler/pass/state_and_body.rs000066400000000000000000000003741446276533600264070ustar00rootroot00000000000000use axum_macros::debug_handler; use axum::{body::BoxBody, extract::State, http::Request}; #[debug_handler(state = AppState, body = BoxBody)] async fn handler(_: State, _: Request) {} #[derive(Clone)] struct AppState; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_ref/000077500000000000000000000000001446276533600211245ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_ref/fail/000077500000000000000000000000001446276533600220375ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_ref/fail/generics.rs000066400000000000000000000001501446276533600242000ustar00rootroot00000000000000use axum::extract::FromRef; #[derive(Clone, FromRef)] struct AppState { foo: T, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_ref/fail/generics.stderr000066400000000000000000000002231446276533600250600ustar00rootroot00000000000000error: `#[derive(FromRef)]` doesn't support generics --> tests/from_ref/fail/generics.rs:4:16 | 4 | struct AppState { | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_ref/pass/000077500000000000000000000000001446276533600220725ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_ref/pass/basic.rs000066400000000000000000000007311446276533600235220ustar00rootroot00000000000000use axum::{Router, routing::get, extract::{State, FromRef}}; // This will implement `FromRef` for each field in the struct. #[derive(Clone, FromRef)] struct AppState { auth_token: String, } // So those types can be extracted via `State` async fn handler(_: State) {} fn main() { let state = AppState { auth_token: Default::default(), }; let _: axum::Router = Router::new() .route("/", get(handler)) .with_state(state); } axum-axum-v0.6.20/axum-macros/tests/from_ref/pass/reference-types.rs000066400000000000000000000002021446276533600255320ustar00rootroot00000000000000#![deny(noop_method_call)] use axum_macros::FromRef; #[derive(FromRef)] struct State { inner: &'static str, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_ref/pass/skip.rs000066400000000000000000000002361446276533600234070ustar00rootroot00000000000000use axum_macros::FromRef; #[derive(Clone, FromRef)] struct AppState { auth_token: String, #[from_ref(skip)] also_string: String, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/000077500000000000000000000000001446276533600220405ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_request/fail/000077500000000000000000000000001446276533600227535ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_request/fail/double_via_attr.rs000066400000000000000000000002711446276533600264640ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); #[derive(Clone)] struct State; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/double_via_attr.stderr000066400000000000000000000003661446276533600273500ustar00rootroot00000000000000error: `via` specified more than once --> tests/from_request/fail/double_via_attr.rs:4:55 | 4 | struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs000066400000000000000000000003231446276533600324530ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest, Clone)] #[from_request(via(axum::Extension))] enum Extractor { Foo { #[from_request(via(axum::Extension))] foo: (), } } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.stderr000066400000000000000000000003371446276533600333370ustar00rootroot00000000000000error: `#[from_request(via(...))]` cannot be used inside variants --> tests/from_request/fail/enum_from_request_ident_in_variant.rs:7:24 | 7 | #[from_request(via(axum::Extension))] | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_from_request_on_variant.rs000066400000000000000000000002671446276533600313050ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest, Clone)] #[from_request(via(axum::Extension))] enum Extractor { #[from_request(via(axum::Extension))] Foo, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_from_request_on_variant.stderr000066400000000000000000000003151446276533600321560ustar00rootroot00000000000000error: `#[from_request(via(...))]` cannot be used on variants --> tests/from_request/fail/enum_from_request_on_variant.rs:6:20 | 6 | #[from_request(via(axum::Extension))] | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_no_via.rs000066400000000000000000000001351446276533600256170ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest, Clone)] enum Extractor {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/enum_no_via.stderr000066400000000000000000000004401446276533600264750ustar00rootroot00000000000000error: missing `#[from_request(via(...))]` --> tests/from_request/fail/enum_no_via.rs:3:10 | 3 | #[derive(FromRequest, Clone)] | ^^^^^^^^^^^ | = note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic.rs000066400000000000000000000001441446276533600247340ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(Option); fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic.stderr000066400000000000000000000003551446276533600256170ustar00rootroot00000000000000error: #[derive(FromRequest)] only supports generics on tuple structs that have exactly one field of the generic type --> tests/from_request/fail/generic.rs:4:21 | 4 | struct Extractor(Option); | ^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic_without_via.rs000066400000000000000000000003511446276533600273560ustar00rootroot00000000000000use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { Router::<(), Body>::new().route("/", get(foo)); } axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic_without_via.stderr000066400000000000000000000025711446276533600302430ustar00rootroot00000000000000error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)] --> tests/from_request/fail/generic_without_via.rs:5:18 | 5 | struct Extractor(T); | ^ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/generic_without_via.rs:10:46 | 10 | Router::<(), Body>::new().route("/", get(foo)); | --- ^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(Extractor<()>) -> impl Future {foo}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: as Handler> as Handler<(), S, B>> note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | | top_level_handler_fn!(get, GET); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs000066400000000000000000000004111446276533600314150ustar00rootroot00000000000000use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] #[from_request(rejection(Foo))] struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { Router::<(), Body>::new().route("/", get(foo)); } axum-axum-v0.6.20/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr000066400000000000000000000026151446276533600323040ustar00rootroot00000000000000error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)] --> tests/from_request/fail/generic_without_via_rejection.rs:6:18 | 6 | struct Extractor(T); | ^ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/generic_without_via_rejection.rs:11:46 | 11 | Router::<(), Body>::new().route("/", get(foo)); | --- ^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(Extractor<()>) -> impl Future {foo}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: as Handler> as Handler<(), S, B>> note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | | top_level_handler_fn!(get, GET); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/from_request/fail/not_enum_or_struct.rs000066400000000000000000000001271446276533600272510ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] union Extractor {} fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/not_enum_or_struct.stderr000066400000000000000000000004341446276533600301310ustar00rootroot00000000000000error: expected `struct` or `enum` --> tests/from_request/fail/not_enum_or_struct.rs:4:1 | 4 | union Extractor {} | ^^^^^^^^^^^^^^^^^^ error: unions cannot have zero fields --> tests/from_request/fail/not_enum_or_struct.rs:4:1 | 4 | union Extractor {} | ^^^^^^^^^^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.rs000066400000000000000000000012631446276533600333460ustar00rootroot00000000000000use axum::{ extract::rejection::ExtensionRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyExtractor) {} async fn handler_result(_: Result) {} #[derive(FromRequest, Clone)] #[from_request(rejection(MyRejection))] enum MyExtractor {} struct MyRejection {} impl From for MyRejection { fn from(_: ExtensionRejection) -> Self { todo!() } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr000066400000000000000000000054571446276533600342360ustar00rootroot00000000000000error: cannot use `rejection` without `via` --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:18:16 | 18 | #[from_request(rejection(MyRejection))] | ^^^^^^^^^ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {handler}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50 | 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); | --- ^^^^^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(MyExtractor) -> impl Future {handler}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: as Handler> as Handler<(), S, B>> note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | | top_level_handler_fn!(get, GET); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `fn(Result) -> impl Future {handler_result}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64 | 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); | ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(Result) -> impl Future {handler_result}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: as Handler> as Handler<(), S, B>> note: required by a bound in `MethodRouter::::post` --> $WORKSPACE/axum/src/routing/method_routing.rs | | chained_handler_fn!(post, POST); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::::post` = note: this error originates in the macro `chained_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/from_request/fail/parts_extracting_body.rs000066400000000000000000000003541446276533600277210ustar00rootroot00000000000000use axum::{extract::FromRequestParts, response::Response}; #[derive(FromRequestParts)] struct Extractor { body: String, } fn assert_from_request() where Extractor: FromRequestParts<(), Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/parts_extracting_body.stderr000066400000000000000000000016471446276533600306060ustar00rootroot00000000000000error[E0277]: the trait bound `String: FromRequestParts` is not satisfied --> tests/from_request/fail/parts_extracting_body.rs:5:11 | 5 | body: String, | ^^^^^^ the trait `FromRequestParts` is not implemented for `String` | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/latest/axum/extract/index.html` for details = help: the following other types implement trait `FromRequestParts`: <() as FromRequestParts> <(T1, T2) as FromRequestParts> <(T1, T2, T3) as FromRequestParts> <(T1, T2, T3, T4) as FromRequestParts> <(T1, T2, T3, T4, T5) as FromRequestParts> <(T1, T2, T3, T4, T5, T6) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> and 27 others axum-axum-v0.6.20/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs000066400000000000000000000006261446276533600330150ustar00rootroot00000000000000use axum_macros::FromRequest; use axum::extract::State; #[derive(FromRequest)] struct Extractor { inner_state: State, other_state: State, } #[derive(Clone)] struct AppState {} #[derive(Clone)] struct OtherState {} fn assert_from_request() where Extractor: axum::extract::FromRequest, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr000066400000000000000000000005421446276533600336710ustar00rootroot00000000000000error: can't infer state type, please add `#[from_request(state = MyStateType)]` attribute --> tests/from_request/fail/state_infer_multiple_different_types.rs:4:10 | 4 | #[derive(FromRequest)] | ^^^^^^^^^^^ | = note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/from_request/fail/unknown_attr_container.rs000066400000000000000000000001531446276533600301130ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] #[from_request(foo)] struct Extractor; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/unknown_attr_container.stderr000066400000000000000000000002441446276533600307730ustar00rootroot00000000000000error: expected one of: `via`, `rejection`, `state` --> tests/from_request/fail/unknown_attr_container.rs:4:16 | 4 | #[from_request(foo)] | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/unknown_attr_field.rs000066400000000000000000000001631446276533600272150ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(#[from_request(foo)] String); fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/unknown_attr_field.stderr000066400000000000000000000002551446276533600300760ustar00rootroot00000000000000error: expected `via` --> tests/from_request/fail/unknown_attr_field.rs:4:33 | 4 | struct Extractor(#[from_request(foo)] String); | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/fail/via_on_container_and_field.rs000066400000000000000000000003111446276533600306160ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] #[from_request(via(axum::Extension))] struct Extractor(#[from_request(via(axum::Extension))] State); #[derive(Clone)] struct State; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr000066400000000000000000000004421446276533600315020ustar00rootroot00000000000000error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container --> tests/from_request/fail/via_on_container_and_field.rs:5:33 | 5 | struct Extractor(#[from_request(via(axum::Extension))] State); | ^^^ axum-axum-v0.6.20/axum-macros/tests/from_request/pass/000077500000000000000000000000001446276533600230065ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/from_request/pass/container.rs000066400000000000000000000005441446276533600253410ustar00rootroot00000000000000use axum::{ body::Body, extract::{FromRequest, Json}, response::Response, }; use serde::Deserialize; #[derive(Deserialize, FromRequest)] #[from_request(via(Json))] struct Extractor { one: i32, two: String, three: bool, } fn assert_from_request() where Extractor: FromRequest<(), Body, Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/container_parts.rs000066400000000000000000000005111446276533600265440ustar00rootroot00000000000000use axum::{ extract::{FromRequestParts, Extension}, response::Response, }; #[derive(Clone, FromRequestParts)] #[from_request(via(Extension))] struct Extractor { one: i32, two: String, three: bool, } fn assert_from_request() where Extractor: FromRequestParts<(), Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/empty_named.rs000066400000000000000000000003431446276533600256560ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor {} fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/empty_named_parts.rs000066400000000000000000000003401446276533600270640ustar00rootroot00000000000000use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor {} fn assert_from_request() where Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/empty_tuple.rs000066400000000000000000000003431446276533600257230ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(); fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/empty_tuple_parts.rs000066400000000000000000000003401446276533600271310ustar00rootroot00000000000000use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor(); fn assert_from_request() where Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/enum_via.rs000066400000000000000000000004121446276533600251540ustar00rootroot00000000000000use axum::{body::Body, routing::get, Extension, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] #[from_request(via(Extension))] enum Extractor {} async fn foo(_: Extractor) {} fn main() { Router::<(), Body>::new().route("/", get(foo)); } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/enum_via_parts.rs000066400000000000000000000004241446276533600263700ustar00rootroot00000000000000use axum::{body::Body, routing::get, Extension, Router}; use axum_macros::FromRequestParts; #[derive(FromRequestParts, Clone)] #[from_request(via(Extension))] enum Extractor {} async fn foo(_: Extractor) {} fn main() { Router::<(), Body>::new().route("/", get(foo)); } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/named.rs000066400000000000000000000010671446276533600244440ustar00rootroot00000000000000use axum::{ body::Body, extract::{FromRequest, TypedHeader, rejection::TypedHeaderRejection}, response::Response, headers::{self, UserAgent}, }; #[derive(FromRequest)] struct Extractor { uri: axum::http::Uri, user_agent: TypedHeader, content_type: TypedHeader, etag: Option>, host: Result, TypedHeaderRejection>, body: String, } fn assert_from_request() where Extractor: FromRequest<(), Body, Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/named_parts.rs000066400000000000000000000010361446276533600256510ustar00rootroot00000000000000use axum::{ extract::{rejection::TypedHeaderRejection, FromRequestParts, TypedHeader}, headers::{self, UserAgent}, response::Response, }; #[derive(FromRequestParts)] struct Extractor { uri: axum::http::Uri, user_agent: TypedHeader, content_type: TypedHeader, etag: Option>, host: Result, TypedHeaderRejection>, } fn assert_from_request() where Extractor: FromRequestParts<(), Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/named_via.rs000066400000000000000000000013471446276533600253040ustar00rootroot00000000000000use axum::{ body::Body, response::Response, extract::{ rejection::TypedHeaderRejection, Extension, FromRequest, TypedHeader, }, headers::{self, UserAgent}, }; #[derive(FromRequest)] struct Extractor { #[from_request(via(Extension))] state: State, #[from_request(via(TypedHeader))] user_agent: UserAgent, #[from_request(via(TypedHeader))] content_type: headers::ContentType, #[from_request(via(TypedHeader))] etag: Option, #[from_request(via(TypedHeader))] host: Result, } fn assert_from_request() where Extractor: FromRequest<(), Body, Rejection = Response>, { } #[derive(Clone)] struct State; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/named_via_parts.rs000066400000000000000000000013401446276533600265060ustar00rootroot00000000000000use axum::{ response::Response, extract::{ rejection::TypedHeaderRejection, Extension, FromRequestParts, TypedHeader, }, headers::{self, UserAgent}, }; #[derive(FromRequestParts)] struct Extractor { #[from_request(via(Extension))] state: State, #[from_request(via(TypedHeader))] user_agent: UserAgent, #[from_request(via(TypedHeader))] content_type: headers::ContentType, #[from_request(via(TypedHeader))] etag: Option, #[from_request(via(TypedHeader))] host: Result, } fn assert_from_request() where Extractor: FromRequestParts<(), Rejection = Response>, { } #[derive(Clone)] struct State; fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection.rs000066400000000000000000000024531446276533600272410ustar00rootroot00000000000000use axum::{ async_trait, extract::{rejection::ExtensionRejection, FromRequest}, http::{StatusCode, Request}, response::{IntoResponse, Response}, routing::get, Extension, Router, }; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyExtractor) {} async fn handler_result(_: Result) {} #[derive(FromRequest)] #[from_request(rejection(MyRejection))] struct MyExtractor { one: Extension, #[from_request(via(Extension))] two: String, three: OtherExtractor, } struct OtherExtractor; #[async_trait] impl FromRequest for OtherExtractor where B: Send + 'static, S: Send + Sync, { // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); async fn from_request(_req: Request, _state: &S) -> Result { todo!() } } struct MyRejection {} impl From for MyRejection { fn from(_: ExtensionRejection) -> Self { todo!() } } impl From<(StatusCode, String)> for MyRejection { fn from(_: (StatusCode, String)) -> Self { todo!() } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs000066400000000000000000000014721446276533600316070ustar00rootroot00000000000000use axum::{ extract::rejection::JsonRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequest; use std::collections::HashMap; use serde::Deserialize; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyJson) {} async fn handler_result(_: Result) {} #[derive(FromRequest, Deserialize)] #[from_request( via(axum::extract::Json), rejection(MyJsonRejection), )] #[serde(transparent)] struct MyJson(HashMap); struct MyJsonRejection {} impl From for MyJsonRejection { fn from(_: JsonRejection) -> Self { todo!() } } impl IntoResponse for MyJsonRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs000066400000000000000000000015201446276533600330120ustar00rootroot00000000000000use axum::{ extract::rejection::QueryRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequestParts; use std::collections::HashMap; use serde::Deserialize; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyQuery) {} async fn handler_result(_: Result) {} #[derive(FromRequestParts, Deserialize)] #[from_request( via(axum::extract::Query), rejection(MyQueryRejection), )] #[serde(transparent)] struct MyQuery(HashMap); struct MyQueryRejection {} impl From for MyQueryRejection { fn from(_: QueryRejection) -> Self { todo!() } } impl IntoResponse for MyQueryRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_parts.rs000066400000000000000000000024541446276533600304530ustar00rootroot00000000000000use axum::{ async_trait, extract::{rejection::ExtensionRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, Extension, Router, }; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyExtractor) {} async fn handler_result(_: Result) {} #[derive(FromRequestParts)] #[from_request(rejection(MyRejection))] struct MyExtractor { one: Extension, #[from_request(via(Extension))] two: String, three: OtherExtractor, } struct OtherExtractor; #[async_trait] impl FromRequestParts for OtherExtractor where S: Send + Sync, { // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result { todo!() } } struct MyRejection {} impl From for MyRejection { fn from(_: ExtensionRejection) -> Self { todo!() } } impl From<(StatusCode, String)> for MyRejection { fn from(_: (StatusCode, String)) -> Self { todo!() } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum.rs000066400000000000000000000013111446276533600326430ustar00rootroot00000000000000use axum::{ extract::rejection::ExtensionRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyExtractor) {} async fn handler_result(_: Result) {} #[derive(FromRequest, Clone)] #[from_request(via(axum::Extension), rejection(MyRejection))] enum MyExtractor {} struct MyRejection {} impl From for MyRejection { fn from(_: ExtensionRejection) -> Self { todo!() } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum_parts.rs000066400000000000000000000013231446276533600340570ustar00rootroot00000000000000use axum::{ extract::rejection::ExtensionRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequestParts; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } async fn handler(_: MyExtractor) {} async fn handler_result(_: Result) {} #[derive(FromRequestParts, Clone)] #[from_request(via(axum::Extension), rejection(MyRejection))] enum MyExtractor {} struct MyRejection {} impl From for MyRejection { fn from(_: ExtensionRejection) -> Self { todo!() } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct.rs000066400000000000000000000014301446276533600332250ustar00rootroot00000000000000use axum::{ extract::rejection::JsonRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequest; use serde::Deserialize; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } #[derive(Deserialize)] struct Payload {} async fn handler(_: MyJson) {} async fn handler_result(_: Result, MyJsonRejection>) {} #[derive(FromRequest)] #[from_request( via(axum::Json), rejection(MyJsonRejection), )] struct MyJson(T); struct MyJsonRejection {} impl From for MyJsonRejection { fn from(_: JsonRejection) -> Self { todo!() } } impl IntoResponse for MyJsonRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs000066400000000000000000000014671446276533600344500ustar00rootroot00000000000000use axum::{ extract::rejection::QueryRejection, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequestParts; use serde::Deserialize; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); } #[derive(Deserialize)] struct Payload {} async fn handler(_: MyQuery) {} async fn handler_result(_: Result, MyQueryRejection>) {} #[derive(FromRequestParts)] #[from_request( via(axum::extract::Query), rejection(MyQueryRejection), )] struct MyQuery(T); struct MyQueryRejection {} impl From for MyQueryRejection { fn from(_: QueryRejection) -> Self { todo!() } } impl IntoResponse for MyQueryRejection { fn into_response(self) -> Response { todo!() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_cookie.rs000066400000000000000000000010231446276533600260210ustar00rootroot00000000000000use axum_macros::FromRequest; use axum::extract::FromRef; use axum_extra::extract::cookie::{PrivateCookieJar, Key}; #[derive(FromRequest)] #[from_request(state(AppState))] struct Extractor { cookies: PrivateCookieJar, } struct AppState { key: Key, } impl FromRef for Key { fn from_ref(input: &AppState) -> Self { input.key.clone() } } fn assert_from_request() where Extractor: axum::extract::FromRequest, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_enum_via.rs000066400000000000000000000012441446276533600263600ustar00rootroot00000000000000use axum::{ extract::{FromRef, State}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/a", get(|_: AppState| async {})) .route("/b", get(|_: InnerState| async {})) .with_state(AppState::default()); } #[derive(Clone, FromRequest)] #[from_request(via(State))] enum AppState { One, } impl Default for AppState { fn default() -> AppState { Self::One } } #[derive(FromRequest)] #[from_request(via(State), state(AppState))] enum InnerState {} impl FromRef for InnerState { fn from_ref(_: &AppState) -> Self { todo!("๐Ÿคท") } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_enum_via_parts.rs000066400000000000000000000013641446276533600275740ustar00rootroot00000000000000use axum::{ extract::{FromRef, State}, routing::get, Router, }; use axum_macros::FromRequestParts; fn main() { let _: axum::Router = Router::new() .route("/a", get(|_: AppState| async {})) .route("/b", get(|_: InnerState| async {})) .route("/c", get(|_: AppState, _: InnerState| async {})) .with_state(AppState::default()); } #[derive(Clone, FromRequestParts)] #[from_request(via(State))] enum AppState { One, } impl Default for AppState { fn default() -> AppState { Self::One } } #[derive(FromRequestParts)] #[from_request(via(State), state(AppState))] enum InnerState {} impl FromRef for InnerState { fn from_ref(_: &AppState) -> Self { todo!("๐Ÿคท") } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_explicit.rs000066400000000000000000000014601446276533600263760ustar00rootroot00000000000000use axum_macros::FromRequest; use axum::{ extract::{FromRef, State}, Router, routing::get, }; fn main() { let _: axum::Router = Router::new() .route("/b", get(|_: Extractor| async {})) .with_state(AppState::default()); } #[derive(FromRequest)] #[from_request(state(AppState))] struct Extractor { app_state: State, one: State, two: State, other_extractor: String, } #[derive(Clone, Default)] struct AppState { one: One, two: Two, } #[derive(Clone, Default)] struct One {} impl FromRef for One { fn from_ref(input: &AppState) -> Self { input.one.clone() } } #[derive(Clone, Default)] struct Two {} impl FromRef for Two { fn from_ref(input: &AppState) -> Self { input.two.clone() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_explicit_parts.rs000066400000000000000000000012721446276533600276100ustar00rootroot00000000000000use axum_macros::FromRequestParts; use axum::{ extract::{FromRef, State, Query}, Router, routing::get, }; use std::collections::HashMap; fn main() { let _: axum::Router = Router::new() .route("/b", get(|_: Extractor| async {})) .with_state(AppState::default()); } #[derive(FromRequestParts)] #[from_request(state(AppState))] struct Extractor { inner_state: State, other: Query>, } #[derive(Default, Clone)] struct AppState { inner: InnerState, } #[derive(Clone, Default)] struct InnerState {} impl FromRef for InnerState { fn from_ref(input: &AppState) -> Self { input.inner.clone() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_field_explicit.rs000066400000000000000000000012461446276533600275430ustar00rootroot00000000000000use axum::{ extract::{State, FromRef}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/", get(|_: Extractor| async {})) .with_state(AppState::default()); } #[derive(FromRequest)] #[from_request(state(AppState))] struct Extractor { #[from_request(via(State))] state: AppState, #[from_request(via(State))] inner: InnerState, } #[derive(Clone, Default)] struct AppState { inner: InnerState, } #[derive(Clone, Default)] struct InnerState {} impl FromRef for InnerState { fn from_ref(input: &AppState) -> Self { input.inner.clone() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_field_infer.rs000066400000000000000000000006021446276533600270200ustar00rootroot00000000000000use axum::{ extract::State, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/", get(|_: Extractor| async {})) .with_state(AppState::default()); } #[derive(FromRequest)] struct Extractor { #[from_request(via(State))] state: AppState, } #[derive(Clone, Default)] struct AppState {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_infer.rs000066400000000000000000000005131446276533600256560ustar00rootroot00000000000000use axum_macros::FromRequest; use axum::extract::State; #[derive(FromRequest)] struct Extractor { inner_state: State, } #[derive(Clone)] struct AppState {} fn assert_from_request() where Extractor: axum::extract::FromRequest, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_infer_multiple.rs000066400000000000000000000005621446276533600275750ustar00rootroot00000000000000use axum_macros::FromRequest; use axum::extract::State; #[derive(FromRequest)] struct Extractor { inner_state: State, also_inner_state: State, } #[derive(Clone)] struct AppState {} fn assert_from_request() where Extractor: axum::extract::FromRequest, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_infer_parts.rs000066400000000000000000000005101446276533600270640ustar00rootroot00000000000000use axum_macros::FromRequestParts; use axum::extract::State; #[derive(FromRequestParts)] struct Extractor { inner_state: State, } #[derive(Clone)] struct AppState {} fn assert_from_request() where Extractor: axum::extract::FromRequestParts, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_via.rs000066400000000000000000000012421446276533600253320ustar00rootroot00000000000000use axum::{ extract::{FromRef, State}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/b", get(|_: (), _: AppState| async {})) .route("/c", get(|_: (), _: InnerState| async {})) .with_state(AppState::default()); } #[derive(Clone, Default, FromRequest)] #[from_request(via(State), state(AppState))] struct AppState { inner: InnerState, } #[derive(Clone, Default, FromRequest)] #[from_request(via(State), state(AppState))] struct InnerState {} impl FromRef for InnerState { fn from_ref(input: &AppState) -> Self { input.inner.clone() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_via_infer.rs000066400000000000000000000006601446276533600265200ustar00rootroot00000000000000use axum::{ extract::State, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/b", get(|_: AppState| async {})) .with_state(AppState::default()); } // if we're extract "via" `State` and not specifying state // assume `AppState` is the state #[derive(Clone, Default, FromRequest)] #[from_request(via(State))] struct AppState {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_via_parts.rs000066400000000000000000000013641446276533600265500ustar00rootroot00000000000000use axum::{ extract::{FromRef, State}, routing::get, Router, }; use axum_macros::FromRequestParts; fn main() { let _: axum::Router = Router::new() .route("/a", get(|_: AppState, _: InnerState, _: String| async {})) .route("/b", get(|_: AppState, _: String| async {})) .route("/c", get(|_: InnerState, _: String| async {})) .with_state(AppState::default()); } #[derive(Clone, Default, FromRequestParts)] #[from_request(via(State))] struct AppState { inner: InnerState, } #[derive(Clone, Default, FromRequestParts)] #[from_request(via(State), state(AppState))] struct InnerState {} impl FromRef for InnerState { fn from_ref(input: &AppState) -> Self { input.inner.clone() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/state_with_rejection.rs000066400000000000000000000013351446276533600275730ustar00rootroot00000000000000use std::convert::Infallible; use axum::{ extract::State, response::{IntoResponse, Response}, routing::get, Router, }; use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() .route("/a", get(|_: Extractor| async {})) .with_state(AppState::default()); } #[derive(Clone, Default, FromRequest)] #[from_request(rejection(MyRejection))] struct Extractor { state: State, } #[derive(Clone, Default)] struct AppState {} struct MyRejection {} impl From for MyRejection { fn from(err: Infallible) -> Self { match err {} } } impl IntoResponse for MyRejection { fn into_response(self) -> Response { ().into_response() } } axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple.rs000066400000000000000000000003321446276533600245030ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(axum::http::HeaderMap, String); fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_parts.rs000066400000000000000000000003431446276533600257160ustar00rootroot00000000000000use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor(axum::http::HeaderMap, axum::http::Method); fn assert_from_request() where Extractor: axum::extract::FromRequestParts<()>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs000066400000000000000000000005201446276533600277430ustar00rootroot00000000000000use axum::extract::Query; use axum_macros::FromRequest; use serde::Deserialize; #[derive(FromRequest)] struct Extractor( Query, axum::extract::Json, ); #[derive(Deserialize)] struct Payload {} fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs000066400000000000000000000005151446276533600311600ustar00rootroot00000000000000use axum::extract::Query; use axum_macros::FromRequestParts; use serde::Deserialize; #[derive(FromRequestParts)] struct Extractor( Query, axum::extract::Path, ); #[derive(Deserialize)] struct Payload {} fn assert_from_request() where Extractor: axum::extract::FromRequestParts<()>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs000066400000000000000000000006561446276533600306140ustar00rootroot00000000000000use axum::extract::Query; use axum::response::Response; use axum_macros::FromRequest; use serde::Deserialize; #[derive(FromRequest)] struct Extractor( #[from_request(via(Query))] Payload, #[from_request(via(axum::extract::Json))] Payload, ); #[derive(Deserialize)] struct Payload {} fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_same_type_twice_via_parts.rs000066400000000000000000000006531446276533600320220ustar00rootroot00000000000000use axum::extract::Query; use axum::response::Response; use axum_macros::FromRequestParts; use serde::Deserialize; #[derive(FromRequestParts)] struct Extractor( #[from_request(via(Query))] Payload, #[from_request(via(axum::extract::Path))] Payload, ); #[derive(Deserialize)] struct Payload {} fn assert_from_request() where Extractor: axum::extract::FromRequestParts<(), Rejection = Response>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_via.rs000066400000000000000000000004271446276533600253470ustar00rootroot00000000000000use axum::Extension; use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(#[from_request(via(Extension))] State); #[derive(Clone)] struct State; fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/tuple_via_parts.rs000066400000000000000000000004241446276533600265550ustar00rootroot00000000000000use axum::Extension; use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor(#[from_request(via(Extension))] State); #[derive(Clone)] struct State; fn assert_from_request() where Extractor: axum::extract::FromRequestParts<()>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/unit.rs000066400000000000000000000003411446276533600243310ustar00rootroot00000000000000use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor; fn assert_from_request() where Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/from_request/pass/unit_parts.rs000066400000000000000000000003361446276533600255460ustar00rootroot00000000000000use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor; fn assert_from_request() where Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, { } fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/000077500000000000000000000000001446276533600214665ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/000077500000000000000000000000001446276533600224015ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/missing_capture.rs000066400000000000000000000002341446276533600261420ustar00rootroot00000000000000use axum_macros::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users")] struct MyPath { id: u32, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/missing_capture.stderr000066400000000000000000000007041446276533600270230ustar00rootroot00000000000000error[E0027]: pattern does not mention field `id` --> tests/typed_path/fail/missing_capture.rs:5:14 | 5 | #[typed_path("/users")] | ^^^^^^^^ missing field `id` | help: include the missing field in the pattern | 5 | #[typed_path("/users" { id })] | ++++++ help: if you don't care about this missing field, you can explicitly ignore it | 5 | #[typed_path("/users" { .. })] | ++++++ axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/missing_field.rs000066400000000000000000000002231446276533600255600ustar00rootroot00000000000000use axum_macros::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:id")] struct MyPath {} fn main() { } axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/missing_field.stderr000066400000000000000000000003331446276533600264410ustar00rootroot00000000000000error[E0026]: struct `MyPath` does not have a field named `id` --> tests/typed_path/fail/missing_field.rs:5:14 | 5 | #[typed_path("/users/:id")] | ^^^^^^^^^^^^ struct `MyPath` does not have this field axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/not_deserialize.rs000066400000000000000000000001731446276533600261300ustar00rootroot00000000000000use axum_macros::TypedPath; #[derive(TypedPath)] #[typed_path("/users/:id")] struct MyPath { id: u32, } fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/not_deserialize.stderr000066400000000000000000000015641446276533600270140ustar00rootroot00000000000000error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is not satisfied --> tests/typed_path/fail/not_deserialize.rs:3:10 | 3 | #[derive(TypedPath)] | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath` | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] &'a serde_json::raw::RawValue &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) and $N others = note: required for `MyPath` to implement `serde::de::DeserializeOwned` = note: required for `axum::extract::Path` to implement `FromRequestParts` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/route_not_starting_with_slash.rs000066400000000000000000000001511446276533600311220ustar00rootroot00000000000000use axum_extra::routing::TypedPath; #[derive(TypedPath)] #[typed_path("")] struct MyPath; fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/route_not_starting_with_slash.stderr000066400000000000000000000002531446276533600320040ustar00rootroot00000000000000error: paths must start with a `/`. Use "/" for root routes --> tests/typed_path/fail/route_not_starting_with_slash.rs:4:14 | 4 | #[typed_path("")] | ^^ axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs000066400000000000000000000001551446276533600332160ustar00rootroot00000000000000use axum_extra::routing::TypedPath; #[derive(TypedPath)] #[typed_path(":foo")] struct MyPath; fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.stderr000066400000000000000000000002441446276533600340740ustar00rootroot00000000000000error: paths must start with a `/` --> tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs:4:14 | 4 | #[typed_path(":foo")] | ^^^^^^ axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/unit_with_capture.rs000066400000000000000000000002201446276533600264760ustar00rootroot00000000000000use axum_macros::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:id")] struct MyPath; fn main() {} axum-axum-v0.6.20/axum-macros/tests/typed_path/fail/unit_with_capture.stderr000066400000000000000000000002631446276533600273640ustar00rootroot00000000000000error: Typed paths for unit structs cannot contain captures --> tests/typed_path/fail/unit_with_capture.rs:5:14 | 5 | #[typed_path("/users/:id")] | ^^^^^^^^^^^^ axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/000077500000000000000000000000001446276533600224345ustar00rootroot00000000000000axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/customize_rejection.rs000066400000000000000000000020741446276533600270710ustar00rootroot00000000000000use axum::{ extract::rejection::PathRejection, response::{IntoResponse, Response}, }; use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/:foo", rejection(MyRejection))] struct MyPathNamed { foo: String, } #[derive(TypedPath, Deserialize)] #[typed_path("/", rejection(MyRejection))] struct MyPathUnit; #[derive(TypedPath, Deserialize)] #[typed_path("/:foo", rejection(MyRejection))] struct MyPathUnnamed(String); struct MyRejection; impl IntoResponse for MyRejection { fn into_response(self) -> Response { ().into_response() } } impl From for MyRejection { fn from(_: PathRejection) -> Self { Self } } impl Default for MyRejection { fn default() -> Self { Self } } fn main() { axum::Router::<(), axum::body::Body>::new() .typed_get(|_: Result| async {}) .typed_post(|_: Result| async {}) .typed_put(|_: Result| async {}); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/into_uri.rs000066400000000000000000000006711446276533600246360ustar00rootroot00000000000000use axum_extra::routing::TypedPath; use axum::http::Uri; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/:id")] struct Named { id: u32, } #[derive(TypedPath, Deserialize)] #[typed_path("/:id")] struct Unnamed(u32); #[derive(TypedPath, Deserialize)] #[typed_path("/")] struct Unit; fn main() { let _: Uri = Named { id: 1 }.to_uri(); let _: Uri = Unnamed(1).to_uri(); let _: Uri = Unit.to_uri(); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/named_fields_struct.rs000066400000000000000000000010651446276533600270220ustar00rootroot00000000000000use axum_extra::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:user_id/teams/:team_id")] struct MyPath { user_id: u32, team_id: u32, } fn main() { axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!( format!( "{}", MyPath { user_id: 1, team_id: 2 } ), "/users/1/teams/2" ); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/option_result.rs000066400000000000000000000012611446276533600257100ustar00rootroot00000000000000use axum_extra::routing::{TypedPath, RouterExt}; use axum::{extract::rejection::PathRejection, http::StatusCode}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:id")] struct UsersShow { id: String, } async fn option_handler(_: Option) {} async fn result_handler(_: Result) {} #[derive(TypedPath, Deserialize)] #[typed_path("/users")] struct UsersIndex; async fn result_handler_unit_struct(_: Result) {} fn main() { axum::Router::<(), axum::body::Body>::new() .typed_get(option_handler) .typed_post(result_handler) .typed_post(result_handler_unit_struct); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/tuple_struct.rs000066400000000000000000000007211446276533600255370ustar00rootroot00000000000000use axum_extra::routing::TypedPath; use serde::Deserialize; pub type Result = std::result::Result; #[derive(TypedPath, Deserialize)] #[typed_path("/users/:user_id/teams/:team_id")] struct MyPath(u32, u32); fn main() { axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/unit_struct.rs000066400000000000000000000004711446276533600253670ustar00rootroot00000000000000use axum_extra::routing::TypedPath; #[derive(TypedPath)] #[typed_path("/users")] struct MyPath; fn main() { axum::Router::<(), axum::body::Body>::new() .route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users"); assert_eq!(format!("{}", MyPath), "/users"); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/url_encoding.rs000066400000000000000000000010371446276533600254530ustar00rootroot00000000000000use axum_extra::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/:param")] struct Named { param: String, } #[derive(TypedPath, Deserialize)] #[typed_path("/:param")] struct Unnamed(String); fn main() { assert_eq!( format!( "{}", Named { param: "a b".to_string() } ), "/a%20b" ); assert_eq!( format!( "{}", Unnamed("a b".to_string()), ), "/a%20b" ); } axum-axum-v0.6.20/axum-macros/tests/typed_path/pass/wildcards.rs000066400000000000000000000004101446276533600247510ustar00rootroot00000000000000use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] #[typed_path("/*rest")] struct MyPath { rest: String, } fn main() { axum::Router::<(), axum::body::Body>::new().typed_get(|_: MyPath| async {}); } axum-axum-v0.6.20/axum/000077500000000000000000000000001446276533600147015ustar00rootroot00000000000000axum-axum-v0.6.20/axum/CHANGELOG.md000066400000000000000000002345631446276533600165270ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). # Unreleased - None. # 0.6.20 (03. August, 2023) - **added:** `WebSocketUpgrade::write_buffer_size` and `WebSocketUpgrade::max_write_buffer_size` - **changed:** Deprecate `WebSocketUpgrade::max_send_queue` - **change:** Update tokio-tungstenite to 0.20 - **added:** Implement `Handler` for `T: IntoResponse` ([#2140]) [#2140]: https://github.com/tokio-rs/axum/pull/2140 # 0.6.19 (17. July, 2023) - **added:** Add `axum::extract::Query::try_from_uri` ([#2058]) - **added:** Implement `IntoResponse` for `Box` and `Box<[u8]>` ([#2035]) - **fixed:** Fix bugs around merging routers with nested fallbacks ([#2096]) - **fixed:** Fix `.source()` of composite rejections ([#2030]) - **fixed:** Allow unreachable code in `#[debug_handler]` ([#2014]) - **change:** Update tokio-tungstenite to 0.19 ([#2021]) - **change:** axum's MSRV is now 1.63 ([#2021]) [#2014]: https://github.com/tokio-rs/axum/pull/2014 [#2021]: https://github.com/tokio-rs/axum/pull/2021 [#2030]: https://github.com/tokio-rs/axum/pull/2030 [#2035]: https://github.com/tokio-rs/axum/pull/2035 [#2058]: https://github.com/tokio-rs/axum/pull/2058 [#2096]: https://github.com/tokio-rs/axum/pull/2096 # 0.6.18 (30. April, 2023) - **fixed:** Don't remove the `Sec-WebSocket-Key` header in `WebSocketUpgrade` ([#1972]) [#1972]: https://github.com/tokio-rs/axum/pull/1972 # 0.6.17 (25. April, 2023) - **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958]) [#1958]: https://github.com/tokio-rs/axum/pull/1958 # 0.6.16 (18. April, 2023) - **fixed:** Don't allow extracting `MatchedPath` in fallbacks ([#1934]) - **fixed:** Fix panic if `Router` with something nested at `/` was used as a fallback ([#1934]) - **added:** Document that `Router::new().fallback(...)` isn't optimal ([#1940]) [#1934]: https://github.com/tokio-rs/axum/pull/1934 [#1940]: https://github.com/tokio-rs/axum/pull/1940 # 0.6.15 (12. April, 2023) - **fixed:** Removed additional leftover debug messages ([#1927]) [#1927]: https://github.com/tokio-rs/axum/pull/1927 # 0.6.14 (11. April, 2023) - **fixed:** Removed leftover "path_router hit" debug message ([#1925]) [#1925]: https://github.com/tokio-rs/axum/pull/1925 # 0.6.13 (11. April, 2023) - **added:** Log rejections from built-in extractors with the `axum::rejection=trace` target ([#1890]) - **fixed:** Fixed performance regression with `Router::nest` introduced in 0.6.0. `nest` now flattens the routes which performs better ([#1711]) - **fixed:** Extracting `MatchedPath` in nested handlers now gives the full matched path, including the nested path ([#1711]) - **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922]) [#1711]: https://github.com/tokio-rs/axum/pull/1711 [#1890]: https://github.com/tokio-rs/axum/pull/1890 [#1922]: https://github.com/tokio-rs/axum/pull/1922 # 0.6.12 (22. March, 2023) - **added:** Implement `IntoResponse` for `MultipartError` ([#1861]) - **fixed:** More clearly document what wildcards matches ([#1873]) [#1861]: https://github.com/tokio-rs/axum/pull/1861 [#1873]: https://github.com/tokio-rs/axum/pull/1873 # 0.6.11 (13. March, 2023) - **fixed:** Don't require `S: Debug` for `impl Debug for Router` ([#1836]) - **fixed:** Clone state a bit less when handling requests ([#1837]) - **fixed:** Unpin itoa dependency ([#1815]) [#1815]: https://github.com/tokio-rs/axum/pull/1815 [#1836]: https://github.com/tokio-rs/axum/pull/1836 [#1837]: https://github.com/tokio-rs/axum/pull/1837 # 0.6.10 (03. March, 2023) - **fixed:** Add `#[must_use]` attributes to types that do nothing unless used ([#1809]) - **fixed:** Gracefully handle missing headers in the `TypedHeader` extractor ([#1810]) - **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806]) [#1806]: https://github.com/tokio-rs/axum/pull/1806 [#1809]: https://github.com/tokio-rs/axum/pull/1809 [#1810]: https://github.com/tokio-rs/axum/pull/1810 # 0.6.9 (24. February, 2023) - **changed:** Update to tower-http 0.4. axum is still compatible with tower-http 0.3 ([#1783]) [#1783]: https://github.com/tokio-rs/axum/pull/1783 # 0.6.8 (24. February, 2023) - **fixed:** Fix `Allow` missing from routers with middleware ([#1773]) - **added:** Add `KeepAlive::event` for customizing the event sent for SSE keep alive ([#1729]) [#1729]: https://github.com/tokio-rs/axum/pull/1729 [#1773]: https://github.com/tokio-rs/axum/pull/1773 # 0.6.7 (17. February, 2023) - **added:** Add `FormRejection::FailedToDeserializeFormBody` which is returned if the request body couldn't be deserialized into the target type, as opposed to `FailedToDeserializeForm` which is only for query parameters ([#1683]) - **added:** Add `MockConnectInfo` for setting `ConnectInfo` during tests ([#1767]) [#1683]: https://github.com/tokio-rs/axum/pull/1683 [#1767]: https://github.com/tokio-rs/axum/pull/1767 # 0.6.6 (12. February, 2023) - **fixed:** Enable passing `MethodRouter` to `Router::fallback` ([#1730]) [#1730]: https://github.com/tokio-rs/axum/pull/1730 # 0.6.5 (11. February, 2023) - **fixed:** Fix `#[debug_handler]` sometimes giving wrong borrow related suggestions ([#1710]) - Document gotchas related to using `impl IntoResponse` as the return type from handler functions ([#1736]) [#1710]: https://github.com/tokio-rs/axum/pull/1710 [#1736]: https://github.com/tokio-rs/axum/pull/1736 # 0.6.4 (22. January, 2023) - Depend on axum-macros 0.3.2 # 0.6.3 (20. January, 2023) - **added:** Implement `IntoResponse` for `&'static [u8; N]` and `[u8; N]` ([#1690]) - **fixed:** Make `Path` support types using `serde::Deserializer::deserialize_any` ([#1693]) - **added:** Add `RawPathParams` ([#1713]) - **added:** Implement `Clone` and `Service` for `axum::middleware::Next` ([#1712]) - **fixed:** Document required tokio features to run "Hello, World!" example ([#1715]) [#1690]: https://github.com/tokio-rs/axum/pull/1690 [#1693]: https://github.com/tokio-rs/axum/pull/1693 [#1712]: https://github.com/tokio-rs/axum/pull/1712 [#1713]: https://github.com/tokio-rs/axum/pull/1713 [#1715]: https://github.com/tokio-rs/axum/pull/1715 # 0.6.2 (9. January, 2023) - **added:** Add `body_text` and `status` methods to built-in rejections ([#1612]) - **added:** Enable the `runtime` feature of `hyper` when using `tokio` ([#1671]) [#1612]: https://github.com/tokio-rs/axum/pull/1612 [#1671]: https://github.com/tokio-rs/axum/pull/1671 # 0.6.1 (29. November, 2022) - **added:** Expand the docs for `Router::with_state` ([#1580]) [#1580]: https://github.com/tokio-rs/axum/pull/1580 # 0.6.0 (25. November, 2022) ## Routing - **fixed:** Nested routers are now allowed to have fallbacks ([#1521]): ```rust let api_router = Router::new() .route("/users", get(|| { ... })) .fallback(api_fallback); let app = Router::new() // this would panic in 0.5 but in 0.6 it just works // // requests starting with `/api` but not handled by `api_router` // will go to `api_fallback` .nest("/api", api_router); ``` The outer router's fallback will still apply if a nested router doesn't have its own fallback: ```rust // this time without a fallback let api_router = Router::new().route("/users", get(|| { ... })); let app = Router::new() .nest("/api", api_router) // `api_router` will inherit this fallback .fallback(app_fallback); ``` - **breaking:** The request `/foo/` no longer matches `/foo/*rest`. If you want to match `/foo/` you have to add a route specifically for that ([#1086]) For example: ```rust use axum::{Router, routing::get, extract::Path}; let app = Router::new() // this will match `/foo/bar/baz` .route("/foo/*rest", get(handler)) // this will match `/foo/` .route("/foo/", get(handler)) // if you want `/foo` to match you must also add an explicit route for it .route("/foo", get(handler)); async fn handler( // use an `Option` because `/foo/` and `/foo` don't have any path params params: Option>, ) {} ``` - **breaking:** Path params for wildcard routes no longer include the prefix `/`. e.g. `/foo.js` will match `/*filepath` with a value of `foo.js`, _not_ `/foo.js` ([#1086]) For example: ```rust use axum::{Router, routing::get, extract::Path}; let app = Router::new().route("/foo/*rest", get(handler)); async fn handler( Path(params): Path, ) { // for the request `/foo/bar/baz` the value of `params` will be `bar/baz` // // on 0.5 it would be `/bar/baz` } ``` - **fixed:** Routes like `/foo` and `/*rest` are no longer considered overlapping. `/foo` will take priority ([#1086]) For example: ```rust use axum::{Router, routing::get}; let app = Router::new() // this used to not be allowed but now just works .route("/foo/*rest", get(foo)) .route("/foo/bar", get(bar)); async fn foo() {} async fn bar() {} ``` - **breaking:** Automatic trailing slash redirects have been removed. Previously if you added a route for `/foo`, axum would redirect calls to `/foo/` to `/foo` (or vice versa for `/foo/`): ```rust use axum::{Router, routing::get}; let app = Router::new() // a request to `GET /foo/` will now get `404 Not Found` // whereas in 0.5 axum would redirect to `/foo` // // same goes the other way if you had the route `/foo/` // axum will no longer redirect from `/foo` to `/foo/` .route("/foo", get(handler)); async fn handler() {} ``` Either explicitly add routes for `/foo` and `/foo/` or use `axum_extra::routing::RouterExt::route_with_tsr` if you want the old behavior ([#1119]) - **breaking:** `Router::fallback` now only accepts `Handler`s (similarly to what `get`, `post`, etc. accept). Use the new `Router::fallback_service` for setting any `Service` as the fallback ([#1155]) This fallback on 0.5: ```rust use axum::{Router, handler::Handler}; let app = Router::new().fallback(fallback.into_service()); async fn fallback() {} ``` Becomes this in 0.6 ```rust use axum::Router; let app = Router::new().fallback(fallback); async fn fallback() {} ``` - **breaking:** It is no longer supported to `nest` twice at the same path, i.e. `.nest("/foo", a).nest("/foo", b)` will panic. Instead use `.nest("/foo", a.merge(b))` - **breaking:** It is no longer supported to `nest` a router and add a route at the same path, such as `.nest("/a", _).route("/a", _)`. Instead use `.nest("/a/", _).route("/a", _)`. - **changed:** `Router::nest` now only accepts `Router`s, the general-purpose `Service` nesting method has been renamed to `nest_service` ([#1368]) - **breaking:** Allow `Error: Into` for `Route::{layer, route_layer}` ([#924]) - **breaking:** `MethodRouter` now panics on overlapping routes ([#1102]) - **breaking:** `Router::route` now only accepts `MethodRouter`s created with `get`, `post`, etc. Use the new `Router::route_service` for routing to any `Service`s ([#1155]) - **breaking:** Adding a `.route_layer` onto a `Router` or `MethodRouter` without any routes will now result in a panic. Previously, this just did nothing. [#1327] - **breaking:** `RouterService` has been removed since `Router` now implements `Service` when the state is `()`. Use `Router::with_state` to provide the state and get a `Router<()>`. Note that `RouterService` only existed in the pre-releases, not 0.5 ([#1552]) ## Extractors - **added:** Added new type safe `State` extractor. This can be used with `Router::with_state` and gives compile errors for missing states, whereas `Extension` would result in runtime errors ([#1155]) We recommend migrating from `Extension` to `State` for sharing application state since that is more type safe and faster. That is done by using `Router::with_state` and `State`. This setup in 0.5 ```rust use axum::{routing::get, Extension, Router}; let app = Router::new() .route("/", get(handler)) .layer(Extension(AppState {})); async fn handler(Extension(app_state): Extension) {} #[derive(Clone)] struct AppState {} ``` Becomes this in 0.6 using `State`: ```rust use axum::{routing::get, extract::State, Router}; let app = Router::new() .route("/", get(handler)) .with_state(AppState {}); async fn handler(State(app_state): State) {} #[derive(Clone)] struct AppState {} ``` If you have multiple extensions, you can use fields on `AppState` and implement `FromRef`: ```rust use axum::{extract::{State, FromRef}, routing::get, Router}; let state = AppState { client: HttpClient {}, database: Database {}, }; let app = Router::new().route("/", get(handler)).with_state(state); async fn handler( State(client): State, State(database): State, ) {} // the derive requires enabling the "macros" feature #[derive(Clone, FromRef)] struct AppState { client: HttpClient, database: Database, } #[derive(Clone)] struct HttpClient {} #[derive(Clone)] struct Database {} ``` - **breaking:** It is now only possible for one extractor per handler to consume the request body. In 0.5 doing so would result in runtime errors but in 0.6 it is a compile error ([#1272]) axum enforces this by only allowing the _last_ extractor to consume the request. For example: ```rust use axum::{Json, http::HeaderMap}; // This wont compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, string: String, ) {} // This won't work either since `Json` is not the last extractor. async fn handler_2( json: Json, headers: HeaderMap, ) {} // This works! async fn handler_3( headers: HeaderMap, json: Json, ) {} ``` This is done by reworking the `FromRequest` trait and introducing a new `FromRequestParts` trait. If your extractor needs to consume the request body then you should implement `FromRequest`, otherwise implement `FromRequestParts`. This extractor in 0.5: ```rust struct MyExtractor { /* ... */ } #[async_trait] impl FromRequest for MyExtractor where B: Send, { type Rejection = StatusCode; async fn from_request(req: &mut RequestParts) -> Result { // ... } } ``` Becomes this in 0.6: ```rust use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, { type Rejection = StatusCode; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // ... } } // implement `FromRequest` if you do need to consume the request body #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, B: Send + 'static, { type Rejection = StatusCode; async fn from_request(req: Request, state: &S) -> Result { // ... } } ``` For an example of how to write an extractor that accepts different `Content-Types` see the [`parse-body-based-on-content-type`] example. - **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from [`axum-macros`] behind the `macros` feature ([#1352]) - **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience methods for running extractors to `http::Request` and `http::request::Parts` ([#1301]) - **added**: `JsonRejection` now displays the path at which a deserialization error occurred ([#1371]) - **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487]) - **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString` rejections, instead of `422 Unprocessable Entity` ([#1387]) - **changed**: The inner error of a `JsonRejection` is now `serde_path_to_error::Error`. Previously it was `serde_json::Error` ([#1371]) - **changed:** The default body limit now applies to the `Multipart` extractor ([#1420]) - **breaking:** `ContentLengthLimit` has been removed. Use `DefaultBodyLimit` instead ([#1400]) - **breaking:** `RequestParts` has been removed as part of the `FromRequest` rework ([#1272]) - **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) - **breaking:** The following types or traits have a new `S` type param which represents the state ([#1155]): - `Router`, defaults to `()` - `MethodRouter`, defaults to `()` - `FromRequest`, no default - `Handler`, no default - **breaking:** `MatchedPath` can now no longer be extracted in middleware for nested routes. In previous versions it returned invalid data when extracted from a middleware applied to a nested router. `MatchedPath` can still be extracted from handlers and middleware that aren't on nested routers ([#1462]) - **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to `FormRejection::FailedToDeserializeForm` ([#1496]) ## Middleware - **added:** Support running extractors on `middleware::from_fn` functions ([#1088]) - **added**: Add `middleware::from_fn_with_state` to enable running extractors that require state ([#1342]) - **added:** Add `middleware::from_extractor_with_state` ([#1396]) - **added:** Add `map_request`, `map_request_with_state` for transforming the request with an async function ([#1408]) - **added:** Add `map_response`, `map_response_with_state` for transforming the response with an async function ([#1414]) - **added:** Support any middleware response that implements `IntoResponse` ([#1152]) - **breaking:** Remove `extractor_middleware` which was previously deprecated. Use `axum::middleware::from_extractor` instead ([#1077]) - **breaking:** Require middleware added with `Handler::layer` to have `Infallible` as the error type ([#1152]) ## Misc - **added:** Support compiling to WASM. See the `simple-router-wasm` example for more details ([#1382]) - **added:** Add `ServiceExt` with methods for turning any `Service` into a `MakeService` similarly to `Router::into_make_service` ([#1302]) - **added:** String and binary `From` impls have been added to `extract::ws::Message` to be more inline with `tungstenite` ([#1421]) - **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430]) - **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) - **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do when upgrading a connection fails ([#1539]) - **fixed:** Annotate panicking functions with `#[track_caller]` so the error message points to where the user added the invalid route, rather than somewhere internally in axum ([#1248]) - **changed:** axum's MSRV is now 1.60 ([#1239]) - **changed:** For methods that accept some `S: Service`, the bounds have been relaxed so the response type must implement `IntoResponse` rather than being a literal `Response` - **breaking:** New `tokio` default feature needed for WASM support. If you don't need WASM support but have `default_features = false` for other reasons you likely need to re-enable the `tokio` feature ([#1382]) - **breaking:** `handler::{WithState, IntoService}` are merged into one type, named `HandlerService` ([#1418]) [#924]: https://github.com/tokio-rs/axum/pull/924 [#1077]: https://github.com/tokio-rs/axum/pull/1077 [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1088]: https://github.com/tokio-rs/axum/pull/1088 [#1102]: https://github.com/tokio-rs/axum/pull/1102 [#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1152]: https://github.com/tokio-rs/axum/pull/1152 [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1272]: https://github.com/tokio-rs/axum/pull/1272 [#1301]: https://github.com/tokio-rs/axum/pull/1301 [#1302]: https://github.com/tokio-rs/axum/pull/1302 [#1327]: https://github.com/tokio-rs/axum/pull/1327 [#1342]: https://github.com/tokio-rs/axum/pull/1342 [#1346]: https://github.com/tokio-rs/axum/pull/1346 [#1352]: https://github.com/tokio-rs/axum/pull/1352 [#1368]: https://github.com/tokio-rs/axum/pull/1368 [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1382]: https://github.com/tokio-rs/axum/pull/1382 [#1387]: https://github.com/tokio-rs/axum/pull/1387 [#1389]: https://github.com/tokio-rs/axum/pull/1389 [#1396]: https://github.com/tokio-rs/axum/pull/1396 [#1397]: https://github.com/tokio-rs/axum/pull/1397 [#1400]: https://github.com/tokio-rs/axum/pull/1400 [#1408]: https://github.com/tokio-rs/axum/pull/1408 [#1414]: https://github.com/tokio-rs/axum/pull/1414 [#1418]: https://github.com/tokio-rs/axum/pull/1418 [#1420]: https://github.com/tokio-rs/axum/pull/1420 [#1421]: https://github.com/tokio-rs/axum/pull/1421 [#1430]: https://github.com/tokio-rs/axum/pull/1430 [#1462]: https://github.com/tokio-rs/axum/pull/1462 [#1487]: https://github.com/tokio-rs/axum/pull/1487 [#1496]: https://github.com/tokio-rs/axum/pull/1496 [#1521]: https://github.com/tokio-rs/axum/pull/1521 [#1529]: https://github.com/tokio-rs/axum/pull/1529 [#1532]: https://github.com/tokio-rs/axum/pull/1532 [#1539]: https://github.com/tokio-rs/axum/pull/1539 [#1552]: https://github.com/tokio-rs/axum/pull/1552 [`axum-macros`]: https://docs.rs/axum-macros/latest/axum_macros/ [`parse-body-based-on-content-type`]: https://github.com/tokio-rs/axum/blob/main/examples/parse-body-based-on-content-type/src/main.rs
0.6.0 Pre-Releases # 0.6.0-rc.5 (18. November, 2022) - **breaking:** `Router::with_state` is no longer a constructor. It is instead used to convert the router into a `RouterService` ([#1532]) This nested router on 0.6.0-rc.4 ```rust Router::with_state(state).route(...); ``` Becomes this in 0.6.0-rc.5 ```rust Router::new().route(...).with_state(state); ``` - **breaking:**: `Router::inherit_state` has been removed. Use `Router::with_state` instead ([#1532]) - **breaking:**: `Router::nest` and `Router::merge` now only supports nesting routers that use the same state type as the router they're being merged into. Use `FromRef` for substates ([#1532]) - **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) - **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521]) - **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do when upgrading a connection fails ([#1539]) [#1521]: https://github.com/tokio-rs/axum/pull/1521 [#1529]: https://github.com/tokio-rs/axum/pull/1529 [#1532]: https://github.com/tokio-rs/axum/pull/1532 [#1539]: https://github.com/tokio-rs/axum/pull/1539 # 0.6.0-rc.4 (9. November, 2022) - **changed**: The inner error of a `JsonRejection` is now `serde_path_to_error::Error`. Previously it was `serde_json::Error` ([#1371]) - **added**: `JsonRejection` now displays the path at which a deserialization error occurred ([#1371]) - **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389]) - **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString` rejections, instead of `422 Unprocessable Entity` ([#1387]) - **added:** Add `middleware::from_extractor_with_state` ([#1396]) - **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) - **added:** Add `map_request`, `map_request_with_state` for transforming the request with an async function ([#1408]) - **added:** Add `map_response`, `map_response_with_state` for transforming the response with an async function ([#1414]) - **breaking:** `ContentLengthLimit` has been removed. Use `DefaultBodyLimit` instead ([#1400]) - **changed:** `Router` no longer implements `Service`, call `.into_service()` on it to obtain a `RouterService` that does ([#1368]) - **added:** Add `Router::inherit_state`, which creates a `Router` with an arbitrary state type without actually supplying the state; such a `Router` can't be turned into a service directly (`.into_service()` will panic), but can be nested or merged into a `Router` with the same state type ([#1368]) - **changed:** `Router::nest` now only accepts `Router`s, the general-purpose `Service` nesting method has been renamed to `nest_service` ([#1368]) - **added:** Support compiling to WASM. See the `simple-router-wasm` example for more details ([#1382]) - **breaking:** New `tokio` default feature needed for WASM support. If you don't need WASM support but have `default_features = false` for other reasons you likely need to re-enable the `tokio` feature ([#1382]) - **breaking:** `handler::{WithState, IntoService}` are merged into one type, named `HandlerService` ([#1418]) - **changed:** The default body limit now applies to the `Multipart` extractor ([#1420]) - **added:** String and binary `From` impls have been added to `extract::ws::Message` to be more inline with `tungstenite` ([#1421]) - **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430]) - **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from [`axum-macros`] behind the `macros` feature ([#1352]) - **breaking:** `MatchedPath` can now no longer be extracted in middleware for nested routes ([#1462]) - **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487]) - **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to `FormRejection::FailedToDeserializeForm` ([#1496]) [#1352]: https://github.com/tokio-rs/axum/pull/1352 [#1368]: https://github.com/tokio-rs/axum/pull/1368 [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1382]: https://github.com/tokio-rs/axum/pull/1382 [#1387]: https://github.com/tokio-rs/axum/pull/1387 [#1389]: https://github.com/tokio-rs/axum/pull/1389 [#1396]: https://github.com/tokio-rs/axum/pull/1396 [#1397]: https://github.com/tokio-rs/axum/pull/1397 [#1400]: https://github.com/tokio-rs/axum/pull/1400 [#1408]: https://github.com/tokio-rs/axum/pull/1408 [#1414]: https://github.com/tokio-rs/axum/pull/1414 [#1418]: https://github.com/tokio-rs/axum/pull/1418 [#1420]: https://github.com/tokio-rs/axum/pull/1420 [#1421]: https://github.com/tokio-rs/axum/pull/1421 [#1430]: https://github.com/tokio-rs/axum/pull/1430 [#1462]: https://github.com/tokio-rs/axum/pull/1462 [#1487]: https://github.com/tokio-rs/axum/pull/1487 [#1496]: https://github.com/tokio-rs/axum/pull/1496 # 0.6.0-rc.3 (8. November, 2022) Yanked, as it didn't compile in release mode. # 0.6.0-rc.2 (10. September, 2022) ## Security - **breaking:** Added default limit to how much data `Bytes::from_request` will consume. Previously it would attempt to consume the entire request body without checking its length. This meant if a malicious peer sent an large (or infinite) request body your server might run out of memory and crash. The default limit is at 2 MB and can be disabled by adding the new `DefaultBodyLimit::disable()` middleware. See its documentation for more details. This also applies to these extractors which used `Bytes::from_request` internally: - `Form` - `Json` - `String` ([#1346]) ## Routing - **breaking:** Adding a `.route_layer` onto a `Router` or `MethodRouter` without any routes will now result in a panic. Previously, this just did nothing. [#1327] [`axum-macros`]: https://docs.rs/axum-macros/latest/axum_macros/ ## Middleware - **added**: Add `middleware::from_fn_with_state` and `middleware::from_fn_with_state_arc` to enable running extractors that require state ([#1342]) [#1327]: https://github.com/tokio-rs/axum/pull/1327 [#1342]: https://github.com/tokio-rs/axum/pull/1342 [#1346]: https://github.com/tokio-rs/axum/pull/1346 # 0.6.0-rc.1 (23. August, 2022) ## Routing - **breaking:** Nested `Router`s will no longer delegate to the outer `Router`'s fallback. Instead you must explicitly set a fallback on the inner `Router` ([#1086]) This nested router on 0.5: ```rust use axum::{Router, handler::Handler}; let api_routes = Router::new(); let app = Router::new() .nest("/api", api_routes) .fallback(fallback.into_service()); async fn fallback() {} ``` Becomes this in 0.6: ```rust use axum::Router; let api_routes = Router::new() // we have to explicitly set the fallback here // since nested routers no longer delegate to the outer // router's fallback .fallback(fallback); let app = Router::new() .nest("/api", api_routes) .fallback(fallback); async fn fallback() {} ``` - **breaking:** The request `/foo/` no longer matches `/foo/*rest`. If you want to match `/foo/` you have to add a route specifically for that ([#1086]) For example: ```rust use axum::{Router, routing::get, extract::Path}; let app = Router::new() // this will match `/foo/bar/baz` .route("/foo/*rest", get(handler)) // this will match `/foo/` .route("/foo/", get(handler)) // if you want `/foo` to match you must also add an explicit route for it .route("/foo", get(handler)); async fn handler( // use an `Option` because `/foo/` and `/foo` don't have any path params params: Option>, ) {} ``` - **breaking:** Path params for wildcard routes no longer include the prefix `/`. e.g. `/foo.js` will match `/*filepath` with a value of `foo.js`, _not_ `/foo.js` ([#1086]) For example: ```rust use axum::{Router, routing::get, extract::Path}; let app = Router::new().route("/foo/*rest", get(handler)); async fn handler( Path(params): Path, ) { // for the request `/foo/bar/baz` the value of `params` will be `bar/baz` // // on 0.5 it would be `/bar/baz` } ``` - **fixed:** Routes like `/foo` and `/*rest` are no longer considered overlapping. `/foo` will take priority ([#1086]) For example: ```rust use axum::{Router, routing::get}; let app = Router::new() // this used to not be allowed but now just works .route("/foo/*rest", get(foo)) .route("/foo/bar", get(bar)); async fn foo() {} async fn bar() {} ``` - **breaking:** Trailing slash redirects have been removed. Previously if you added a route for `/foo`, axum would redirect calls to `/foo/` to `/foo` (or vice versa for `/foo/`). That is no longer supported and such requests will now be sent to the fallback. Consider using `axum_extra::routing::RouterExt::route_with_tsr` if you want the old behavior ([#1119]) For example: ```rust use axum::{Router, routing::get}; let app = Router::new() // a request to `GET /foo/` will now get `404 Not Found` // whereas in 0.5 axum would redirect to `/foo` // // same goes the other way if you had the route `/foo/` // axum will no longer redirect from `/foo` to `/foo/` .route("/foo", get(handler)); async fn handler() {} ``` - **breaking:** `Router::fallback` now only accepts `Handler`s (similarly to what `get`, `post`, etc accept). Use the new `Router::fallback_service` for setting any `Service` as the fallback ([#1155]) This fallback on 0.5: ```rust use axum::{Router, handler::Handler}; let app = Router::new().fallback(fallback.into_service()); async fn fallback() {} ``` Becomes this in 0.6 ```rust use axum::Router; let app = Router::new().fallback(fallback); async fn fallback() {} ``` - **breaking:** Allow `Error: Into` for `Route::{layer, route_layer}` ([#924]) - **breaking:** `MethodRouter` now panics on overlapping routes ([#1102]) - **breaking:** `Router::route` now only accepts `MethodRouter`s created with `get`, `post`, etc. Use the new `Router::route_service` for routing to any `Service`s ([#1155]) ## Extractors - **added:** Added new type safe `State` extractor. This can be used with `Router::with_state` and gives compile errors for missing states, whereas `Extension` would result in runtime errors ([#1155]) We recommend migrating from `Extension` to `State` since that is more type safe and faster. That is done by using `Router::with_state` and `State`. This setup in 0.5 ```rust use axum::{routing::get, Extension, Router}; let app = Router::new() .route("/", get(handler)) .layer(Extension(AppState {})); async fn handler(Extension(app_state): Extension) {} #[derive(Clone)] struct AppState {} ``` Becomes this in 0.6 using `State`: ```rust use axum::{routing::get, extract::State, Router}; let app = Router::with_state(AppState {}) .route("/", get(handler)); async fn handler(State(app_state): State) {} #[derive(Clone)] struct AppState {} ``` If you have multiple extensions you can use fields on `AppState` and implement `FromRef`: ```rust use axum::{extract::{State, FromRef}, routing::get, Router}; let state = AppState { client: HttpClient {}, database: Database {}, }; let app = Router::with_state(state).route("/", get(handler)); async fn handler( State(client): State, State(database): State, ) {} #[derive(Clone)] struct AppState { client: HttpClient, database: Database, } #[derive(Clone)] struct HttpClient {} impl FromRef for HttpClient { fn from_ref(state: &AppState) -> Self { state.client.clone() } } #[derive(Clone)] struct Database {} impl FromRef for Database { fn from_ref(state: &AppState) -> Self { state.database.clone() } } ``` - **breaking:** It is now only possible for one extractor per handler to consume the request body. In 0.5 doing so would result in runtime errors but in 0.6 it is a compile error ([#1272]) axum enforces this by only allowing the _last_ extractor to consume the request. For example: ```rust use axum::{Json, http::HeaderMap}; // This wont compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, string: String, ) {} // This won't work either since `Json` is not the last extractor. async fn handler_2( json: Json, headers: HeaderMap, ) {} // This works! async fn handler_3( headers: HeaderMap, json: Json, ) {} ``` This is done by reworking the `FromRequest` trait and introducing a new `FromRequestParts` trait. If your extractor needs to consume the request body then you should implement `FromRequest`, otherwise implement `FromRequestParts`. This extractor in 0.5: ```rust struct MyExtractor { /* ... */ } #[async_trait] impl FromRequest for MyExtractor where B: Send, { type Rejection = StatusCode; async fn from_request(req: &mut RequestParts) -> Result { // ... } } ``` Becomes this in 0.6: ```rust use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, { type Rejection = StatusCode; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // ... } } // implement `FromRequest` if you do need to consume the request body #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, B: Send + 'static, { type Rejection = StatusCode; async fn from_request(req: Request, state: &S) -> Result { // ... } } ``` - **breaking:** `RequestParts` has been removed as part of the `FromRequest` rework ([#1272]) - **breaking:** `BodyAlreadyExtracted` has been removed ([#1272]) - **breaking:** The following types or traits have a new `S` type param which represents the state ([#1155]): - `Router`, defaults to `()` - `MethodRouter`, defaults to `()` - `FromRequest`, no default - `Handler`, no default - **added:** Add `RequestExt` and `RequestPartsExt` which adds convenience methods for running extractors to `http::Request` and `http::request::Parts` ([#1301]) ## Middleware - **breaking:** Remove `extractor_middleware` which was previously deprecated. Use `axum::middleware::from_extractor` instead ([#1077]) - **added:** Support running extractors on `middleware::from_fn` functions ([#1088]) - **added:** Support any middleware response that implements `IntoResponse` ([#1152]) - **breaking:** Require middleware added with `Handler::layer` to have `Infallible` as the error type ([#1152]) ## Misc - **changed:** axum's MSRV is now 1.60 ([#1239]) - **changed:** For methods that accept some `S: Service`, the bounds have been relaxed so the response type must implement `IntoResponse` rather than being a literal `Response` - **fixed:** Annotate panicking functions with `#[track_caller]` so the error message points to where the user added the invalid route, rather than somewhere internally in axum ([#1248]) - **added:** Add `ServiceExt` with methods for turning any `Service` into a `MakeService` similarly to `Router::into_make_service` ([#1302]) [#1077]: https://github.com/tokio-rs/axum/pull/1077 [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1088]: https://github.com/tokio-rs/axum/pull/1088 [#1102]: https://github.com/tokio-rs/axum/pull/1102 [#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1152]: https://github.com/tokio-rs/axum/pull/1152 [#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1272]: https://github.com/tokio-rs/axum/pull/1272 [#1301]: https://github.com/tokio-rs/axum/pull/1301 [#1302]: https://github.com/tokio-rs/axum/pull/1302 [#924]: https://github.com/tokio-rs/axum/pull/924
# 0.5.16 (10. September, 2022) ## Security - **breaking:** Added default limit to how much data `Bytes::from_request` will consume. Previously it would attempt to consume the entire request body without checking its length. This meant if a malicious peer sent an large (or infinite) request body your server might run out of memory and crash. The default limit is at 2 MB and can be disabled by adding the new `DefaultBodyLimit::disable()` middleware. See its documentation for more details. This also applies to these extractors which used `Bytes::from_request` internally: - `Form` - `Json` - `String` ([#1346]) [#1346]: https://github.com/tokio-rs/axum/pull/1346 # 0.5.15 (9. August, 2022) - **fixed:** Don't expose internal type names in `QueryRejection` response. ([#1171]) - **fixed:** Improve performance of JSON serialization ([#1178]) - **fixed:** Improve build times by generating less IR ([#1192]) [#1171]: https://github.com/tokio-rs/axum/pull/1171 [#1178]: https://github.com/tokio-rs/axum/pull/1178 [#1192]: https://github.com/tokio-rs/axum/pull/1192 # 0.5.14 (25. July, 2022) Yanked, as it contained an accidental breaking change. # 0.5.13 (15. July, 2022) - **fixed:** If `WebSocketUpgrade` cannot upgrade the connection it will return a `WebSocketUpgradeRejection::ConnectionNotUpgradable` rejection ([#1135]) - **changed:** `WebSocketUpgradeRejection` has a new variant `ConnectionNotUpgradable` variant ([#1135]) [#1135]: https://github.com/tokio-rs/axum/pull/1135 # 0.5.12 (10. July, 2022) - **added:** Added `debug_handler` which is an attribute macro that improves type errors when applied to handler function. It is re-exported from `axum-macros` ([#1144]) [#1144]: https://github.com/tokio-rs/axum/pull/1144 # 0.5.11 (02. July, 2022) - **added:** Implement `TryFrom` for `MethodFilter` and use new `NoMatchingMethodFilter` error in case of failure ([#1130]) - **added:** Document how to run extractors from middleware ([#1140]) [#1130]: https://github.com/tokio-rs/axum/pull/1130 [#1140]: https://github.com/tokio-rs/axum/pull/1140 # 0.5.10 (28. June, 2022) - **fixed:** Make `Router` cheaper to clone ([#1123]) - **fixed:** Fix possible panic when doing trailing slash redirect ([#1124]) [#1123]: https://github.com/tokio-rs/axum/pull/1123 [#1124]: https://github.com/tokio-rs/axum/pull/1124 # 0.5.9 (20. June, 2022) - **fixed:** Fix compile error when the `headers` is enabled and the `form` feature is disabled ([#1107]) [#1107]: https://github.com/tokio-rs/axum/pull/1107 # 0.5.8 (18. June, 2022) - **added:** Support resolving host name via `Forwarded` header in `Host` extractor ([#1078]) - **added:** Implement `IntoResponse` for `Form` ([#1095]) - **changed:** axum's MSRV is now 1.56 ([#1098]) [#1078]: https://github.com/tokio-rs/axum/pull/1078 [#1095]: https://github.com/tokio-rs/axum/pull/1095 [#1098]: https://github.com/tokio-rs/axum/pull/1098 # 0.5.7 (08. June, 2022) - **added:** Implement `Default` for `Extension` ([#1043]) - **fixed:** Support deserializing `Vec<(String, String)>` in `extract::Path<_>` to get vector of key/value pairs ([#1059]) - **added:** Add `extract::ws::close_code` which contains constants for close codes ([#1067]) - **fixed:** Use `impl IntoResponse` less in docs ([#1049]) [#1043]: https://github.com/tokio-rs/axum/pull/1043 [#1049]: https://github.com/tokio-rs/axum/pull/1049 [#1059]: https://github.com/tokio-rs/axum/pull/1059 [#1067]: https://github.com/tokio-rs/axum/pull/1067 # 0.5.6 (15. May, 2022) - **added:** Add `WebSocket::protocol` to return the selected WebSocket subprotocol, if there is one. ([#1022]) - **fixed:** Improve error message for `PathRejection::WrongNumberOfParameters` to hint at using `Path<(String, String)>` or `Path` ([#1023]) - **fixed:** `PathRejection::WrongNumberOfParameters` now uses `500 Internal Server Error` since it's a programmer error and not a client error ([#1023]) - **fixed:** Fix `InvalidFormContentType` mentioning the wrong content type [#1022]: https://github.com/tokio-rs/axum/pull/1022 [#1023]: https://github.com/tokio-rs/axum/pull/1023 # 0.5.5 (10. May, 2022) - **fixed:** Correctly handle `GET`, `HEAD`, and `OPTIONS` requests in `ContentLengthLimit`. Request with these methods are now accepted if they _do not_ have a `Content-Length` header, and the request body will not be checked. If they do have a `Content-Length` header they'll be rejected. This allows `ContentLengthLimit` to be used as middleware around several routes, including `GET` routes ([#989]) - **added:** Add `MethodRouter::{into_make_service, into_make_service_with_connect_info}` ([#1010]) [#989]: https://github.com/tokio-rs/axum/pull/989 [#1010]: https://github.com/tokio-rs/axum/pull/1010 # 0.5.4 (26. April, 2022) - **added:** Add `response::ErrorResponse` and `response::Result` for `IntoResponse`-based error handling ([#921]) - **added:** Add `middleware::from_extractor` and deprecate `extract::extractor_middleware` ([#957]) - **changed:** Update to tower-http 0.3 ([#965]) [#921]: https://github.com/tokio-rs/axum/pull/921 [#957]: https://github.com/tokio-rs/axum/pull/957 [#965]: https://github.com/tokio-rs/axum/pull/965 # 0.5.3 (19. April, 2022) - **added:** Add `AppendHeaders` for appending headers to a response rather than overriding them ([#927]) - **added:** Add `axum::extract::multipart::Field::chunk` method for streaming a single chunk from the field ([#901]) - **fixed:** Fix trailing slash redirection with query parameters ([#936]) [#901]: https://github.com/tokio-rs/axum/pull/901 [#927]: https://github.com/tokio-rs/axum/pull/927 [#936]: https://github.com/tokio-rs/axum/pull/936 # 0.5.2 (19. April, 2022) Yanked, as it contained an accidental breaking change. # 0.5.1 (03. April, 2022) - **added:** Add `RequestParts::extract` which allows applying an extractor as a method call ([#897]) [#897]: https://github.com/tokio-rs/axum/pull/897 # 0.5.0 (31. March, 2022) - **added:** Document sharing state between handler and middleware ([#783]) - **added:** `Extension<_>` can now be used in tuples for building responses, and will set an extension on the response ([#797]) - **added:** `extract::Host` for extracting the hostname of a request ([#827]) - **added:** Add `IntoResponseParts` trait which allows defining custom response types for adding headers or extensions to responses ([#797]) - **added:** `TypedHeader` implements the new `IntoResponseParts` trait so they can be returned from handlers as parts of a response ([#797]) - **changed:** `Router::merge` now accepts `Into` ([#819]) - **breaking:** `sse::Event` now accepts types implementing `AsRef` instead of `Into` as field values. - **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently overwriting old values. - **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644]) - **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665]) - **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead `HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the headers you need ([#698]) This includes these breaking changes: - `RequestParts::take_headers` has been removed. - `RequestParts::headers` returns `&HeaderMap`. - `RequestParts::headers_mut` returns `&mut HeaderMap`. - `HeadersAlreadyExtracted` has been removed. - The `HeadersAlreadyExtracted` variant has been removed from these rejections: - `RequestAlreadyExtracted` - `RequestPartsAlreadyExtracted` - `JsonRejection` - `FormRejection` - `ContentLengthLimitRejection` - `WebSocketUpgradeRejection` - `>::Rejection` has been changed to `std::convert::Infallible`. - **breaking:** `axum::http::Extensions` is no longer an extractor (ie it doesn't implement `FromRequest`). The `axum::extract::Extension` extractor is _not_ impacted by this and works the same. This change makes it harder to accidentally remove all extensions which would result in confusing errors elsewhere ([#699]) This includes these breaking changes: - `RequestParts::take_extensions` has been removed. - `RequestParts::extensions` returns `&Extensions`. - `RequestParts::extensions_mut` returns `&mut Extensions`. - `RequestAlreadyExtracted` has been removed. - `::Rejection` is now `BodyAlreadyExtracted`. - `::Rejection` is now `Infallible`. - `ExtensionsAlreadyExtracted` has been removed. - The `ExtensionsAlreadyExtracted` removed variant has been removed from these rejections: - `ExtensionRejection` - `PathRejection` - `MatchedPathRejection` - `WebSocketUpgradeRejection` - **breaking:** `Redirect::found` has been removed ([#800]) - **breaking:** `AddExtensionLayer` has been removed. Use `Extension` instead. It now implements `tower::Layer` ([#807]) - **breaking:** `AddExtension` has been moved from the root module to `middleware` - **breaking:** `.nest("/foo/", Router::new().route("/bar", _))` now does the right thing and results in a route at `/foo/bar` instead of `/foo//bar` ([#824]) - **breaking:** Routes are now required to start with `/`. Previously routes such as `:foo` would be accepted but most likely result in bugs ([#823]) - **breaking:** `Headers` has been removed. Arrays of tuples directly implement `IntoResponseParts` so `([("x-foo", "foo")], response)` now works ([#797]) - **breaking:** `InvalidJsonBody` has been replaced with `JsonDataError` to clearly signal that the request body was syntactically valid JSON but couldn't be deserialized into the target type - **breaking:** `Handler` is no longer an `#[async_trait]` but instead has an associated `Future` type. That allows users to build their own `Handler` types without paying the cost of `#[async_trait]` ([#879]) - **changed:** New `JsonSyntaxError` variant added to `JsonRejection`. This is returned when the request body contains syntactically invalid JSON - **fixed:** Correctly set the `Content-Length` header for response to `HEAD` requests ([#734]) - **fixed:** Fix wrong `content-length` for `HEAD` requests to endpoints that returns chunked responses ([#755]) - **fixed:** Fixed several routing bugs related to nested "opaque" tower services (i.e. non-`Router` services) ([#841] and [#842]) - **changed:** Update to tokio-tungstenite 0.17 ([#791]) - **breaking:** `Redirect::{to, temporary, permanent}` now accept `&str` instead of `Uri` ([#889]) - **breaking:** Remove second type parameter from `Router::into_make_service_with_connect_info` and `Handler::into_make_service_with_connect_info` to support `MakeService`s that accept multiple targets ([#892]) [#644]: https://github.com/tokio-rs/axum/pull/644 [#665]: https://github.com/tokio-rs/axum/pull/665 [#698]: https://github.com/tokio-rs/axum/pull/698 [#699]: https://github.com/tokio-rs/axum/pull/699 [#734]: https://github.com/tokio-rs/axum/pull/734 [#755]: https://github.com/tokio-rs/axum/pull/755 [#783]: https://github.com/tokio-rs/axum/pull/783 [#791]: https://github.com/tokio-rs/axum/pull/791 [#797]: https://github.com/tokio-rs/axum/pull/797 [#800]: https://github.com/tokio-rs/axum/pull/800 [#807]: https://github.com/tokio-rs/axum/pull/807 [#819]: https://github.com/tokio-rs/axum/pull/819 [#823]: https://github.com/tokio-rs/axum/pull/823 [#824]: https://github.com/tokio-rs/axum/pull/824 [#827]: https://github.com/tokio-rs/axum/pull/827 [#841]: https://github.com/tokio-rs/axum/pull/841 [#842]: https://github.com/tokio-rs/axum/pull/842 [#879]: https://github.com/tokio-rs/axum/pull/879 [#889]: https://github.com/tokio-rs/axum/pull/889 [#892]: https://github.com/tokio-rs/axum/pull/892 # 0.4.8 (2. March, 2022) - Use correct path for `AddExtensionLayer` and `AddExtension::layer` deprecation notes ([#812]) [#812]: https://github.com/tokio-rs/axum/pull/812 # 0.4.7 (1. March, 2022) - **added:** Implement `tower::Layer` for `Extension` ([#801]) - **changed:** Deprecate `AddExtensionLayer`. Use `Extension` instead ([#805]) [#801]: https://github.com/tokio-rs/axum/pull/801 [#805]: https://github.com/tokio-rs/axum/pull/805 # 0.4.6 (22. February, 2022) - **added:** `middleware::from_fn` for creating middleware from async functions. This previously lived in axum-extra but has been moved to axum ([#719]) - **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` ([#733]) [#719]: https://github.com/tokio-rs/axum/pull/719 [#733]: https://github.com/tokio-rs/axum/pull/733 # 0.4.5 (31. January, 2022) - Reference [axum-macros] instead of [axum-debug]. The latter has been superseded by axum-macros and is deprecated ([#738]) [#738]: https://github.com/tokio-rs/axum/pull/738 [axum-debug]: https://docs.rs/axum-debug [axum-macros]: https://docs.rs/axum-macros # 0.4.4 (13. January, 2022) - **fixed:** Fix using incorrect path prefix when nesting `Router`s at `/` ([#691]) - **fixed:** Make `nest("", service)` work and mean the same as `nest("/", service)` ([#691]) - **fixed:** Replace response code `301` with `308` for trailing slash redirects. Also deprecates `Redirect::found` (`302`) in favor of `Redirect::temporary` (`307`) or `Redirect::to` (`303`). This is to prevent clients from changing non-`GET` requests to `GET` requests ([#682]) [#691]: https://github.com/tokio-rs/axum/pull/691 [#682]: https://github.com/tokio-rs/axum/pull/682 # 0.4.3 (21. December, 2021) - **added:** `axum::AddExtension::layer` ([#607]) - **added:** Re-export the headers crate when the headers feature is active ([#630]) - **fixed:** `sse::Event` will no longer drop the leading space of data, event ID and name values that have it ([#600]) - **fixed:** `sse::Event` is more strict about what field values it supports, disallowing any SSE events that break the specification (such as field values containing carriage returns) ([#599]) - **fixed:** Improve documentation of `sse::Event` ([#601]) - **fixed:** Make `Path` fail with `ExtensionsAlreadyExtracted` if another extractor (such as `Request`) has previously taken the request extensions. Thus `PathRejection` now contains a variant with `ExtensionsAlreadyExtracted`. This is not a breaking change since `PathRejection` is marked as `#[non_exhaustive]` ([#619]) - **fixed:** Fix misleading error message for `PathRejection` if extensions had previously been extracted ([#619]) - **fixed:** Use `AtomicU32` internally, rather than `AtomicU64`, to improve portability ([#616]) [#599]: https://github.com/tokio-rs/axum/pull/599 [#600]: https://github.com/tokio-rs/axum/pull/600 [#601]: https://github.com/tokio-rs/axum/pull/601 [#607]: https://github.com/tokio-rs/axum/pull/607 [#616]: https://github.com/tokio-rs/axum/pull/616 [#619]: https://github.com/tokio-rs/axum/pull/619 [#619]: https://github.com/tokio-rs/axum/pull/619 [#630]: https://github.com/tokio-rs/axum/pull/630 # 0.4.2 (06. December, 2021) - **fix:** Depend on the correct version of `axum-core` ([#592]) [#592]: https://github.com/tokio-rs/axum/pull/592 # 0.4.1 (06. December, 2021) - **added:** `axum::response::Response` now exists as a shorthand for writing `Response` ([#590]) [#590]: https://github.com/tokio-rs/axum/pull/590 # 0.4.0 (02. December, 2021) - **breaking:** New `MethodRouter` that works similarly to `Router`: - Route to handlers and services with the same type - Add middleware to some routes more easily with `MethodRouter::layer` and `MethodRouter::route_layer`. - Merge method routers with `MethodRouter::merge` - Customize response for unsupported methods with `MethodRouter::fallback` - **breaking:** The default for the type parameter in `FromRequest` and `RequestParts` has been removed. Use `FromRequest` and `RequestParts` to get the previous behavior ([#564]) - **added:** `FromRequest` and `IntoResponse` are now defined in a new called `axum-core`. This crate is intended for library authors to depend on, rather than `axum` itself, if possible. `axum-core` has a smaller API and will thus receive fewer breaking changes. `FromRequest` and `IntoResponse` are re-exported from `axum` in the same location so nothing is changed for `axum` users ([#564]) - **breaking:** The previously deprecated `axum::body::box_body` function has been removed. Use `axum::body::boxed` instead. - **fixed:** Adding the same route with different methods now works ie `.route("/", get(_)).route("/", post(_))`. - **breaking:** `routing::handler_method_router` and `routing::service_method_router` has been removed in favor of `routing::{get, get_service, ..., MethodRouter}`. - **breaking:** `HandleErrorExt` has been removed in favor of `MethodRouter::handle_error`. - **breaking:** `HandleErrorLayer` now requires the handler function to be `async` ([#534]) - **added:** `HandleErrorLayer` now supports running extractors. - **breaking:** The `Handler` trait is now defined as `Handler`. That is the type parameters have been swapped and `B` defaults to `axum::body::Body` ([#527]) - **breaking:** `Router::merge` will panic if both routers have fallbacks. Previously the left side fallback would be silently discarded ([#529]) - **breaking:** `Router::nest` will panic if the nested router has a fallback. Previously it would be silently discarded ([#529]) - Update WebSockets to use tokio-tungstenite 0.16 ([#525]) - **added:** Default to return `charset=utf-8` for text content type. ([#554]) - **breaking:** The `Body` and `BodyError` associated types on the `IntoResponse` trait have been removed - instead, `.into_response()` will now always return `Response` ([#571]) - **breaking:** `PathParamsRejection` has been renamed to `PathRejection` and its variants renamed to `FailedToDeserializePathParams` and `MissingPathParams`. This makes it more consistent with the rest of axum ([#574]) - **added:** `Path`'s rejection type now provides data about exactly which part of the path couldn't be deserialized ([#574]) [#525]: https://github.com/tokio-rs/axum/pull/525 [#527]: https://github.com/tokio-rs/axum/pull/527 [#529]: https://github.com/tokio-rs/axum/pull/529 [#534]: https://github.com/tokio-rs/axum/pull/534 [#554]: https://github.com/tokio-rs/axum/pull/554 [#564]: https://github.com/tokio-rs/axum/pull/564 [#571]: https://github.com/tokio-rs/axum/pull/571 [#574]: https://github.com/tokio-rs/axum/pull/574 # 0.3.4 (13. November, 2021) - **changed:** `box_body` has been renamed to `boxed`. `box_body` still exists but is deprecated ([#530]) [#530]: https://github.com/tokio-rs/axum/pull/530 # 0.3.3 (13. November, 2021) - Implement `FromRequest` for [`http::request::Parts`] so it can be used an extractor ([#489]) - Implement `IntoResponse` for `http::response::Parts` ([#490]) [#489]: https://github.com/tokio-rs/axum/pull/489 [#490]: https://github.com/tokio-rs/axum/pull/490 [`http::request::Parts`]: https://docs.rs/http/latest/http/request/struct.Parts.html # 0.3.2 (08. November, 2021) - **added:** Add `Router::route_layer` for applying middleware that will only run on requests that match a route. This is useful for middleware that return early, such as authorization ([#474]) [#474]: https://github.com/tokio-rs/axum/pull/474 # 0.3.1 (06. November, 2021) - **fixed:** Implement `Clone` for `IntoMakeServiceWithConnectInfo` ([#471]) [#471]: https://github.com/tokio-rs/axum/pull/471 # 0.3.0 (02. November, 2021) - Overall: - **fixed:** All known compile time issues are resolved, including those with `boxed` and those introduced by Rust 1.56 ([#404]) - **breaking:** The router's type is now always `Router` regardless of how many routes or middleware are applied ([#404]) This means router types are all always nameable: ```rust fn my_routes() -> Router { Router::new().route( "/users", post(|| async { "Hello, World!" }), ) } ``` - **breaking:** Added feature flags for HTTP1 and JSON. This enables removing a few dependencies if your app only uses HTTP2 or doesn't use JSON. This is only a breaking change if you depend on axum with `default_features = false`. ([#286]) - **breaking:** `Route::boxed` and `BoxRoute` have been removed as they're no longer necessary ([#404]) - **breaking:** `Nested`, `Or` types are now private. They no longer had to be public because `Router` is internally boxed ([#404]) - **breaking:** Remove `routing::Layered` as it didn't actually do anything and thus wasn't necessary - **breaking:** Vendor `AddExtensionLayer` and `AddExtension` to reduce public dependencies - **breaking:** `body::BoxBody` is now a type alias for `http_body::combinators::UnsyncBoxBody` and thus is no longer `Sync`. This is because bodies are streams and requiring streams to be `Sync` is unnecessary. - **added:** Implement `IntoResponse` for `http_body::combinators::UnsyncBoxBody`. - **added:** Add `Handler::into_make_service` for serving a handler without a `Router`. - **added:** Add `Handler::into_make_service_with_connect_info` for serving a handler without a `Router`, and storing info about the incoming connection. - **breaking:** axum's minimum supported rust version is now 1.56 - Routing: - Big internal refactoring of routing leading to several improvements ([#363]) - **added:** Wildcard routes like `.route("/api/users/*rest", service)` are now supported. - **fixed:** The order routes are added in no longer matters. - **fixed:** Adding a conflicting route will now cause a panic instead of silently making a route unreachable. - **fixed:** Route matching is faster as number of routes increases. - **breaking:** Handlers for multiple HTTP methods must be added in the same `Router::route` call. So `.route("/", get(get_handler).post(post_handler))` and _not_ `.route("/", get(get_handler)).route("/", post(post_handler))`. - **fixed:** Correctly handle trailing slashes in routes: - If a route with a trailing slash exists and a request without a trailing slash is received, axum will send a 301 redirection to the route with the trailing slash. - Or vice versa if a route without a trailing slash exists and a request with a trailing slash is received. - This can be overridden by explicitly defining two routes: One with and one without a trailing slash. - **breaking:** Method routing for handlers has been moved from `axum::handler` to `axum::routing`. So `axum::handler::get` now lives at `axum::routing::get` ([#405]) - **breaking:** Method routing for services has been moved from `axum::service` to `axum::routing::service_method_routing`. So `axum::service::get` now lives at `axum::routing::service_method_routing::get`, etc. ([#405]) - **breaking:** `Router::or` renamed to `Router::merge` and will now panic on overlapping routes. It now only accepts `Router`s and not general `Service`s. Use `Router::fallback` for adding fallback routes ([#408]) - **added:** `Router::fallback` for adding handlers for request that didn't match any routes. `Router::fallback` must be use instead of `nest("/", _)` ([#408]) - **breaking:** `EmptyRouter` has been renamed to `MethodNotAllowed` as it's only used in method routers and not in path routers (`Router`) - **breaking:** Remove support for routing based on the `CONNECT` method. An example of combining axum with and HTTP proxy can be found [here][proxy] ([#428]) - Extractors: - **fixed:** Expand accepted content types for JSON requests ([#378]) - **fixed:** Support deserializing `i128` and `u128` in `extract::Path` - **breaking:** Automatically do percent decoding in `extract::Path` ([#272]) - **breaking:** Change `Connected::connect_info` to return `Self` and remove the associated type `ConnectInfo` ([#396]) - **added:** Add `extract::MatchedPath` for accessing path in router that matched the request ([#412]) - Error handling: - **breaking:** Simplify error handling model ([#402]): - All services part of the router are now required to be infallible. - Error handling utilities have been moved to an `error_handling` module. - `Router::check_infallible` has been removed since routers are always infallible with the error handling changes. - Error handling closures must now handle all errors and thus always return something that implements `IntoResponse`. With these changes handling errors from fallible middleware is done like so: ```rust,no_run use axum::{ routing::get, http::StatusCode, error_handling::HandleErrorLayer, response::IntoResponse, Router, BoxError, }; use tower::ServiceBuilder; use std::time::Duration; let middleware_stack = ServiceBuilder::new() // Handle errors from middleware // // This middleware most be added above any fallible // ones if you're using `ServiceBuilder`, due to how ordering works .layer(HandleErrorLayer::new(handle_error)) // Return an error after 30 seconds .timeout(Duration::from_secs(30)); let app = Router::new() .route("/", get(|| async { /* ... */ })) .layer(middleware_stack); fn handle_error(_error: BoxError) -> impl IntoResponse { StatusCode::REQUEST_TIMEOUT } ``` And handling errors from fallible leaf services is done like so: ```rust use axum::{ Router, service, body::Body, routing::service_method_routing::get, response::IntoResponse, http::{Request, Response}, error_handling::HandleErrorExt, // for `.handle_error` }; use std::{io, convert::Infallible}; use tower::service_fn; let app = Router::new() .route( "/", get(service_fn(|_req: Request| async { let contents = tokio::fs::read_to_string("some_file").await?; Ok::<_, io::Error>(Response::new(Body::from(contents))) })) .handle_error(handle_io_error), ); fn handle_io_error(error: io::Error) -> impl IntoResponse { // ... } ``` - Misc: - `InvalidWebsocketVersionHeader` has been renamed to `InvalidWebSocketVersionHeader` ([#416]) - `WebsocketKeyHeaderMissing` has been renamed to `WebSocketKeyHeaderMissing` ([#416]) [#339]: https://github.com/tokio-rs/axum/pull/339 [#286]: https://github.com/tokio-rs/axum/pull/286 [#272]: https://github.com/tokio-rs/axum/pull/272 [#378]: https://github.com/tokio-rs/axum/pull/378 [#363]: https://github.com/tokio-rs/axum/pull/363 [#396]: https://github.com/tokio-rs/axum/pull/396 [#402]: https://github.com/tokio-rs/axum/pull/402 [#404]: https://github.com/tokio-rs/axum/pull/404 [#405]: https://github.com/tokio-rs/axum/pull/405 [#408]: https://github.com/tokio-rs/axum/pull/408 [#412]: https://github.com/tokio-rs/axum/pull/412 [#416]: https://github.com/tokio-rs/axum/pull/416 [#428]: https://github.com/tokio-rs/axum/pull/428 [proxy]: https://github.com/tokio-rs/axum/blob/main/examples/http-proxy/src/main.rs # 0.2.8 (07. October, 2021) - Document debugging handler type errors with "axum-debug" ([#372]) [#372]: https://github.com/tokio-rs/axum/pull/372 # 0.2.7 (06. October, 2021) - Bump minimum version of async-trait ([#370]) [#370]: https://github.com/tokio-rs/axum/pull/370 # 0.2.6 (02. October, 2021) - Clarify that `handler::any` and `service::any` only accepts standard HTTP methods ([#337]) - Document how to customize error responses from extractors ([#359]) [#337]: https://github.com/tokio-rs/axum/pull/337 [#359]: https://github.com/tokio-rs/axum/pull/359 # 0.2.5 (18. September, 2021) - Add accessors for `TypedHeaderRejection` fields ([#317]) - Improve docs for extractors ([#327]) [#317]: https://github.com/tokio-rs/axum/pull/317 [#327]: https://github.com/tokio-rs/axum/pull/327 # 0.2.4 (10. September, 2021) - Document using `StreamExt::split` with `WebSocket` ([#291]) - Document adding middleware to multiple groups of routes ([#293]) [#291]: https://github.com/tokio-rs/axum/pull/291 [#293]: https://github.com/tokio-rs/axum/pull/293 # 0.2.3 (26. August, 2021) - **fixed:** Fix accidental breaking change introduced by internal refactor. `BoxRoute` used to be `Sync` but was accidental made `!Sync` ([#273](https://github.com/tokio-rs/axum/pull/273)) # 0.2.2 (26. August, 2021) - **fixed:** Fix URI captures matching empty segments. This means requests with URI `/` will no longer be matched by `/:key` ([#264](https://github.com/tokio-rs/axum/pull/264)) - **fixed:** Remove needless trait bounds from `Router::boxed` ([#269](https://github.com/tokio-rs/axum/pull/269)) # 0.2.1 (24. August, 2021) - **added:** Add `Redirect::to` constructor ([#255](https://github.com/tokio-rs/axum/pull/255)) - **added:** Document how to implement `IntoResponse` for custom error type ([#258](https://github.com/tokio-rs/axum/pull/258)) # 0.2.0 (23. August, 2021) - Overall: - **fixed:** Overall compile time improvements. If you're having issues with compile time please file an issue! ([#184](https://github.com/tokio-rs/axum/pull/184)) ([#198](https://github.com/tokio-rs/axum/pull/198)) ([#220](https://github.com/tokio-rs/axum/pull/220)) - **changed:** Remove `prelude`. Explicit imports are now required ([#195](https://github.com/tokio-rs/axum/pull/195)) - Routing: - **added:** Add dedicated `Router` to replace the `RoutingDsl` trait ([#214](https://github.com/tokio-rs/axum/pull/214)) - **added:** Add `Router::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108)) - **fixed:** Support matching different HTTP methods for the same route that aren't defined together. So `Router::new().route("/", get(...)).route("/", post(...))` now accepts both `GET` and `POST`. Previously only `POST` would be accepted ([#224](https://github.com/tokio-rs/axum/pull/224)) - **fixed:** `get` routes will now also be called for `HEAD` requests but will always have the response body removed ([#129](https://github.com/tokio-rs/axum/pull/129)) - **changed:** Replace `axum::route(...)` with `axum::Router::new().route(...)`. This means there is now only one way to create a new router. Same goes for `axum::routing::nest`. ([#215](https://github.com/tokio-rs/axum/pull/215)) - **changed:** Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags) ([#158](https://github.com/tokio-rs/axum/pull/158)) - **changed:** Move `handle_error` from `ServiceExt` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160)) With these changes this app using 0.1: ```rust use axum::{extract::Extension, prelude::*, routing::BoxRoute, AddExtensionLayer}; let app = route("/", get(|| async { "hi" })) .nest("/api", api_routes()) .layer(AddExtensionLayer::new(state)); fn api_routes() -> BoxRoute { route( "/users", post(|Extension(state): Extension| async { "hi from nested" }), ) .boxed() } ``` Becomes this in 0.2: ```rust use axum::{ extract::Extension, handler::{get, post}, routing::BoxRoute, Router, }; let app = Router::new() .route("/", get(|| async { "hi" })) .nest("/api", api_routes()); fn api_routes() -> Router { Router::new() .route( "/users", post(|Extension(state): Extension| async { "hi from nested" }), ) .boxed() } ``` - Extractors: - **added:** Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146)) - **added:** Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153)) - **added:** Add `OriginalUri` for extracting original request URI in nested services ([#197](https://github.com/tokio-rs/axum/pull/197)) - **added:** Implement `FromRequest` for `http::Extensions` ([#169](https://github.com/tokio-rs/axum/pull/169)) - **added:** Make `RequestParts::{new, try_into_request}` public so extractors can be used outside axum ([#194](https://github.com/tokio-rs/axum/pull/194)) - **added:** Implement `FromRequest` for `axum::body::Body` ([#241](https://github.com/tokio-rs/axum/pull/241)) - **changed:** Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead ([#154](https://github.com/tokio-rs/axum/pull/154)) - **changed:** `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167)) - **changed:** Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167)) - **changed:** `extract::BodyStream` is no longer generic over the request body ([#234](https://github.com/tokio-rs/axum/pull/234)) - **changed:** `extract::Body` has been renamed to `extract::RawBody` to avoid conflicting with `body::Body` ([#233](https://github.com/tokio-rs/axum/pull/233)) - **changed:** `RequestParts` changes ([#153](https://github.com/tokio-rs/axum/pull/153)) - `method` new returns an `&http::Method` - `method_mut` new returns an `&mut http::Method` - `take_method` has been removed - `uri` new returns an `&http::Uri` - `uri_mut` new returns an `&mut http::Uri` - `take_uri` has been removed - **changed:** Remove several rejection types that were no longer used ([#153](https://github.com/tokio-rs/axum/pull/153)) ([#154](https://github.com/tokio-rs/axum/pull/154)) - Responses: - **added:** Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193)) - **added:** Add `Redirect` response ([#192](https://github.com/tokio-rs/axum/pull/192)) - **added:** Add `body::StreamBody` for easily responding with a stream of byte chunks ([#237](https://github.com/tokio-rs/axum/pull/237)) - **changed:** Add associated `Body` and `BodyError` types to `IntoResponse`. This is required for returning responses with bodies other than `hyper::Body` from handlers. See the docs for advice on how to implement `IntoResponse` ([#86](https://github.com/tokio-rs/axum/pull/86)) - **changed:** `tower::util::Either` no longer implements `IntoResponse` ([#229](https://github.com/tokio-rs/axum/pull/229)) This `IntoResponse` from 0.1: ```rust use axum::{http::Response, prelude::*, response::IntoResponse}; struct MyResponse; impl IntoResponse for MyResponse { fn into_response(self) -> Response { Response::new(Body::empty()) } } ``` Becomes this in 0.2: ```rust use axum::{body::Body, http::Response, response::IntoResponse}; struct MyResponse; impl IntoResponse for MyResponse { type Body = Body; type BodyError = ::Error; fn into_response(self) -> Response { Response::new(Body::empty()) } } ``` - SSE: - **added:** Add `response::sse::Sse`. This implements SSE using a response rather than a service ([#98](https://github.com/tokio-rs/axum/pull/98)) - **changed:** Remove `axum::sse`. It has been replaced by `axum::response::sse` ([#98](https://github.com/tokio-rs/axum/pull/98)) Handler using SSE in 0.1: ```rust use axum::{ prelude::*, sse::{sse, Event}, }; use std::convert::Infallible; let app = route( "/", sse(|| async { let stream = futures::stream::iter(vec![Ok::<_, Infallible>( Event::default().data("hi there!"), )]); Ok::<_, Infallible>(stream) }), ); ``` Becomes this in 0.2: ```rust use axum::{ handler::get, response::sse::{Event, Sse}, Router, }; use std::convert::Infallible; let app = Router::new().route( "/", get(|| async { let stream = futures::stream::iter(vec![Ok::<_, Infallible>( Event::default().data("hi there!"), )]); Sse::new(stream) }), ); ``` - WebSockets: - **changed:** Change WebSocket API to use an extractor plus a response ([#121](https://github.com/tokio-rs/axum/pull/121)) - **changed:** Make WebSocket `Message` an enum ([#116](https://github.com/tokio-rs/axum/pull/116)) - **changed:** `WebSocket` now uses `Error` as its error type ([#150](https://github.com/tokio-rs/axum/pull/150)) Handler using WebSockets in 0.1: ```rust use axum::{ prelude::*, ws::{ws, WebSocket}, }; let app = route( "/", ws(|socket: WebSocket| async move { // do stuff with socket }), ); ``` Becomes this in 0.2: ```rust use axum::{ extract::ws::{WebSocket, WebSocketUpgrade}, handler::get, Router, }; let app = Router::new().route( "/", get(|ws: WebSocketUpgrade| async move { ws.on_upgrade(|socket: WebSocket| async move { // do stuff with socket }) }), ); ``` - Misc - **added:** Add default feature `tower-log` which exposes `tower`'s `log` feature. ([#218](https://github.com/tokio-rs/axum/pull/218)) - **changed:** Replace `body::BoxStdError` with `axum::Error`, which supports downcasting ([#150](https://github.com/tokio-rs/axum/pull/150)) - **changed:** `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108)) - **changed:** `Router::check_infallible` now returns a `CheckInfallible` service. This is to improve compile times ([#198](https://github.com/tokio-rs/axum/pull/198)) - **changed:** `Router::into_make_service` now returns `routing::IntoMakeService` rather than `tower::make::Shared` ([#229](https://github.com/tokio-rs/axum/pull/229)) - **changed:** All usage of `tower::BoxError` has been replaced with `axum::BoxError` ([#229](https://github.com/tokio-rs/axum/pull/229)) - **changed:** Several response future types have been moved into dedicated `future` modules ([#133](https://github.com/tokio-rs/axum/pull/133)) - **changed:** `EmptyRouter`, `ExtractorMiddleware`, `ExtractorMiddlewareLayer`, and `QueryStringMissing` no longer implement `Copy` ([#132](https://github.com/tokio-rs/axum/pull/132)) - **changed:** `service::OnMethod`, `handler::OnMethod`, and `routing::Nested` have new response future types ([#157](https://github.com/tokio-rs/axum/pull/157)) # 0.1.3 (06. August, 2021) - Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91)) - Add support for WebSocket protocol negotiation ([#83](https://github.com/tokio-rs/axum/pull/83)) - Use `pin-project-lite` instead of `pin-project` ([#95](https://github.com/tokio-rs/axum/pull/95)) - Re-export `http` crate and `hyper::Server` ([#110](https://github.com/tokio-rs/axum/pull/110)) - Fix `Query` and `Form` extractors giving bad request error when query string is empty. ([#117](https://github.com/tokio-rs/axum/pull/117)) - Add `Path` extractor. ([#124](https://github.com/tokio-rs/axum/pull/124)) - Fixed the implementation of `IntoResponse` of `(HeaderMap, T)` and `(StatusCode, HeaderMap, T)` would ignore headers from `T` ([#137](https://github.com/tokio-rs/axum/pull/137)) - Deprecate `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead ([#138](https://github.com/tokio-rs/axum/pull/138)) # 0.1.2 (01. August, 2021) - Implement `Stream` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52)) - Implement `Sink` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52)) - Implement `Deref` most extractors ([#56](https://github.com/tokio-rs/axum/pull/56)) - Return `405 Method Not Allowed` for unsupported method for route ([#63](https://github.com/tokio-rs/axum/pull/63)) - Add extractor for remote connection info ([#55](https://github.com/tokio-rs/axum/pull/55)) - Improve error message of `MissingExtension` rejections ([#72](https://github.com/tokio-rs/axum/pull/72)) - Improve documentation for routing ([#71](https://github.com/tokio-rs/axum/pull/71)) - Clarify required response body type when routing to `tower::Service`s ([#69](https://github.com/tokio-rs/axum/pull/69)) - Add `axum::body::box_body` to converting an `http_body::Body` to `axum::body::BoxBody` ([#69](https://github.com/tokio-rs/axum/pull/69)) - Add `axum::sse` for Server-Sent Events ([#75](https://github.com/tokio-rs/axum/pull/75)) - Mention required dependencies in docs ([#77](https://github.com/tokio-rs/axum/pull/77)) - Fix WebSockets failing on Firefox ([#76](https://github.com/tokio-rs/axum/pull/76)) # 0.1.1 (30. July, 2021) - Misc readme fixes. # 0.1.0 (30. July, 2021) - Initial release. axum-axum-v0.6.20/axum/Cargo.toml000066400000000000000000000123071446276533600166340ustar00rootroot00000000000000[package] name = "axum" version = "0.6.20" categories = ["asynchronous", "network-programming", "web-programming::http-server"] description = "Web framework that focuses on ergonomics and modularity" edition = "2021" rust-version = "1.63" homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" readme = "README.md" repository = "https://github.com/tokio-rs/axum" [features] default = ["form", "http1", "json", "matched-path", "original-uri", "query", "tokio", "tower-log"] form = ["dep:serde_urlencoded"] http1 = ["hyper/http1"] http2 = ["hyper/http2"] json = ["dep:serde_json", "dep:serde_path_to_error"] macros = ["dep:axum-macros"] matched-path = [] multipart = ["dep:multer"] original-uri = [] query = ["dep:serde_urlencoded"] tokio = ["dep:tokio", "hyper/server", "hyper/tcp", "hyper/runtime", "tower/make"] tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] ws = ["tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] # Required for intra-doc links to resolve correctly __private_docs = ["tower/full", "dep:tower-http"] [dependencies] async-trait = "0.1.67" axum-core = { path = "../axum-core", version = "0.3.4" } bitflags = "1.0" bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2.9" http-body = "0.4.4" hyper = { version = "0.14.24", features = ["stream"] } itoa = "1.0.5" matchit = "0.7" memchr = "2.4.1" mime = "0.3.16" percent-encoding = "2.1" pin-project-lite = "0.2.7" serde = "1.0" sync_wrapper = "0.1.1" tower = { version = "0.4.13", default-features = false, features = ["util"] } tower-layer = "0.3.2" tower-service = "0.3" # optional dependencies axum-macros = { path = "../axum-macros", version = "0.3.8", optional = true } base64 = { version = "0.21.0", optional = true } headers = { version = "0.3.7", optional = true } multer = { version = "2.0.0", optional = true } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } serde_urlencoded = { version = "0.7", optional = true } sha1 = { version = "0.10", optional = true } tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true } tokio-tungstenite = { version = "0.20", optional = true } tracing = { version = "0.1", default-features = false, optional = true } [dependencies.tower-http] version = "0.4" optional = true features = [ # all tower-http features except (de)?compression-zstd which doesn't # build on `--target armv5te-unknown-linux-musleabi` "add-extension", "auth", "catch-panic", "compression-br", "compression-deflate", "compression-gzip", "cors", "decompression-br", "decompression-deflate", "decompression-gzip", "follow-redirect", "fs", "limit", "map-request-body", "map-response-body", "metrics", "normalize-path", "propagate-header", "redirect", "request-id", "sensitive-headers", "set-header", "set-status", "timeout", "trace", "util", "validate-request", ] [build-dependencies] rustversion = "1.0.9" [dev-dependencies] anyhow = "1.0" axum-macros = { path = "../axum-macros", version = "0.3.8", features = ["__private"] } quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] } rustversion = "1.0.9" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json"] } uuid = { version = "1.0", features = ["serde", "v4"] } [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies.tower] package = "tower" version = "0.4.10" features = [ "util", "timeout", "limit", "load-shed", "steer", "filter", ] [dev-dependencies.tower-http] version = "0.4" features = [ # all tower-http features except (de)?compression-zstd which doesn't # build on `--target armv5te-unknown-linux-musleabi` "add-extension", "auth", "catch-panic", "compression-br", "compression-deflate", "compression-gzip", "cors", "decompression-br", "decompression-deflate", "decompression-gzip", "follow-redirect", "fs", "limit", "map-request-body", "map-response-body", "metrics", "normalize-path", "propagate-header", "redirect", "request-id", "sensitive-headers", "set-header", "set-status", "timeout", "trace", "util", "validate-request", ] [package.metadata.playground] features = [ "http1", "http2", "json", "multipart", "ws", ] [package.metadata.cargo-public-api-crates] allowed = [ "async_trait", "axum_core", "axum_macros", "bytes", "futures_core", "futures_sink", "futures_util", "headers", "headers_core", "http", "http_body", "hyper", "serde", "serde_json", "tower_layer", "tower_service", ] [[bench]] name = "benches" harness = false axum-axum-v0.6.20/axum/LICENSE000066400000000000000000000020451446276533600157070ustar00rootroot00000000000000Copyright (c) 2019 Axum Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. axum-axum-v0.6.20/axum/README.md000066400000000000000000000125611446276533600161650ustar00rootroot00000000000000# axum `axum` is a web application framework that focuses on ergonomics and modularity. [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum)](https://crates.io/crates/axum) [![Documentation](https://docs.rs/axum/badge.svg)](https://docs.rs/axum) More information about this crate can be found in the [crate documentation][docs]. ## High level features - Route requests to handlers with a macro free API. - Declaratively parse requests using extractors. - Simple and predictable error handling model. - Generate responses with minimal boilerplate. - Take full advantage of the [`tower`] and [`tower-http`] ecosystem of middleware, services, and utilities. In particular the last point is what sets `axum` apart from other frameworks. `axum` doesn't have its own middleware system but instead uses [`tower::Service`]. This means `axum` gets timeouts, tracing, compression, authorization, and more, for free. It also enables you to share middleware with applications written using [`hyper`] or [`tonic`]. ## Usage example ```rust use axum::{ routing::{get, post}, http::StatusCode, response::IntoResponse, Json, Router, }; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; #[tokio::main] async fn main() { // initialize tracing tracing_subscriber::fmt::init(); // build our application with a route let app = Router::new() // `GET /` goes to `root` .route("/", get(root)) // `POST /users` goes to `create_user` .route("/users", post(create_user)); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // basic handler that responds with a static string async fn root() -> &'static str { "Hello, World!" } async fn create_user( // this argument tells axum to parse the request body // as JSON into a `CreateUser` type Json(payload): Json, ) -> (StatusCode, Json) { // insert your application logic here let user = User { id: 1337, username: payload.username, }; // this will be converted into a JSON response // with a status code of `201 Created` (StatusCode::CREATED, Json(user)) } // the input to our `create_user` handler #[derive(Deserialize)] struct CreateUser { username: String, } // the output to our `create_user` handler #[derive(Serialize)] struct User { id: u64, username: String, } ``` You can find this [example][readme-example] as well as other example projects in the [example directory][examples]. See the [crate documentation][docs] for way more examples. ## Performance `axum` is a relatively thin layer on top of [`hyper`] and adds very little overhead. So `axum`'s performance is comparable to [`hyper`]. You can find benchmarks [here](https://github.com/programatik29/rust-web-benchmarks) and [here](https://web-frameworks-benchmark.netlify.app/result?l=rust). ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. ## Minimum supported Rust version axum's MSRV is 1.63. ## Examples The [examples] folder contains various examples of how to use `axum`. The [docs] also provide lots of code snippets and examples. For full-fledged examples, check out community-maintained [showcases] or [tutorials]. ## Getting Help In the `axum`'s repo we also have a [number of examples][examples] showing how to put everything together. Community-maintained [showcases] and [tutorials] also demonstrate how to use `axum` for real-world applications. You're also welcome to ask in the [Discord channel][chat] or open a [discussion] with your question. ## Community projects See [here][ecosystem] for a list of community maintained crates and projects built with `axum`. ## Contributing ๐ŸŽˆ Thanks for your help improving the project! We are so happy to have you! We have a [contributing guide][contributing] to help you get involved in the `axum` project. ## License This project is licensed under the [MIT license][license]. ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in `axum` by you, shall be licensed as MIT, without any additional terms or conditions. [readme-example]: https://github.com/tokio-rs/axum/tree/main/examples/readme [examples]: https://github.com/tokio-rs/axum/tree/main/examples [docs]: https://docs.rs/axum [`tower`]: https://crates.io/crates/tower [`hyper`]: https://crates.io/crates/hyper [`tower-http`]: https://crates.io/crates/tower-http [`tonic`]: https://crates.io/crates/tonic [contributing]: https://github.com/tokio-rs/axum/blob/main/CONTRIBUTING.md [chat]: https://discord.gg/tokio [discussion]: https://github.com/tokio-rs/axum/discussions/new?category=q-a [`tower::Service`]: https://docs.rs/tower/latest/tower/trait.Service.html [ecosystem]: https://github.com/tokio-rs/axum/blob/main/ECOSYSTEM.md [showcases]: https://github.com/tokio-rs/axum/blob/main/ECOSYSTEM.md#project-showcase [tutorials]: https://github.com/tokio-rs/axum/blob/main/ECOSYSTEM.md#tutorials [license]: https://github.com/tokio-rs/axum/blob/main/axum/LICENSE axum-axum-v0.6.20/axum/benches/000077500000000000000000000000001446276533600163105ustar00rootroot00000000000000axum-axum-v0.6.20/axum/benches/benches.rs000066400000000000000000000152121446276533600202660ustar00rootroot00000000000000use axum::{ extract::State, routing::{get, post}, Extension, Json, Router, Server, }; use hyper::server::conn::AddrIncoming; use serde::{Deserialize, Serialize}; use std::{ io::BufRead, process::{Command, Stdio}, }; fn main() { if on_ci() { install_rewrk(); } else { ensure_rewrk_is_installed(); } benchmark("minimal").run(Router::new); benchmark("basic") .path("/a/b/c") .run(|| Router::new().route("/a/b/c", get(|| async { "Hello, World!" }))); benchmark("basic-merge").path("/a/b/c").run(|| { let inner = Router::new().route("/a/b/c", get(|| async { "Hello, World!" })); Router::new().merge(inner) }); benchmark("basic-nest").path("/a/b/c").run(|| { let c = Router::new().route("/c", get(|| async { "Hello, World!" })); let b = Router::new().nest("/b", c); Router::new().nest("/a", b) }); benchmark("routing").path("/foo/bar/baz").run(|| { let mut app = Router::new(); for a in 0..10 { for b in 0..10 { for c in 0..10 { app = app.route(&format!("/foo-{a}/bar-{b}/baz-{c}"), get(|| async {})); } } } app.route("/foo/bar/baz", get(|| async {})) }); benchmark("receive-json") .method("post") .headers(&[("content-type", "application/json")]) .body(r#"{"n": 123, "s": "hi there", "b": false}"#) .run(|| Router::new().route("/", post(|_: Json| async {}))); benchmark("send-json").run(|| { Router::new().route( "/", get(|| async { Json(Payload { n: 123, s: "hi there".to_owned(), b: false, }) }), ) }); let state = AppState { _string: "aaaaaaaaaaaaaaaaaa".to_owned(), _vec: Vec::from([ "aaaaaaaaaaaaaaaaaa".to_owned(), "bbbbbbbbbbbbbbbbbb".to_owned(), "cccccccccccccccccc".to_owned(), ]), }; benchmark("extension").run(|| { Router::new() .route("/", get(|_: Extension| async {})) .layer(Extension(state.clone())) }); benchmark("state").run(|| { Router::new() .route("/", get(|_: State| async {})) .with_state(state.clone()) }); } #[derive(Clone)] struct AppState { _string: String, _vec: Vec, } #[derive(Deserialize, Serialize)] struct Payload { n: u32, s: String, b: bool, } fn benchmark(name: &'static str) -> BenchmarkBuilder { BenchmarkBuilder { name, path: None, method: None, headers: None, body: None, } } struct BenchmarkBuilder { name: &'static str, path: Option<&'static str>, method: Option<&'static str>, headers: Option<&'static [(&'static str, &'static str)]>, body: Option<&'static str>, } macro_rules! config_method { ($name:ident, $ty:ty) => { fn $name(mut self, $name: $ty) -> Self { self.$name = Some($name); self } }; } impl BenchmarkBuilder { config_method!(path, &'static str); config_method!(method, &'static str); config_method!(headers, &'static [(&'static str, &'static str)]); config_method!(body, &'static str); fn run(self, f: F) where F: FnOnce() -> Router<()>, { // support only running some benchmarks with // ``` // cargo bench -- routing send-json // ``` let args = std::env::args().collect::>(); if args.len() != 1 { let names = &args[1..args.len() - 1]; if !names.is_empty() && !names.contains(&self.name.to_owned()) { return; } } let app = f(); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap(); let listener = rt .block_on(tokio::net::TcpListener::bind("0.0.0.0:0")) .unwrap(); let addr = listener.local_addr().unwrap(); std::thread::spawn(move || { rt.block_on(async move { let incoming = AddrIncoming::from_listener(listener).unwrap(); Server::builder(incoming) .serve(app.into_make_service()) .await .unwrap(); }); }); let mut cmd = Command::new("rewrk"); cmd.stdout(Stdio::piped()); cmd.arg("--host"); cmd.arg(format!("http://{}{}", addr, self.path.unwrap_or(""))); cmd.args(["--connections", "10"]); cmd.args(["--threads", "10"]); if on_ci() { // don't slow down CI by running the benchmarks for too long // but do run them for a bit cmd.args(["--duration", "1s"]); } else { cmd.args(["--duration", "10s"]); } if let Some(method) = self.method { cmd.args(["--method", method]); } for (key, value) in self.headers.into_iter().flatten() { cmd.arg("--header"); cmd.arg(format!("{key}: {value}")); } if let Some(body) = self.body { cmd.args(["--body", body]); } eprintln!("Running {:?} benchmark", self.name); // indent output from `rewrk` so its easier to read when running multiple benchmarks let mut child = cmd.spawn().unwrap(); let stdout = child.stdout.take().unwrap(); let stdout = std::io::BufReader::new(stdout); for line in stdout.lines() { let line = line.unwrap(); println!(" {line}"); } let status = child.wait().unwrap(); if !status.success() { eprintln!("`rewrk` command failed"); std::process::exit(status.code().unwrap()); } } } fn install_rewrk() { println!("installing rewrk"); let mut cmd = Command::new("cargo"); cmd.args([ "install", "rewrk", "--git", "https://github.com/ChillFish8/rewrk.git", ]); let status = cmd .status() .unwrap_or_else(|_| panic!("failed to install rewrk")); if !status.success() { panic!("failed to install rewrk"); } } fn ensure_rewrk_is_installed() { let mut cmd = Command::new("rewrk"); cmd.arg("--help"); cmd.stdout(Stdio::null()); cmd.stderr(Stdio::null()); cmd.status().unwrap_or_else(|_| { panic!("rewrk is not installed. See https://github.com/lnx-search/rewrk") }); } fn on_ci() -> bool { std::env::var("GITHUB_ACTIONS").is_ok() } axum-axum-v0.6.20/axum/build.rs000066400000000000000000000002111446276533600163400ustar00rootroot00000000000000#[rustversion::nightly] fn main() { println!("cargo:rustc-cfg=nightly_error_messages"); } #[rustversion::not(nightly)] fn main() {} axum-axum-v0.6.20/axum/src/000077500000000000000000000000001446276533600154705ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/body/000077500000000000000000000000001446276533600164255ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/body/mod.rs000066400000000000000000000004541446276533600175550ustar00rootroot00000000000000//! HTTP body utilities. mod stream_body; pub use self::stream_body::StreamBody; #[doc(no_inline)] pub use http_body::{Body as HttpBody, Empty, Full}; #[doc(no_inline)] pub use hyper::body::Body; #[doc(no_inline)] pub use bytes::Bytes; #[doc(inline)] pub use axum_core::body::{boxed, BoxBody}; axum-axum-v0.6.20/axum/src/body/stream_body.rs000066400000000000000000000072471446276533600213150ustar00rootroot00000000000000use crate::{ body::{self, Bytes, HttpBody}, response::{IntoResponse, Response}, BoxError, Error, }; use futures_util::{ ready, stream::{self, TryStream}, }; use http::HeaderMap; use pin_project_lite::pin_project; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; use sync_wrapper::SyncWrapper; pin_project! { /// An [`http_body::Body`] created from a [`Stream`]. /// /// The purpose of this type is to be used in responses. If you want to /// extract the request body as a stream consider using /// [`BodyStream`](crate::extract::BodyStream). /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// body::StreamBody, /// response::IntoResponse, /// }; /// use futures_util::stream::{self, Stream}; /// use std::io; /// /// async fn handler() -> StreamBody>> { /// let chunks: Vec> = vec![ /// Ok("Hello,"), /// Ok(" "), /// Ok("world!"), /// ]; /// let stream = stream::iter(chunks); /// StreamBody::new(stream) /// } /// /// let app = Router::new().route("/", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// [`Stream`]: futures_util::stream::Stream #[must_use] pub struct StreamBody { #[pin] stream: SyncWrapper, } } impl From for StreamBody where S: TryStream + Send + 'static, S::Ok: Into, S::Error: Into, { fn from(stream: S) -> Self { Self::new(stream) } } impl StreamBody { /// Create a new `StreamBody` from a [`Stream`]. /// /// [`Stream`]: futures_util::stream::Stream pub fn new(stream: S) -> Self where S: TryStream + Send + 'static, S::Ok: Into, S::Error: Into, { Self { stream: SyncWrapper::new(stream), } } } impl IntoResponse for StreamBody where S: TryStream + Send + 'static, S::Ok: Into, S::Error: Into, { fn into_response(self) -> Response { Response::new(body::boxed(self)) } } impl Default for StreamBody>> { fn default() -> Self { Self::new(stream::empty()) } } impl fmt::Debug for StreamBody { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("StreamBody").finish() } } impl HttpBody for StreamBody where S: TryStream, S::Ok: Into, S::Error: Into, { type Data = Bytes; type Error = Error; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let stream = self.project().stream.get_pin_mut(); match ready!(stream.try_poll_next(cx)) { Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))), Some(Err(err)) => Poll::Ready(Some(Err(Error::new(err)))), None => Poll::Ready(None), } } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } } #[test] fn stream_body_traits() { use futures_util::stream::Empty; type EmptyStream = StreamBody>>; crate::test_helpers::assert_send::(); crate::test_helpers::assert_sync::(); crate::test_helpers::assert_unpin::(); } axum-axum-v0.6.20/axum/src/boxed.rs000066400000000000000000000113551446276533600171440ustar00rootroot00000000000000use std::{convert::Infallible, fmt}; use http::Request; use tower::Service; use crate::{ body::HttpBody, handler::Handler, routing::{future::RouteFuture, Route}, Router, }; pub(crate) struct BoxedIntoRoute(Box>); impl BoxedIntoRoute where S: Clone + Send + Sync + 'static, B: Send + 'static, { pub(crate) fn from_handler(handler: H) -> Self where H: Handler, T: 'static, B: HttpBody, { Self(Box::new(MakeErasedHandler { handler, into_route: |handler, state| Route::new(Handler::with_state(handler, state)), })) } } impl BoxedIntoRoute { pub(crate) fn map(self, f: F) -> BoxedIntoRoute where S: 'static, B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static, { BoxedIntoRoute(Box::new(Map { inner: self.0, layer: Box::new(f), })) } pub(crate) fn into_route(self, state: S) -> Route { self.0.into_route(state) } } impl Clone for BoxedIntoRoute { fn clone(&self) -> Self { Self(self.0.clone_box()) } } impl fmt::Debug for BoxedIntoRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("BoxedIntoRoute").finish() } } pub(crate) trait ErasedIntoRoute: Send { fn clone_box(&self) -> Box>; fn into_route(self: Box, state: S) -> Route; fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture; } pub(crate) struct MakeErasedHandler { pub(crate) handler: H, pub(crate) into_route: fn(H, S) -> Route, } impl ErasedIntoRoute for MakeErasedHandler where H: Clone + Send + 'static, S: 'static, B: HttpBody + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.handler, state) } fn call_with_state( self: Box, request: Request, state: S, ) -> RouteFuture { self.into_route(state).call(request) } } impl Clone for MakeErasedHandler where H: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), into_route: self.into_route, } } } pub(crate) struct MakeErasedRouter { pub(crate) router: Router, pub(crate) into_route: fn(Router, S) -> Route, } impl ErasedIntoRoute for MakeErasedRouter where S: Clone + Send + Sync + 'static, B: HttpBody + Send + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) } fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.router, state) } fn call_with_state( mut self: Box, request: Request, state: S, ) -> RouteFuture { self.router.call_with_state(request, state) } } impl Clone for MakeErasedRouter where S: Clone, { fn clone(&self) -> Self { Self { router: self.router.clone(), into_route: self.into_route, } } } pub(crate) struct Map { pub(crate) inner: Box>, pub(crate) layer: Box>, } impl ErasedIntoRoute for Map where S: 'static, B: 'static, E: 'static, B2: HttpBody + 'static, E2: 'static, { fn clone_box(&self) -> Box> { Box::new(Self { inner: self.inner.clone_box(), layer: self.layer.clone_box(), }) } fn into_route(self: Box, state: S) -> Route { (self.layer)(self.inner.into_route(state)) } fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { (self.layer)(self.inner.into_route(state)).call(request) } } pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { fn clone_box(&self) -> Box>; } impl LayerFn for F where F: FnOnce(Route) -> Route + Clone + Send + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) } } axum-axum-v0.6.20/axum/src/docs/000077500000000000000000000000001446276533600164205ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/docs/debugging_handler_type_errors.md000066400000000000000000000027441446276533600250360ustar00rootroot00000000000000## Debugging handler type errors For a function to be used as a handler it must implement the [`Handler`] trait. axum provides blanket implementations for functions that: - Are `async fn`s. - Take no more than 16 arguments that all implement [`FromRequest`]. - Returns something that implements [`IntoResponse`]. - If a closure is used it must implement `Clone + Send` and be `'static`. - Returns a future that is `Send`. The most common way to accidentally make a future `!Send` is to hold a `!Send` type across an await. Unfortunately Rust gives poor error messages if you try to use a function that doesn't quite match what's required by [`Handler`]. You might get an error like this: ```not_rust error[E0277]: the trait bound `fn(bool) -> impl Future {handler}: Handler<_, _>` is not satisfied --> src/main.rs:13:44 | 13 | let app = Router::new().route("/", get(handler)); | ^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(bool) -> impl Future {handler}` | ::: axum/src/handler/mod.rs:116:8 | 116 | H: Handler, | ------------- required by this bound in `axum::routing::get` ``` This error doesn't tell you _why_ your function doesn't implement [`Handler`]. It's possible to improve the error with the [`debug_handler`] proc-macro from the [axum-macros] crate. [axum-macros]: https://docs.rs/axum-macros [`debug_handler`]: https://docs.rs/axum-macros/latest/axum_macros/attr.debug_handler.html axum-axum-v0.6.20/axum/src/docs/error_handling.md000066400000000000000000000137341446276533600217470ustar00rootroot00000000000000Error handling model and utilities # Table of contents - [axum's error handling model](#axums-error-handling-model) - [Routing to fallible services](#routing-to-fallible-services) - [Applying fallible middleware](#applying-fallible-middleware) - [Running extractors for error handling](#running-extractors-for-error-handling) # axum's error handling model axum is based on [`tower::Service`] which bundles errors through its associated `Error` type. If you have a [`Service`] that produces an error and that error makes it all the way up to hyper, the connection will be terminated _without_ sending a response. This is generally not desirable so axum makes sure you always produce a response by relying on the type system. axum does this by requiring all services have [`Infallible`] as their error type. `Infallible` is the error type for errors that can never happen. This means if you define a handler like: ```rust use axum::http::StatusCode; async fn handler() -> Result { # todo!() // ... } ``` While it looks like it might fail with a `StatusCode` this actually isn't an "error". If this handler returns `Err(some_status_code)` that will still be converted into a [`Response`] and sent back to the client. This is done through `StatusCode`'s [`IntoResponse`] implementation. It doesn't matter whether you return `Err(StatusCode::NOT_FOUND)` or `Err(StatusCode::INTERNAL_SERVER_ERROR)`. These are not considered errors in axum. Instead of a direct `StatusCode`, it makes sense to use intermediate error type that can ultimately be converted to `Response`. This allows using `?` operator in handlers. See those examples: * [`anyhow-error-response`][anyhow] for generic boxed errors * [`error-handling-and-dependency-injection`][ehdi] for application-specific detailed errors [anyhow]: https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs [ehdi]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling-and-dependency-injection/src/main.rs This also applies to extractors. If an extractor doesn't match the request the request will be rejected and a response will be returned without calling your handler. See [`extract`](crate::extract) to learn more about handling extractor failures. # Routing to fallible services You generally don't have to think about errors if you're only using async functions as handlers. However if you're embedding general `Service`s or applying middleware, which might produce errors you have to tell axum how to convert those errors into responses. ```rust use axum::{ Router, body::Body, http::{Request, Response, StatusCode}, error_handling::HandleError, }; async fn thing_that_might_fail() -> Result<(), anyhow::Error> { # Ok(()) // ... } // this service might fail with `anyhow::Error` let some_fallible_service = tower::service_fn(|_req| async { thing_that_might_fail().await?; Ok::<_, anyhow::Error>(Response::new(Body::empty())) }); let app = Router::new().route_service( "/", // we cannot route to `some_fallible_service` directly since it might fail. // we have to use `handle_error` which converts its errors into responses // and changes its error type from `anyhow::Error` to `Infallible`. HandleError::new(some_fallible_service, handle_anyhow_error), ); // handle errors by converting them into something that implements // `IntoResponse` async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", err), ) } # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Applying fallible middleware Similarly axum requires you to handle errors from middleware. That is done with [`HandleErrorLayer`]: ```rust use axum::{ Router, BoxError, routing::get, http::StatusCode, error_handling::HandleErrorLayer, }; use std::time::Duration; use tower::ServiceBuilder; let app = Router::new() .route("/", get(|| async {})) .layer( ServiceBuilder::new() // `timeout` will produce an error if the handler takes // too long so we must handle those .layer(HandleErrorLayer::new(handle_timeout_error)) .timeout(Duration::from_secs(30)) ); async fn handle_timeout_error(err: BoxError) -> (StatusCode, String) { if err.is::() { ( StatusCode::REQUEST_TIMEOUT, "Request took too long".to_string(), ) } else { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", err), ) } } # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Running extractors for error handling `HandleErrorLayer` also supports running extractors: ```rust use axum::{ Router, BoxError, routing::get, http::{StatusCode, Method, Uri}, error_handling::HandleErrorLayer, }; use std::time::Duration; use tower::ServiceBuilder; let app = Router::new() .route("/", get(|| async {})) .layer( ServiceBuilder::new() // `timeout` will produce an error if the handler takes // too long so we must handle those .layer(HandleErrorLayer::new(handle_timeout_error)) .timeout(Duration::from_secs(30)) ); async fn handle_timeout_error( // `Method` and `Uri` are extractors so they can be used here method: Method, uri: Uri, // the last argument must be the error itself err: BoxError, ) -> (StatusCode, String) { ( StatusCode::INTERNAL_SERVER_ERROR, format!("`{} {}` failed with {}", method, uri, err), ) } # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` [`tower::Service`]: `tower::Service` [`Infallible`]: std::convert::Infallible [`Response`]: crate::response::Response [`IntoResponse`]: crate::response::IntoResponse axum-axum-v0.6.20/axum/src/docs/extract.md000066400000000000000000000567041446276533600204300ustar00rootroot00000000000000Types and traits for extracting data from requests. # Table of contents - [Intro](#intro) - [Common extractors](#common-extractors) - [Applying multiple extractors](#applying-multiple-extractors) - [The order of extractors](#the-order-of-extractors) - [Optional extractors](#optional-extractors) - [Customizing extractor responses](#customizing-extractor-responses) - [Accessing inner errors](#accessing-inner-errors) - [Defining custom extractors](#defining-custom-extractors) - [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations) - [Request body limits](#request-body-limits) - [Request body extractors](#request-body-extractors) - [Running extractors from middleware](#running-extractors-from-middleware) - [Wrapping extractors](#wrapping-extractors) - [Logging rejections](#logging-rejections) # Intro A handler function is an async function that takes any number of "extractors" as arguments. An extractor is a type that implements [`FromRequest`](crate::extract::FromRequest) or [`FromRequestParts`](crate::extract::FromRequestParts). For example, [`Json`] is an extractor that consumes the request body and deserializes it as JSON into some target type: ```rust,no_run use axum::{ extract::Json, routing::post, handler::Handler, Router, }; use serde::Deserialize; #[derive(Deserialize)] struct CreateUser { email: String, password: String, } async fn create_user(Json(payload): Json) { // ... } let app = Router::new().route("/users", post(create_user)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Common extractors Some commonly used extractors are: ```rust,no_run use axum::{ extract::{Json, TypedHeader, Path, Extension, Query}, routing::post, headers::UserAgent, http::{Request, header::HeaderMap}, body::{Bytes, Body}, Router, }; use serde_json::Value; use std::collections::HashMap; // `Path` gives you the path parameters and deserializes them. See its docs for // more details async fn path(Path(user_id): Path) {} // `Query` gives you the query parameters and deserializes them. async fn query(Query(params): Query>) {} // `HeaderMap` gives you all the headers async fn headers(headers: HeaderMap) {} // `TypedHeader` can be used to extract a single header // note this requires you've enabled axum's `headers` feature async fn user_agent(TypedHeader(user_agent): TypedHeader) {} // `String` consumes the request body and ensures it is valid utf-8 async fn string(body: String) {} // `Bytes` gives you the raw request body async fn bytes(body: Bytes) {} // We've already seen `Json` for parsing the request body as json async fn json(Json(payload): Json) {} // `Request` gives you the whole request for maximum control async fn request(request: Request) {} // `Extension` extracts data from "request extensions" // This is commonly used to share state with handlers async fn extension(Extension(state): Extension) {} #[derive(Clone)] struct State { /* ... */ } let app = Router::new() .route("/path/:user_id", post(path)) .route("/query", post(query)) .route("/user_agent", post(user_agent)) .route("/headers", post(headers)) .route("/string", post(string)) .route("/bytes", post(bytes)) .route("/json", post(json)) .route("/request", post(request)) .route("/extension", post(extension)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Applying multiple extractors You can also apply multiple extractors: ```rust,no_run use axum::{ extract::{Path, Query}, routing::get, Router, }; use uuid::Uuid; use serde::Deserialize; let app = Router::new().route("/users/:id/things", get(get_user_things)); #[derive(Deserialize)] struct Pagination { page: usize, per_page: usize, } impl Default for Pagination { fn default() -> Self { Self { page: 1, per_page: 30 } } } async fn get_user_things( Path(user_id): Path, pagination: Option>, ) { let Query(pagination) = pagination.unwrap_or_default(); // ... } # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # The order of extractors Extractors always run in the order of the function parameters that is from left to right. The request body is an asynchronous stream that can only be consumed once. Therefore you can only have one extractor that consumes the request body. axum enforces this by requiring such extractors to be the _last_ argument your handler takes. For example ```rust use axum::{extract::State, http::{Method, HeaderMap}}; # # #[derive(Clone)] # struct AppState { # } async fn handler( // `Method` and `HeaderMap` don't consume the request body so they can // put anywhere in the argument list (but before `body`) method: Method, headers: HeaderMap, // `State` is also an extractor so it needs to be before `body` State(state): State, // `String` consumes the request body and thus must be the last extractor body: String, ) { // ... } # # let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` We get a compile error if `String` isn't the last extractor: ```rust,compile_fail use axum::http::Method; async fn handler( // this doesn't work since `String` must be the last argument body: String, method: Method, ) { // ... } # # let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` This also means you cannot consume the request body twice: ```rust,compile_fail use axum::Json; use serde::Deserialize; #[derive(Deserialize)] struct Payload {} async fn handler( // `String` and `Json` both consume the request body // so they cannot both be used string_body: String, json_body: Json, ) { // ... } # # let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` axum enforces this by requiring the last extractor implements [`FromRequest`] and all others implement [`FromRequestParts`]. # Optional extractors All extractors defined in axum will reject the request if it doesn't match. If you wish to make an extractor optional you can wrap it in `Option`: ```rust,no_run use axum::{ extract::Json, routing::post, Router, }; use serde_json::Value; async fn create_user(payload: Option>) { if let Some(payload) = payload { // We got a valid JSON payload } else { // Payload wasn't valid JSON } } let app = Router::new().route("/users", post(create_user)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` Wrapping extractors in `Result` makes them optional and gives you the reason the extraction failed: ```rust,no_run use axum::{ extract::{Json, rejection::JsonRejection}, routing::post, Router, }; use serde_json::Value; async fn create_user(payload: Result, JsonRejection>) { match payload { Ok(payload) => { // We got a valid JSON payload } Err(JsonRejection::MissingJsonContentType(_)) => { // Request didn't have `Content-Type: application/json` // header } Err(JsonRejection::JsonDataError(_)) => { // Couldn't deserialize the body into the target type } Err(JsonRejection::JsonSyntaxError(_)) => { // Syntax error in the body } Err(JsonRejection::BytesRejection(_)) => { // Failed to extract the request body } Err(_) => { // `JsonRejection` is marked `#[non_exhaustive]` so match must // include a catch-all case. } } } let app = Router::new().route("/users", post(create_user)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Customizing extractor responses If an extractor fails it will return a response with the error and your handler will not be called. To customize the error response you have a two options: 1. Use `Result` as your extractor like shown in ["Optional extractors"](#optional-extractors). This works well if you're only using the extractor in a single handler. 2. Create your own extractor that in its [`FromRequest`] implemention calls one of axum's built in extractors but returns a different response for rejections. See the [customize-extractor-error] example for more details. # Accessing inner errors axum's built-in extractors don't directly expose the inner error. This gives us more flexibility and allows us to change internal implementations without breaking the public API. For example that means while [`Json`] is implemented using [`serde_json`] it doesn't directly expose the [`serde_json::Error`] thats contained in [`JsonRejection::JsonDataError`]. However it is still possible to access via methods from [`std::error::Error`]: ```rust use std::error::Error; use axum::{ extract::{Json, rejection::JsonRejection}, response::IntoResponse, http::StatusCode, }; use serde_json::{json, Value}; async fn handler( result: Result, JsonRejection>, ) -> Result, (StatusCode, String)> { match result { // if the client sent valid JSON then we're good Ok(Json(payload)) => Ok(Json(json!({ "payload": payload }))), Err(err) => match err { JsonRejection::JsonDataError(err) => { Err(serde_json_error_response(err)) } JsonRejection::JsonSyntaxError(err) => { Err(serde_json_error_response(err)) } // handle other rejections from the `Json` extractor JsonRejection::MissingJsonContentType(_) => Err(( StatusCode::BAD_REQUEST, "Missing `Content-Type: application/json` header".to_string(), )), JsonRejection::BytesRejection(_) => Err(( StatusCode::INTERNAL_SERVER_ERROR, "Failed to buffer request body".to_string(), )), // we must provide a catch-all case since `JsonRejection` is marked // `#[non_exhaustive]` _ => Err(( StatusCode::INTERNAL_SERVER_ERROR, "Unknown error".to_string(), )), }, } } // attempt to extract the inner `serde_path_to_error::Error`, // if that succeeds we can provide a more specific error. // // `Json` uses `serde_path_to_error` so the error will be wrapped in `serde_path_to_error::Error`. fn serde_json_error_response(err: E) -> (StatusCode, String) where E: Error + 'static, { if let Some(err) = find_error_source::>(&err) { let serde_json_err = err.inner(); ( StatusCode::BAD_REQUEST, format!( "Invalid JSON at line {} column {}", serde_json_err.line(), serde_json_err.column() ), ) } else { (StatusCode::BAD_REQUEST, "Unknown error".to_string()) } } // attempt to downcast `err` into a `T` and if that fails recursively try and // downcast `err`'s source fn find_error_source<'a, T>(err: &'a (dyn Error + 'static)) -> Option<&'a T> where T: Error + 'static, { if let Some(err) = err.downcast_ref::() { Some(err) } else if let Some(source) = err.source() { find_error_source(source) } else { None } } # # #[tokio::main] # async fn main() { # use axum::extract::FromRequest; # # let req = axum::http::Request::builder() # .header("content-type", "application/json") # .body(axum::body::Body::from("{")) # .unwrap(); # # let err = match Json::::from_request(req, &()).await.unwrap_err() { # JsonRejection::JsonSyntaxError(err) => err, # _ => panic!(), # }; # # let (_, body) = serde_json_error_response(err); # assert_eq!(body, "Invalid JSON at line 1 column 1"); # } ``` Note that while this approach works it might break in the future if axum changes its implementation to use a different error type internally. Such changes might happen without major breaking versions. # Defining custom extractors You can also define your own extractors by implementing either [`FromRequestParts`] or [`FromRequest`]. ## Implementing `FromRequestParts` Implement `FromRequestParts` if your extractor doesn't need access to the request body: ```rust,no_run use axum::{ async_trait, extract::FromRequestParts, routing::get, Router, http::{ StatusCode, header::{HeaderValue, USER_AGENT}, request::Parts, }, }; struct ExtractUserAgent(HeaderValue); #[async_trait] impl FromRequestParts for ExtractUserAgent where S: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { if let Some(user_agent) = parts.headers.get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) } } } async fn handler(ExtractUserAgent(user_agent): ExtractUserAgent) { // ... } let app = Router::new().route("/foo", get(handler)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` ## Implementing `FromRequest` If your extractor needs to consume the request body you must implement [`FromRequest`] ```rust,no_run use axum::{ async_trait, extract::FromRequest, response::{Response, IntoResponse}, body::Bytes, routing::get, Router, http::{ StatusCode, header::{HeaderValue, USER_AGENT}, Request, }, }; struct ValidatedBody(Bytes); #[async_trait] impl FromRequest for ValidatedBody where Bytes: FromRequest, B: Send + 'static, S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let body = Bytes::from_request(req, state) .await .map_err(IntoResponse::into_response)?; // do validation... Ok(Self(body)) } } async fn handler(ValidatedBody(body): ValidatedBody) { // ... } let app = Router::new().route("/foo", get(handler)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` ## Cannot implement both `FromRequest` and `FromRequestParts` Note that you will make your extractor unusable by implementing both `FromRequest` and `FromRequestParts` directly for the same type, unless it is wrapping another extractor: ```rust,compile_fail use axum::{ Router, routing::get, extract::{FromRequest, FromRequestParts}, http::{Request, request::Parts}, async_trait, }; use std::convert::Infallible; // Some extractor that doesn't wrap another extractor struct MyExtractor; // `MyExtractor` implements both `FromRequest` #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, B: Send + 'static, { type Rejection = Infallible; async fn from_request(req: Request, state: &S) -> Result { // ... # todo!() } } // and `FromRequestParts` #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // ... # todo!() } } let app = Router::new().route( "/", // This fails when we go to actually use `MyExtractor` in a handler function. // This is due to a limit in Rust's type system. // // The workaround is to implement either `FromRequest` or `FromRequestParts` // but not both, if your extractor doesn't wrap another extractor. // // See "Wrapping extractors" for how to wrap other extractors. get(|_: MyExtractor| async {}), ); # let _: Router = app; ``` # Accessing other extractors in `FromRequest` or `FromRequestParts` implementations When defining custom extractors you often need to access another extractors in your implementation. ```rust use axum::{ async_trait, extract::{Extension, FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, http::{StatusCode, request::Parts}, response::{IntoResponse, Response}, routing::get, Router, }; #[derive(Clone)] struct State { // ... } struct AuthenticatedUser { // ... } #[async_trait] impl FromRequestParts for AuthenticatedUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // You can either call them directly... let TypedHeader(Authorization(token)) = TypedHeader::>::from_request_parts(parts, state) .await .map_err(|err| err.into_response())?; // ... or use `extract` / `extract_with_state` from `RequestExt` / `RequestPartsExt` use axum::RequestPartsExt; let Extension(state) = parts.extract::>() .await .map_err(|err| err.into_response())?; unimplemented!("actually perform the authorization") } } async fn handler(user: AuthenticatedUser) { // ... } let state = State { /* ... */ }; let app = Router::new().route("/", get(handler)).layer(Extension(state)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Request body limits For security reasons, [`Bytes`] will, by default, not accept bodies larger than 2MB. This also applies to extractors that uses [`Bytes`] internally such as `String`, [`Json`], and [`Form`]. For more details, including how to disable this limit, see [`DefaultBodyLimit`]. # Request body extractors Most of the time your request body type will be [`body::Body`] (a re-export of [`hyper::Body`]), which is directly supported by all extractors. However if you're applying a tower middleware that changes the request body type you might have to apply a different body type to some extractors: ```rust use std::{ task::{Context, Poll}, pin::Pin, }; use tower_http::map_request_body::MapRequestBodyLayer; use axum::{ extract::{self, BodyStream}, body::{Body, HttpBody}, routing::get, http::{header::HeaderMap, Request}, Router, }; struct MyBody(B); impl HttpBody for MyBody where B: HttpBody + Unpin, { type Data = B::Data; type Error = B::Error; fn poll_data( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { Pin::new(&mut self.0).poll_data(cx) } fn poll_trailers( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Pin::new(&mut self.0).poll_trailers(cx) } } let app = Router::new() .route( "/string", // `String` works directly with any body type get(|_: String| async {}) ) .route( "/body", // `extract::Body` defaults to `axum::body::Body` // but can be customized get(|_: extract::RawBody>| async {}) ) .route( "/body-stream", // same for `extract::BodyStream` get(|_: extract::BodyStream| async {}), ) .route( // and `Request<_>` "/request", get(|_: Request>| async {}) ) // middleware that changes the request body type .layer(MapRequestBodyLayer::new(MyBody)); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Running extractors from middleware Extractors can also be run from middleware: ```rust use axum::{ middleware::{self, Next}, extract::{TypedHeader, FromRequestParts}, http::{Request, StatusCode}, response::Response, headers::authorization::{Authorization, Bearer}, RequestPartsExt, Router, }; async fn auth_middleware( request: Request, next: Next, ) -> Result where B: Send, { // running extractors requires a `axum::http::request::Parts` let (mut parts, body) = request.into_parts(); // `TypedHeader>` extracts the auth token let auth: TypedHeader> = parts.extract() .await .map_err(|_| StatusCode::UNAUTHORIZED)?; if !token_is_valid(auth.token()) { return Err(StatusCode::UNAUTHORIZED); } // reconstruct the request let request = Request::from_parts(parts, body); Ok(next.run(request).await) } fn token_is_valid(token: &str) -> bool { // ... # false } let app = Router::new().layer(middleware::from_fn(auth_middleware)); # let _: Router<()> = app; ``` # Wrapping extractors If you want write an extractor that generically wraps another extractor (that may or may not consume the request body) you should implement both [`FromRequest`] and [`FromRequestParts`]: ```rust use axum::{ Router, routing::get, extract::{FromRequest, FromRequestParts}, http::{Request, HeaderMap, request::Parts}, async_trait, }; use std::time::{Instant, Duration}; // an extractor that wraps another and measures how long time it takes to run struct Timing { extractor: E, duration: Duration, } // we must implement both `FromRequestParts` #[async_trait] impl FromRequestParts for Timing where S: Send + Sync, T: FromRequestParts, { type Rejection = T::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let start = Instant::now(); let extractor = T::from_request_parts(parts, state).await?; let duration = start.elapsed(); Ok(Timing { extractor, duration, }) } } // and `FromRequest` #[async_trait] impl FromRequest for Timing where B: Send + 'static, S: Send + Sync, T: FromRequest, { type Rejection = T::Rejection; async fn from_request(req: Request, state: &S) -> Result { let start = Instant::now(); let extractor = T::from_request(req, state).await?; let duration = start.elapsed(); Ok(Timing { extractor, duration, }) } } async fn handler( // this uses the `FromRequestParts` impl _: Timing, // this uses the `FromRequest` impl _: Timing, ) {} # let _: axum::routing::MethodRouter = axum::routing::get(handler); ``` # Logging rejections All built-in extractors will log rejections for easier debugging. To see the logs, enable the `tracing` feature for axum and the `axum::rejection=trace` tracing target, for example with `RUST_LOG=info,axum::rejection=trace cargo run`. [`body::Body`]: crate::body::Body [`Bytes`]: crate::body::Bytes [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs [`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html [`Request`]: https://docs.rs/http/latest/http/struct.Request.html [`RequestParts::body_mut`]: crate::extract::RequestParts::body_mut [`JsonRejection::JsonDataError`]: rejection::JsonRejection::JsonDataError axum-axum-v0.6.20/axum/src/docs/handlers_intro.md000066400000000000000000000006101446276533600217520ustar00rootroot00000000000000In axum a "handler" is an async function that accepts zero or more ["extractors"](crate::extract) as arguments and returns something that can be converted [into a response](crate::response). Handlers are where your application logic lives and axum applications are built by routing between handlers. [`debug_handler`]: https://docs.rs/axum-macros/latest/axum_macros/attr.debug_handler.html axum-axum-v0.6.20/axum/src/docs/method_routing/000077500000000000000000000000001446276533600214475ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/docs/method_routing/fallback.md000066400000000000000000000032171446276533600235330ustar00rootroot00000000000000Add a fallback service to the router. This service will be called if no routes matches the incoming request. ```rust use axum::{ Router, routing::get, handler::Handler, response::IntoResponse, http::{StatusCode, Method, Uri}, }; let handler = get(|| async {}).fallback(fallback); let app = Router::new().route("/", handler); async fn fallback(method: Method, uri: Uri) -> (StatusCode, String) { (StatusCode::NOT_FOUND, format!("`{}` not allowed for {}", method, uri)) } # async { # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` ## When used with `MethodRouter::merge` Two routers that both have a fallback cannot be merged. Doing so results in a panic: ```rust,should_panic use axum::{ routing::{get, post}, handler::Handler, response::IntoResponse, http::{StatusCode, Uri}, }; let one = get(|| async {}).fallback(fallback_one); let two = post(|| async {}).fallback(fallback_two); let method_route = one.merge(two); async fn fallback_one() -> impl IntoResponse { /* ... */ } async fn fallback_two() -> impl IntoResponse { /* ... */ } # let app = axum::Router::new().route("/", method_route); # async { # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` ## Setting the `Allow` header By default `MethodRouter` will set the `Allow` header when returning `405 Method Not Allowed`. This is also done when the fallback is used unless the response generated by the fallback already sets the `Allow` header. This means if you use `fallback` to accept additional methods, you should make sure you set the `Allow` header correctly. axum-axum-v0.6.20/axum/src/docs/method_routing/layer.md000066400000000000000000000015461446276533600231130ustar00rootroot00000000000000Apply a [`tower::Layer`] to all routes in the router. This can be used to add additional processing to a request for a group of routes. Note that the middleware is only applied to existing routes. So you have to first add your routes (and / or fallback) and then call `layer` afterwards. Additional routes added after `layer` is called will not have the middleware added. Works similarly to [`Router::layer`](super::Router::layer). See that method for more details. # Example ```rust use axum::{routing::get, Router}; use tower::limit::ConcurrencyLimitLayer; async fn hander() {} let app = Router::new().route( "/", // All requests to `GET /` will be sent through `ConcurrencyLimitLayer` get(hander).layer(ConcurrencyLimitLayer::new(64)), ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` axum-axum-v0.6.20/axum/src/docs/method_routing/merge.md000066400000000000000000000007241446276533600230730ustar00rootroot00000000000000Merge two routers into one. This is useful for breaking routers into smaller pieces and combining them into one. ```rust use axum::{ routing::{get, post}, Router, }; let get = get(|| async {}); let post = post(|| async {}); let merged = get.merge(post); let app = Router::new().route("/", merged); // Our app now accepts // - GET / // - POST / # async { # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` axum-axum-v0.6.20/axum/src/docs/method_routing/route_layer.md000066400000000000000000000022221446276533600243210ustar00rootroot00000000000000Apply a [`tower::Layer`] to the router that will only run if the request matches a route. Note that the middleware is only applied to existing routes. So you have to first add your routes (and / or fallback) and then call `layer` afterwards. Additional routes added after `layer` is called will not have the middleware added. This works similarly to [`MethodRouter::layer`] except the middleware will only run if the request matches a route. This is useful for middleware that return early (such as authorization) which might otherwise convert a `405 Method Not Allowed` into a `401 Unauthorized`. # Example ```rust use axum::{ routing::get, Router, }; use tower_http::validate_request::ValidateRequestHeaderLayer; let app = Router::new().route( "/foo", get(|| async {}) .route_layer(ValidateRequestHeaderLayer::bearer("password")) ); // `GET /foo` with a valid token will receive `200 OK` // `GET /foo` with a invalid token will receive `401 Unauthorized` // `POST /FOO` with a invalid token will receive `405 Method Not Allowed` # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` axum-axum-v0.6.20/axum/src/docs/middleware.md000066400000000000000000000437621446276533600210730ustar00rootroot00000000000000# Table of contents - [Intro](#intro) - [Applying middleware](#applying-middleware) - [Commonly used middleware](#commonly-used-middleware) - [Ordering](#ordering) - [Writing middleware](#writing-middleware) - [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) - [Accessing state in middleware](#accessing-state-in-middleware) - [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers) - [Rewriting request URI in middleware](#rewriting-request-uri-in-middleware) # Intro axum is unique in that it doesn't have its own bespoke middleware system and instead integrates with [`tower`]. This means the ecosystem of [`tower`] and [`tower-http`] middleware all work with axum. While its not necessary to fully understand tower to write or use middleware with axum, having at least a basic understanding of tower's concepts is recommended. See [tower's guides][tower-guides] for a general introduction. Reading the documentation for [`tower::ServiceBuilder`] is also recommended. # Applying middleware axum allows you to add middleware just about anywhere - To entire routers with [`Router::layer`] and [`Router::route_layer`]. - To method routers with [`MethodRouter::layer`] and [`MethodRouter::route_layer`]. - To individual handlers with [`Handler::layer`]. ## Applying multiple middleware Its recommended to use [`tower::ServiceBuilder`] to apply multiple middleware at once, instead of calling `layer` (or `route_layer`) repeatedly: ```rust use axum::{ routing::get, Extension, Router, }; use tower_http::{trace::TraceLayer}; use tower::ServiceBuilder; async fn handler() {} #[derive(Clone)] struct State {} let app = Router::new() .route("/", get(handler)) .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(Extension(State {})) ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Commonly used middleware Some commonly used middleware are: - [`TraceLayer`](tower_http::trace) for high level tracing/logging. - [`CorsLayer`](tower_http::cors) for handling CORS. - [`CompressionLayer`](tower_http::compression) for automatic compression of responses. - [`RequestIdLayer`](tower_http::request_id) and [`PropagateRequestIdLayer`](tower_http::request_id) set and propagate request ids. - [`TimeoutLayer`](tower::timeout::TimeoutLayer) for timeouts. Note this requires using [`HandleErrorLayer`](crate::error_handling::HandleErrorLayer) to convert timeouts to responses. # Ordering When you add middleware with [`Router::layer`] (or similar) all previously added routes will be wrapped in the middleware. Generally speaking, this results in middleware being executed from bottom to top. So if you do this: ```rust use axum::{routing::get, Router}; async fn handler() {} # let layer_one = axum::Extension(()); # let layer_two = axum::Extension(()); # let layer_three = axum::Extension(()); # let app = Router::new() .route("/", get(handler)) .layer(layer_one) .layer(layer_two) .layer(layer_three); # let _: Router<(), axum::body::Body> = app; ``` Think of the middleware as being layered like an onion where each new layer wraps all previous layers: ```not_rust requests | v +----- layer_three -----+ | +---- layer_two ----+ | | | +-- layer_one --+ | | | | | | | | | | | handler | | | | | | | | | | | +-- layer_one --+ | | | +---- layer_two ----+ | +----- layer_three -----+ | v responses ``` That is: - First `layer_three` receives the request - It then does its thing and passes the request onto `layer_two` - Which passes the request onto `layer_one` - Which passes the request onto `handler` where a response is produced - That response is then passed to `layer_one` - Then to `layer_two` - And finally to `layer_three` where it's returned out of your app It's a little more complicated in practice because any middleware is free to return early and not call the next layer, for example if a request cannot be authorized, but its a useful mental model to have. As previously mentioned its recommended to add multiple middleware using `tower::ServiceBuilder`, however this impacts ordering: ```rust use tower::ServiceBuilder; use axum::{routing::get, Router}; async fn handler() {} # let layer_one = axum::Extension(()); # let layer_two = axum::Extension(()); # let layer_three = axum::Extension(()); # let app = Router::new() .route("/", get(handler)) .layer( ServiceBuilder::new() .layer(layer_one) .layer(layer_two) .layer(layer_three), ); # let _: Router<(), axum::body::Body> = app; ``` `ServiceBuilder` works by composing all layers into one such that they run top to bottom. So with the previous code `layer_one` would receive the request first, then `layer_two`, then `layer_three`, then `handler`, and then the response would bubble back up through `layer_three`, then `layer_two`, and finally `layer_one`. Executing middleware top to bottom is generally easier to understand and follow mentally which is one of the reasons `ServiceBuilder` is recommended. # Writing middleware axum offers many ways of writing middleware, at different levels of abstraction and with different pros and cons. ## `axum::middleware::from_fn` Use [`axum::middleware::from_fn`] to write your middleware when: - You're not comfortable with implementing your own futures and would rather use the familiar `async`/`await` syntax. - You don't intend to publish your middleware as a crate for others to use. Middleware written like this are only compatible with axum. ## `axum::middleware::from_extractor` Use [`axum::middleware::from_extractor`] to write your middleware when: - You have a type that you sometimes want to use as an extractor and sometimes as a middleware. If you only need your type as a middleware prefer [`middleware::from_fn`]. ## tower's combinators tower has several utility combinators that can be used to perform simple modifications to requests or responses. The most commonly used ones are - [`ServiceBuilder::map_request`] - [`ServiceBuilder::map_response`] - [`ServiceBuilder::then`] - [`ServiceBuilder::and_then`] You should use these when - You want to perform a small ad hoc operation, such as adding a header. - You don't intend to publish your middleware as a crate for others to use. ## `tower::Service` and `Pin>` For maximum control (and a more low level API) you can write you own middleware by implementing [`tower::Service`]: Use [`tower::Service`] with `Pin>` to write your middleware when: - Your middleware needs to be configurable for example via builder methods on your [`tower::Layer`] such as [`tower_http::trace::TraceLayer`]. - You do intend to publish your middleware as a crate for others to use. - You're not comfortable with implementing your own futures. A decent template for such a middleware could be: ```rust use axum::{ response::Response, body::Body, http::Request, }; use futures_util::future::BoxFuture; use tower::{Service, Layer}; use std::task::{Context, Poll}; #[derive(Clone)] struct MyLayer; impl Layer for MyLayer { type Service = MyMiddleware; fn layer(&self, inner: S) -> Self::Service { MyMiddleware { inner } } } #[derive(Clone)] struct MyMiddleware { inner: S, } impl Service> for MyMiddleware where S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; // `BoxFuture` is a type alias for `Pin>` type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { let future = self.inner.call(request); Box::pin(async move { let response: Response = future.await?; Ok(response) }) } } ``` ## `tower::Service` and custom futures If you're comfortable implementing your own futures (or want to learn it) and need as much control as possible then using `tower::Service` without boxed futures is the way to go. Use [`tower::Service`] with manual futures to write your middleware when: - You want your middleware to have the lowest possible overhead. - Your middleware needs to be configurable for example via builder methods on your [`tower::Layer`] such as [`tower_http::trace::TraceLayer`]. - You do intend to publish your middleware as a crate for others to use, perhaps as part of tower-http. - You're comfortable with implementing your own futures, or want to learn how the lower levels of async Rust works. tower's ["Building a middleware from scratch"][tower-from-scratch-guide] guide is a good place to learn how to do this. # Error handling for middleware axum's error handling model requires handlers to always return a response. However middleware is one possible way to introduce errors into an application. If hyper receives an error the connection will be closed without sending a response. Thus axum requires those errors to be handled gracefully: ```rust use axum::{ routing::get, error_handling::HandleErrorLayer, http::StatusCode, BoxError, Router, }; use tower::{ServiceBuilder, timeout::TimeoutLayer}; use std::time::Duration; async fn handler() {} let app = Router::new() .route("/", get(handler)) .layer( ServiceBuilder::new() // this middleware goes above `TimeoutLayer` because it will receive // errors returned by `TimeoutLayer` .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) .layer(TimeoutLayer::new(Duration::from_secs(10))) ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` See [`error_handling`](crate::error_handling) for more details on axum's error handling model. # Routing to services/middleware and backpressure Generally routing to one of multiple services and backpressure doesn't mix well. Ideally you would want ensure a service is ready to receive a request before calling it. However, in order to know which service to call, you need the request... One approach is to not consider the router service itself ready until all destination services are ready. That is the approach used by [`tower::steer::Steer`]. Another approach is to always consider all services ready (always return `Poll::Ready(Ok(()))`) from `Service::poll_ready` and then actually drive readiness inside the response future returned by `Service::call`. This works well when your services don't care about backpressure and are always ready anyway. axum expects that all services used in your app wont care about backpressure and so it uses the latter strategy. However that means you should avoid routing to a service (or using a middleware) that _does_ care about backpressure. At the very least you should [load shed] so requests are dropped quickly and don't keep piling up. It also means that if `poll_ready` returns an error then that error will be returned in the response future from `call` and _not_ from `poll_ready`. In that case, the underlying service will _not_ be discarded and will continue to be used for future requests. Services that expect to be discarded if `poll_ready` fails should _not_ be used with axum. One possible approach is to only apply backpressure sensitive middleware around your entire app. This is possible because axum applications are themselves services: ```rust use axum::{ routing::get, Router, }; use tower::ServiceBuilder; # let some_backpressure_sensitive_middleware = # tower::layer::util::Identity::new(); async fn handler() { /* ... */ } let app = Router::new().route("/", get(handler)); let app = ServiceBuilder::new() .layer(some_backpressure_sensitive_middleware) .service(app); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` However when applying middleware around your whole application in this way you have to take care that errors are still being handled with appropriately. Also note that handlers created from async functions don't care about backpressure and are always ready. So if you're not using any Tower middleware you don't have to worry about any of this. # Accessing state in middleware How to make state available to middleware depends on how the middleware is written. ## Accessing state in `axum::middleware::from_fn` Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state). ## Accessing state in custom `tower::Layer`s ```rust use axum::{ Router, routing::get, middleware::{self, Next}, response::Response, extract::State, http::Request, }; use tower::{Layer, Service}; use std::task::{Context, Poll}; #[derive(Clone)] struct AppState {} #[derive(Clone)] struct MyLayer { state: AppState, } impl Layer for MyLayer { type Service = MyService; fn layer(&self, inner: S) -> Self::Service { MyService { inner, state: self.state.clone(), } } } #[derive(Clone)] struct MyService { inner: S, state: AppState, } impl Service> for MyService where S: Service>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { // Do something with `self.state`. // // See `axum::RequestExt` for how to run extractors directly from // a `Request`. self.inner.call(req) } } async fn handler(_: State) {} let state = AppState {}; let app = Router::new() .route("/", get(handler)) .layer(MyLayer { state: state.clone() }) .with_state(state); # let _: axum::Router = app; ``` # Passing state from middleware to handlers State can be passed from middleware to handlers using [request extensions]: ```rust use axum::{ Router, http::{Request, StatusCode}, routing::get, response::{IntoResponse, Response}, middleware::{self, Next}, extract::Extension, }; #[derive(Clone)] struct CurrentUser { /* ... */ } async fn auth(mut req: Request, next: Next) -> Result { let auth_header = req.headers() .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); let auth_header = if let Some(auth_header) = auth_header { auth_header } else { return Err(StatusCode::UNAUTHORIZED); }; if let Some(current_user) = authorize_current_user(auth_header).await { // insert the current user into a request extension so the handler can // extract it req.extensions_mut().insert(current_user); Ok(next.run(req).await) } else { Err(StatusCode::UNAUTHORIZED) } } async fn authorize_current_user(auth_token: &str) -> Option { // ... # unimplemented!() } async fn handler( // extract the current user, set by the middleware Extension(current_user): Extension, ) { // ... } let app = Router::new() .route("/", get(handler)) .route_layer(middleware::from_fn(auth)); # let _: Router<()> = app; ``` [Response extensions] can also be used but note that request extensions are not automatically moved to response extensions. You need to manually do that for the extensions you need. # Rewriting request URI in middleware Middleware added with [`Router::layer`] will run after routing. That means it cannot be used to run middleware that rewrites the request URI. By the time the middleware runs the routing is already done. The workaround is to wrap the middleware around the entire `Router` (this works because `Router` implements [`Service`]): ```rust use tower::Layer; use axum::{ Router, ServiceExt, // for `into_make_service` response::Response, middleware::Next, http::Request, }; async fn rewrite_request_uri(req: Request, next: Next) -> Response { // ... # next.run(req).await } // this can be any `tower::Layer` let middleware = axum::middleware::from_fn(rewrite_request_uri); let app = Router::new(); // apply the layer around the whole `Router` // this way the middleware will run before `Router` receives the request let app_with_middleware = middleware.layer(app); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(app_with_middleware.into_make_service()) .await .unwrap(); # }; ``` [`tower`]: https://crates.io/crates/tower [`tower-http`]: https://crates.io/crates/tower-http [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides [`axum::middleware::from_fn`]: fn@crate::middleware::from_fn [`middleware::from_fn`]: fn@crate::middleware::from_fn [tower-from-scratch-guide]: https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md [`ServiceBuilder::map_request`]: tower::ServiceBuilder::map_request [`ServiceBuilder::map_response`]: tower::ServiceBuilder::map_response [`ServiceBuilder::then`]: tower::ServiceBuilder::then [`ServiceBuilder::and_then`]: tower::ServiceBuilder::and_then [`axum::middleware::from_extractor`]: fn@crate::middleware::from_extractor [`Handler::layer`]: crate::handler::Handler::layer [`Router::layer`]: crate::routing::Router::layer [`MethodRouter::layer`]: crate::routing::MethodRouter::layer [`Router::route_layer`]: crate::routing::Router::route_layer [`MethodRouter::route_layer`]: crate::routing::MethodRouter::route_layer [request extensions]: https://docs.rs/http/latest/http/request/struct.Request.html#method.extensions [Response extensions]: https://docs.rs/http/latest/http/response/struct.Response.html#method.extensions [`State`]: crate::extract::State [`Service`]: tower::Service axum-axum-v0.6.20/axum/src/docs/response.md000066400000000000000000000206011446276533600205770ustar00rootroot00000000000000Types and traits for generating responses. # Table of contents - [Building responses](#building-responses) - [Returning different response types](#returning-different-response-types) - [Regarding `impl IntoResponse`](#regarding-impl-intoresponse) # Building responses Anything that implements [`IntoResponse`] can be returned from a handler. axum provides implementations for common types: ```rust,no_run use axum::{ Json, response::{Html, IntoResponse}, http::{StatusCode, Uri, header::{self, HeaderMap, HeaderName}}, }; // `()` gives an empty response async fn empty() {} // String will get a `text/plain; charset=utf-8` content-type async fn plain_text(uri: Uri) -> String { format!("Hi from {}", uri.path()) } // Bytes will get a `application/octet-stream` content-type async fn bytes() -> Vec { vec![1, 2, 3, 4] } // `Json` will get a `application/json` content-type and work with anything that // implements `serde::Serialize` async fn json() -> Json> { Json(vec!["foo".to_owned(), "bar".to_owned()]) } // `Html` will get a `text/html` content-type async fn html() -> Html<&'static str> { Html("

Hello, World!

") } // `StatusCode` gives an empty response with that status code async fn status() -> StatusCode { StatusCode::NOT_FOUND } // `HeaderMap` gives an empty response with some headers async fn headers() -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert(header::SERVER, "axum".parse().unwrap()); headers } // An array of tuples also gives headers async fn array_headers() -> [(HeaderName, &'static str); 2] { [ (header::SERVER, "axum"), (header::CONTENT_TYPE, "text/plain") ] } // Use `impl IntoResponse` to avoid writing the whole type async fn impl_trait() -> impl IntoResponse { [ (header::SERVER, "axum"), (header::CONTENT_TYPE, "text/plain") ] } ``` Additionally you can return tuples to build more complex responses from individual parts. ```rust,no_run use axum::{ Json, response::IntoResponse, http::{StatusCode, HeaderMap, Uri, header}, extract::Extension, }; // `(StatusCode, impl IntoResponse)` will override the status code of the response async fn with_status(uri: Uri) -> (StatusCode, String) { (StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path())) } // Use `impl IntoResponse` to avoid having to type the whole type async fn impl_trait(uri: Uri) -> impl IntoResponse { (StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path())) } // `(HeaderMap, impl IntoResponse)` to add additional headers async fn with_headers() -> impl IntoResponse { let mut headers = HeaderMap::new(); headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap()); (headers, "foo") } // Or an array of tuples to more easily build the headers async fn with_array_headers() -> impl IntoResponse { ([(header::CONTENT_TYPE, "text/plain")], "foo") } // Use string keys for custom headers async fn with_array_headers_custom() -> impl IntoResponse { ([("x-custom", "custom")], "foo") } // `(StatusCode, headers, impl IntoResponse)` to set status and add headers // `headers` can be either a `HeaderMap` or an array of tuples async fn with_status_and_array_headers() -> impl IntoResponse { ( StatusCode::NOT_FOUND, [(header::CONTENT_TYPE, "text/plain")], "foo", ) } // `(Extension<_>, impl IntoResponse)` to set response extensions async fn with_status_extensions() -> impl IntoResponse { ( Extension(Foo("foo")), "foo", ) } struct Foo(&'static str); // Or mix and match all the things async fn all_the_things(uri: Uri) -> impl IntoResponse { let mut header_map = HeaderMap::new(); if uri.path() == "/" { header_map.insert(header::SERVER, "axum".parse().unwrap()); } ( // set status code StatusCode::NOT_FOUND, // headers with an array [("x-custom", "custom")], // some extensions Extension(Foo("foo")), Extension(Foo("bar")), // more headers, built dynamically header_map, // and finally the body "foo", ) } ``` In general you can return tuples like: - `(StatusCode, impl IntoResponse)` - `(Parts, impl IntoResponse)` - `(Response<()>, impl IntoResponse)` - `(T1, .., Tn, impl IntoResponse)` where `T1` to `Tn` all implement [`IntoResponseParts`]. - `(StatusCode, T1, .., Tn, impl IntoResponse)` where `T1` to `Tn` all implement [`IntoResponseParts`]. - `(Parts, T1, .., Tn, impl IntoResponse)` where `T1` to `Tn` all implement [`IntoResponseParts`]. - `(Response<()>, T1, .., Tn, impl IntoResponse)` where `T1` to `Tn` all implement [`IntoResponseParts`]. This means you cannot accidentally override the status or body as [`IntoResponseParts`] only allows setting headers and extensions. Use [`Response`](crate::response::Response) for more low level control: ```rust,no_run use axum::{ Json, response::{IntoResponse, Response}, body::{Full, Bytes}, http::StatusCode, }; async fn response() -> Response> { Response::builder() .status(StatusCode::NOT_FOUND) .header("x-foo", "custom header") .body(Full::from("not found")) .unwrap() } ``` # Returning different response types If you need to return multiple response types, and `Result` isn't appropriate, you can call `.into_response()` to turn things into `axum::response::Response`: ```rust use axum::{ response::{IntoResponse, Redirect, Response}, http::StatusCode, }; async fn handle() -> Response { if something() { "All good!".into_response() } else if something_else() { ( StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong...", ).into_response() } else { Redirect::to("/").into_response() } } fn something() -> bool { // ... # true } fn something_else() -> bool { // ... # true } ``` # Regarding `impl IntoResponse` You can use `impl IntoResponse` as the return type from handlers to avoid typing large types. For example ```rust use axum::http::StatusCode; async fn handler() -> (StatusCode, [(&'static str, &'static str); 1], &'static str) { (StatusCode::OK, [("x-foo", "bar")], "Hello, World!") } ``` Becomes easier using `impl IntoResponse`: ```rust use axum::{http::StatusCode, response::IntoResponse}; async fn impl_into_response() -> impl IntoResponse { (StatusCode::OK, [("x-foo", "bar")], "Hello, World!") } ``` However `impl IntoResponse` has a few limitations. Firstly it can only be used to return a single type: ```rust,compile_fail use axum::{http::StatusCode, response::IntoResponse}; async fn handler() -> impl IntoResponse { if check_something() { StatusCode::NOT_FOUND } else { "Hello, World!" } } fn check_something() -> bool { # false // ... } ``` This function returns either a `StatusCode` or a `&'static str` which `impl Trait` doesn't allow. Secondly `impl IntoResponse` can lead to type inference issues when used with `Result` and `?`: ```rust,compile_fail use axum::{http::StatusCode, response::IntoResponse}; async fn handler() -> impl IntoResponse { create_thing()?; Ok(StatusCode::CREATED) } fn create_thing() -> Result<(), StatusCode> { # Ok(()) // ... } ``` This is because `?` supports using the [`From`] trait to convert to a different error type but it doesn't know which type to convert to, because we only specified `impl IntoResponse` as the return type. `Result` doesn't always work either: ```rust,compile_fail use axum::{http::StatusCode, response::IntoResponse}; async fn handler() -> Result { create_thing()?; Ok(StatusCode::CREATED) } fn create_thing() -> Result<(), StatusCode> { # Ok(()) // ... } ``` The solution is to use a concrete error type, such as `Result`: ```rust use axum::{http::StatusCode, response::IntoResponse}; async fn handler() -> Result { create_thing()?; Ok(StatusCode::CREATED) } fn create_thing() -> Result<(), StatusCode> { # Ok(()) // ... } ``` Because of this it is generally not recommended to use `impl IntoResponse` unless you're familiar with the details of how `impl Trait` works. [`IntoResponse`]: crate::response::IntoResponse [`IntoResponseParts`]: crate::response::IntoResponseParts [`StatusCode`]: http::StatusCode axum-axum-v0.6.20/axum/src/docs/routing/000077500000000000000000000000001446276533600201075ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/docs/routing/fallback.md000066400000000000000000000026371446276533600222000ustar00rootroot00000000000000Add a fallback [`Handler`] to the router. This service will be called if no routes matches the incoming request. ```rust use axum::{ Router, routing::get, handler::Handler, response::IntoResponse, http::{StatusCode, Uri}, }; let app = Router::new() .route("/foo", get(|| async { /* ... */ })) .fallback(fallback); async fn fallback(uri: Uri) -> (StatusCode, String) { (StatusCode::NOT_FOUND, format!("No route for {}", uri)) } # async { # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` Fallbacks only apply to routes that aren't matched by anything in the router. If a handler is matched by a request but returns 404 the fallback is not called. # Handling all requests without other routes Using `Router::new().fallback(...)` to accept all request regardless of path or method, if you don't have other routes, isn't optimal: ```rust use axum::Router; async fn handler() {} let app = Router::new().fallback(handler); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(app.into_make_service()) .await .unwrap(); # }; ``` Running the handler directly is faster since it avoids the overhead of routing: ```rust use axum::handler::HandlerWithoutStateExt; async fn handler() {} # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(handler.into_make_service()) .await .unwrap(); # }; ``` axum-axum-v0.6.20/axum/src/docs/routing/into_make_service_with_connect_info.md000066400000000000000000000035661446276533600277100ustar00rootroot00000000000000Convert this router into a [`MakeService`], that will store `C`'s associated `ConnectInfo` in a request extension such that [`ConnectInfo`] can extract it. This enables extracting things like the client's remote address. Extracting [`std::net::SocketAddr`] is supported out of the box: ```rust use axum::{ extract::ConnectInfo, routing::get, Router, }; use std::net::SocketAddr; let app = Router::new().route("/", get(handler)); async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("Hello {}", addr) } # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve( app.into_make_service_with_connect_info::() ) .await .expect("server failed"); # }; ``` You can implement custom a [`Connected`] like so: ```rust use axum::{ extract::connect_info::{ConnectInfo, Connected}, routing::get, Router, }; use hyper::server::conn::AddrStream; let app = Router::new().route("/", get(handler)); async fn handler( ConnectInfo(my_connect_info): ConnectInfo, ) -> String { format!("Hello {:?}", my_connect_info) } #[derive(Clone, Debug)] struct MyConnectInfo { // ... } impl Connected<&AddrStream> for MyConnectInfo { fn connect_info(target: &AddrStream) -> Self { MyConnectInfo { // ... } } } # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve( app.into_make_service_with_connect_info::() ) .await .expect("server failed"); # }; ``` See the [unix domain socket example][uds] for an example of how to use this to collect UDS connection info. [`MakeService`]: tower::make::MakeService [`Connected`]: crate::extract::connect_info::Connected [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix-domain-socket/src/main.rs axum-axum-v0.6.20/axum/src/docs/routing/layer.md000066400000000000000000000037511446276533600215530ustar00rootroot00000000000000Apply a [`tower::Layer`] to all routes in the router. This can be used to add additional processing to a request for a group of routes. Note that the middleware is only applied to existing routes. So you have to first add your routes (and / or fallback) and then call `layer` afterwards. Additional routes added after `layer` is called will not have the middleware added. If you want to add middleware to a single handler you can either use [`MethodRouter::layer`] or [`Handler::layer`]. # Example Adding the [`tower_http::trace::TraceLayer`]: ```rust use axum::{routing::get, Router}; use tower_http::trace::TraceLayer; let app = Router::new() .route("/foo", get(|| async {})) .route("/bar", get(|| async {})) .layer(TraceLayer::new_for_http()); # let _: Router = app; ``` If you need to write your own middleware see ["Writing middleware"](crate::middleware#writing-middleware) for the different options. If you only want middleware on some routes you can use [`Router::merge`]: ```rust use axum::{routing::get, Router}; use tower_http::{trace::TraceLayer, compression::CompressionLayer}; let with_tracing = Router::new() .route("/foo", get(|| async {})) .layer(TraceLayer::new_for_http()); let with_compression = Router::new() .route("/bar", get(|| async {})) .layer(CompressionLayer::new()); // Merge everything into one `Router` let app = Router::new() .merge(with_tracing) .merge(with_compression); # let _: Router = app; ``` # Multiple middleware It's recommended to use [`tower::ServiceBuilder`] when applying multiple middleware. See [`middleware`](crate::middleware) for more details. # Runs after routing Middleware added with this method will run _after_ routing and thus cannot be used to rewrite the request URI. See ["Rewriting request URI in middleware"](crate::middleware#rewriting-request-uri-in-middleware) for more details and a workaround. # Error handling See [`middleware`](crate::middleware) for details on how error handling impacts middleware. axum-axum-v0.6.20/axum/src/docs/routing/merge.md000066400000000000000000000032521446276533600215320ustar00rootroot00000000000000Merge two routers into one. This is useful for breaking apps into smaller pieces and combining them into one. ```rust use axum::{ routing::get, Router, }; # # async fn users_list() {} # async fn users_show() {} # async fn teams_list() {} // define some routes separately let user_routes = Router::new() .route("/users", get(users_list)) .route("/users/:id", get(users_show)); let team_routes = Router::new() .route("/teams", get(teams_list)); // combine them into one let app = Router::new() .merge(user_routes) .merge(team_routes); // could also do `user_routes.merge(team_routes)` // Our app now accepts // - GET /users // - GET /users/:id // - GET /teams # async { # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Merging routers with state When combining [`Router`]s with this method, each [`Router`] must have the same type of state. If your routers have different types you can use [`Router::with_state`] to provide the state and make the types match: ```rust use axum::{ Router, routing::get, extract::State, }; #[derive(Clone)] struct InnerState {} #[derive(Clone)] struct OuterState {} async fn inner_handler(state: State) {} let inner_router = Router::new() .route("/bar", get(inner_handler)) .with_state(InnerState {}); async fn outer_handler(state: State) {} let app = Router::new() .route("/", get(outer_handler)) .merge(inner_router) .with_state(OuterState {}); # let _: axum::Router = app; ``` # Panics - If two routers that each have a [fallback](Router::fallback) are merged. This is because `Router` only allows a single fallback. axum-axum-v0.6.20/axum/src/docs/routing/nest.md000066400000000000000000000112701446276533600214030ustar00rootroot00000000000000Nest a [`Router`] at some path. This allows you to break your application into smaller pieces and compose them together. # Example ```rust use axum::{ routing::{get, post}, Router, }; let user_routes = Router::new().route("/:id", get(|| async {})); let team_routes = Router::new().route("/", post(|| async {})); let api_routes = Router::new() .nest("/users", user_routes) .nest("/teams", team_routes); let app = Router::new().nest("/api", api_routes); // Our app now accepts // - GET /api/users/:id // - POST /api/teams # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # How the URI changes Note that nested routes will not see the original request URI but instead have the matched prefix stripped. This is necessary for services like static file serving to work. Use [`OriginalUri`] if you need the original request URI. # Captures from outer routes Take care when using `nest` together with dynamic routes as nesting also captures from the outer routes: ```rust use axum::{ extract::Path, routing::get, Router, }; use std::collections::HashMap; async fn users_get(Path(params): Path>) { // Both `version` and `id` were captured even though `users_api` only // explicitly captures `id`. let version = params.get("version"); let id = params.get("id"); } let users_api = Router::new().route("/users/:id", get(users_get)); let app = Router::new().nest("/:version/api", users_api); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Differences from wildcard routes Nested routes are similar to wildcard routes. The difference is that wildcard routes still see the whole URI whereas nested routes will have the prefix stripped: ```rust use axum::{routing::get, http::Uri, Router}; let nested_router = Router::new() .route("/", get(|uri: Uri| async { // `uri` will _not_ contain `/bar` })); let app = Router::new() .route("/foo/*rest", get(|uri: Uri| async { // `uri` will contain `/foo` })) .nest("/bar", nested_router); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Fallbacks If a nested router doesn't have its own fallback then it will inherit the fallback from the outer router: ```rust use axum::{routing::get, http::StatusCode, handler::Handler, Router}; async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } let api_routes = Router::new().route("/users", get(|| async {})); let app = Router::new() .nest("/api", api_routes) .fallback(fallback); # let _: Router = app; ``` Here requests like `GET /api/not-found` will go into `api_routes` but because it doesn't have a matching route and doesn't have its own fallback it will call the fallback from the outer router, i.e. the `fallback` function. If the nested router has its own fallback then the outer fallback will not be inherited: ```rust use axum::{ routing::get, http::StatusCode, handler::Handler, Json, Router, }; async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } async fn api_fallback() -> (StatusCode, Json) { ( StatusCode::NOT_FOUND, Json(serde_json::json!({ "status": "Not Found" })), ) } let api_routes = Router::new() .route("/users", get(|| async {})) .fallback(api_fallback); let app = Router::new() .nest("/api", api_routes) .fallback(fallback); # let _: Router = app; ``` Here requests like `GET /api/not-found` will go to `api_fallback`. # Nesting routers with state When combining [`Router`]s with this method, each [`Router`] must have the same type of state. If your routers have different types you can use [`Router::with_state`] to provide the state and make the types match: ```rust use axum::{ Router, routing::get, extract::State, }; #[derive(Clone)] struct InnerState {} #[derive(Clone)] struct OuterState {} async fn inner_handler(state: State) {} let inner_router = Router::new() .route("/bar", get(inner_handler)) .with_state(InnerState {}); async fn outer_handler(state: State) {} let app = Router::new() .route("/", get(outer_handler)) .nest("/foo", inner_router) .with_state(OuterState {}); # let _: axum::Router = app; ``` Note that the inner router will still inherit the fallback from the outer router. # Panics - If the route overlaps with another route. See [`Router::route`] for more details. - If the route contains a wildcard (`*`). - If `path` is empty. [`OriginalUri`]: crate::extract::OriginalUri [fallbacks]: Router::fallback axum-axum-v0.6.20/axum/src/docs/routing/route.md000066400000000000000000000071621446276533600215750ustar00rootroot00000000000000Add another route to the router. `path` is a string of path segments separated by `/`. Each segment can be either static, a capture, or a wildcard. `method_router` is the [`MethodRouter`] that should receive the request if the path matches `path`. `method_router` will commonly be a handler wrapped in a method router like [`get`](crate::routing::get). See [`handler`](crate::handler) for more details on handlers. # Static paths Examples: - `/` - `/foo` - `/users/123` If the incoming request matches the path exactly the corresponding service will be called. # Captures Paths can contain segments like `/:key` which matches any single segment and will store the value captured at `key`. Examples: - `/:key` - `/users/:id` - `/users/:id/tweets` Captures can be extracted using [`Path`](crate::extract::Path). See its documentation for more details. It is not possible to create segments that only match some types like numbers or regular expression. You must handle that manually in your handlers. [`MatchedPath`](crate::extract::MatchedPath) can be used to extract the matched path rather than the actual path. # Wildcards Paths can end in `/*key` which matches all segments and will store the segments captured at `key`. Examples: - `/*key` - `/assets/*path` - `/:id/:repo/*tree` Note that `/*key` doesn't match empty segments. Thus: - `/*key` doesn't match `/` but does match `/a`, `/a/`, etc. - `/x/*key` doesn't match `/x` or `/x/` but does match `/x/a`, `/x/a/`, etc. Wildcard captures can also be extracted using [`Path`](crate::extract::Path). Note that the leading slash is not included, i.e. for the route `/foo/*rest` and the path `/foo/bar/baz` the value of `rest` will be `bar/baz`. # Accepting multiple methods To accept multiple methods for the same route you can add all handlers at the same time: ```rust use axum::{Router, routing::{get, delete}, extract::Path}; let app = Router::new().route( "/", get(get_root).post(post_root).delete(delete_root), ); async fn get_root() {} async fn post_root() {} async fn delete_root() {} # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` Or you can add them one by one: ```rust # use axum::Router; # use axum::routing::{get, post, delete}; # let app = Router::new() .route("/", get(get_root)) .route("/", post(post_root)) .route("/", delete(delete_root)); # # let _: Router = app; # async fn get_root() {} # async fn post_root() {} # async fn delete_root() {} ``` # More examples ```rust use axum::{Router, routing::{get, delete}, extract::Path}; let app = Router::new() .route("/", get(root)) .route("/users", get(list_users).post(create_user)) .route("/users/:id", get(show_user)) .route("/api/:version/users/:id/action", delete(do_users_action)) .route("/assets/*path", get(serve_asset)); async fn root() {} async fn list_users() {} async fn create_user() {} async fn show_user(Path(id): Path) {} async fn do_users_action(Path((version, id)): Path<(String, u64)>) {} async fn serve_asset(Path(path): Path) {} # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` # Panics Panics if the route overlaps with another route: ```rust,should_panic use axum::{routing::get, Router}; let app = Router::new() .route("/", get(|| async {})) .route("/", get(|| async {})); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` The static route `/foo` and the dynamic route `/:key` are not considered to overlap and `/foo` will take precedence. Also panics if `path` is empty. axum-axum-v0.6.20/axum/src/docs/routing/route_layer.md000066400000000000000000000021661446276533600227700ustar00rootroot00000000000000Apply a [`tower::Layer`] to the router that will only run if the request matches a route. Note that the middleware is only applied to existing routes. So you have to first add your routes (and / or fallback) and then call `layer` afterwards. Additional routes added after `layer` is called will not have the middleware added. This works similarly to [`Router::layer`] except the middleware will only run if the request matches a route. This is useful for middleware that return early (such as authorization) which might otherwise convert a `404 Not Found` into a `401 Unauthorized`. # Example ```rust use axum::{ routing::get, Router, }; use tower_http::validate_request::ValidateRequestHeaderLayer; let app = Router::new() .route("/foo", get(|| async {})) .route_layer(ValidateRequestHeaderLayer::bearer("password")); // `GET /foo` with a valid token will receive `200 OK` // `GET /foo` with a invalid token will receive `401 Unauthorized` // `GET /not-found` with a invalid token will receive `404 Not Found` # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` axum-axum-v0.6.20/axum/src/docs/routing/route_service.md000066400000000000000000000042641446276533600233150ustar00rootroot00000000000000Add another route to the router that calls a [`Service`]. # Example ```rust,no_run use axum::{ Router, body::Body, routing::{any_service, get_service}, http::{Request, StatusCode}, error_handling::HandleErrorLayer, }; use tower_http::services::ServeFile; use http::Response; use std::{convert::Infallible, io}; use tower::service_fn; let app = Router::new() .route( // Any request to `/` goes to a service "/", // Services whose response body is not `axum::body::BoxBody` // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) // to have the response body mapped any_service(service_fn(|_: Request| async { let res = Response::new(Body::from("Hi from `GET /`")); Ok::<_, Infallible>(res) })) ) .route_service( "/foo", // This service's response body is `axum::body::BoxBody` so // it can be routed to directly. service_fn(|req: Request| async move { let body = Body::from(format!("Hi from `{} /foo`", req.method())); let body = axum::body::boxed(body); let res = Response::new(body); Ok::<_, Infallible>(res) }) ) .route_service( // GET `/static/Cargo.toml` goes to a service from tower-http "/static/Cargo.toml", ServeFile::new("Cargo.toml"), ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` Routing to arbitrary services in this way has complications for backpressure ([`Service::poll_ready`]). See the [Routing to services and backpressure] module for more details. # Panics Panics for the same reasons as [`Router::route`] or if you attempt to route to a `Router`: ```rust,should_panic use axum::{routing::get, Router}; let app = Router::new().route_service( "/", Router::new().route("/foo", get(|| async {})), ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` Use [`Router::nest`] instead. [Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure axum-axum-v0.6.20/axum/src/docs/routing/with_state.md000066400000000000000000000152311446276533600226060ustar00rootroot00000000000000Provide the state for the router. ```rust use axum::{Router, routing::get, extract::State}; #[derive(Clone)] struct AppState {} let routes = Router::new() .route("/", get(|State(state): State| async { // use state })) .with_state(AppState {}); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(routes.into_make_service()) .await; # }; ``` # Returning routers with states from functions When returning `Router`s from functions it is generally recommend not set the state directly: ```rust use axum::{Router, routing::get, extract::State}; #[derive(Clone)] struct AppState {} // Don't call `Router::with_state` here fn routes() -> Router { Router::new() .route("/", get(|_: State| async {})) } // Instead do it before you run the server let routes = routes().with_state(AppState {}); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(routes.into_make_service()) .await; # }; ``` If you do need to provide the state, and you're _not_ nesting/merging the router into another router, then return `Router` without any type parameters: ```rust # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # // Don't return `Router` fn routes(state: AppState) -> Router { Router::new() .route("/", get(|_: State| async {})) .with_state(state) } let routes = routes(AppState {}); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(routes.into_make_service()) .await; # }; ``` This is because we can only call `Router::into_make_service` on `Router<()>`, not `Router`. See below for more details about why that is. Note that the state defaults to `()` so `Router` and `Router<()>` is the same. If you are nesting/merging the router it is recommended to use a generic state type on the resulting router: ```rust # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # fn routes(state: AppState) -> Router { Router::new() .route("/", get(|_: State| async {})) .with_state(state) } let routes = Router::new().nest("/api", routes(AppState {})); # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(routes.into_make_service()) .await; # }; ``` # State is global within the router The state passed to this method will be used for all requests this router receives. That means it is not suitable for holding state derived from a request, such as authorization data extracted in a middleware. Use [`Extension`] instead for such data. # What `S` in `Router` means `Router` means a router that is _missing_ a state of type `S` to be able to handle requests. It does _not_ mean a `Router` that _has_ a state of type `S`. For example: ```rust # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # // A router that _needs_ an `AppState` to handle requests let router: Router = Router::new() .route("/", get(|_: State| async {})); // Once we call `Router::with_state` the router isn't missing // the state anymore, because we just provided it // // Therefore the router type becomes `Router<()>`, i.e a router // that is not missing any state let router: Router<()> = router.with_state(AppState {}); // Only `Router<()>` has the `into_make_service` method. // // You cannot call `into_make_service` on a `Router` // because it is still missing an `AppState`. # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(router.into_make_service()) .await; # }; ``` Perhaps a little counter intuitively, `Router::with_state` doesn't always return a `Router<()>`. Instead you get to pick what the new missing state type is: ```rust # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # let router: Router = Router::new() .route("/", get(|_: State| async {})); // When we call `with_state` we're able to pick what the next missing state type is. // Here we pick `String`. let string_router: Router = router.with_state(AppState {}); // That allows us to add new routes that uses `String` as the state type let string_router = string_router .route("/needs-string", get(|_: State| async {})); // Provide the `String` and choose `()` as the new missing state. let final_router: Router<()> = string_router.with_state("foo".to_owned()); // Since we have a `Router<()>` we can run it. # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(final_router.into_make_service()) .await; # }; ``` This why this returning `Router` after calling `with_state` doesn't work: ```rust,compile_fail # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # // This wont work because we're returning a `Router` // i.e. we're saying we're still missing an `AppState` fn routes(state: AppState) -> Router { Router::new() .route("/", get(|_: State| async {})) .with_state(state) } let app = routes(AppState {}); // We can only call `Router::into_make_service` on a `Router<()>` // but `app` is a `Router` # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(app.into_make_service()) .await; # }; ``` Instead return `Router<()>` since we have provided all the state needed: ```rust # use axum::{Router, routing::get, extract::State}; # #[derive(Clone)] # struct AppState {} # // We've provided all the state necessary so return `Router<()>` fn routes(state: AppState) -> Router<()> { Router::new() .route("/", get(|_: State| async {})) .with_state(state) } let app = routes(AppState {}); // We can now call `Router::into_make_service` # async { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(app.into_make_service()) .await; # }; ``` # A note about performance If you need a `Router` that implements `Service` but you don't need any state (perhaps you're making a library that uses axum internally) then it is recommended to call this method before you start serving requests: ```rust use axum::{Router, routing::get}; let app = Router::new() .route("/", get(|| async { /* ... */ })) // even though we don't need any state, call `with_state(())` anyway .with_state(()); # let _: Router = app; ``` This is not required but it gives axum a chance to update some internals in the router which may impact performance and reduce allocations. Note that [`Router::into_make_service`] and [`Router::into_make_service_with_connect_info`] do this automatically. [`Extension`]: crate::Extension axum-axum-v0.6.20/axum/src/error_handling/000077500000000000000000000000001446276533600204655ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/error_handling/mod.rs000066400000000000000000000162631446276533600216220ustar00rootroot00000000000000#![doc = include_str!("../docs/error_handling.md")] use crate::{ extract::FromRequestParts, http::Request, response::{IntoResponse, Response}, }; use std::{ convert::Infallible, fmt, future::Future, marker::PhantomData, task::{Context, Poll}, }; use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; /// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter /// that handles errors by converting them into responses. /// /// See [module docs](self) for more details on axum's error handling model. pub struct HandleErrorLayer { f: F, _extractor: PhantomData T>, } impl HandleErrorLayer { /// Create a new `HandleErrorLayer`. pub fn new(f: F) -> Self { Self { f, _extractor: PhantomData, } } } impl Clone for HandleErrorLayer where F: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), _extractor: PhantomData, } } } impl fmt::Debug for HandleErrorLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HandleErrorLayer") .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl Layer for HandleErrorLayer where F: Clone, { type Service = HandleError; fn layer(&self, inner: S) -> Self::Service { HandleError::new(inner, self.f.clone()) } } /// A [`Service`] adapter that handles errors by converting them into responses. /// /// See [module docs](self) for more details on axum's error handling model. pub struct HandleError { inner: S, f: F, _extractor: PhantomData T>, } impl HandleError { /// Create a new `HandleError`. pub fn new(inner: S, f: F) -> Self { Self { inner, f, _extractor: PhantomData, } } } impl Clone for HandleError where S: Clone, F: Clone, { fn clone(&self) -> Self { Self { inner: self.inner.clone(), f: self.f.clone(), _extractor: PhantomData, } } } impl fmt::Debug for HandleError where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HandleError") .field("inner", &self.inner) .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } impl Service> for HandleError where S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, B: Send + 'static, { type Response = Response; type Error = Infallible; type Future = future::HandleErrorFuture; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), Err(err) => Ok(f(err).await.into_response()), } }); future::HandleErrorFuture { future } } } #[allow(unused_macros)] macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { impl Service> for HandleError where S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, $( $ty: FromRequestParts<()> + Send,)* B: Send + 'static, { type Response = Response; type Error = Infallible; type Future = future::HandleErrorFuture; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[allow(non_snake_case)] fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &()).await { Ok(value) => value, Err(rejection) => return Ok(rejection.into_response()), }; )* let req = Request::from_parts(parts, body); match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), Err(err) => Ok(f($($ty),*, err).await.into_response()), } }); future::HandleErrorFuture { future } } } } } impl_service!(T1); impl_service!(T1, T2); impl_service!(T1, T2, T3); impl_service!(T1, T2, T3, T4); impl_service!(T1, T2, T3, T4, T5); impl_service!(T1, T2, T3, T4, T5, T6); impl_service!(T1, T2, T3, T4, T5, T6, T7); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); pub mod future { //! Future types. use crate::response::Response; use pin_project_lite::pin_project; use std::{ convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, }; pin_project! { /// Response future for [`HandleError`]. pub struct HandleErrorFuture { #[pin] pub(super) future: Pin> + Send + 'static >>, } } impl Future for HandleErrorFuture { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().future.poll(cx) } } } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); assert_sync::>(); } axum-axum-v0.6.20/axum/src/extension.rs000066400000000000000000000100761446276533600200560ustar00rootroot00000000000000use crate::{extract::rejection::*, response::IntoResponseParts}; use async_trait::async_trait; use axum_core::{ extract::FromRequestParts, response::{IntoResponse, Response, ResponseParts}, }; use http::{request::Parts, Request}; use std::{ convert::Infallible, task::{Context, Poll}, }; use tower_service::Service; /// Extractor and response for extensions. /// /// # As extractor /// /// This is commonly used to share state across handlers. /// /// ```rust,no_run /// use axum::{ /// Router, /// Extension, /// routing::get, /// }; /// use std::sync::Arc; /// /// // Some shared state used throughout our application /// struct State { /// // ... /// } /// /// async fn handler(state: Extension>) { /// // ... /// } /// /// let state = Arc::new(State { /* ... */ }); /// /// let app = Router::new().route("/", get(handler)) /// // Add middleware that inserts the state into all incoming request's /// // extensions. /// .layer(Extension(state)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If the extension is missing it will reject the request with a `500 Internal /// Server Error` response. /// /// # As response /// /// Response extensions can be used to share state with middleware. /// /// ```rust /// use axum::{ /// Extension, /// response::IntoResponse, /// }; /// /// async fn handler() -> (Extension, &'static str) { /// ( /// Extension(Foo("foo")), /// "Hello, World!" /// ) /// } /// /// #[derive(Clone)] /// struct Foo(&'static str); /// ``` #[derive(Debug, Clone, Copy, Default)] #[must_use] pub struct Extension(pub T); #[async_trait] impl FromRequestParts for Extension where T: Clone + Send + Sync + 'static, S: Send + Sync, { type Rejection = ExtensionRejection; async fn from_request_parts(req: &mut Parts, _state: &S) -> Result { let value = req .extensions .get::() .ok_or_else(|| { MissingExtension::from_err(format!( "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.", std::any::type_name::() )) }) .map(|x| x.clone())?; Ok(Extension(value)) } } axum_core::__impl_deref!(Extension); impl IntoResponseParts for Extension where T: Send + Sync + 'static, { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { res.extensions_mut().insert(self.0); Ok(res) } } impl IntoResponse for Extension where T: Send + Sync + 'static, { fn into_response(self) -> Response { let mut res = ().into_response(); res.extensions_mut().insert(self.0); res } } impl tower_layer::Layer for Extension where T: Clone + Send + Sync + 'static, { type Service = AddExtension; fn layer(&self, inner: S) -> Self::Service { AddExtension { inner, value: self.0.clone(), } } } /// Middleware for adding some shareable value to [request extensions]. /// /// See [Sharing state with handlers](index.html#sharing-state-with-handlers) /// for more details. /// /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[derive(Clone, Copy, Debug)] pub struct AddExtension { pub(crate) inner: S, pub(crate) value: T, } impl Service> for AddExtension where S: Service>, T: Clone + Send + Sync + 'static, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { req.extensions_mut().insert(self.value.clone()); self.inner.call(req) } } axum-axum-v0.6.20/axum/src/extract/000077500000000000000000000000001446276533600171425ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/extract/connect_info.rs000066400000000000000000000236051446276533600221620ustar00rootroot00000000000000//! Extractor for getting connection information from a client. //! //! See [`Router::into_make_service_with_connect_info`] for more details. //! //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info use super::{Extension, FromRequestParts}; use crate::middleware::AddExtension; use async_trait::async_trait; use http::request::Parts; use hyper::server::conn::AddrStream; use std::{ convert::Infallible, fmt, future::ready, marker::PhantomData, net::SocketAddr, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// A [`MakeService`] created from a router. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub struct IntoMakeServiceWithConnectInfo { svc: S, _connect_info: PhantomData C>, } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); } impl IntoMakeServiceWithConnectInfo { pub(crate) fn new(svc: S) -> Self { Self { svc, _connect_info: PhantomData, } } } impl fmt::Debug for IntoMakeServiceWithConnectInfo where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoMakeServiceWithConnectInfo") .field("svc", &self.svc) .finish() } } impl Clone for IntoMakeServiceWithConnectInfo where S: Clone, { fn clone(&self) -> Self { Self { svc: self.svc.clone(), _connect_info: PhantomData, } } } /// Trait that connected IO resources implement and use to produce information /// about the connection. /// /// The goal for this trait is to allow users to implement custom IO types that /// can still provide the same connection metadata. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. fn connect_info(target: T) -> Self; } impl Connected<&AddrStream> for SocketAddr { fn connect_info(target: &AddrStream) -> Self { target.remote_addr() } } impl Service for IntoMakeServiceWithConnectInfo where S: Clone, C: Connected, { type Response = AddExtension>; type Error = Infallible; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, target: T) -> Self::Future { let connect_info = ConnectInfo(C::connect_info(target)); let svc = Extension(connect_info).layer(self.svc.clone()); ResponseFuture::new(ready(Ok(svc))) } } opaque_future! { /// Response future for [`IntoMakeServiceWithConnectInfo`]. pub type ResponseFuture = std::future::Ready>, Infallible>>; } /// Extractor for getting connection information produced by a [`Connected`]. /// /// Note this extractor requires you to use /// [`Router::into_make_service_with_connect_info`] to run your app /// otherwise it will fail at runtime. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[derive(Clone, Copy, Debug)] pub struct ConnectInfo(pub T); #[async_trait] impl FromRequestParts for ConnectInfo where S: Send + Sync, T: Clone + Send + Sync + 'static, { type Rejection = as FromRequestParts>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { match Extension::::from_request_parts(parts, state).await { Ok(Extension(connect_info)) => Ok(connect_info), Err(err) => match parts.extensions.get::>() { Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())), None => Err(err), }, } } } axum_core::__impl_deref!(ConnectInfo); /// Middleware used to mock [`ConnectInfo`] during tests. /// /// If you're accidentally using [`MockConnectInfo`] and /// [`Router::into_make_service_with_connect_info`] at the same time then /// [`Router::into_make_service_with_connect_info`] takes precedence. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// extract::connect_info::{MockConnectInfo, ConnectInfo}, /// body::Body, /// routing::get, /// http::{Request, StatusCode}, /// }; /// use std::net::SocketAddr; /// use tower::ServiceExt; /// /// async fn handler(ConnectInfo(addr): ConnectInfo) {} /// /// // this router you can run with `app.into_make_service_with_connect_info::()` /// fn app() -> Router { /// Router::new().route("/", get(handler)) /// } /// /// // use this router for tests /// fn test_app() -> Router { /// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))) /// } /// /// // #[tokio::test] /// async fn some_test() { /// let app = test_app(); /// /// let request = Request::new(Body::empty()); /// let response = app.oneshot(request).await.unwrap(); /// assert_eq!(response.status(), StatusCode::OK); /// } /// # /// # #[tokio::main] /// # async fn main() { /// # some_test().await; /// # } /// ``` /// /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info #[derive(Clone, Copy, Debug)] pub struct MockConnectInfo(pub T); impl Layer for MockConnectInfo where T: Clone + Send + Sync + 'static, { type Service = as Layer>::Service; fn layer(&self, inner: S) -> Self::Service { Extension(self.clone()).layer(inner) } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router, Server}; use std::net::{SocketAddr, TcpListener}; #[crate::test] async fn socket_addr() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let (tx, rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let app = Router::new().route("/", get(handler)); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); tx.send(()).unwrap(); server.await.expect("server error"); }); rx.await.unwrap(); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert!(body.starts_with("127.0.0.1:")); } #[crate::test] async fn custom() { #[derive(Clone, Debug)] struct MyConnectInfo { value: &'static str, } impl Connected<&AddrStream> for MyConnectInfo { fn connect_info(_target: &AddrStream) -> Self { Self { value: "it worked!", } } } async fn handler(ConnectInfo(addr): ConnectInfo) -> &'static str { addr.value } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let (tx, rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let app = Router::new().route("/", get(handler)); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); tx.send(()).unwrap(); server.await.expect("server error"); }); rx.await.unwrap(); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert_eq!(body, "it worked!"); } #[crate::test] async fn mock_connect_info() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let app = Router::new() .route("/", get(handler)) .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); let client = TestClient::new(app); let res = client.get("/").send().await; let body = res.text().await; assert!(body.starts_with("0.0.0.0:1337")); } #[crate::test] async fn both_mock_and_real_connect_info() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { let app = Router::new() .route("/", get(handler)) .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); server.await.expect("server error"); }); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert!(body.starts_with("127.0.0.1:")); } } axum-axum-v0.6.20/axum/src/extract/host.rs000066400000000000000000000123421446276533600204670ustar00rootroot00000000000000use super::{ rejection::{FailedToResolveHost, HostRejection}, FromRequestParts, }; use async_trait::async_trait; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, }; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; /// Extractor that resolves the hostname of the request. /// /// Hostname is resolved through the following, in order: /// - `Forwarded` header /// - `X-Forwarded-Host` header /// - `Host` header /// - request target / URI /// /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. #[derive(Debug, Clone)] pub struct Host(pub String); #[async_trait] impl FromRequestParts for Host where S: Send + Sync, { type Rejection = HostRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(host) = parse_forwarded(&parts.headers) { return Ok(Host(host.to_owned())); } if let Some(host) = parts .headers .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } if let Some(host) = parts .headers .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } if let Some(host) = parts.uri.host() { return Ok(Host(host.to_owned())); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) } } #[allow(warnings)] fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { // if there are multiple `Forwarded` `HeaderMap::get` will return the first one let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?; // get the first set of values let first_value = forwarded_values.split(',').nth(0)?; // find the value of the `host` field first_value.split(';').find_map(|pair| { let (key, value) = pair.split_once('=')?; key.trim() .eq_ignore_ascii_case("host") .then(|| value.trim().trim_matches('"')) }) } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router}; use http::header::HeaderName; fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { host } TestClient::new(Router::new().route("/", get(host_as_body))) } #[crate::test] async fn host_header() { let original_host = "some-domain:123"; let host = test_client() .get("/") .header(http::header::HOST, original_host) .send() .await .text() .await; assert_eq!(host, original_host); } #[crate::test] async fn x_forwarded_host_header() { let original_host = "some-domain:456"; let host = test_client() .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, original_host) .send() .await .text() .await; assert_eq!(host, original_host); } #[crate::test] async fn x_forwarded_host_precedence_over_host_header() { let x_forwarded_host_header = "some-domain:456"; let host_header = "some-domain:123"; let host = test_client() .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header) .header(http::header::HOST, host_header) .send() .await .text() .await; assert_eq!(host, x_forwarded_host_header); } #[crate::test] async fn uri_host() { let host = test_client().get("/").send().await.text().await; assert!(host.contains("127.0.0.1")); } #[test] fn forwarded_parsing() { // the basic case let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // is case insensitive let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // ipv6 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "[2001:db8:cafe::17]:4711"); // multiple values in one header let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // multiple header values let headers = header_map(&[ (FORWARDED, "host=192.0.2.60"), (FORWARDED, "host=127.0.0.1"), ]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); } fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap { let mut headers = HeaderMap::new(); for (key, value) in values { headers.append(key, value.parse().unwrap()); } headers } } axum-axum-v0.6.20/axum/src/extract/matched_path.rs000066400000000000000000000271401446276533600221350ustar00rootroot00000000000000use super::{rejection::*, FromRequestParts}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; use async_trait::async_trait; use http::request::Parts; use std::{collections::HashMap, sync::Arc}; /// Access the path in the router that matches the request. /// /// ``` /// use axum::{ /// Router, /// extract::MatchedPath, /// routing::get, /// }; /// /// let app = Router::new().route( /// "/users/:id", /// get(|path: MatchedPath| async move { /// let path = path.as_str(); /// // `path` will be "/users/:id" /// }) /// ); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # Accessing `MatchedPath` via extensions /// /// `MatchedPath` can also be accessed from middleware via request extensions. /// /// This is useful for example with [`Trace`](tower_http::trace::Trace) to /// create a span that contains the matched path: /// /// ``` /// use axum::{ /// Router, /// extract::MatchedPath, /// http::Request, /// routing::get, /// }; /// use tower_http::trace::TraceLayer; /// /// let app = Router::new() /// .route("/users/:id", get(|| async { /* ... */ })) /// .layer( /// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { /// let path = if let Some(path) = req.extensions().get::() { /// path.as_str() /// } else { /// req.uri().path() /// }; /// tracing::info_span!("http-request", %path) /// }), /// ); /// # let _: Router = app; /// ``` /// /// # Matched path in nested routers /// /// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes: /// /// ``` /// use axum::{ /// Router, /// RequestExt, /// routing::get, /// extract::{MatchedPath, rejection::MatchedPathRejection}, /// middleware::map_request, /// http::Request, /// body::Body, /// }; /// /// async fn access_matched_path(mut request: Request) -> Request { /// // if `/foo/bar` is called this will be `Err(_)` since that matches /// // a nested route /// let matched_path: Result = /// request.extract_parts::().await; /// /// request /// } /// /// // `MatchedPath` is always accessible on handlers added via `Router::route` /// async fn handler(matched_path: MatchedPath) {} /// /// let app = Router::new() /// .nest( /// "/foo", /// Router::new().route("/bar", get(handler)), /// ) /// .layer(map_request(access_matched_path)); /// # let _: Router = app; /// ``` /// /// [nesting]: crate::Router::nest #[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))] #[derive(Clone, Debug)] pub struct MatchedPath(pub(crate) Arc); impl MatchedPath { /// Returns a `str` representation of the path. pub fn as_str(&self) -> &str { &self.0 } } #[async_trait] impl FromRequestParts for MatchedPath where S: Send + Sync, { type Rejection = MatchedPathRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let matched_path = parts .extensions .get::() .ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))? .clone(); Ok(matched_path) } } #[derive(Clone, Debug)] struct MatchedNestedPath(Arc); pub(crate) fn set_matched_path_for_request( id: RouteId, route_id_to_path: &HashMap>, extensions: &mut http::Extensions, ) { let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) { matched_path } else { #[cfg(debug_assertions)] panic!("should always have a matched path for a route id"); #[cfg(not(debug_assertions))] return; }; let matched_path = append_nested_matched_path(matched_path, extensions); if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) { extensions.insert(MatchedNestedPath(matched_path)); debug_assert!(extensions.remove::().is_none()); } else { extensions.insert(MatchedPath(matched_path)); extensions.remove::(); } } // a previous `MatchedPath` might exist if we're inside a nested Router fn append_nested_matched_path(matched_path: &Arc, extensions: &http::Extensions) -> Arc { if let Some(previous) = extensions .get::() .map(|matched_path| matched_path.as_str()) .or_else(|| Some(&extensions.get::()?.0)) { let previous = previous .strip_suffix(NEST_TAIL_PARAM_CAPTURE) .unwrap_or(previous); let matched_path = format!("{previous}{matched_path}"); matched_path.into() } else { Arc::clone(matched_path) } } #[cfg(test)] mod tests { use super::*; use crate::{ body::Body, handler::HandlerWithoutStateExt, middleware::map_request, routing::{any, get}, test_helpers::*, Router, }; use http::{Request, StatusCode}; #[crate::test] async fn extracting_on_handler() { let app = Router::new().route( "/:a", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.text().await, "/:a"); } #[crate::test] async fn extracting_on_handler_in_nested_router() { let app = Router::new().nest( "/:a", Router::new().route( "/:b", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ), ); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.text().await, "/:a/:b"); } #[crate::test] async fn extracting_on_handler_in_deeply_nested_router() { let app = Router::new().nest( "/:a", Router::new().nest( "/:b", Router::new().route( "/:c", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ), ), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.text().await, "/:a/:b/:c"); } #[crate::test] async fn cannot_extract_nested_matched_path_in_middleware() { async fn extract_matched_path( matched_path: Option, req: Request, ) -> Request { assert!(matched_path.is_none()); req } let app = Router::new() .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(extract_matched_path)); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn can_extract_nested_matched_path_in_middleware_using_nest() { async fn extract_matched_path( matched_path: Option, req: Request, ) -> Request { assert_eq!(matched_path.unwrap().as_str(), "/:a/:b"); req } let app = Router::new() .nest("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(extract_matched_path)); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn cannot_extract_nested_matched_path_in_middleware_via_extension() { async fn assert_no_matched_path(req: Request) -> Request { assert!(req.extensions().get::().is_none()); req } let app = Router::new() .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(assert_no_matched_path)); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn can_extract_nested_matched_path_in_middleware_via_extension_using_nest() { async fn assert_matched_path(req: Request) -> Request { assert!(req.extensions().get::().is_some()); req } let app = Router::new() .nest("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(assert_matched_path)); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn can_extract_nested_matched_path_in_middleware_on_nested_router() { async fn extract_matched_path(matched_path: MatchedPath, req: Request) -> Request { assert_eq!(matched_path.as_str(), "/:a/:b"); req } let app = Router::new().nest( "/:a", Router::new() .route("/:b", get(|| async move {})) .layer(map_request(extract_matched_path)), ); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() { async fn extract_matched_path(req: Request) -> Request { let matched_path = req.extensions().get::().unwrap(); assert_eq!(matched_path.as_str(), "/:a/:b"); req } let app = Router::new().nest( "/:a", Router::new() .route("/:b", get(|| async move {})) .layer(map_request(extract_matched_path)), ); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn extracting_on_nested_handler() { async fn handler(path: Option) { assert!(path.is_none()); } let app = Router::new().nest_service("/:a", handler.into_service()); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } // https://github.com/tokio-rs/axum/issues/1579 #[crate::test] async fn doesnt_panic_if_router_called_from_wildcard_route() { use tower::ServiceExt; let app = Router::new().route( "/*path", any(|req: Request| { Router::new() .nest("/", Router::new().route("/foo", get(|| async {}))) .oneshot(req) }), ); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn cant_extract_in_fallback() { async fn handler(path: Option, req: Request) { assert!(path.is_none()); assert!(req.extensions().get::().is_none()); } let app = Router::new().fallback(handler); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } } axum-axum-v0.6.20/axum/src/extract/mod.rs000066400000000000000000000046601446276533600202750ustar00rootroot00000000000000#![doc = include_str!("../docs/extract.md")] use http::header::{self, HeaderMap}; #[cfg(feature = "tokio")] pub mod connect_info; pub mod path; pub mod rejection; #[cfg(feature = "ws")] pub mod ws; mod host; mod raw_form; mod raw_query; mod request_parts; mod state; #[doc(inline)] pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts}; #[cfg(feature = "macros")] pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; #[doc(inline)] #[allow(deprecated)] pub use self::{ host::Host, path::{Path, RawPathParams}, raw_form::RawForm, raw_query::RawQuery, request_parts::{BodyStream, RawBody}, state::State, }; #[doc(inline)] #[cfg(feature = "tokio")] pub use self::connect_info::ConnectInfo; #[doc(no_inline)] #[cfg(feature = "json")] pub use crate::Json; #[doc(no_inline)] pub use crate::Extension; #[cfg(feature = "form")] #[doc(no_inline)] pub use crate::form::Form; #[cfg(feature = "matched-path")] pub(crate) mod matched_path; #[cfg(feature = "matched-path")] #[doc(inline)] pub use self::matched_path::MatchedPath; #[cfg(feature = "multipart")] pub mod multipart; #[cfg(feature = "multipart")] #[doc(inline)] pub use self::multipart::Multipart; #[cfg(feature = "query")] mod query; #[cfg(feature = "query")] #[doc(inline)] pub use self::query::Query; #[cfg(feature = "original-uri")] #[doc(inline)] pub use self::request_parts::OriginalUri; #[cfg(feature = "ws")] #[doc(inline)] pub use self::ws::WebSocketUpgrade; #[cfg(feature = "headers")] #[doc(no_inline)] pub use crate::TypedHeader; // this is duplicated in `axum-extra/src/extract/form.rs` pub(super) fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool { let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { content_type } else { return false; }; let content_type = if let Ok(content_type) = content_type.to_str() { content_type } else { return false; }; content_type.starts_with(expected_content_type.as_ref()) } #[cfg(test)] mod tests { use crate::{routing::get, test_helpers::*, Router}; #[crate::test] async fn consume_body() { let app = Router::new().route("/", get(|body: String| async { body })); let client = TestClient::new(app); let res = client.get("/").body("foo").send().await; let body = res.text().await; assert_eq!(body, "foo"); } } axum-axum-v0.6.20/axum/src/extract/multipart.rs000066400000000000000000000272141446276533600215370ustar00rootroot00000000000000//! Extractor that parses `multipart/form-data` requests commonly used with file uploads. //! //! See [`Multipart`] for more details. use super::{BodyStream, FromRequest}; use crate::body::{Bytes, HttpBody}; use crate::BoxError; use async_trait::async_trait; use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; use axum_core::response::{IntoResponse, Response}; use axum_core::RequestExt; use futures_util::stream::Stream; use http::header::{HeaderMap, CONTENT_TYPE}; use http::{Request, StatusCode}; use std::error::Error; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; /// Extractor that parses `multipart/form-data` requests (commonly used with file uploads). /// /// โš ๏ธ Since extracting multipart form data from the request requires consuming the body, the /// `Multipart` extractor must be *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::Multipart, /// routing::post, /// Router, /// }; /// use futures_util::stream::StreamExt; /// /// async fn upload(mut multipart: Multipart) { /// while let Some(mut field) = multipart.next_field().await.unwrap() { /// let name = field.name().unwrap().to_string(); /// let data = field.bytes().await.unwrap(); /// /// println!("Length of `{}` is {} bytes", name, data.len()); /// } /// } /// /// let app = Router::new().route("/upload", post(upload)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] #[derive(Debug)] pub struct Multipart { inner: multer::Multipart<'static>, } #[async_trait] impl FromRequest for Multipart where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, S: Send + Sync, { type Rejection = MultipartRejection; async fn from_request(req: Request, state: &S) -> Result { let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; let stream_result = match req.with_limited_body() { Ok(limited) => BodyStream::from_request(limited, state).await, Err(unlimited) => BodyStream::from_request(unlimited, state).await, }; let stream = stream_result.unwrap_or_else(|err| match err {}); let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) } } impl Multipart { /// Yields the next [`Field`] if available. pub async fn next_field(&mut self) -> Result>, MultipartError> { let field = self .inner .next_field() .await .map_err(MultipartError::from_multer)?; if let Some(field) = field { Ok(Some(Field { inner: field, _multipart: self, })) } else { Ok(None) } } } /// A single field in a multipart stream. #[derive(Debug)] pub struct Field<'a> { inner: multer::Field<'static>, // multer requires there to only be one live `multer::Field` at any point. This enforces that // statically, which multer does not do, it returns an error instead. _multipart: &'a mut Multipart, } impl<'a> Stream for Field<'a> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner) .poll_next(cx) .map_err(MultipartError::from_multer) } } impl<'a> Field<'a> { /// The field name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. pub fn name(&self) -> Option<&str> { self.inner.name() } /// The file name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. pub fn file_name(&self) -> Option<&str> { self.inner.file_name() } /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field. pub fn content_type(&self) -> Option<&str> { self.inner.content_type().map(|m| m.as_ref()) } /// Get a map of headers as [`HeaderMap`]. pub fn headers(&self) -> &HeaderMap { self.inner.headers() } /// Get the full data of the field as [`Bytes`]. pub async fn bytes(self) -> Result { self.inner .bytes() .await .map_err(MultipartError::from_multer) } /// Get the full field data as text. pub async fn text(self) -> Result { self.inner.text().await.map_err(MultipartError::from_multer) } /// Stream a chunk of the field data. /// /// When the field data has been exhausted, this will return [`None`]. /// /// Note this does the same thing as `Field`'s [`Stream`] implementation. /// /// # Example /// /// ``` /// use axum::{ /// extract::Multipart, /// routing::post, /// response::IntoResponse, /// http::StatusCode, /// Router, /// }; /// /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> { /// while let Some(mut field) = multipart /// .next_field() /// .await /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))? /// { /// while let Some(chunk) = field /// .chunk() /// .await /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))? /// { /// println!("received {} bytes", chunk.len()); /// } /// } /// /// Ok(()) /// } /// /// let app = Router::new().route("/upload", post(upload)); /// # let _: Router = app; /// ``` pub async fn chunk(&mut self) -> Result, MultipartError> { self.inner .chunk() .await .map_err(MultipartError::from_multer) } } /// Errors associated with parsing `multipart/form-data` requests. #[derive(Debug)] pub struct MultipartError { source: multer::Error, } impl MultipartError { fn from_multer(multer: multer::Error) -> Self { Self { source: multer } } /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { self.source.to_string() } /// Get the status code used for this rejection. pub fn status(&self) -> http::StatusCode { status_code_from_multer_error(&self.source) } } fn status_code_from_multer_error(err: &multer::Error) -> StatusCode { match err { multer::Error::UnknownField { .. } | multer::Error::IncompleteFieldData { .. } | multer::Error::IncompleteHeaders | multer::Error::ReadHeaderFailed(..) | multer::Error::DecodeHeaderName { .. } | multer::Error::DecodeContentType(..) | multer::Error::NoBoundary | multer::Error::DecodeHeaderValue { .. } | multer::Error::NoMultipart | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST, multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => { StatusCode::PAYLOAD_TOO_LARGE } multer::Error::StreamReadFailed(err) => { if let Some(err) = err.downcast_ref::() { return status_code_from_multer_error(err); } if err .downcast_ref::() .and_then(|err| err.source()) .and_then(|err| err.downcast_ref::()) .is_some() { return StatusCode::PAYLOAD_TOO_LARGE; } StatusCode::INTERNAL_SERVER_ERROR } _ => StatusCode::INTERNAL_SERVER_ERROR, } } impl fmt::Display for MultipartError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Error parsing `multipart/form-data` request") } } impl std::error::Error for MultipartError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.source) } } impl IntoResponse for MultipartError { fn into_response(self) -> Response { axum_core::__log_rejection!( rejection_type = Self, body_text = self.body_text(), status = self.status(), ); (self.status(), self.body_text()).into_response() } } fn parse_boundary(headers: &HeaderMap) -> Option { let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?; multer::parse_boundary(content_type).ok() } composite_rejection! { /// Rejection used for [`Multipart`]. /// /// Contains one variant for each way the [`Multipart`] extractor can fail. pub enum MultipartRejection { InvalidBoundary, } } define_rejection! { #[status = BAD_REQUEST] #[body = "Invalid `boundary` for `multipart/form-data` request"] /// Rejection type used if the `boundary` in a `multipart/form-data` is /// missing or invalid. pub struct InvalidBoundary; } #[cfg(test)] mod tests { use axum_core::extract::DefaultBodyLimit; use super::*; use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router}; #[crate::test] async fn content_type_with_encoding() { const BYTES: &[u8] = "๐Ÿฆ€".as_bytes(); const FILE_NAME: &str = "index.html"; const CONTENT_TYPE: &str = "text/html; charset=utf-8"; async fn handle(mut multipart: Multipart) -> impl IntoResponse { let field = multipart.next_field().await.unwrap().unwrap(); assert_eq!(field.file_name().unwrap(), FILE_NAME); assert_eq!(field.content_type().unwrap(), CONTENT_TYPE); assert_eq!(field.bytes().await.unwrap(), BYTES); assert!(multipart.next_field().await.unwrap().is_none()); } let app = Router::new().route("/", post(handle)); let client = TestClient::new(app); let form = reqwest::multipart::Form::new().part( "file", reqwest::multipart::Part::bytes(BYTES) .file_name(FILE_NAME) .mime_str(CONTENT_TYPE) .unwrap(), ); client.post("/").multipart(form).send().await; } // No need for this to be a #[test], we just want to make sure it compiles fn _multipart_from_request_limited() { async fn handler(_: Multipart) {} let _app: Router<(), http_body::Limited> = Router::new().route("/", post(handler)); } #[crate::test] async fn body_too_large() { const BYTES: &[u8] = "๐Ÿฆ€".as_bytes(); async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> { while let Some(field) = multipart.next_field().await? { field.bytes().await?; } Ok(()) } let app = Router::new() .route("/", post(handle)) .layer(DefaultBodyLimit::max(BYTES.len() - 1)); let client = TestClient::new(app); let form = reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES)); let res = client.post("/").multipart(form).send().await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } } axum-axum-v0.6.20/axum/src/extract/path/000077500000000000000000000000001446276533600200765ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/extract/path/de.rs000066400000000000000000000650641446276533600210470ustar00rootroot00000000000000use super::{ErrorKind, PathDeserializationError}; use crate::util::PercentDecodedStr; use serde::{ de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, forward_to_deserialize_any, Deserializer, }; use std::{any::type_name, sync::Arc}; macro_rules! unsupported_type { ($trait_fn:ident) => { fn $trait_fn(self, _: V) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type(type_name::< V::Value, >())) } }; } macro_rules! parse_single_value { ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { fn $trait_fn(self, visitor: V) -> Result where V: Visitor<'de>, { if self.url_params.len() != 1 { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(1)); } let value = self.url_params[0].1.parse().map_err(|_| { PathDeserializationError::new(ErrorKind::ParseError { value: self.url_params[0].1.as_str().to_owned(), expected_type: $ty, }) })?; visitor.$visit_fn(value) } }; } pub(crate) struct PathDeserializer<'de> { url_params: &'de [(Arc, PercentDecodedStr)], } impl<'de> PathDeserializer<'de> { #[inline] pub(crate) fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { PathDeserializer { url_params } } } impl<'de> Deserializer<'de> for PathDeserializer<'de> { type Error = PathDeserializationError; unsupported_type!(deserialize_bytes); unsupported_type!(deserialize_option); unsupported_type!(deserialize_identifier); unsupported_type!(deserialize_ignored_any); parse_single_value!(deserialize_bool, visit_bool, "bool"); parse_single_value!(deserialize_i8, visit_i8, "i8"); parse_single_value!(deserialize_i16, visit_i16, "i16"); parse_single_value!(deserialize_i32, visit_i32, "i32"); parse_single_value!(deserialize_i64, visit_i64, "i64"); parse_single_value!(deserialize_i128, visit_i128, "i128"); parse_single_value!(deserialize_u8, visit_u8, "u8"); parse_single_value!(deserialize_u16, visit_u16, "u16"); parse_single_value!(deserialize_u32, visit_u32, "u32"); parse_single_value!(deserialize_u64, visit_u64, "u64"); parse_single_value!(deserialize_u128, visit_u128, "u128"); parse_single_value!(deserialize_f32, visit_f32, "f32"); parse_single_value!(deserialize_f64, visit_f64, "f64"); parse_single_value!(deserialize_string, visit_string, "String"); parse_single_value!(deserialize_byte_buf, visit_string, "String"); parse_single_value!(deserialize_char, visit_char, "char"); fn deserialize_any(self, v: V) -> Result where V: Visitor<'de>, { self.deserialize_str(v) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { if self.url_params.len() != 1 { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(1)); } visitor.visit_borrowed_str(&self.url_params[0].1) } fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_unit() } fn deserialize_unit_struct( self, _name: &'static str, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_unit() } fn deserialize_newtype_struct( self, _name: &'static str, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(SeqDeserializer { params: self.url_params, idx: 0, }) } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { if self.url_params.len() < len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); } visitor.visit_seq(SeqDeserializer { params: self.url_params, idx: 0, }) } fn deserialize_tuple_struct( self, _name: &'static str, len: usize, visitor: V, ) -> Result where V: Visitor<'de>, { if self.url_params.len() < len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); } visitor.visit_seq(SeqDeserializer { params: self.url_params, idx: 0, }) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_map(MapDeserializer { params: self.url_params, value: None, key: None, }) } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { self.deserialize_map(visitor) } fn deserialize_enum( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { if self.url_params.len() != 1 { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(1)); } visitor.visit_enum(EnumDeserializer { value: self.url_params[0].1.clone().into_inner(), }) } } struct MapDeserializer<'de> { params: &'de [(Arc, PercentDecodedStr)], key: Option, value: Option<&'de PercentDecodedStr>, } impl<'de> MapAccess<'de> for MapDeserializer<'de> { type Error = PathDeserializationError; fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> where K: DeserializeSeed<'de>, { match self.params.split_first() { Some(((key, value), tail)) => { self.value = Some(value); self.params = tail; self.key = Some(KeyOrIdx::Key(key.clone())); seed.deserialize(KeyDeserializer { key: Arc::clone(key), }) .map(Some) } None => Ok(None), } } fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, { match self.value.take() { Some(value) => seed.deserialize(ValueDeserializer { key: self.key.take(), value, }), None => Err(PathDeserializationError::custom("value is missing")), } } } struct KeyDeserializer { key: Arc, } macro_rules! parse_key { ($trait_fn:ident) => { fn $trait_fn(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_str(&self.key) } }; } impl<'de> Deserializer<'de> for KeyDeserializer { type Error = PathDeserializationError; parse_key!(deserialize_identifier); parse_key!(deserialize_str); parse_key!(deserialize_string); fn deserialize_any(self, _visitor: V) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::custom("Unexpected key type")) } forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char bytes byte_buf option unit unit_struct seq tuple tuple_struct map newtype_struct struct enum ignored_any } } macro_rules! parse_value { ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { fn $trait_fn(mut self, visitor: V) -> Result where V: Visitor<'de>, { let v = self.value.parse().map_err(|_| { if let Some(key) = self.key.take() { let kind = match key { KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { key: key.to_string(), value: self.value.as_str().to_owned(), expected_type: $ty, }, KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex { index, value: self.value.as_str().to_owned(), expected_type: $ty, }, }; PathDeserializationError::new(kind) } else { PathDeserializationError::new(ErrorKind::ParseError { value: self.value.as_str().to_owned(), expected_type: $ty, }) } })?; visitor.$visit_fn(v) } }; } #[derive(Debug)] struct ValueDeserializer<'de> { key: Option, value: &'de PercentDecodedStr, } impl<'de> Deserializer<'de> for ValueDeserializer<'de> { type Error = PathDeserializationError; unsupported_type!(deserialize_map); unsupported_type!(deserialize_identifier); parse_value!(deserialize_bool, visit_bool, "bool"); parse_value!(deserialize_i8, visit_i8, "i8"); parse_value!(deserialize_i16, visit_i16, "i16"); parse_value!(deserialize_i32, visit_i32, "i32"); parse_value!(deserialize_i64, visit_i64, "i64"); parse_value!(deserialize_i128, visit_i128, "i128"); parse_value!(deserialize_u8, visit_u8, "u8"); parse_value!(deserialize_u16, visit_u16, "u16"); parse_value!(deserialize_u32, visit_u32, "u32"); parse_value!(deserialize_u64, visit_u64, "u64"); parse_value!(deserialize_u128, visit_u128, "u128"); parse_value!(deserialize_f32, visit_f32, "f32"); parse_value!(deserialize_f64, visit_f64, "f64"); parse_value!(deserialize_string, visit_string, "String"); parse_value!(deserialize_byte_buf, visit_string, "String"); parse_value!(deserialize_char, visit_char, "char"); fn deserialize_any(self, v: V) -> Result where V: Visitor<'de>, { self.deserialize_str(v) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_borrowed_str(self.value) } fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_borrowed_bytes(self.value.as_bytes()) } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_some(self) } fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_unit() } fn deserialize_unit_struct( self, _name: &'static str, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_unit() } fn deserialize_newtype_struct( self, _name: &'static str, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_newtype_struct(self) } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { struct PairDeserializer<'de> { key: Option, value: Option<&'de PercentDecodedStr>, } impl<'de> SeqAccess<'de> for PairDeserializer<'de> { type Error = PathDeserializationError; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { match self.key.take() { Some(KeyOrIdx::Idx { idx: _, key }) => { return seed.deserialize(KeyDeserializer { key }).map(Some); } // `KeyOrIdx::Key` is only used when deserializing maps so `deserialize_seq` // wouldn't be called for that Some(KeyOrIdx::Key(_)) => unreachable!(), None => {} }; self.value .take() .map(|value| seed.deserialize(ValueDeserializer { key: None, value })) .transpose() } } if len == 2 { match self.key { Some(key) => visitor.visit_seq(PairDeserializer { key: Some(key), value: Some(self.value), }), // `self.key` is only `None` when deserializing maps so `deserialize_seq` // wouldn't be called for that None => unreachable!(), } } else { Err(PathDeserializationError::unsupported_type(type_name::< V::Value, >())) } } fn deserialize_seq(self, _visitor: V) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type(type_name::< V::Value, >())) } fn deserialize_tuple_struct( self, _name: &'static str, _len: usize, _visitor: V, ) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type(type_name::< V::Value, >())) } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], _visitor: V, ) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type(type_name::< V::Value, >())) } fn deserialize_enum( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_enum(EnumDeserializer { value: self.value.clone().into_inner(), }) } fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_unit() } } struct EnumDeserializer { value: Arc, } impl<'de> EnumAccess<'de> for EnumDeserializer { type Error = PathDeserializationError; type Variant = UnitVariant; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where V: de::DeserializeSeed<'de>, { Ok(( seed.deserialize(KeyDeserializer { key: self.value })?, UnitVariant, )) } } struct UnitVariant; impl<'de> VariantAccess<'de> for UnitVariant { type Error = PathDeserializationError; fn unit_variant(self) -> Result<(), Self::Error> { Ok(()) } fn newtype_variant_seed(self, _seed: T) -> Result where T: DeserializeSeed<'de>, { Err(PathDeserializationError::unsupported_type( "newtype enum variant", )) } fn tuple_variant(self, _len: usize, _visitor: V) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type( "tuple enum variant", )) } fn struct_variant( self, _fields: &'static [&'static str], _visitor: V, ) -> Result where V: Visitor<'de>, { Err(PathDeserializationError::unsupported_type( "struct enum variant", )) } } struct SeqDeserializer<'de> { params: &'de [(Arc, PercentDecodedStr)], idx: usize, } impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { type Error = PathDeserializationError; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { match self.params.split_first() { Some(((key, value), tail)) => { self.params = tail; let idx = self.idx; self.idx += 1; Ok(Some(seed.deserialize(ValueDeserializer { key: Some(KeyOrIdx::Idx { idx, key: key.clone(), }), value, })?)) } None => Ok(None), } } } #[derive(Debug, Clone)] enum KeyOrIdx { Key(Arc), Idx { idx: usize, key: Arc }, } #[cfg(test)] mod tests { use super::*; use serde::Deserialize; use std::collections::HashMap; #[derive(Debug, Deserialize, Eq, PartialEq)] enum MyEnum { A, B, #[serde(rename = "c")] C, } #[derive(Debug, Deserialize, Eq, PartialEq)] struct Struct { c: String, b: bool, a: i32, } fn create_url_params(values: I) -> Vec<(Arc, PercentDecodedStr)> where I: IntoIterator, K: AsRef, V: AsRef, { values .into_iter() .map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap())) .collect() } macro_rules! check_single_value { ($ty:ty, $value_str:literal, $value:expr) => { #[allow(clippy::bool_assert_comparison)] { let url_params = create_url_params(vec![("value", $value_str)]); let deserializer = PathDeserializer::new(&url_params); assert_eq!(<$ty>::deserialize(deserializer).unwrap(), $value); } }; } #[test] fn test_parse_single_value() { check_single_value!(bool, "true", true); check_single_value!(bool, "false", false); check_single_value!(i8, "-123", -123); check_single_value!(i16, "-123", -123); check_single_value!(i32, "-123", -123); check_single_value!(i64, "-123", -123); check_single_value!(i128, "123", 123); check_single_value!(u8, "123", 123); check_single_value!(u16, "123", 123); check_single_value!(u32, "123", 123); check_single_value!(u64, "123", 123); check_single_value!(u128, "123", 123); check_single_value!(f32, "123", 123.0); check_single_value!(f64, "123", 123.0); check_single_value!(String, "abc", "abc"); check_single_value!(String, "one%20two", "one two"); check_single_value!(&str, "abc", "abc"); check_single_value!(&str, "one%20two", "one two"); check_single_value!(char, "a", 'a'); let url_params = create_url_params(vec![("a", "B")]); assert_eq!( MyEnum::deserialize(PathDeserializer::new(&url_params)).unwrap(), MyEnum::B ); let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); let error_kind = i32::deserialize(PathDeserializer::new(&url_params)) .unwrap_err() .kind; assert!(matches!( error_kind, ErrorKind::WrongNumberOfParameters { expected: 1, got: 2 } )); } #[test] fn test_parse_seq() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); assert_eq!( <(i32, bool, String)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), (1, true, "abc".to_owned()) ); #[derive(Debug, Deserialize, Eq, PartialEq)] struct TupleStruct(i32, bool, String); assert_eq!( TupleStruct::deserialize(PathDeserializer::new(&url_params)).unwrap(), TupleStruct(1, true, "abc".to_owned()) ); let url_params = create_url_params(vec![("a", "1"), ("b", "2"), ("c", "3")]); assert_eq!( >::deserialize(PathDeserializer::new(&url_params)).unwrap(), vec![1, 2, 3] ); let url_params = create_url_params(vec![("a", "c"), ("a", "B")]); assert_eq!( >::deserialize(PathDeserializer::new(&url_params)).unwrap(), vec![MyEnum::C, MyEnum::B] ); } #[test] fn test_parse_seq_tuple_string_string() { let url_params = create_url_params(vec![("a", "foo"), ("b", "bar")]); assert_eq!( >::deserialize(PathDeserializer::new(&url_params)).unwrap(), vec![ ("a".to_owned(), "foo".to_owned()), ("b".to_owned(), "bar".to_owned()) ] ); } #[test] fn test_parse_seq_tuple_string_parse() { let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); assert_eq!( >::deserialize(PathDeserializer::new(&url_params)).unwrap(), vec![("a".to_owned(), 1), ("b".to_owned(), 2)] ); } #[test] fn test_parse_struct() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); assert_eq!( Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(), Struct { c: "abc".to_owned(), b: true, a: 1, } ); } #[test] fn test_parse_struct_ignoring_additional_fields() { let url_params = create_url_params(vec![ ("a", "1"), ("b", "true"), ("c", "abc"), ("d", "false"), ]); assert_eq!( Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(), Struct { c: "abc".to_owned(), b: true, a: 1, } ); } #[test] fn test_parse_tuple_ignoring_additional_fields() { let url_params = create_url_params(vec![ ("a", "abc"), ("b", "true"), ("c", "1"), ("d", "false"), ]); assert_eq!( <(&str, bool, u32)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), ("abc", true, 1) ); } #[test] fn test_parse_map() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); assert_eq!( >::deserialize(PathDeserializer::new(&url_params)).unwrap(), [("a", "1"), ("b", "true"), ("c", "abc")] .iter() .map(|(key, value)| ((*key).to_owned(), (*value).to_owned())) .collect() ); } macro_rules! test_parse_error { ( $params:expr, $ty:ty, $expected_error_kind:expr $(,)? ) => { let url_params = create_url_params($params); let actual_error_kind = <$ty>::deserialize(PathDeserializer::new(&url_params)) .unwrap_err() .kind; assert_eq!(actual_error_kind, $expected_error_kind); }; } #[test] fn test_wrong_number_of_parameters_error() { test_parse_error!( vec![("a", "1")], (u32, u32), ErrorKind::WrongNumberOfParameters { got: 1, expected: 2, } ); } #[test] fn test_parse_error_at_key_error() { #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { a: u32, } test_parse_error!( vec![("a", "false")], Params, ErrorKind::ParseErrorAtKey { key: "a".to_owned(), value: "false".to_owned(), expected_type: "u32", } ); } #[test] fn test_parse_error_at_key_error_multiple() { #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { a: u32, b: u32, } test_parse_error!( vec![("a", "false")], Params, ErrorKind::ParseErrorAtKey { key: "a".to_owned(), value: "false".to_owned(), expected_type: "u32", } ); } #[test] fn test_parse_error_at_index_error() { test_parse_error!( vec![("a", "false"), ("b", "true")], (bool, u32), ErrorKind::ParseErrorAtIndex { index: 1, value: "true".to_owned(), expected_type: "u32", } ); } #[test] fn test_parse_error_error() { test_parse_error!( vec![("a", "false")], u32, ErrorKind::ParseError { value: "false".to_owned(), expected_type: "u32", } ); } #[test] fn test_unsupported_type_error_nested_data_structure() { test_parse_error!( vec![("a", "false")], Vec>, ErrorKind::UnsupportedType { name: "alloc::vec::Vec", } ); } #[test] fn test_parse_seq_tuple_unsupported_key_type() { test_parse_error!( vec![("a", "false")], Vec<(u32, String)>, ErrorKind::Message("Unexpected key type".to_owned()) ); } #[test] fn test_parse_seq_wrong_tuple_length() { test_parse_error!( vec![("a", "false")], Vec<(String, String, String)>, ErrorKind::UnsupportedType { name: "(alloc::string::String, alloc::string::String, alloc::string::String)", } ); } #[test] fn test_parse_seq_seq() { test_parse_error!( vec![("a", "false")], Vec>, ErrorKind::UnsupportedType { name: "alloc::vec::Vec", } ); } } axum-axum-v0.6.20/axum/src/extract/path/mod.rs000066400000000000000000000625601446276533600212340ustar00rootroot00000000000000//! Extractor that will get captures from the URL and parse them using //! [`serde`]. mod de; use crate::{ extract::{rejection::*, FromRequestParts}, routing::url_params::UrlParams, util::PercentDecodedStr, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; use std::{fmt, sync::Arc}; /// Extractor that will get captures from the URL and parse them using /// [`serde`]. /// /// Any percent encoded parameters will be automatically decoded. The decoded /// parameters must be valid UTF-8, otherwise `Path` will fail and return a `400 /// Bad Request` response. /// /// # Example /// /// These examples assume the `serde` feature of the [`uuid`] crate is enabled. /// /// [`uuid`]: https://crates.io/crates/uuid /// /// ```rust,no_run /// use axum::{ /// extract::Path, /// routing::get, /// Router, /// }; /// use uuid::Uuid; /// /// async fn users_teams_show( /// Path((user_id, team_id)): Path<(Uuid, Uuid)>, /// ) { /// // ... /// } /// /// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If the path contains only one parameter, then you can omit the tuple. /// /// ```rust,no_run /// use axum::{ /// extract::Path, /// routing::get, /// Router, /// }; /// use uuid::Uuid; /// /// async fn user_info(Path(user_id): Path) { /// // ... /// } /// /// let app = Router::new().route("/users/:user_id", get(user_info)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Path segments also can be deserialized into any type that implements /// [`serde::Deserialize`]. This includes tuples and structs: /// /// ```rust,no_run /// use axum::{ /// extract::Path, /// routing::get, /// Router, /// }; /// use serde::Deserialize; /// use uuid::Uuid; /// /// // Path segment labels will be matched with struct field names /// #[derive(Deserialize)] /// struct Params { /// user_id: Uuid, /// team_id: Uuid, /// } /// /// async fn users_teams_show( /// Path(Params { user_id, team_id }): Path, /// ) { /// // ... /// } /// /// // When using tuples the path segments will be matched by their position in the route /// async fn users_teams_create( /// Path((user_id, team_id)): Path<(String, String)>, /// ) { /// // ... /// } /// /// let app = Router::new().route( /// "/users/:user_id/team/:team_id", /// get(users_teams_show).post(users_teams_create), /// ); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If you wish to capture all path parameters you can use `HashMap` or `Vec`: /// /// ```rust,no_run /// use axum::{ /// extract::Path, /// routing::get, /// Router, /// }; /// use std::collections::HashMap; /// /// async fn params_map( /// Path(params): Path>, /// ) { /// // ... /// } /// /// async fn params_vec( /// Path(params): Path>, /// ) { /// // ... /// } /// /// let app = Router::new() /// .route("/users/:user_id/team/:team_id", get(params_map).post(params_vec)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # Providing detailed rejection output /// /// If the URI cannot be deserialized into the target type the request will be rejected and an /// error response will be returned. See [`customize-path-rejection`] for an example of how to customize that error. /// /// [`serde`]: https://crates.io/crates/serde /// [`serde::Deserialize`]: https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html /// [`customize-path-rejection`]: https://github.com/tokio-rs/axum/blob/main/examples/customize-path-rejection/src/main.rs #[derive(Debug)] pub struct Path(pub T); axum_core::__impl_deref!(Path); #[async_trait] impl FromRequestParts for Path where T: DeserializeOwned + Send, S: Send + Sync, { type Rejection = PathRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let params = match parts.extensions.get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { let err = PathDeserializationError { kind: ErrorKind::InvalidUtf8InPathParam { key: key.to_string(), }, }; let err = FailedToDeserializePathParams(err); return Err(err.into()); } None => { return Err(MissingPathParams.into()); } }; T::deserialize(de::PathDeserializer::new(params)) .map_err(|err| { PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err)) }) .map(Path) } } // this wrapper type is used as the deserializer error to hide the `serde::de::Error` impl which // would otherwise be public if we used `ErrorKind` as the error directly #[derive(Debug)] pub(crate) struct PathDeserializationError { pub(super) kind: ErrorKind, } impl PathDeserializationError { pub(super) fn new(kind: ErrorKind) -> Self { Self { kind } } pub(super) fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> { WrongNumberOfParameters { got: () } } #[track_caller] pub(super) fn unsupported_type(name: &'static str) -> Self { Self::new(ErrorKind::UnsupportedType { name }) } } pub(super) struct WrongNumberOfParameters { got: G, } impl WrongNumberOfParameters { #[allow(clippy::unused_self)] pub(super) fn got(self, got: G2) -> WrongNumberOfParameters { WrongNumberOfParameters { got } } } impl WrongNumberOfParameters { pub(super) fn expected(self, expected: usize) -> PathDeserializationError { PathDeserializationError::new(ErrorKind::WrongNumberOfParameters { got: self.got, expected, }) } } impl serde::de::Error for PathDeserializationError { #[inline] fn custom(msg: T) -> Self where T: fmt::Display, { Self { kind: ErrorKind::Message(msg.to_string()), } } } impl fmt::Display for PathDeserializationError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.kind.fmt(f) } } impl std::error::Error for PathDeserializationError {} /// The kinds of errors that can happen we deserializing into a [`Path`]. /// /// This type is obtained through [`FailedToDeserializePathParams::kind`] or /// [`FailedToDeserializePathParams::into_kind`] and is useful for building /// more precise error messages. #[derive(Debug, PartialEq, Eq)] #[non_exhaustive] pub enum ErrorKind { /// The URI contained the wrong number of parameters. WrongNumberOfParameters { /// The number of actual parameters in the URI. got: usize, /// The number of expected parameters. expected: usize, }, /// Failed to parse the value at a specific key into the expected type. /// /// This variant is used when deserializing into types that have named fields, such as structs. ParseErrorAtKey { /// The key at which the value was located. key: String, /// The value from the URI. value: String, /// The expected type of the value. expected_type: &'static str, }, /// Failed to parse the value at a specific index into the expected type. /// /// This variant is used when deserializing into sequence types, such as tuples. ParseErrorAtIndex { /// The index at which the value was located. index: usize, /// The value from the URI. value: String, /// The expected type of the value. expected_type: &'static str, }, /// Failed to parse a value into the expected type. /// /// This variant is used when deserializing into a primitive type (such as `String` and `u32`). ParseError { /// The value from the URI. value: String, /// The expected type of the value. expected_type: &'static str, }, /// A parameter contained text that, once percent decoded, wasn't valid UTF-8. InvalidUtf8InPathParam { /// The key at which the invalid value was located. key: String, }, /// Tried to serialize into an unsupported type such as nested maps. /// /// This error kind is caused by programmer errors and thus gets converted into a `500 Internal /// Server Error` response. UnsupportedType { /// The name of the unsupported type. name: &'static str, }, /// Catch-all variant for errors that don't fit any other variant. Message(String), } impl fmt::Display for ErrorKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ErrorKind::Message(error) => error.fmt(f), ErrorKind::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{key}`"), ErrorKind::WrongNumberOfParameters { got, expected } => { write!( f, "Wrong number of path arguments for `Path`. Expected {expected} but got {got}" )?; if *expected == 1 { write!(f, ". Note that multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`")?; } Ok(()) } ErrorKind::UnsupportedType { name } => write!(f, "Unsupported type `{name}`"), ErrorKind::ParseErrorAtKey { key, value, expected_type, } => write!( f, "Cannot parse `{key}` with value `{value:?}` to a `{expected_type}`" ), ErrorKind::ParseError { value, expected_type, } => write!(f, "Cannot parse `{value:?}` to a `{expected_type}`"), ErrorKind::ParseErrorAtIndex { index, value, expected_type, } => write!( f, "Cannot parse value at index {index} with value `{value:?}` to a `{expected_type}`" ), } } } /// Rejection type for [`Path`](super::Path) if the captured routes params couldn't be deserialized /// into the expected type. #[derive(Debug)] pub struct FailedToDeserializePathParams(PathDeserializationError); impl FailedToDeserializePathParams { /// Get a reference to the underlying error kind. pub fn kind(&self) -> &ErrorKind { &self.0.kind } /// Convert this error into the underlying error kind. pub fn into_kind(self) -> ErrorKind { self.0.kind } /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { match self.0.kind { ErrorKind::Message(_) | ErrorKind::InvalidUtf8InPathParam { .. } | ErrorKind::ParseError { .. } | ErrorKind::ParseErrorAtIndex { .. } | ErrorKind::ParseErrorAtKey { .. } => format!("Invalid URL: {}", self.0.kind), ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => { self.0.kind.to_string() } } } /// Get the status code used for this rejection. pub fn status(&self) -> StatusCode { match self.0.kind { ErrorKind::Message(_) | ErrorKind::InvalidUtf8InPathParam { .. } | ErrorKind::ParseError { .. } | ErrorKind::ParseErrorAtIndex { .. } | ErrorKind::ParseErrorAtKey { .. } => StatusCode::BAD_REQUEST, ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => { StatusCode::INTERNAL_SERVER_ERROR } } } } impl IntoResponse for FailedToDeserializePathParams { fn into_response(self) -> Response { axum_core::__log_rejection!( rejection_type = Self, body_text = self.body_text(), status = self.status(), ); (self.status(), self.body_text()).into_response() } } impl fmt::Display for FailedToDeserializePathParams { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } impl std::error::Error for FailedToDeserializePathParams {} /// Extractor that will get captures from the URL without deserializing them. /// /// In general you should prefer to use [`Path`] as it is higher level, however `RawPathParams` is /// suitable if just want the raw params without deserializing them and thus saving some /// allocations. /// /// Any percent encoded parameters will be automatically decoded. The decoded parameters must be /// valid UTF-8, otherwise `RawPathParams` will fail and return a `400 Bad Request` response. /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::RawPathParams, /// routing::get, /// Router, /// }; /// /// async fn users_teams_show(params: RawPathParams) { /// for (key, value) in ¶ms { /// println!("{key:?} = {value:?}"); /// } /// } /// /// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); /// # let _: Router = app; /// ``` #[derive(Debug)] pub struct RawPathParams(Vec<(Arc, PercentDecodedStr)>); #[async_trait] impl FromRequestParts for RawPathParams where S: Send + Sync, { type Rejection = RawPathParamsRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let params = match parts.extensions.get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { return Err(InvalidUtf8InPathParam { key: Arc::clone(key), } .into()); } None => { return Err(MissingPathParams.into()); } }; Ok(Self(params.clone())) } } impl RawPathParams { /// Get an iterator over the path parameters. pub fn iter(&self) -> RawPathParamsIter<'_> { self.into_iter() } } impl<'a> IntoIterator for &'a RawPathParams { type Item = (&'a str, &'a str); type IntoIter = RawPathParamsIter<'a>; fn into_iter(self) -> Self::IntoIter { RawPathParamsIter(self.0.iter()) } } /// An iterator over raw path parameters. /// /// Created with [`RawPathParams::iter`]. #[derive(Debug)] pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc, PercentDecodedStr)>); impl<'a> Iterator for RawPathParamsIter<'a> { type Item = (&'a str, &'a str); fn next(&mut self) -> Option { let (key, value) = self.0.next()?; Some((&**key, value.as_str())) } } /// Rejection used by [`RawPathParams`] if a parameter contained text that, once percent decoded, /// wasn't valid UTF-8. #[derive(Debug)] pub struct InvalidUtf8InPathParam { key: Arc, } impl InvalidUtf8InPathParam { /// Get the response body text used for this rejection. pub fn body_text(&self) -> String { self.to_string() } /// Get the status code used for this rejection. pub fn status(&self) -> StatusCode { StatusCode::BAD_REQUEST } } impl fmt::Display for InvalidUtf8InPathParam { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Invalid UTF-8 in `{}`", self.key) } } impl std::error::Error for InvalidUtf8InPathParam {} impl IntoResponse for InvalidUtf8InPathParam { fn into_response(self) -> Response { (self.status(), self.body_text()).into_response() } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use http::StatusCode; use serde::Deserialize; use std::collections::HashMap; #[crate::test] async fn extracting_url_params() { let app = Router::new().route( "/users/:id", get(|Path(id): Path| async move { assert_eq!(id, 42); }) .post(|Path(params_map): Path>| async move { assert_eq!(params_map.get("id").unwrap(), &1337); }), ); let client = TestClient::new(app); let res = client.get("/users/42").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/users/1337").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn extracting_url_params_multiple_times() { let app = Router::new().route("/users/:id", get(|_: Path, _: Path| async {})); let client = TestClient::new(app); let res = client.get("/users/42").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn percent_decoding() { let app = Router::new().route( "/:key", get(|Path(param): Path| async move { param }), ); let client = TestClient::new(app); let res = client.get("/one%20two").send().await; assert_eq!(res.text().await, "one two"); } #[crate::test] async fn supports_128_bit_numbers() { let app = Router::new() .route( "/i/:key", get(|Path(param): Path| async move { param.to_string() }), ) .route( "/u/:key", get(|Path(param): Path| async move { param.to_string() }), ); let client = TestClient::new(app); let res = client.get("/i/123").send().await; assert_eq!(res.text().await, "123"); let res = client.get("/u/123").send().await; assert_eq!(res.text().await, "123"); } #[crate::test] async fn wildcard() { let app = Router::new() .route( "/foo/*rest", get(|Path(param): Path| async move { param }), ) .route( "/bar/*rest", get(|Path(params): Path>| async move { params.get("rest").unwrap().clone() }), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.text().await, "bar/baz"); let res = client.get("/bar/baz/qux").send().await; assert_eq!(res.text().await, "baz/qux"); } #[crate::test] async fn captures_dont_match_empty_segments() { let app = Router::new().route("/:key", get(|| async {})); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn str_reference_deserialize() { struct Param(String); impl<'de> serde::Deserialize<'de> for Param { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = <&str as serde::Deserialize>::deserialize(deserializer)?; Ok(Param(s.to_owned())) } } let app = Router::new().route("/:key", get(|param: Path| async move { param.0 .0 })); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.text().await, "foo"); // percent decoding should also work let res = client.get("/foo%20bar").send().await; assert_eq!(res.text().await, "foo bar"); } #[crate::test] async fn two_path_extractors() { let app = Router::new().route("/:a/:b", get(|_: Path, _: Path| async {})); let client = TestClient::new(app); let res = client.get("/a/b").send().await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( res.text().await, "Wrong number of path arguments for `Path`. Expected 1 but got 2. \ Note that multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`", ); } #[crate::test] async fn deserialize_into_vec_of_tuples() { let app = Router::new().route( "/:a/:b", get(|Path(params): Path>| async move { assert_eq!( params, vec![ ("a".to_owned(), "foo".to_owned()), ("b".to_owned(), "bar".to_owned()) ] ); }), ); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn type_that_uses_deserialize_any() { use time::Date; #[derive(Deserialize)] struct Params { a: Date, b: Date, c: Date, } let app = Router::new() .route( "/single/:a", get(|Path(a): Path| async move { format!("single: {a}") }), ) .route( "/tuple/:a/:b/:c", get(|Path((a, b, c)): Path<(Date, Date, Date)>| async move { format!("tuple: {a} {b} {c}") }), ) .route( "/vec/:a/:b/:c", get(|Path(vec): Path>| async move { let [a, b, c]: [Date; 3] = vec.try_into().unwrap(); format!("vec: {a} {b} {c}") }), ) .route( "/vec_pairs/:a/:b/:c", get(|Path(vec): Path>| async move { let [(_, a), (_, b), (_, c)]: [(String, Date); 3] = vec.try_into().unwrap(); format!("vec_pairs: {a} {b} {c}") }), ) .route( "/map/:a/:b/:c", get(|Path(mut map): Path>| async move { let a = map.remove("a").unwrap(); let b = map.remove("b").unwrap(); let c = map.remove("c").unwrap(); format!("map: {a} {b} {c}") }), ) .route( "/struct/:a/:b/:c", get(|Path(params): Path| async move { format!("struct: {} {} {}", params.a, params.b, params.c) }), ); let client = TestClient::new(app); let res = client.get("/single/2023-01-01").send().await; assert_eq!(res.text().await, "single: 2023-01-01"); let res = client .get("/tuple/2023-01-01/2023-01-02/2023-01-03") .send() .await; assert_eq!(res.text().await, "tuple: 2023-01-01 2023-01-02 2023-01-03"); let res = client .get("/vec/2023-01-01/2023-01-02/2023-01-03") .send() .await; assert_eq!(res.text().await, "vec: 2023-01-01 2023-01-02 2023-01-03"); let res = client .get("/vec_pairs/2023-01-01/2023-01-02/2023-01-03") .send() .await; assert_eq!( res.text().await, "vec_pairs: 2023-01-01 2023-01-02 2023-01-03", ); let res = client .get("/map/2023-01-01/2023-01-02/2023-01-03") .send() .await; assert_eq!(res.text().await, "map: 2023-01-01 2023-01-02 2023-01-03"); let res = client .get("/struct/2023-01-01/2023-01-02/2023-01-03") .send() .await; assert_eq!(res.text().await, "struct: 2023-01-01 2023-01-02 2023-01-03"); } #[crate::test] async fn wrong_number_of_parameters_json() { use serde_json::Value; let app = Router::new() .route("/one/:a", get(|_: Path<(Value, Value)>| async {})) .route("/two/:a/:b", get(|_: Path| async {})); let client = TestClient::new(app); let res = client.get("/one/1").send().await; assert!(res .text() .await .starts_with("Wrong number of path arguments for `Path`. Expected 2 but got 1")); let res = client.get("/two/1/2").send().await; assert!(res .text() .await .starts_with("Wrong number of path arguments for `Path`. Expected 1 but got 2")); } #[crate::test] async fn raw_path_params() { let app = Router::new().route( "/:a/:b/:c", get(|params: RawPathParams| async move { params .into_iter() .map(|(key, value)| format!("{key}={value}")) .collect::>() .join(" ") }), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; let body = res.text().await; assert_eq!(body, "a=foo b=bar c=baz"); } } axum-axum-v0.6.20/axum/src/extract/query.rs000066400000000000000000000122231446276533600206550ustar00rootroot00000000000000use super::{rejection::*, FromRequestParts}; use async_trait::async_trait; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; /// Extractor that deserializes query strings into some type. /// /// `T` is expected to implement [`serde::Deserialize`]. /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::Query, /// routing::get, /// Router, /// }; /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Pagination { /// page: usize, /// per_page: usize, /// } /// /// // This will parse query strings like `?page=2&per_page=30` into `Pagination` /// // structs. /// async fn list_things(pagination: Query) { /// let pagination: Pagination = pagination.0; /// /// // ... /// } /// /// let app = Router::new().route("/list_things", get(list_things)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If the query string cannot be parsed it will reject the request with a `400 /// Bad Request` response. /// /// For handling values being empty vs missing see the [query-params-with-empty-strings][example] /// example. /// /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs #[cfg_attr(docsrs, doc(cfg(feature = "query")))] #[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); #[async_trait] impl FromRequestParts for Query where T: DeserializeOwned, S: Send + Sync, { type Rejection = QueryRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { Self::try_from_uri(&parts.uri) } } impl Query where T: DeserializeOwned, { /// Attempts to construct a [`Query`] from a reference to a [`Uri`]. /// /// # Example /// ``` /// use axum::extract::Query; /// use http::Uri; /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct ExampleParams { /// foo: String, /// bar: u32, /// } /// /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); /// let result: Query = Query::try_from_uri(&uri).unwrap(); /// assert_eq!(result.foo, String::from("hello")); /// assert_eq!(result.bar, 42); /// ``` pub fn try_from_uri(value: &Uri) -> Result { let query = value.query().unwrap_or_default(); let params = serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; Ok(Query(params)) } } axum_core::__impl_deref!(Query); #[cfg(test)] mod tests { use crate::{routing::get, test_helpers::TestClient, Router}; use super::*; use axum_core::extract::FromRequest; use http::{Request, StatusCode}; use serde::Deserialize; use std::fmt::Debug; async fn check(uri: impl AsRef, value: T) where T: DeserializeOwned + PartialEq + Debug, { let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); assert_eq!(Query::::from_request(req, &()).await.unwrap().0, value); } #[crate::test] async fn test_query() { #[derive(Debug, PartialEq, Deserialize)] struct Pagination { size: Option, page: Option, } check( "http://example.com/test", Pagination { size: None, page: None, }, ) .await; check( "http://example.com/test?size=10", Pagination { size: Some(10), page: None, }, ) .await; check( "http://example.com/test?size=10&page=20", Pagination { size: Some(10), page: Some(20), }, ) .await; } #[crate::test] async fn correct_rejection_status_code() { #[derive(Deserialize)] #[allow(dead_code)] struct Params { n: i32, } async fn handler(_: Query) {} let app = Router::new().route("/", get(handler)); let client = TestClient::new(app); let res = client.get("/?n=hi").send().await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); } #[test] fn test_try_from_uri() { #[derive(Deserialize)] struct TestQueryParams { foo: String, bar: u32, } let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); let result: Query = Query::try_from_uri(&uri).unwrap(); assert_eq!(result.foo, String::from("hello")); assert_eq!(result.bar, 42); } #[test] fn test_try_from_uri_with_invalid_query() { #[derive(Deserialize)] struct TestQueryParams { _foo: String, _bar: u32, } let uri: Uri = "http://example.com/path?foo=hello&bar=invalid" .parse() .unwrap(); let result: Result, _> = Query::try_from_uri(&uri); assert!(result.is_err()); } } axum-axum-v0.6.20/axum/src/extract/raw_form.rs000066400000000000000000000063251446276533600213320ustar00rootroot00000000000000use async_trait::async_trait; use axum_core::extract::FromRequest; use bytes::{Bytes, BytesMut}; use http::{Method, Request}; use super::{ has_content_type, rejection::{InvalidFormContentType, RawFormRejection}, }; use crate::{body::HttpBody, BoxError}; /// Extractor that extracts raw form requests. /// /// For `GET` requests it will extract the raw query. For other methods it extracts the raw /// `application/x-www-form-urlencoded` encoded request body. /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::RawForm, /// routing::get, /// Router /// }; /// /// async fn handler(RawForm(form): RawForm) {} /// /// let app = Router::new().route("/", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[derive(Debug)] pub struct RawForm(pub Bytes); #[async_trait] impl FromRequest for RawForm where B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = RawFormRejection; async fn from_request(req: Request, state: &S) -> Result { if req.method() == Method::GET { let mut bytes = BytesMut::new(); if let Some(query) = req.uri().query() { bytes.extend(query.as_bytes()); } Ok(Self(bytes.freeze())) } else { if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) { return Err(InvalidFormContentType.into()); } Ok(Self(Bytes::from_request(req, state).await?)) } } } #[cfg(test)] mod tests { use http::{header::CONTENT_TYPE, Request}; use super::{InvalidFormContentType, RawForm, RawFormRejection}; use crate::{ body::{Bytes, Empty, Full}, extract::FromRequest, }; async fn check_query(uri: &str, value: &[u8]) { let req = Request::builder() .uri(uri) .body(Empty::::new()) .unwrap(); assert_eq!(RawForm::from_request(req, &()).await.unwrap().0, value); } async fn check_body(body: &'static [u8]) { let req = Request::post("http://example.com/test") .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) .body(Full::new(Bytes::from(body))) .unwrap(); assert_eq!(RawForm::from_request(req, &()).await.unwrap().0, body); } #[crate::test] async fn test_from_query() { check_query("http://example.com/test", b"").await; check_query("http://example.com/test?page=0&size=10", b"page=0&size=10").await; } #[crate::test] async fn test_from_body() { check_body(b"").await; check_body(b"username=user&password=secure%20password").await; } #[crate::test] async fn test_incorrect_content_type() { let req = Request::post("http://example.com/test") .body(Full::::from(Bytes::from("page=0&size=10"))) .unwrap(); assert!(matches!( RawForm::from_request(req, &()).await.unwrap_err(), RawFormRejection::InvalidFormContentType(InvalidFormContentType) )) } } axum-axum-v0.6.20/axum/src/extract/raw_query.rs000066400000000000000000000017331446276533600215320ustar00rootroot00000000000000use super::FromRequestParts; use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; /// Extractor that extracts the raw query string, without parsing it. /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::RawQuery, /// routing::get, /// Router, /// }; /// use futures_util::StreamExt; /// /// async fn handler(RawQuery(query): RawQuery) { /// // ... /// } /// /// let app = Router::new().route("/users", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[derive(Debug)] pub struct RawQuery(pub Option); #[async_trait] impl FromRequestParts for RawQuery where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let query = parts.uri.query().map(|query| query.to_owned()); Ok(Self(query)) } } axum-axum-v0.6.20/axum/src/extract/rejection.rs000066400000000000000000000147671446276533600215110ustar00rootroot00000000000000//! Rejection response types. use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; pub use crate::extract::path::{FailedToDeserializePathParams, InvalidUtf8InPathParam}; pub use axum_core::extract::rejection::*; #[cfg(feature = "json")] define_rejection! { #[status = UNPROCESSABLE_ENTITY] #[body = "Failed to deserialize the JSON body into the target type"] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] /// Rejection type for [`Json`](super::Json). /// /// This rejection is used if the request body is syntactically valid JSON but couldn't be /// deserialized into the target type. pub struct JsonDataError(Error); } #[cfg(feature = "json")] define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to parse the request body as JSON"] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] /// Rejection type for [`Json`](super::Json). /// /// This rejection is used if the request body didn't contain syntactically valid JSON. pub struct JsonSyntaxError(Error); } #[cfg(feature = "json")] define_rejection! { #[status = UNSUPPORTED_MEDIA_TYPE] #[body = "Expected request with `Content-Type: application/json`"] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] /// Rejection type for [`Json`](super::Json) used if the `Content-Type` /// header is missing. pub struct MissingJsonContentType; } define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Missing request extension"] /// Rejection type for [`Extension`](super::Extension) if an expected /// request extension was not found. pub struct MissingExtension(Error); } define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No paths parameters found for matched route"] /// Rejection type used if axum's internal representation of path parameters /// is missing. This is commonly caused by extracting `Request<_>`. `Path` /// must be extracted first. pub struct MissingPathParams; } define_rejection! { #[status = UNSUPPORTED_MEDIA_TYPE] #[body = "Form requests must have `Content-Type: application/x-www-form-urlencoded`"] /// Rejection type for [`Form`](super::Form) or [`RawForm`](super::RawForm) /// used if the `Content-Type` header is missing /// or its value is not `application/x-www-form-urlencoded`. pub struct InvalidFormContentType; } define_rejection! { #[status = BAD_REQUEST] #[body = "No host found in request"] /// Rejection type used if the [`Host`](super::Host) extractor is unable to /// resolve a host. pub struct FailedToResolveHost; } define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to deserialize form"] /// Rejection type used if the [`Form`](super::Form) extractor is unable to /// deserialize the form into the target type. pub struct FailedToDeserializeForm(Error); } define_rejection! { #[status = UNPROCESSABLE_ENTITY] #[body = "Failed to deserialize form body"] /// Rejection type used if the [`Form`](super::Form) extractor is unable to /// deserialize the form body into the target type. pub struct FailedToDeserializeFormBody(Error); } define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to deserialize query string"] /// Rejection type used if the [`Query`](super::Query) extractor is unable to /// deserialize the query string into the target type. pub struct FailedToDeserializeQueryString(Error); } composite_rejection! { /// Rejection used for [`Query`](super::Query). /// /// Contains one variant for each way the [`Query`](super::Query) extractor /// can fail. pub enum QueryRejection { FailedToDeserializeQueryString, } } composite_rejection! { /// Rejection used for [`Form`](super::Form). /// /// Contains one variant for each way the [`Form`](super::Form) extractor /// can fail. pub enum FormRejection { InvalidFormContentType, FailedToDeserializeForm, FailedToDeserializeFormBody, BytesRejection, } } composite_rejection! { /// Rejection used for [`RawForm`](super::RawForm). /// /// Contains one variant for each way the [`RawForm`](super::RawForm) extractor /// can fail. pub enum RawFormRejection { InvalidFormContentType, BytesRejection, } } #[cfg(feature = "json")] composite_rejection! { /// Rejection used for [`Json`](super::Json). /// /// Contains one variant for each way the [`Json`](super::Json) extractor /// can fail. #[cfg_attr(docsrs, doc(cfg(feature = "json")))] pub enum JsonRejection { JsonDataError, JsonSyntaxError, MissingJsonContentType, BytesRejection, } } composite_rejection! { /// Rejection used for [`Extension`](super::Extension). /// /// Contains one variant for each way the [`Extension`](super::Extension) extractor /// can fail. pub enum ExtensionRejection { MissingExtension, } } composite_rejection! { /// Rejection used for [`Path`](super::Path). /// /// Contains one variant for each way the [`Path`](super::Path) extractor /// can fail. pub enum PathRejection { FailedToDeserializePathParams, MissingPathParams, } } composite_rejection! { /// Rejection used for [`RawPathParams`](super::RawPathParams). /// /// Contains one variant for each way the [`RawPathParams`](super::RawPathParams) extractor /// can fail. pub enum RawPathParamsRejection { InvalidUtf8InPathParam, MissingPathParams, } } composite_rejection! { /// Rejection used for [`Host`](super::Host). /// /// Contains one variant for each way the [`Host`](super::Host) extractor /// can fail. pub enum HostRejection { FailedToResolveHost, } } #[cfg(feature = "matched-path")] define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No matched path found"] /// Rejection if no matched path could be found. /// /// See [`MatchedPath`](super::MatchedPath) for more details. #[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))] pub struct MatchedPathMissing; } #[cfg(feature = "matched-path")] composite_rejection! { /// Rejection used for [`MatchedPath`](super::MatchedPath). #[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))] pub enum MatchedPathRejection { MatchedPathMissing, } } #[cfg(feature = "headers")] pub use crate::typed_header::{TypedHeaderRejection, TypedHeaderRejectionReason}; axum-axum-v0.6.20/axum/src/extract/request_parts.rs000066400000000000000000000160601446276533600224140ustar00rootroot00000000000000use super::{Extension, FromRequest, FromRequestParts}; use crate::{ body::{Body, Bytes, HttpBody}, BoxError, Error, }; use async_trait::async_trait; use futures_util::stream::Stream; use http::{request::Parts, Request, Uri}; use std::{ convert::Infallible, fmt, pin::Pin, task::{Context, Poll}, }; use sync_wrapper::SyncWrapper; /// Extractor that gets the original request URI regardless of nesting. /// /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will /// have the prefix stripped if used in a nested service. /// /// # Example /// /// ``` /// use axum::{ /// routing::get, /// Router, /// extract::OriginalUri, /// http::Uri /// }; /// /// let api_routes = Router::new() /// .route( /// "/users", /// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async { /// // `uri` is `/users` /// // `original_uri` is `/api/users` /// }), /// ); /// /// let app = Router::new().nest("/api", api_routes); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # Extracting via request extensions /// /// `OriginalUri` can also be accessed from middleware via request extensions. /// This is useful for example with [`Trace`](tower_http::trace::Trace) to /// create a span that contains the full path, if your service might be nested: /// /// ``` /// use axum::{ /// Router, /// extract::OriginalUri, /// http::Request, /// routing::get, /// }; /// use tower_http::trace::TraceLayer; /// /// let api_routes = Router::new() /// .route("/users/:id", get(|| async { /* ... */ })) /// .layer( /// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { /// let path = if let Some(path) = req.extensions().get::() { /// // This will include `/api` /// path.0.path().to_owned() /// } else { /// // The `OriginalUri` extension will always be present if using /// // `Router` unless another extractor or middleware has removed it /// req.uri().path().to_owned() /// }; /// tracing::info_span!("http-request", %path) /// }), /// ); /// /// let app = Router::new().nest("/api", api_routes); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[cfg(feature = "original-uri")] #[derive(Debug, Clone)] pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] impl FromRequestParts for OriginalUri where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let uri = Extension::::from_request_parts(parts, state) .await .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone()))) .0; Ok(uri) } } #[cfg(feature = "original-uri")] axum_core::__impl_deref!(OriginalUri: Uri); /// Extractor that extracts the request body as a [`Stream`]. /// /// Since extracting the request body requires consuming it, the `BodyStream` extractor must be /// *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::BodyStream, /// routing::get, /// Router, /// }; /// use futures_util::StreamExt; /// /// async fn handler(mut stream: BodyStream) { /// while let Some(chunk) = stream.next().await { /// // ... /// } /// } /// /// let app = Router::new().route("/users", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html /// [`body::Body`]: crate::body::Body pub struct BodyStream( SyncWrapper + Send + 'static>>>, ); impl Stream for BodyStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(self.0.get_mut()).poll_data(cx) } } #[async_trait] impl FromRequest for BodyStream where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, _state: &S) -> Result { let body = req .into_body() .map_data(Into::into) .map_err(|err| Error::new(err.into())); let stream = BodyStream(SyncWrapper::new(Box::pin(body))); Ok(stream) } } impl fmt::Debug for BodyStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("BodyStream").finish() } } #[test] fn body_stream_traits() { crate::test_helpers::assert_send::(); crate::test_helpers::assert_sync::(); } /// Extractor that extracts the raw request body. /// /// Since extracting the raw request body requires consuming it, the `RawBody` extractor must be /// *last* if there are multiple extractors in a handler. See ["the order of extractors"][order-of-extractors] /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// # Example /// /// ```rust,no_run /// use axum::{ /// extract::RawBody, /// routing::get, /// Router, /// }; /// use futures_util::StreamExt; /// /// async fn handler(RawBody(body): RawBody) { /// // ... /// } /// /// let app = Router::new().route("/users", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// [`body::Body`]: crate::body::Body #[derive(Debug, Default, Clone)] pub struct RawBody(pub B); #[async_trait] impl FromRequest for RawBody where B: Send, S: Send + Sync, { type Rejection = Infallible; async fn from_request(req: Request, _state: &S) -> Result { Ok(Self(req.into_body())) } } axum_core::__impl_deref!(RawBody); #[cfg(test)] mod tests { use crate::{extract::Extension, routing::get, test_helpers::*, Router}; use http::{Method, StatusCode}; #[crate::test] async fn extract_request_parts() { #[derive(Clone)] struct Ext; async fn handler(parts: http::request::Parts) { assert_eq!(parts.method, Method::GET); assert_eq!(parts.uri, "/"); assert_eq!(parts.version, http::Version::HTTP_11); assert_eq!(parts.headers["x-foo"], "123"); parts.extensions.get::().unwrap(); } let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext))); let res = client.get("/").header("x-foo", "123").send().await; assert_eq!(res.status(), StatusCode::OK); } } axum-axum-v0.6.20/axum/src/extract/state.rs000066400000000000000000000241261446276533600206350ustar00rootroot00000000000000use async_trait::async_trait; use axum_core::extract::{FromRef, FromRequestParts}; use http::request::Parts; use std::{ convert::Infallible, ops::{Deref, DerefMut}, }; /// Extractor for state. /// /// See ["Accessing state in middleware"][state-from-middleware] for how to /// access state in middleware. /// /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware /// /// # With `Router` /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// // the application state /// // /// // here you can put configuration, database connection pools, or whatever /// // state you need /// // /// // see "When states need to implement `Clone`" for more details on why we need /// // `#[derive(Clone)]` here. /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// // create a `Router` that holds our state /// let app = Router::new() /// .route("/", get(handler)) /// // provide the state so the router can access it /// .with_state(state); /// /// async fn handler( /// // access the state via the `State` extractor /// // extracting a state of the wrong type results in a compile error /// State(state): State, /// ) { /// // use `state`... /// } /// # let _: axum::Router = app; /// ``` /// /// Note that `State` is an extractor, so be sure to put it before any body /// extractors, see ["the order of extractors"][order-of-extractors]. /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// ## Combining stateful routers /// /// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`] /// When combining [`Router`]s with one of these methods, the [`Router`]s must have /// the same state type. Generally, this can be inferred automatically: /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// // create a `Router` that will be nested within another /// let api = Router::new() /// .route("/posts", get(posts_handler)); /// /// let app = Router::new() /// .nest("/api", api) /// .with_state(state); /// /// async fn posts_handler(State(state): State) { /// // use `state`... /// } /// # let _: axum::Router = app; /// ``` /// /// However, if you are composing [`Router`]s that are defined in separate scopes, /// you may need to annotate the [`State`] type explicitly: /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// fn make_app() -> Router { /// let state = AppState {}; /// /// Router::new() /// .nest("/api", make_api()) /// .with_state(state) // the outer Router's state is inferred /// } /// /// // the inner Router must specify its state type to compose with the /// // outer router /// fn make_api() -> Router { /// Router::new() /// .route("/posts", get(posts_handler)) /// } /// /// async fn posts_handler(State(state): State) { /// // use `state`... /// } /// # let _: axum::Router = make_app(); /// ``` /// /// In short, a [`Router`]'s generic state type defaults to `()` /// (no state) unless [`Router::with_state`] is called or the value /// of the generic type is given explicitly. /// /// [`Router`]: crate::Router /// [`Router::merge`]: crate::Router::merge /// [`Router::nest`]: crate::Router::nest /// [`Router::with_state`]: crate::Router::with_state /// /// # With `MethodRouter` /// /// ``` /// use axum::{routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// let method_router_with_state = get(handler) /// // provide the state so the handler can access it /// .with_state(state); /// /// async fn handler(State(state): State) { /// // use `state`... /// } /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # With `Handler` /// /// ``` /// use axum::{routing::get, handler::Handler, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// async fn handler(State(state): State) { /// // use `state`... /// } /// /// // provide the state so the handler can access it /// let handler_with_state = handler.with_state(state); /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(handler_with_state.into_make_service()) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// # Substates /// /// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates": /// /// ``` /// use axum::{Router, routing::get, extract::{State, FromRef}}; /// /// // the application state /// #[derive(Clone)] /// struct AppState { /// // that holds some api specific state /// api_state: ApiState, /// } /// /// // the api specific state /// #[derive(Clone)] /// struct ApiState {} /// /// // support converting an `AppState` in an `ApiState` /// impl FromRef for ApiState { /// fn from_ref(app_state: &AppState) -> ApiState { /// app_state.api_state.clone() /// } /// } /// /// let state = AppState { /// api_state: ApiState {}, /// }; /// /// let app = Router::new() /// .route("/", get(handler)) /// .route("/api/users", get(api_users)) /// .with_state(state); /// /// async fn api_users( /// // access the api specific state /// State(api_state): State, /// ) { /// } /// /// async fn handler( /// // we can still access to top level state /// State(state): State, /// ) { /// } /// # let _: axum::Router = app; /// ``` /// /// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`. /// /// # For library authors /// /// If you're writing a library that has an extractor that needs state, this is the recommended way /// to do it: /// /// ```rust /// use axum_core::extract::{FromRequestParts, FromRef}; /// use http::request::Parts; /// use async_trait::async_trait; /// use std::convert::Infallible; /// /// // the extractor your library provides /// struct MyLibraryExtractor; /// /// #[async_trait] /// impl FromRequestParts for MyLibraryExtractor /// where /// // keep `S` generic but require that it can produce a `MyLibraryState` /// // this means users will have to implement `FromRef for MyLibraryState` /// MyLibraryState: FromRef, /// S: Send + Sync, /// { /// type Rejection = Infallible; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // get a `MyLibraryState` from a reference to the state /// let state = MyLibraryState::from_ref(state); /// /// // ... /// # todo!() /// } /// } /// /// // the state your library needs /// struct MyLibraryState { /// // ... /// } /// ``` /// /// # When states need to implement `Clone` /// /// Your top level state type must implement `Clone` to be extractable with `State`: /// /// ``` /// use axum::extract::State; /// /// // no substates, so to extract to `State` we must implement `Clone` for `AppState` /// #[derive(Clone)] /// struct AppState {} /// /// async fn handler(State(state): State) { /// // ... /// } /// ``` /// /// This works because of [`impl FromRef for S where S: Clone`][`FromRef`]. /// /// This is also true if you're extracting substates, unless you _never_ extract the top level /// state itself: /// /// ``` /// use axum::extract::{State, FromRef}; /// /// // we never extract `State`, just `State`. So `AppState` doesn't need to /// // implement `Clone` /// struct AppState { /// inner: InnerState, /// } /// /// #[derive(Clone)] /// struct InnerState {} /// /// impl FromRef for InnerState { /// fn from_ref(app_state: &AppState) -> InnerState { /// app_state.inner.clone() /// } /// } /// /// async fn api_users(State(inner): State) { /// // ... /// } /// ``` /// /// In general however we recommend you implement `Clone` for all your state types to avoid /// potential type errors. /// /// # Shared mutable state /// /// [As state is global within a `Router`][global] you can't directly get a mutable reference to /// the state. /// /// The most basic solution is to use an `Arc>`. Which kind of mutex you need depends on /// your use case. See [the tokio docs] for more details. /// /// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send` /// futures which are incompatible with axum. If you need to hold a mutex across `.await` points, /// consider using a `tokio::sync::Mutex` instead. /// /// ## Example /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// use std::sync::{Arc, Mutex}; /// /// #[derive(Clone)] /// struct AppState { /// data: Arc>, /// } /// /// async fn handler(State(state): State) { /// let mut data = state.data.lock().expect("mutex was poisoned"); /// *data = "updated foo".to_owned(); /// /// // ... /// } /// /// let state = AppState { /// data: Arc::new(Mutex::new("foo".to_owned())), /// }; /// /// let app = Router::new() /// .route("/", get(handler)) /// .with_state(state); /// # let _: Router = app; /// ``` /// /// [global]: crate::Router::with_state /// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); #[async_trait] impl FromRequestParts for State where InnerState: FromRef, OuterState: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( _parts: &mut Parts, state: &OuterState, ) -> Result { let inner_state = InnerState::from_ref(state); Ok(Self(inner_state)) } } impl Deref for State { type Target = S; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for State { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } axum-axum-v0.6.20/axum/src/extract/ws.rs000066400000000000000000000730411446276533600201460ustar00rootroot00000000000000//! Handle WebSocket connections. //! //! # Example //! //! ``` //! use axum::{ //! extract::ws::{WebSocketUpgrade, WebSocket}, //! routing::get, //! response::{IntoResponse, Response}, //! Router, //! }; //! //! let app = Router::new().route("/ws", get(handler)); //! //! async fn handler(ws: WebSocketUpgrade) -> Response { //! ws.on_upgrade(handle_socket) //! } //! //! async fn handle_socket(mut socket: WebSocket) { //! while let Some(msg) = socket.recv().await { //! let msg = if let Ok(msg) = msg { //! msg //! } else { //! // client disconnected //! return; //! }; //! //! if socket.send(msg).await.is_err() { //! // client disconnected //! return; //! } //! } //! } //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! # Passing data and/or state to an `on_upgrade` callback //! //! ``` //! use axum::{ //! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, //! response::Response, //! routing::get, //! Router, //! }; //! //! #[derive(Clone)] //! struct AppState { //! // ... //! } //! //! async fn handler(ws: WebSocketUpgrade, State(state): State) -> Response { //! ws.on_upgrade(|socket| handle_socket(socket, state)) //! } //! //! async fn handle_socket(socket: WebSocket, state: AppState) { //! // ... //! } //! //! let app = Router::new() //! .route("/ws", get(handler)) //! .with_state(AppState { /* ... */ }); //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! # Read and write concurrently //! //! If you need to read and write concurrently from a [`WebSocket`] you can use //! [`StreamExt::split`]: //! //! ```rust,no_run //! use axum::{Error, extract::ws::{WebSocket, Message}}; //! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}}; //! //! async fn handle_socket(mut socket: WebSocket) { //! let (mut sender, mut receiver) = socket.split(); //! //! tokio::spawn(write(sender)); //! tokio::spawn(read(receiver)); //! } //! //! async fn read(receiver: SplitStream) { //! // ... //! } //! //! async fn write(sender: SplitSink) { //! // ... //! } //! ``` //! //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split use self::rejection::*; use super::FromRequestParts; use crate::{ body::{self, Bytes}, response::Response, Error, }; use async_trait::async_trait; use futures_util::{ sink::{Sink, SinkExt}, stream::{Stream, StreamExt}, }; use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, request::Parts, Method, StatusCode, }; use hyper::upgrade::{OnUpgrade, Upgraded}; use sha1::{Digest, Sha1}; use std::{ borrow::Cow, future::Future, pin::Pin, task::{Context, Poll}, }; use tokio_tungstenite::{ tungstenite::{ self as ts, protocol::{self, WebSocketConfig}, }, WebSocketStream, }; /// Extractor for establishing WebSocket connections. /// /// Note: This extractor requires the request method to be `GET` so it should /// always be used with [`get`](crate::routing::get). Requests with other methods will be /// rejected. /// /// See the [module docs](self) for an example. #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] pub struct WebSocketUpgrade { config: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. protocol: Option, sec_websocket_key: HeaderValue, on_upgrade: OnUpgrade, on_failed_upgrade: F, sec_websocket_protocol: Option, } impl std::fmt::Debug for WebSocketUpgrade { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("WebSocketUpgrade") .field("config", &self.config) .field("protocol", &self.protocol) .field("sec_websocket_key", &self.sec_websocket_key) .field("sec_websocket_protocol", &self.sec_websocket_protocol) .finish_non_exhaustive() } } impl WebSocketUpgrade { /// Does nothing, instead use `max_write_buffer_size`. #[deprecated] pub fn max_send_queue(self, _: usize) -> Self { self } /// The target minimum size of the write buffer to reach before writing the data /// to the underlying stream. /// /// The default value is 128 KiB. /// /// If set to `0` each message will be eagerly written to the underlying stream. /// It is often more optimal to allow them to buffer a little, hence the default value. /// /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless. pub fn write_buffer_size(mut self, size: usize) -> Self { self.config.write_buffer_size = size; self } /// The max size of the write buffer in bytes. Setting this can provide backpressure /// in the case the write buffer is filling up due to write errors. /// /// The default value is unlimited. /// /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size) /// when writes to the underlying stream are failing. So the **write buffer can not /// fill up if you are not observing write errors even if not flushing**. /// /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) /// and probably a little more depending on error handling strategy. pub fn max_write_buffer_size(mut self, max: usize) -> Self { self.config.max_write_buffer_size = max; self } /// Set the maximum message size (defaults to 64 megabytes) pub fn max_message_size(mut self, max: usize) -> Self { self.config.max_message_size = Some(max); self } /// Set the maximum frame size (defaults to 16 megabytes) pub fn max_frame_size(mut self, max: usize) -> Self { self.config.max_frame_size = Some(max); self } /// Allow server to accept unmasked frames (defaults to false) pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { self.config.accept_unmasked_frames = accept; self } /// Set the known protocols. /// /// If the protocol name specified by `Sec-WebSocket-Protocol` header /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and /// return the protocol name. /// /// The protocols should be listed in decreasing order of preference: if the client offers /// multiple protocols that the server could support, the server will pick the first one in /// this list. /// /// # Examples /// /// ``` /// use axum::{ /// extract::ws::{WebSocketUpgrade, WebSocket}, /// routing::get, /// response::{IntoResponse, Response}, /// Router, /// }; /// /// let app = Router::new().route("/ws", get(handler)); /// /// async fn handler(ws: WebSocketUpgrade) -> Response { /// ws.protocols(["graphql-ws", "graphql-transport-ws"]) /// .on_upgrade(|socket| async { /// // ... /// }) /// } /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub fn protocols(mut self, protocols: I) -> Self where I: IntoIterator, I::Item: Into>, { if let Some(req_protocols) = self .sec_websocket_protocol .as_ref() .and_then(|p| p.to_str().ok()) { self.protocol = protocols .into_iter() // FIXME: This will often allocate a new `String` and so is less efficient than it // could be. But that can't be fixed without breaking changes to the public API. .map(Into::into) .find(|protocol| { req_protocols .split(',') .any(|req_protocol| req_protocol.trim() == protocol) }) .map(|protocol| match protocol { Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), Cow::Borrowed(s) => HeaderValue::from_static(s), }); } self } /// Provide a callback to call if upgrading the connection fails. /// /// The connection upgrade is performed in a background task. If that fails this callback /// will be called. /// /// By default any errors will be silently ignored. /// /// # Example /// /// ``` /// use axum::{ /// extract::{WebSocketUpgrade}, /// response::Response, /// }; /// /// async fn handler(ws: WebSocketUpgrade) -> Response { /// ws.on_failed_upgrade(|error| { /// report_error(error); /// }) /// .on_upgrade(|socket| async { /* ... */ }) /// } /// # /// # fn report_error(_: axum::Error) {} /// ``` pub fn on_failed_upgrade(self, callback: C) -> WebSocketUpgrade where C: OnFailedUpdgrade, { WebSocketUpgrade { config: self.config, protocol: self.protocol, sec_websocket_key: self.sec_websocket_key, on_upgrade: self.on_upgrade, on_failed_upgrade: callback, sec_websocket_protocol: self.sec_websocket_protocol, } } /// Finalize upgrading the connection and call the provided callback with /// the stream. #[must_use = "to setup the WebSocket connection, this response must be returned"] pub fn on_upgrade(self, callback: C) -> Response where C: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future + Send + 'static, F: OnFailedUpdgrade, { let on_upgrade = self.on_upgrade; let config = self.config; let on_failed_upgrade = self.on_failed_upgrade; let protocol = self.protocol.clone(); tokio::spawn(async move { let upgraded = match on_upgrade.await { Ok(upgraded) => upgraded, Err(err) => { on_failed_upgrade.call(Error::new(err)); return; } }; let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) .await; let socket = WebSocket { inner: socket, protocol, }; callback(socket).await; }); #[allow(clippy::declare_interior_mutable_const)] const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); #[allow(clippy::declare_interior_mutable_const)] const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); let mut builder = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) .header(header::CONNECTION, UPGRADE) .header(header::UPGRADE, WEBSOCKET) .header( header::SEC_WEBSOCKET_ACCEPT, sign(self.sec_websocket_key.as_bytes()), ); if let Some(protocol) = self.protocol { builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); } builder.body(body::boxed(body::Empty::new())).unwrap() } } /// What to do when a connection upgrade fails. /// /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. pub trait OnFailedUpdgrade: Send + 'static { /// Call the callback. fn call(self, error: Error); } impl OnFailedUpdgrade for F where F: FnOnce(Error) + Send + 'static, { fn call(self, error: Error) { self(error) } } /// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`. /// /// It simply ignores the error. #[non_exhaustive] #[derive(Debug)] pub struct DefaultOnFailedUpdgrade; impl OnFailedUpdgrade for DefaultOnFailedUpdgrade { #[inline] fn call(self, _error: Error) {} } #[async_trait] impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, { type Rejection = WebSocketUpgradeRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if parts.method != Method::GET { return Err(MethodNotGet.into()); } if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { return Err(InvalidConnectionHeader.into()); } if !header_eq(&parts.headers, header::UPGRADE, "websocket") { return Err(InvalidUpgradeHeader.into()); } if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } let sec_websocket_key = parts .headers .get(header::SEC_WEBSOCKET_KEY) .ok_or(WebSocketKeyHeaderMissing)? .clone(); let on_upgrade = parts .extensions .remove::() .ok_or(ConnectionNotUpgradable)?; let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); Ok(Self { config: Default::default(), protocol: None, sec_websocket_key, on_upgrade, sec_websocket_protocol, on_failed_upgrade: DefaultOnFailedUpdgrade, }) } } fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { if let Some(header) = headers.get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { false } } fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool { let header = if let Some(header) = headers.get(&key) { header } else { return false; }; if let Ok(header) = std::str::from_utf8(header.as_bytes()) { header.to_ascii_lowercase().contains(value) } else { false } } /// A stream of WebSocket messages. /// /// See [the module level documentation](self) for more details. #[derive(Debug)] pub struct WebSocket { inner: WebSocketStream, protocol: Option, } impl WebSocket { /// Receive another message. /// /// Returns `None` if the stream has closed. pub async fn recv(&mut self) -> Option> { self.next().await } /// Send a message. pub async fn send(&mut self, msg: Message) -> Result<(), Error> { self.inner .send(msg.into_tungstenite()) .await .map_err(Error::new) } /// Gracefully close this WebSocket. pub async fn close(mut self) -> Result<(), Error> { self.inner.close(None).await.map_err(Error::new) } /// Return the selected WebSocket subprotocol, if one has been chosen. pub fn protocol(&self) -> Option<&HeaderValue> { self.protocol.as_ref() } } impl Stream for WebSocket { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match futures_util::ready!(self.inner.poll_next_unpin(cx)) { Some(Ok(msg)) => { if let Some(msg) = Message::from_tungstenite(msg) { return Poll::Ready(Some(Ok(msg))); } } Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))), None => return Poll::Ready(None), } } } } impl Sink for WebSocket { type Error = Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new) } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { Pin::new(&mut self.inner) .start_send(item.into_tungstenite()) .map_err(Error::new) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new) } } /// Status code used to indicate why an endpoint is closing the WebSocket connection. pub type CloseCode = u16; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] pub struct CloseFrame<'t> { /// The reason as a code. pub code: CloseCode, /// The reason as text string. pub reason: Cow<'t, str>, } /// A WebSocket message. // // This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license: // Copyright (c) 2017 Alexey Galakhov // Copyright (c) 2016 Jason Housley // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message Text(String), /// A binary WebSocket message Binary(Vec), /// A ping message with the specified payload /// /// The payload here must have a length less than 125 bytes. /// /// Ping messages will be automatically responded to by the server, so you do not have to worry /// about dealing with them yourself. Ping(Vec), /// A pong message with the specified payload /// /// The payload here must have a length less than 125 bytes. /// /// Pong messages will be automatically sent to the client if a ping message is received, so /// you do not have to worry about constructing them yourself unless you want to implement a /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). Pong(Vec), /// A close message with the optional close frame. Close(Option>), } impl Message { fn into_tungstenite(self) -> ts::Message { match self { Self::Text(text) => ts::Message::Text(text), Self::Binary(binary) => ts::Message::Binary(binary), Self::Ping(ping) => ts::Message::Ping(ping), Self::Pong(pong) => ts::Message::Pong(pong), Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame { code: ts::protocol::frame::coding::CloseCode::from(close.code), reason: close.reason, })), Self::Close(None) => ts::Message::Close(None), } } fn from_tungstenite(message: ts::Message) -> Option { match message { ts::Message::Text(text) => Some(Self::Text(text)), ts::Message::Binary(binary) => Some(Self::Binary(binary)), ts::Message::Ping(ping) => Some(Self::Ping(ping)), ts::Message::Pong(pong) => Some(Self::Pong(pong)), ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame { code: close.code.into(), reason: close.reason, }))), ts::Message::Close(None) => Some(Self::Close(None)), // we can ignore `Frame` frames as recommended by the tungstenite maintainers // https://github.com/snapview/tungstenite-rs/issues/268 ts::Message::Frame(_) => None, } } /// Consume the WebSocket and return it as binary data. pub fn into_data(self) -> Vec { match self { Self::Text(string) => string.into_bytes(), Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data, Self::Close(None) => Vec::new(), Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } } /// Attempt to consume the WebSocket message and convert it to a String. pub fn into_text(self) -> Result { match self { Self::Text(string) => Ok(string), Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data) .map_err(|err| err.utf8_error()) .map_err(Error::new)?), Self::Close(None) => Ok(String::new()), Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), } } /// Attempt to get a &str from the WebSocket message, /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str, Error> { match *self { Self::Text(ref string) => Ok(string), Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { Ok(std::str::from_utf8(data).map_err(Error::new)?) } Self::Close(None) => Ok(""), Self::Close(Some(ref frame)) => Ok(&frame.reason), } } } impl From for Message { fn from(string: String) -> Self { Message::Text(string) } } impl<'s> From<&'s str> for Message { fn from(string: &'s str) -> Self { Message::Text(string.into()) } } impl<'b> From<&'b [u8]> for Message { fn from(data: &'b [u8]) -> Self { Message::Binary(data.into()) } } impl From> for Message { fn from(data: Vec) -> Self { Message::Binary(data) } } impl From for Vec { fn from(msg: Message) -> Self { msg.into_data() } } fn sign(key: &[u8]) -> HeaderValue { use base64::engine::Engine as _; let mut sha1 = Sha1::default(); sha1.update(key); sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]); let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize())); HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value") } pub mod rejection { //! WebSocket specific rejections. use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; define_rejection! { #[status = METHOD_NOT_ALLOWED] #[body = "Request method must be `GET`"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). pub struct MethodNotGet; } define_rejection! { #[status = BAD_REQUEST] #[body = "Connection header did not include 'upgrade'"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). pub struct InvalidConnectionHeader; } define_rejection! { #[status = BAD_REQUEST] #[body = "`Upgrade` header did not include 'websocket'"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). pub struct InvalidUpgradeHeader; } define_rejection! { #[status = BAD_REQUEST] #[body = "`Sec-WebSocket-Version` header did not include '13'"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). pub struct InvalidWebSocketVersionHeader; } define_rejection! { #[status = BAD_REQUEST] #[body = "`Sec-WebSocket-Key` header missing"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). pub struct WebSocketKeyHeaderMissing; } define_rejection! { #[status = UPGRADE_REQUIRED] #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"] /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). /// /// This rejection is returned if the connection cannot be upgraded for example if the /// request is HTTP/1.0. /// /// See [MDN] for more details about connection upgrades. /// /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade pub struct ConnectionNotUpgradable; } composite_rejection! { /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade). /// /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade) /// extractor can fail. pub enum WebSocketUpgradeRejection { MethodNotGet, InvalidConnectionHeader, InvalidUpgradeHeader, InvalidWebSocketVersionHeader, WebSocketKeyHeaderMissing, ConnectionNotUpgradable, } } } pub mod close_code { //! Constants for [`CloseCode`]s. //! //! [`CloseCode`]: super::CloseCode /// Indicates a normal closure, meaning that the purpose for which the connection was /// established has been fulfilled. pub const NORMAL: u16 = 1000; /// Indicates that an endpoint is "going away", such as a server going down or a browser having /// navigated away from a page. pub const AWAY: u16 = 1001; /// Indicates that an endpoint is terminating the connection due to a protocol error. pub const PROTOCOL: u16 = 1002; /// Indicates that an endpoint is terminating the connection because it has received a type of /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if /// it receives a binary message). pub const UNSUPPORTED: u16 = 1003; /// Indicates that no status code was included in a closing frame. pub const STATUS: u16 = 1005; /// Indicates an abnormal closure. pub const ABNORMAL: u16 = 1006; /// Indicates that an endpoint is terminating the connection because it has received data /// within a message that was not consistent with the type of the message (e.g., non-UTF-8 /// RFC3629 data within a text message). pub const INVALID: u16 = 1007; /// Indicates that an endpoint is terminating the connection because it has received a message /// that violates its policy. This is a generic status code that can be returned when there is /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to /// hide specific details about the policy. pub const POLICY: u16 = 1008; /// Indicates that an endpoint is terminating the connection because it has received a message /// that is too big for it to process. pub const SIZE: u16 = 1009; /// Indicates that an endpoint (client) is terminating the connection because it has expected /// the server to negotiate one or more extension, but the server didn't return them in the /// response message of the WebSocket handshake. The list of extensions that are needed should /// be given as the reason for closing. Note that this status code is not used by the server, /// because it can fail the WebSocket handshake instead. pub const EXTENSION: u16 = 1010; /// Indicates that a server is terminating the connection because it encountered an unexpected /// condition that prevented it from fulfilling the request. pub const ERROR: u16 = 1011; /// Indicates that the server is restarting. pub const RESTART: u16 = 1012; /// Indicates that the server is overloaded and the client should either connect to a different /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an /// action. pub const AGAIN: u16 = 1013; } #[cfg(test)] mod tests { use super::*; use crate::{body::Body, routing::get, Router}; use http::{Request, Version}; use tower::ServiceExt; #[crate::test] async fn rejects_http_1_0_requests() { let svc = get(|ws: Result| { let rejection = ws.unwrap_err(); assert!(matches!( rejection, WebSocketUpgradeRejection::ConnectionNotUpgradable(_) )); std::future::ready(()) }); let req = Request::builder() .version(Version::HTTP_10) .method(Method::GET) .header("upgrade", "websocket") .header("connection", "Upgrade") .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==") .header("sec-websocket-version", "13") .body(Body::empty()) .unwrap(); let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); } #[allow(dead_code)] fn default_on_failed_upgrade() { async fn handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|_| async {}) } let _: Router = Router::new().route("/", get(handler)); } #[allow(dead_code)] fn on_failed_upgrade() { async fn handler(ws: WebSocketUpgrade) -> Response { ws.on_failed_upgrade(|_error: Error| println!("oops!")) .on_upgrade(|_| async {}) } let _: Router = Router::new().route("/", get(handler)); } } axum-axum-v0.6.20/axum/src/form.rs000066400000000000000000000163461446276533600170130ustar00rootroot00000000000000use crate::body::HttpBody; use crate::extract::{rejection::*, FromRequest, RawForm}; use crate::BoxError; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use axum_core::RequestExt; use http::header::CONTENT_TYPE; use http::{Request, StatusCode}; use serde::de::DeserializeOwned; use serde::Serialize; /// URL encoded extractor and response. /// /// # As extractor /// /// If used as an extractor `Form` will deserialize the query parameters for `GET` and `HEAD` /// requests and `application/x-www-form-urlencoded` encoded request bodies for other methods. It /// supports any type that implements [`serde::Deserialize`]. /// /// โš ๏ธ Since parsing form data might require consuming the request body, the `Form` extractor must be /// *last* if there are multiple extractors in a handler. See ["the order of /// extractors"][order-of-extractors] /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// ```rust /// use axum::Form; /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct SignUp { /// username: String, /// password: String, /// } /// /// async fn accept_form(Form(sign_up): Form) { /// // ... /// } /// ``` /// /// Note that `Content-Type: multipart/form-data` requests are not supported. Use [`Multipart`] /// instead. /// /// # As response /// /// ```rust /// use axum::Form; /// use serde::Serialize; /// /// #[derive(Serialize)] /// struct Payload { /// value: String, /// } /// /// async fn handler() -> Form { /// Form(Payload { value: "foo".to_owned() }) /// } /// ``` /// /// [`Multipart`]: crate::extract::Multipart #[cfg_attr(docsrs, doc(cfg(feature = "form")))] #[derive(Debug, Clone, Copy, Default)] #[must_use] pub struct Form(pub T); #[async_trait] impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = FormRejection; async fn from_request(req: Request, _state: &S) -> Result { let is_get_or_head = req.method() == http::Method::GET || req.method() == http::Method::HEAD; match req.extract().await { Ok(RawForm(bytes)) => { let value = serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection { if is_get_or_head { FailedToDeserializeForm::from_err(err).into() } else { FailedToDeserializeFormBody::from_err(err).into() } })?; Ok(Form(value)) } Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)), Err(RawFormRejection::InvalidFormContentType(r)) => { Err(FormRejection::InvalidFormContentType(r)) } } } } impl IntoResponse for Form where T: Serialize, { fn into_response(self) -> Response { match serde_urlencoded::to_string(&self.0) { Ok(body) => ( [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())], body, ) .into_response(), Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), } } } axum_core::__impl_deref!(Form); #[cfg(test)] mod tests { use super::*; use crate::{ body::{Empty, Full}, routing::{on, MethodFilter}, test_helpers::TestClient, Router, }; use bytes::Bytes; use http::{header::CONTENT_TYPE, Method, Request}; use mime::APPLICATION_WWW_FORM_URLENCODED; use serde::{Deserialize, Serialize}; use std::fmt::Debug; #[derive(Debug, PartialEq, Serialize, Deserialize)] struct Pagination { size: Option, page: Option, } async fn check_query(uri: impl AsRef, value: T) { let req = Request::builder() .uri(uri.as_ref()) .body(Empty::::new()) .unwrap(); assert_eq!(Form::::from_request(req, &()).await.unwrap().0, value); } async fn check_body(value: T) { let req = Request::builder() .uri("http://example.com/test") .method(Method::POST) .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref()) .body(Full::::new( serde_urlencoded::to_string(&value).unwrap().into(), )) .unwrap(); assert_eq!(Form::::from_request(req, &()).await.unwrap().0, value); } #[crate::test] async fn test_form_query() { check_query( "http://example.com/test", Pagination { size: None, page: None, }, ) .await; check_query( "http://example.com/test?size=10", Pagination { size: Some(10), page: None, }, ) .await; check_query( "http://example.com/test?size=10&page=20", Pagination { size: Some(10), page: Some(20), }, ) .await; } #[crate::test] async fn test_form_body() { check_body(Pagination { size: None, page: None, }) .await; check_body(Pagination { size: Some(10), page: None, }) .await; check_body(Pagination { size: Some(10), page: Some(20), }) .await; } #[crate::test] async fn test_incorrect_content_type() { let req = Request::builder() .uri("http://example.com/test") .method(Method::POST) .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Full::::new( serde_urlencoded::to_string(&Pagination { size: Some(10), page: None, }) .unwrap() .into(), )) .unwrap(); assert!(matches!( Form::::from_request(req, &()) .await .unwrap_err(), FormRejection::InvalidFormContentType(InvalidFormContentType) )); } #[tokio::test] async fn deserialize_error_status_codes() { #[allow(dead_code)] #[derive(Deserialize)] struct Payload { a: i32, } let app = Router::new().route( "/", on( MethodFilter::GET | MethodFilter::POST, |_: Form| async {}, ), ); let client = TestClient::new(app); let res = client.get("/?a=false").send().await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); let res = client .post("/") .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref()) .body("a=false") .send() .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); } } axum-axum-v0.6.20/axum/src/handler/000077500000000000000000000000001446276533600171055ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/handler/future.rs000066400000000000000000000023461446276533600207720ustar00rootroot00000000000000//! Handler future types. use crate::response::Response; use futures_util::future::Map; use http::Request; use pin_project_lite::pin_project; use std::{convert::Infallible, future::Future, pin::Pin, task::Context}; use tower::util::Oneshot; use tower_service::Service; opaque_future! { /// The response future for [`IntoService`](super::IntoService). pub type IntoServiceFuture = Map< F, fn(Response) -> Result, >; } pin_project! { /// The response future for [`Layered`](super::Layered). pub struct LayeredFuture where S: Service>, { #[pin] inner: Map>, fn(Result) -> Response>, } } impl LayeredFuture where S: Service>, { pub(super) fn new( inner: Map>, fn(Result) -> Response>, ) -> Self { Self { inner } } } impl Future for LayeredFuture where S: Service>, { type Output = Response; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll { self.project().inner.poll(cx) } } axum-axum-v0.6.20/axum/src/handler/into_service_state_in_extension.rs000066400000000000000000000045131446276533600261310ustar00rootroot00000000000000use super::Handler; use crate::response::Response; use http::Request; use std::{ convert::Infallible, fmt, marker::PhantomData, task::{Context, Poll}, }; use tower_service::Service; pub(crate) struct IntoServiceStateInExtension { handler: H, _marker: PhantomData (T, S, B)>, } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); assert_sync::>(); } impl IntoServiceStateInExtension { pub(crate) fn new(handler: H) -> Self { Self { handler, _marker: PhantomData, } } } impl fmt::Debug for IntoServiceStateInExtension { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoServiceStateInExtension") .finish_non_exhaustive() } } impl Clone for IntoServiceStateInExtension where H: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), _marker: PhantomData, } } } impl Service> for IntoServiceStateInExtension where H: Handler + Clone + Send + 'static, B: Send + 'static, S: Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = super::future::IntoServiceFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { // `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or // from `Layered` which buffers in `::call` and is therefore // also always ready. Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { use futures_util::future::FutureExt; let state = req .extensions_mut() .remove::() .expect("state extension missing. This is a bug in axum, please file an issue"); let handler = self.handler.clone(); let future = Handler::call(handler, req, state); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) } } axum-axum-v0.6.20/axum/src/handler/mod.rs000066400000000000000000000334171446276533600202420ustar00rootroot00000000000000//! Async functions that can be used to handle requests. //! #![doc = include_str!("../docs/handlers_intro.md")] //! //! Some examples of handlers: //! //! ```rust //! use axum::{body::Bytes, http::StatusCode}; //! //! // Handler that immediately returns an empty `200 OK` response. //! async fn unit_handler() {} //! //! // Handler that immediately returns an empty `200 OK` response with a plain //! // text body. //! async fn string_handler() -> String { //! "Hello, World!".to_string() //! } //! //! // Handler that buffers the request body and returns it. //! // //! // This works because `Bytes` implements `FromRequest` //! // and therefore can be used as an extractor. //! // //! // `String` and `StatusCode` both implement `IntoResponse` and //! // therefore `Result` also implements `IntoResponse` //! async fn echo(body: Bytes) -> Result { //! if let Ok(string) = String::from_utf8(body.to_vec()) { //! Ok(string) //! } else { //! Err(StatusCode::BAD_REQUEST) //! } //! } //! ``` //! //! Instead of a direct `StatusCode`, it makes sense to use intermediate error type //! that can ultimately be converted to `Response`. This allows using `?` operator //! in handlers. See those examples: //! //! * [`anyhow-error-response`][anyhow] for generic boxed errors //! * [`error-handling-and-dependency-injection`][ehdi] for application-specific detailed errors //! //! [anyhow]: https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs //! [ehdi]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling-and-dependency-injection/src/main.rs //! #![doc = include_str!("../docs/debugging_handler_type_errors.md")] #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::Body, extract::{FromRequest, FromRequestParts}, response::{IntoResponse, Response}, routing::IntoMakeService, }; use http::Request; use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin}; use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; pub mod future; mod service; pub use self::service::HandlerService; /// Trait for async functions that can be used to handle requests. /// /// You shouldn't need to depend on this trait directly. It is automatically /// implemented to closures of the right types. /// /// See the [module docs](crate::handler) for more details. /// /// # Converting `Handler`s into [`Service`]s /// /// To convert `Handler`s into [`Service`]s you have to call either /// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]: /// /// ``` /// use tower::Service; /// use axum::{ /// extract::State, /// body::Body, /// http::Request, /// handler::{HandlerWithoutStateExt, Handler}, /// }; /// /// // this handler doesn't require any state /// async fn one() {} /// // so it can be converted to a service with `HandlerWithoutStateExt::into_service` /// assert_service(one.into_service()); /// /// // this handler requires state /// async fn two(_: State) {} /// // so we have to provide it /// let handler_with_state = two.with_state(String::new()); /// // which gives us a `Service` /// assert_service(handler_with_state); /// /// // helper to check that a value implements `Service` /// fn assert_service(service: S) /// where /// S: Service>, /// {} /// ``` #[doc = include_str!("../docs/debugging_handler_type_errors.md")] /// /// # Handlers that aren't functions /// /// The `Handler` trait is also implemented for `T: IntoResponse`. That allows easily returning /// fixed data for routes: /// /// ``` /// use axum::{ /// Router, /// routing::{get, post}, /// Json, /// http::StatusCode, /// }; /// use serde_json::json; /// /// let app = Router::new() /// // respond with a fixed string /// .route("/", get("Hello, World!")) /// // or return some mock data /// .route("/users", post(( /// StatusCode::CREATED, /// Json(json!({ "id": 1, "username": "alice" })), /// ))); /// # let _: Router = app; /// ``` #[cfg_attr( nightly_error_messages, rustc_on_unimplemented( note = "Consider using `#[axum::debug_handler]` to improve the error message" ) )] pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; /// Call the handler with the given request. fn call(self, req: Request, state: S) -> Self::Future; /// Apply a [`tower::Layer`] to the handler. /// /// All requests to the handler will be processed by the layer's /// corresponding middleware. /// /// This can be used to add additional processing to a request for a single /// handler. /// /// Note this differs from [`routing::Router::layer`](crate::routing::Router::layer) /// which adds a middleware to a group of routes. /// /// If you're applying middleware that produces errors you have to handle the errors /// so they're converted into responses. You can learn more about doing that /// [here](crate::error_handling). /// /// # Example /// /// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a handler /// can be done like so: /// /// ```rust /// use axum::{ /// routing::get, /// handler::Handler, /// Router, /// }; /// use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit}; /// /// async fn handler() { /* ... */ } /// /// let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64)); /// let app = Router::new().route("/", get(layered_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` fn layer(self, layer: L) -> Layered where L: Layer> + Clone, L::Service: Service>, { Layered { layer, handler: self, _marker: PhantomData, } } /// Convert the handler into a [`Service`] by providing the state fn with_state(self, state: S) -> HandlerService { HandlerService::new(self, state) } } impl Handler<((),), S, B> for F where F: FnOnce() -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, B: Send + 'static, { type Future = Pin + Send>>; fn call(self, _req: Request, _state: S) -> Self::Future { Box::pin(async move { self().await.into_response() }) } } macro_rules! impl_handler { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] impl Handler<(M, $($ty,)* $last,), S, B> for F where F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, S: Send + Sync + 'static, Res: IntoResponse, $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, { type Future = Pin + Send>>; fn call(self, req: Request, state: S) -> Self::Future { Box::pin(async move { let (mut parts, body) = req.into_parts(); let state = &state; $( let $ty = match $ty::from_request_parts(&mut parts, state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); let $last = match $last::from_request(req, state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; let res = self($($ty,)* $last,).await; res.into_response() }) } } }; } all_the_tuples!(impl_handler); mod private { // Marker type for `impl Handler for T` #[allow(missing_debug_implementations)] pub enum IntoResponseHandler {} } impl Handler for T where T: IntoResponse + Clone + Send + 'static, B: Send + 'static, { type Future = std::future::Ready; fn call(self, _req: Request, _state: S) -> Self::Future { std::future::ready(self.into_response()) } } /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// /// Created with [`Handler::layer`]. See that method for more details. pub struct Layered { layer: L, handler: H, _marker: PhantomData (T, S, B, B2)>, } impl fmt::Debug for Layered where L: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Layered") .field("layer", &self.layer) .finish() } } impl Clone for Layered where L: Clone, H: Clone, { fn clone(&self) -> Self { Self { layer: self.layer.clone(), handler: self.handler.clone(), _marker: PhantomData, } } } impl Handler for Layered where L: Layer> + Clone + Send + 'static, H: Handler, L::Service: Service, Error = Infallible> + Clone + Send + 'static, >>::Response: IntoResponse, >>::Future: Send, T: 'static, S: 'static, B: Send + 'static, B2: Send + 'static, { type Future = future::LayeredFuture; fn call(self, req: Request, state: S) -> Self::Future { use futures_util::future::{FutureExt, Map}; let svc = self.handler.with_state(state); let svc = self.layer.layer(svc); let future: Map< _, fn( Result< >>::Response, >>::Error, >, ) -> _, > = svc.oneshot(req).map(|result| match result { Ok(res) => res.into_response(), Err(err) => match err {}, }); future::LayeredFuture::new(future) } } /// Extension trait for [`Handler`]s that don't have state. /// /// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`]. /// /// [`MakeService`]: tower::make::MakeService pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`Service`] and no state. fn into_service(self) -> HandlerService; /// Convert the handler into a [`MakeService`] and no state. /// /// See [`HandlerService::into_make_service`] for more details. /// /// [`MakeService`]: tower::make::MakeService fn into_make_service(self) -> IntoMakeService>; /// Convert the handler into a [`MakeService`] which stores information /// about the incoming connection and has no state. /// /// See [`HandlerService::into_make_service_with_connect_info`] for more details. /// /// [`MakeService`]: tower::make::MakeService #[cfg(feature = "tokio")] fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo, C>; } impl HandlerWithoutStateExt for H where H: Handler, { fn into_service(self) -> HandlerService { self.with_state(()) } fn into_make_service(self) -> IntoMakeService> { self.into_service().into_make_service() } #[cfg(feature = "tokio")] fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo, C> { self.into_service().into_make_service_with_connect_info() } } #[cfg(test)] mod tests { use super::*; use crate::{body, extract::State, test_helpers::*}; use http::StatusCode; use std::time::Duration; use tower_http::{ compression::CompressionLayer, limit::RequestBodyLimitLayer, map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer, timeout::TimeoutLayer, }; #[crate::test] async fn handler_into_service() { async fn handle(body: String) -> impl IntoResponse { format!("you said: {body}") } let client = TestClient::new(handle.into_service()); let res = client.post("/").body("hi there!").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "you said: hi there!"); } #[crate::test] async fn with_layer_that_changes_request_body_and_state() { async fn handle(State(state): State<&'static str>) -> &'static str { state } let svc = handle .layer(( RequestBodyLimitLayer::new(1024), TimeoutLayer::new(Duration::from_secs(10)), MapResponseBodyLayer::new(body::boxed), CompressionLayer::new(), )) .layer(MapRequestBodyLayer::new(body::boxed)) .with_state("foo"); let client = TestClient::new(svc); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } } axum-axum-v0.6.20/axum/src/handler/service.rs000066400000000000000000000117041446276533600211160ustar00rootroot00000000000000use super::Handler; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::response::Response; use crate::routing::IntoMakeService; use http::Request; use std::{ convert::Infallible, fmt, marker::PhantomData, task::{Context, Poll}, }; use tower_service::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// /// Created with [`Handler::with_state`] or [`HandlerWithoutStateExt::into_service`]. /// /// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service pub struct HandlerService { handler: H, state: S, _marker: PhantomData (T, B)>, } impl HandlerService { /// Get a reference to the state. pub fn state(&self) -> &S { &self.state } /// Convert the handler into a [`MakeService`]. /// /// This allows you to serve a single handler if you don't need any routing: /// /// ```rust /// use axum::{ /// Server, /// handler::Handler, /// extract::State, /// http::{Uri, Method}, /// response::IntoResponse, /// }; /// use std::net::SocketAddr; /// /// #[derive(Clone)] /// struct AppState {} /// /// async fn handler(State(state): State) { /// // ... /// } /// /// let app = handler.with_state(AppState {}); /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) /// .serve(app.into_make_service()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService pub fn into_make_service(self) -> IntoMakeService> { IntoMakeService::new(self) } /// Convert the handler into a [`MakeService`] which stores information /// about the incoming connection. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// ```rust /// use axum::{ /// Server, /// handler::Handler, /// response::IntoResponse, /// extract::{ConnectInfo, State}, /// }; /// use std::net::SocketAddr; /// /// #[derive(Clone)] /// struct AppState {}; /// /// async fn handler( /// ConnectInfo(addr): ConnectInfo, /// State(state): State, /// ) -> String { /// format!("Hello {}", addr) /// } /// /// let app = handler.with_state(AppState {}); /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) /// .serve(app.into_make_service_with_connect_info::()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[cfg(feature = "tokio")] pub fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo, C> { IntoMakeServiceWithConnectInfo::new(self) } } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); assert_sync::>(); } impl HandlerService { pub(super) fn new(handler: H, state: S) -> Self { Self { handler, state, _marker: PhantomData, } } } impl fmt::Debug for HandlerService { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoService").finish_non_exhaustive() } } impl Clone for HandlerService where H: Clone, S: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), state: self.state.clone(), _marker: PhantomData, } } } impl Service> for HandlerService where H: Handler + Clone + Send + 'static, B: Send + 'static, S: Clone + Send + Sync, { type Response = Response; type Error = Infallible; type Future = super::future::IntoServiceFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { // `IntoService` can only be constructed from async functions which are always ready, or // from `Layered` which buffers in `::call` and is therefore // also always ready. Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { use futures_util::future::FutureExt; let handler = self.handler.clone(); let future = Handler::call(handler, req, self.state.clone()); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) } } axum-axum-v0.6.20/axum/src/json.rs000066400000000000000000000227411446276533600170150ustar00rootroot00000000000000use crate::{ body::{Bytes, HttpBody}, extract::{rejection::*, FromRequest}, BoxError, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use bytes::{BufMut, BytesMut}; use http::{ header::{self, HeaderMap, HeaderValue}, Request, StatusCode, }; use serde::{de::DeserializeOwned, Serialize}; /// JSON Extractor / Response. /// /// When used as an extractor, it can deserialize request bodies into some type that /// implements [`serde::Deserialize`]. The request will be rejected (and a [`JsonRejection`] will /// be returned) if: /// /// - The request doesn't have a `Content-Type: application/json` (or similar) header. /// - The body doesn't contain syntactically valid JSON. /// - The body contains syntactically valid JSON but it couldn't be deserialized into the target /// type. /// - Buffering the request body fails. /// /// โš ๏ธ Since parsing JSON requires consuming the request body, the `Json` extractor must be /// *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// See [`JsonRejection`] for more details. /// /// # Extractor example /// /// ```rust,no_run /// use axum::{ /// extract, /// routing::post, /// Router, /// }; /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct CreateUser { /// email: String, /// password: String, /// } /// /// async fn create_user(extract::Json(payload): extract::Json) { /// // payload is a `CreateUser` /// } /// /// let app = Router::new().route("/users", post(create_user)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// When used as a response, it can serialize any type that implements [`serde::Serialize`] to /// `JSON`, and will automatically set `Content-Type: application/json` header. /// /// # Response example /// /// ``` /// use axum::{ /// extract::Path, /// routing::get, /// Router, /// Json, /// }; /// use serde::Serialize; /// use uuid::Uuid; /// /// #[derive(Serialize)] /// struct User { /// id: Uuid, /// username: String, /// } /// /// async fn get_user(Path(user_id) : Path) -> Json { /// let user = find_user(user_id).await; /// Json(user) /// } /// /// async fn find_user(user_id: Uuid) -> User { /// // ... /// # unimplemented!() /// } /// /// let app = Router::new().route("/users/:id", get(get_user)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[derive(Debug, Clone, Copy, Default)] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] #[must_use] pub struct Json(pub T); #[async_trait] impl FromRequest for Json where T: DeserializeOwned, B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, S: Send + Sync, { type Rejection = JsonRejection; async fn from_request(req: Request, state: &S) -> Result { if json_content_type(req.headers()) { let bytes = Bytes::from_request(req, state).await?; let deserializer = &mut serde_json::Deserializer::from_slice(&bytes); let value = match serde_path_to_error::deserialize(deserializer) { Ok(value) => value, Err(err) => { let rejection = match err.inner().classify() { serde_json::error::Category::Data => JsonDataError::from_err(err).into(), serde_json::error::Category::Syntax | serde_json::error::Category::Eof => { JsonSyntaxError::from_err(err).into() } serde_json::error::Category::Io => { if cfg!(debug_assertions) { // we don't use `serde_json::from_reader` and instead always buffer // bodies first, so we shouldn't encounter any IO errors unreachable!() } else { JsonSyntaxError::from_err(err).into() } } }; return Err(rejection); } }; Ok(Json(value)) } else { Err(MissingJsonContentType.into()) } } } fn json_content_type(headers: &HeaderMap) -> bool { let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { content_type } else { return false; }; let content_type = if let Ok(content_type) = content_type.to_str() { content_type } else { return false; }; let mime = if let Ok(mime) = content_type.parse::() { mime } else { return false; }; let is_json_content_type = mime.type_() == "application" && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); is_json_content_type } axum_core::__impl_deref!(Json); impl From for Json { fn from(inner: T) -> Self { Self(inner) } } impl IntoResponse for Json where T: Serialize, { fn into_response(self) -> Response { // Use a small initial capacity of 128 bytes like serde_json::to_vec // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189 let mut buf = BytesMut::with_capacity(128).writer(); match serde_json::to_writer(&mut buf, &self.0) { Ok(()) => ( [( header::CONTENT_TYPE, HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), )], buf.into_inner().freeze(), ) .into_response(), Err(err) => ( StatusCode::INTERNAL_SERVER_ERROR, [( header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), )], err.to_string(), ) .into_response(), } } } #[cfg(test)] mod tests { use super::*; use crate::{routing::post, test_helpers::*, Router}; use serde::Deserialize; use serde_json::{json, Value}; #[crate::test] async fn deserialize_body() { #[derive(Debug, Deserialize)] struct Input { foo: String, } let app = Router::new().route("/", post(|input: Json| async { input.0.foo })); let client = TestClient::new(app); let res = client.post("/").json(&json!({ "foo": "bar" })).send().await; let body = res.text().await; assert_eq!(body, "bar"); } #[crate::test] async fn consume_body_to_json_requires_json_content_type() { #[derive(Debug, Deserialize)] struct Input { foo: String, } let app = Router::new().route("/", post(|input: Json| async { input.0.foo })); let client = TestClient::new(app); let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await; let status = res.status(); assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE); } #[crate::test] async fn json_content_types() { async fn valid_json_content_type(content_type: &str) -> bool { println!("testing {content_type:?}"); let app = Router::new().route("/", post(|Json(_): Json| async {})); let res = TestClient::new(app) .post("/") .header("content-type", content_type) .body("{}") .send() .await; res.status() == StatusCode::OK } assert!(valid_json_content_type("application/json").await); assert!(valid_json_content_type("application/json; charset=utf-8").await); assert!(valid_json_content_type("application/json;charset=utf-8").await); assert!(valid_json_content_type("application/cloudevents+json").await); assert!(!valid_json_content_type("text/json").await); } #[crate::test] async fn invalid_json_syntax() { let app = Router::new().route("/", post(|_: Json| async {})); let client = TestClient::new(app); let res = client .post("/") .body("{") .header("content-type", "application/json") .send() .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); } #[derive(Deserialize)] struct Foo { #[allow(dead_code)] a: i32, #[allow(dead_code)] b: Vec, } #[derive(Deserialize)] struct Bar { #[allow(dead_code)] x: i32, #[allow(dead_code)] y: i32, } #[crate::test] async fn invalid_json_data() { let app = Router::new().route("/", post(|_: Json| async {})); let client = TestClient::new(app); let res = client .post("/") .body("{\"a\": 1, \"b\": [{\"x\": 2}]}") .header("content-type", "application/json") .send() .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); let body_text = res.text().await; assert_eq!( body_text, "Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23" ); } } axum-axum-v0.6.20/axum/src/lib.rs000066400000000000000000000364701446276533600166160ustar00rootroot00000000000000#![cfg_attr(nightly_error_messages, feature(rustc_attrs))] //! axum is a web application framework that focuses on ergonomics and modularity. //! //! # Table of contents //! //! - [High-level features](#high-level-features) //! - [Compatibility](#compatibility) //! - [Example](#example) //! - [Routing](#routing) //! - [Handlers](#handlers) //! - [Extractors](#extractors) //! - [Responses](#responses) //! - [Error handling](#error-handling) //! - [Middleware](#middleware) //! - [Sharing state with handlers](#sharing-state-with-handlers) //! - [Building integrations for axum](#building-integrations-for-axum) //! - [Required dependencies](#required-dependencies) //! - [Examples](#examples) //! - [Feature flags](#feature-flags) //! //! # High-level features //! //! - Route requests to handlers with a macro-free API. //! - Declaratively parse requests using extractors. //! - Simple and predictable error handling model. //! - Generate responses with minimal boilerplate. //! - Take full advantage of the [`tower`] and [`tower-http`] ecosystem of //! middleware, services, and utilities. //! //! In particular, the last point is what sets `axum` apart from other frameworks. //! `axum` doesn't have its own middleware system but instead uses //! [`tower::Service`]. This means `axum` gets timeouts, tracing, compression, //! authorization, and more, for free. It also enables you to share middleware with //! applications written using [`hyper`] or [`tonic`]. //! //! # Compatibility //! //! axum is designed to work with [tokio] and [hyper]. Runtime and //! transport layer independence is not a goal, at least for the time being. //! //! # Example //! //! The "Hello, World!" of axum is: //! //! ```rust,no_run //! use axum::{ //! routing::get, //! Router, //! }; //! //! #[tokio::main] //! async fn main() { //! // build our application with a single route //! let app = Router::new().route("/", get(|| async { "Hello, World!" })); //! //! // run it with hyper on localhost:3000 //! axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) //! .serve(app.into_make_service()) //! .await //! .unwrap(); //! } //! ``` //! //! Note using `#[tokio::main]` requires you enable tokio's `macros` and `rt-multi-thread` features //! or just `full` to enable all features (`cargo add tokio --features macros,rt-multi-thread`). //! //! # Routing //! //! [`Router`] is used to setup which paths goes to which services: //! //! ```rust //! use axum::{Router, routing::get}; //! //! // our router //! let app = Router::new() //! .route("/", get(root)) //! .route("/foo", get(get_foo).post(post_foo)) //! .route("/foo/bar", get(foo_bar)); //! //! // which calls one of these handlers //! async fn root() {} //! async fn get_foo() {} //! async fn post_foo() {} //! async fn foo_bar() {} //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! See [`Router`] for more details on routing. //! //! # Handlers //! #![doc = include_str!("docs/handlers_intro.md")] //! //! See [`handler`](crate::handler) for more details on handlers. //! //! # Extractors //! //! An extractor is a type that implements [`FromRequest`] or [`FromRequestParts`]. Extractors are //! how you pick apart the incoming request to get the parts your handler needs. //! //! ```rust //! use axum::extract::{Path, Query, Json}; //! use std::collections::HashMap; //! //! // `Path` gives you the path parameters and deserializes them. //! async fn path(Path(user_id): Path) {} //! //! // `Query` gives you the query parameters and deserializes them. //! async fn query(Query(params): Query>) {} //! //! // Buffer the request body and deserialize it as JSON into a //! // `serde_json::Value`. `Json` supports any type that implements //! // `serde::Deserialize`. //! async fn json(Json(payload): Json) {} //! ``` //! //! See [`extract`](crate::extract) for more details on extractors. //! //! # Responses //! //! Anything that implements [`IntoResponse`] can be returned from handlers. //! //! ```rust,no_run //! use axum::{ //! body::Body, //! routing::get, //! response::Json, //! Router, //! }; //! use serde_json::{Value, json}; //! //! // `&'static str` becomes a `200 OK` with `content-type: text/plain; charset=utf-8` //! async fn plain_text() -> &'static str { //! "foo" //! } //! //! // `Json` gives a content-type of `application/json` and works with any type //! // that implements `serde::Serialize` //! async fn json() -> Json { //! Json(json!({ "data": 42 })) //! } //! //! let app = Router::new() //! .route("/plain_text", get(plain_text)) //! .route("/json", get(json)); //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! See [`response`](crate::response) for more details on building responses. //! //! # Error handling //! //! axum aims to have a simple and predictable error handling model. That means //! it is simple to convert errors into responses and you are guaranteed that //! all errors are handled. //! //! See [`error_handling`](crate::error_handling) for more details on axum's //! error handling model and how to handle errors gracefully. //! //! # Middleware //! //! There are several different ways to write middleware for axum. See //! [`middleware`](crate::middleware) for more details. //! //! # Sharing state with handlers //! //! It is common to share some state between handlers. For example, a //! pool of database connections or clients to other services may need to //! be shared. //! //! The three most common ways of doing that are: //! - Using the [`State`] extractor //! - Using request extensions //! - Using closure captures //! //! ## Using the [`State`] extractor //! //! ```rust,no_run //! use axum::{ //! extract::State, //! routing::get, //! Router, //! }; //! use std::sync::Arc; //! //! struct AppState { //! // ... //! } //! //! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route("/", get(handler)) //! .with_state(shared_state); //! //! async fn handler( //! State(state): State>, //! ) { //! // ... //! } //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! You should prefer using [`State`] if possible since it's more type safe. The downside is that //! it's less dynamic than request extensions. //! //! See [`State`] for more details about accessing state. //! //! ## Using request extensions //! //! Another way to extract state in handlers is using [`Extension`](crate::extract::Extension) as //! layer and extractor: //! //! ```rust,no_run //! use axum::{ //! extract::Extension, //! routing::get, //! Router, //! }; //! use std::sync::Arc; //! //! struct AppState { //! // ... //! } //! //! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route("/", get(handler)) //! .layer(Extension(shared_state)); //! //! async fn handler( //! Extension(state): Extension>, //! ) { //! // ... //! } //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! The downside to this approach is that you'll get runtime errors //! (specifically a `500 Internal Server Error` response) if you try and extract //! an extension that doesn't exist, perhaps because you forgot to add the //! middleware or because you're extracting the wrong type. //! //! ## Using closure captures //! //! State can also be passed directly to handlers using closure captures: //! //! ```rust,no_run //! use axum::{ //! Json, //! extract::{Extension, Path}, //! routing::{get, post}, //! Router, //! }; //! use std::sync::Arc; //! use serde::Deserialize; //! //! struct AppState { //! // ... //! } //! //! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route( //! "/users", //! post({ //! let shared_state = Arc::clone(&shared_state); //! move |body| create_user(body, shared_state) //! }), //! ) //! .route( //! "/users/:id", //! get({ //! let shared_state = Arc::clone(&shared_state); //! move |path| get_user(path, shared_state) //! }), //! ); //! //! async fn get_user(Path(user_id): Path, state: Arc) { //! // ... //! } //! //! async fn create_user(Json(payload): Json, state: Arc) { //! // ... //! } //! //! #[derive(Deserialize)] //! struct CreateUserPayload { //! // ... //! } //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! //! The downside to this approach is that it's a little more verbose than using //! [`State`] or extensions. //! //! # Building integrations for axum //! //! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or //! [`IntoResponse`] implementations should depend on the [`axum-core`] crate, instead of `axum` if //! possible. [`axum-core`] contains core types and traits and is less likely to receive breaking //! changes. //! //! # Required dependencies //! //! To use axum there are a few dependencies you have to pull in as well: //! //! ```toml //! [dependencies] //! axum = "" //! hyper = { version = "", features = ["full"] } //! tokio = { version = "", features = ["full"] } //! tower = "" //! ``` //! //! The `"full"` feature for hyper and tokio isn't strictly necessary but it's //! the easiest way to get started. //! //! Note that [`hyper::Server`] is re-exported by axum so if that's all you need //! then you don't have to explicitly depend on hyper. //! //! Tower isn't strictly necessary either but helpful for testing. See the //! testing example in the repo to learn more about testing axum apps. //! //! # Examples //! //! The axum repo contains [a number of examples][examples] that show how to put all the //! pieces together. //! //! # Feature flags //! //! axum uses a set of [feature flags] to reduce the amount of compiled and //! optional dependencies. //! //! The following optional features are available: //! //! Name | Description | Default? //! ---|---|--- //! `headers` | Enables extracting typed headers via [`TypedHeader`] | No //! `http1` | Enables hyper's `http1` feature | Yes //! `http2` | Enables hyper's `http2` feature | No //! `json` | Enables the [`Json`] type and some similar convenience functionality | Yes //! `macros` | Enables optional utility macros | No //! `matched-path` | Enables capturing of every request's router path and the [`MatchedPath`] extractor | Yes //! `multipart` | Enables parsing `multipart/form-data` requests with [`Multipart`] | No //! `original-uri` | Enables capturing of every request's original URI and the [`OriginalUri`] extractor | Yes //! `tokio` | Enables `tokio` as a dependency and `axum::Server`, `SSE` and `extract::connect_info` types. | Yes //! `tower-log` | Enables `tower`'s `log` feature | Yes //! `tracing` | Log rejections from built-in extractors | No //! `ws` | Enables WebSockets support via [`extract::ws`] | No //! `form` | Enables the `Form` extractor | Yes //! `query` | Enables the `Query` extractor | Yes //! //! [`TypedHeader`]: crate::extract::TypedHeader //! [`MatchedPath`]: crate::extract::MatchedPath //! [`Multipart`]: crate::extract::Multipart //! [`OriginalUri`]: crate::extract::OriginalUri //! [`tower`]: https://crates.io/crates/tower //! [`tower-http`]: https://crates.io/crates/tower-http //! [`tokio`]: http://crates.io/crates/tokio //! [`hyper`]: http://crates.io/crates/hyper //! [`tonic`]: http://crates.io/crates/tonic //! [feature flags]: https://doc.rust-lang.org/cargo/reference/features.html#the-features-section //! [`IntoResponse`]: crate::response::IntoResponse //! [`Timeout`]: tower::timeout::Timeout //! [examples]: https://github.com/tokio-rs/axum/tree/main/examples //! [`Router::merge`]: crate::routing::Router::merge //! [`axum::Server`]: hyper::server::Server //! [`Service`]: tower::Service //! [`Service::poll_ready`]: tower::Service::poll_ready //! [`Service`'s]: tower::Service //! [`tower::Service`]: tower::Service //! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ //! [`FromRequest`]: crate::extract::FromRequest //! [`FromRequestParts`]: crate::extract::FromRequestParts //! [`HeaderMap`]: http::header::HeaderMap //! [`Request`]: http::Request //! [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs //! [axum-macros]: https://docs.rs/axum-macros //! [`debug_handler`]: https://docs.rs/axum-macros/latest/axum_macros/attr.debug_handler.html //! [`Handler`]: crate::handler::Handler //! [`Infallible`]: std::convert::Infallible //! [load shed]: tower::load_shed //! [`axum-core`]: http://crates.io/crates/axum-core //! [`State`]: crate::extract::State #![warn( clippy::all, clippy::todo, clippy::empty_enum, clippy::enum_glob_use, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, clippy::needless_continue, clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, clippy::rest_pat_in_fully_bound_structs, clippy::fn_params_excessive_bools, clippy::exit, clippy::inefficient_to_string, clippy::linkedlist, clippy::macro_use_imports, clippy::option_option, clippy::verbose_file_reads, clippy::unnested_or_patterns, clippy::str_to_string, rust_2018_idioms, future_incompatible, nonstandard_style, missing_debug_implementations, missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] #![forbid(unsafe_code)] #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] #[macro_use] pub(crate) mod macros; mod boxed; mod extension; #[cfg(feature = "form")] mod form; #[cfg(feature = "json")] mod json; mod service_ext; #[cfg(feature = "headers")] mod typed_header; mod util; pub mod body; pub mod error_handling; pub mod extract; pub mod handler; pub mod middleware; pub mod response; pub mod routing; #[cfg(test)] mod test_helpers; #[doc(no_inline)] pub use async_trait::async_trait; #[cfg(feature = "headers")] #[doc(no_inline)] pub use headers; #[doc(no_inline)] pub use http; #[cfg(feature = "tokio")] #[doc(no_inline)] pub use hyper::Server; #[doc(inline)] pub use self::extension::Extension; #[doc(inline)] #[cfg(feature = "json")] pub use self::json::Json; #[doc(inline)] pub use self::routing::Router; #[doc(inline)] #[cfg(feature = "headers")] pub use self::typed_header::TypedHeader; #[doc(inline)] #[cfg(feature = "form")] pub use self::form::Form; #[doc(inline)] pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; #[cfg(feature = "macros")] pub use axum_macros::debug_handler; pub use self::service_ext::ServiceExt; #[cfg(test)] use axum_macros::__private_axum_test as test; axum-axum-v0.6.20/axum/src/macros.rs000066400000000000000000000043631446276533600173300ustar00rootroot00000000000000//! Internal macros macro_rules! opaque_future { ($(#[$m:meta])* pub type $name:ident = $actual:ty;) => { opaque_future! { $(#[$m])* pub type $name<> = $actual; } }; ($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => { pin_project_lite::pin_project! { $(#[$m])* pub struct $name<$($param),*> { #[pin] future: $actual, } } impl<$($param),*> $name<$($param),*> { pub(crate) fn new(future: $actual) -> Self { Self { future } } } impl<$($param),*> std::fmt::Debug for $name<$($param),*> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct(stringify!($name)).finish_non_exhaustive() } } impl<$($param),*> std::future::Future for $name<$($param),*> where $actual: std::future::Future, { type Output = <$actual as std::future::Future>::Output; #[inline] fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { self.project().future.poll(cx) } } }; } #[rustfmt::skip] macro_rules! all_the_tuples { ($name:ident) => { $name!([], T1); $name!([T1], T2); $name!([T1, T2], T3); $name!([T1, T2, T3], T4); $name!([T1, T2, T3, T4], T5); $name!([T1, T2, T3, T4, T5], T6); $name!([T1, T2, T3, T4, T5, T6], T7); $name!([T1, T2, T3, T4, T5, T6, T7], T8); $name!([T1, T2, T3, T4, T5, T6, T7, T8], T9); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], T14); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], T15); $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], T16); }; } axum-axum-v0.6.20/axum/src/middleware/000077500000000000000000000000001446276533600176055ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/middleware/from_extractor.rs000066400000000000000000000247451446276533600232250ustar00rootroot00000000000000use crate::{ extract::FromRequestParts, response::{IntoResponse, Response}, }; use futures_util::{future::BoxFuture, ready}; use http::Request; use pin_project_lite::pin_project; use std::{ fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an extractor. /// /// If the extractor succeeds the value will be discarded and the inner service /// will be called. If the extractor fails the rejection will be returned and /// the inner service will _not_ be called. /// /// This can be used to perform validation of requests if the validation doesn't /// produce any useful output, and run the extractor for several handlers /// without repeating it in the function signature. /// /// Note that if the extractor consumes the request body, as `String` or /// [`Bytes`] does, an empty body will be left in its place. Thus wont be /// accessible to subsequent extractors or handlers. /// /// # Example /// /// ```rust /// use axum::{ /// extract::FromRequestParts, /// middleware::from_extractor, /// routing::{get, post}, /// Router, /// http::{header, StatusCode, request::Parts}, /// }; /// use async_trait::async_trait; /// /// // An extractor that performs authorization. /// struct RequireAuth; /// /// #[async_trait] /// impl FromRequestParts for RequireAuth /// where /// S: Send + Sync, /// { /// type Rejection = StatusCode; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// let auth_header = parts /// .headers /// .get(header::AUTHORIZATION) /// .and_then(|value| value.to_str().ok()); /// /// match auth_header { /// Some(auth_header) if token_is_valid(auth_header) => { /// Ok(Self) /// } /// _ => Err(StatusCode::UNAUTHORIZED), /// } /// } /// } /// /// fn token_is_valid(token: &str) -> bool { /// // ... /// # false /// } /// /// async fn handler() { /// // If we get here the request has been authorized /// } /// /// async fn other_handler() { /// // If we get here the request has been authorized /// } /// /// let app = Router::new() /// .route("/", get(handler)) /// .route("/foo", post(other_handler)) /// // The extractor will run before all routes /// .route_layer(from_extractor::()); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// [`Bytes`]: bytes::Bytes pub fn from_extractor() -> FromExtractorLayer { from_extractor_with_state(()) } /// Create a middleware from an extractor with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. pub fn from_extractor_with_state(state: S) -> FromExtractorLayer { FromExtractorLayer { state, _marker: PhantomData, } } /// [`Layer`] that applies [`FromExtractor`] that runs an extractor and /// discards the value. /// /// See [`from_extractor`] for more details. /// /// [`Layer`]: tower::Layer #[must_use] pub struct FromExtractorLayer { state: S, _marker: PhantomData E>, } impl Clone for FromExtractorLayer where S: Clone, { fn clone(&self) -> Self { Self { state: self.state.clone(), _marker: PhantomData, } } } impl fmt::Debug for FromExtractorLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractorLayer") .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } impl Layer for FromExtractorLayer where S: Clone, { type Service = FromExtractor; fn layer(&self, inner: T) -> Self::Service { FromExtractor { inner, state: self.state.clone(), _extractor: PhantomData, } } } /// Middleware that runs an extractor and discards the value. /// /// See [`from_extractor`] for more details. pub struct FromExtractor { inner: T, state: S, _extractor: PhantomData E>, } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); assert_sync::>(); } impl Clone for FromExtractor where T: Clone, S: Clone, { fn clone(&self) -> Self { Self { inner: self.inner.clone(), state: self.state.clone(), _extractor: PhantomData, } } } impl fmt::Debug for FromExtractor where T: fmt::Debug, S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractor") .field("inner", &self.inner) .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } impl Service> for FromExtractor where E: FromRequestParts + 'static, B: Send + 'static, T: Service> + Clone, T::Response: IntoResponse, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = T::Error; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let state = self.state.clone(); let extract_future = Box::pin(async move { let (mut parts, body) = req.into_parts(); let extracted = E::from_request_parts(&mut parts, &state).await; let req = Request::from_parts(parts, body); (req, extracted) }); ResponseFuture { state: State::Extracting { future: extract_future, }, svc: Some(self.inner.clone()), } } } pin_project! { /// Response future for [`FromExtractor`]. #[allow(missing_debug_implementations)] pub struct ResponseFuture where E: FromRequestParts, T: Service>, { #[pin] state: State, svc: Option, } } pin_project! { #[project = StateProj] enum State where E: FromRequestParts, T: Service>, { Extracting { future: BoxFuture<'static, (Request, Result)>, }, Call { #[pin] future: T::Future }, } } impl Future for ResponseFuture where E: FromRequestParts, T: Service>, T::Response: IntoResponse, { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let mut this = self.as_mut().project(); let new_state = match this.state.as_mut().project() { StateProj::Extracting { future } => { let (req, extracted) = ready!(future.as_mut().poll(cx)); match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); let future = svc.call(req); State::Call { future } } Err(err) => { let res = err.into_response(); return Poll::Ready(Ok(res)); } } } StateProj::Call { future } => { return future .poll(cx) .map(|result| result.map(IntoResponse::into_response)); } }; this.state.set(new_state); } } } #[cfg(test)] mod tests { use super::*; use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router}; use axum_core::extract::FromRef; use http::{header, request::Parts, StatusCode}; use tower_http::limit::RequestBodyLimitLayer; #[crate::test] async fn test_from_extractor() { #[derive(Clone)] struct Secret(&'static str); struct RequireAuth; #[async_trait::async_trait] impl FromRequestParts for RequireAuth where S: Send + Sync, Secret: FromRef, { type Rejection = StatusCode; async fn from_request_parts( parts: &mut Parts, state: &S, ) -> Result { let Secret(secret) = Secret::from_ref(state); if let Some(auth) = parts .headers .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) { if auth == secret { return Ok(Self); } } Err(StatusCode::UNAUTHORIZED) } } async fn handler() {} let state = Secret("secret"); let app = Router::new().route( "/", get(handler.layer(from_extractor_with_state::(state))), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); let res = client .get("/") .header(http::header::AUTHORIZATION, "secret") .send() .await; assert_eq!(res.status(), StatusCode::OK); } // just needs to compile #[allow(dead_code)] fn works_with_request_body_limit() { struct MyExtractor; #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, { type Rejection = std::convert::Infallible; async fn from_request_parts( _parts: &mut Parts, _state: &S, ) -> Result { unimplemented!() } } let _: Router = Router::new() .layer(from_extractor::()) .layer(RequestBodyLimitLayer::new(1)); } } axum-axum-v0.6.20/axum/src/middleware/from_fn.rs000066400000000000000000000264051446276533600216100ustar00rootroot00000000000000use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use std::{ any::type_name, convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower::{util::BoxCloneService, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an async function. /// /// `from_fn` requires the function given to /// /// 1. Be an `async fn`. /// 2. Take one or more [extractors] as the first arguments. /// 3. Take [`Next`](Next) as the final argument. /// 4. Return something that implements [`IntoResponse`]. /// /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`]. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{self, Request}, /// routing::get, /// response::Response, /// middleware::{self, Next}, /// }; /// /// async fn my_middleware( /// request: Request, /// next: Next, /// ) -> Response { /// // do something with `request`... /// /// let response = next.run(request).await; /// /// // do something with `response`... /// /// response /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(middleware::from_fn(my_middleware)); /// # let app: Router = app; /// ``` /// /// # Running extractors /// /// ```rust /// use axum::{ /// Router, /// extract::TypedHeader, /// http::StatusCode, /// headers::authorization::{Authorization, Bearer}, /// http::Request, /// middleware::{self, Next}, /// response::Response, /// routing::get, /// }; /// /// async fn auth( /// // run the `TypedHeader` extractor /// TypedHeader(auth): TypedHeader>, /// // you can also add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// next: Next, /// ) -> Result { /// if token_is_valid(auth.token()) { /// let response = next.run(request).await; /// Ok(response) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// fn token_is_valid(token: &str) -> bool { /// // ... /// # false /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(auth)); /// # let app: Router = app; /// ``` /// /// [extractors]: crate::extract::FromRequest /// [`State`]: crate::extract::State pub fn from_fn(f: F) -> FromFnLayer { from_fn_with_state((), f) } /// Create a middleware from an async function with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// response::{IntoResponse, Response}, /// middleware::{self, Next}, /// extract::State, /// }; /// /// #[derive(Clone)] /// struct AppState { /* ... */ } /// /// async fn my_middleware( /// State(state): State, /// // you can add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// next: Next, /// ) -> Response { /// // do something with `request`... /// /// let response = next.run(request).await; /// /// // do something with `response`... /// /// response /// } /// /// let state = AppState { /* ... */ }; /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { FromFnLayer { f, state, _extractor: PhantomData, } } /// A [`tower::Layer`] from an async function. /// /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// /// Created with [`from_fn`]. See that function for more details. #[must_use] pub struct FromFnLayer { f: F, state: S, _extractor: PhantomData T>, } impl Clone for FromFnLayer where F: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), state: self.state.clone(), _extractor: self._extractor, } } } impl Layer for FromFnLayer where F: Clone, S: Clone, { type Service = FromFn; fn layer(&self, inner: I) -> Self::Service { FromFn { f: self.f.clone(), state: self.state.clone(), inner, _extractor: PhantomData, } } } impl fmt::Debug for FromFnLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) .field("state", &self.state) .finish() } } /// A middleware created from an async function. /// /// Created with [`from_fn`]. See that function for more details. pub struct FromFn { f: F, inner: I, state: S, _extractor: PhantomData T>, } impl Clone for FromFn where F: Clone, I: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), state: self.state.clone(), _extractor: self._extractor, } } } macro_rules! impl_service { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] impl Service> for FromFn where F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static, $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, Fut: Future + Send + 'static, Out: IntoResponse + 'static, I: Service, Error = Infallible> + Clone + Send + 'static, I::Response: IntoResponse, I::Future: Send + 'static, B: Send + 'static, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); let state = self.state.clone(); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); let $last = match $last::from_request(req, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; let inner = ServiceBuilder::new() .boxed_clone() .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; f($($ty,)* $last, next).await.into_response() }); ResponseFuture { inner: future } } } }; } all_the_tuples!(impl_service); impl fmt::Debug for FromFn where S: fmt::Debug, I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) .field("state", &self.state) .finish() } } /// The remainder of a middleware stack, including the handler. pub struct Next { inner: BoxCloneService, Response, Infallible>, } impl Next { /// Execute the remaining middleware stack. pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, Err(err) => match err {}, } } } impl fmt::Debug for Next { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("inner", &self.inner) .finish() } } impl Clone for Next { fn clone(&self) -> Self { Self { inner: self.inner.clone(), } } } impl Service> for Next { type Response = Response; type Error = Infallible; type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { self.inner.call(req) } } /// Response future for [`FromFn`]. pub struct ResponseFuture { inner: BoxFuture<'static, Response>, } impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.as_mut().poll(cx).map(Ok) } } impl fmt::Debug for ResponseFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseFuture").finish() } } #[cfg(test)] mod tests { use super::*; use crate::{body::Body, routing::get, Router}; use http::{HeaderMap, StatusCode}; use tower::ServiceExt; #[crate::test] async fn basic() { async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse { req.headers_mut() .insert("x-axum-test", "ok".parse().unwrap()); next.run(req).await } async fn handle(headers: HeaderMap) -> String { headers["x-axum-test"].to_str().unwrap().to_owned() } let app = Router::new() .route("/", get(handle)) .layer(from_fn(insert_header)); let res = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = hyper::body::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"ok"); } } axum-axum-v0.6.20/axum/src/middleware/map_request.rs000066400000000000000000000274601446276533600225110ustar00rootroot00000000000000use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use std::{ any::type_name, convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an async function that transforms a request. /// /// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific /// extractors. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_request, /// http::Request, /// }; /// /// async fn set_header(mut request: Request) -> Request { /// request.headers_mut().insert("x-foo", "foo".parse().unwrap()); /// request /// } /// /// async fn handler(request: Request) { /// // `request` will have an `x-foo` header /// } /// /// let app = Router::new() /// .route("/", get(handler)) /// .layer(map_request(set_header)); /// # let _: Router = app; /// ``` /// /// # Rejecting the request /// /// The function given to `map_request` is allowed to also return a `Result` which can be used to /// reject the request and return a response immediately, without calling the remaining /// middleware. /// /// Specifically the valid return types are: /// /// - `Request` /// - `Result, E> where E: IntoResponse` /// /// ``` /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// middleware::map_request, /// }; /// /// async fn auth(request: Request) -> Result, StatusCode> { /// let auth_header = request.headers() /// .get(http::header::AUTHORIZATION) /// .and_then(|header| header.to_str().ok()); /// /// match auth_header { /// Some(auth_header) if token_is_valid(auth_header) => Ok(request), /// _ => Err(StatusCode::UNAUTHORIZED), /// } /// } /// /// fn token_is_valid(token: &str) -> bool { /// // ... /// # false /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_request(auth)); /// # let app: Router = app; /// ``` /// /// # Running extractors /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_request, /// extract::Path, /// http::Request, /// }; /// use std::collections::HashMap; /// /// async fn log_path_params( /// Path(path_params): Path>, /// request: Request, /// ) -> Request { /// tracing::debug!(?path_params); /// request /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(map_request(log_path_params)); /// # let _: Router = app; /// ``` /// /// Note that to access state you must use either [`map_request_with_state`]. pub fn map_request(f: F) -> MapRequestLayer { map_request_with_state((), f) } /// Create a middleware from an async function that transforms a request, with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// response::IntoResponse, /// middleware::map_request_with_state, /// extract::State, /// }; /// /// #[derive(Clone)] /// struct AppState { /* ... */ } /// /// async fn my_middleware( /// State(state): State, /// // you can add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// ) -> Request { /// // do something with `state` and `request`... /// request /// } /// /// let state = AppState { /* ... */ }; /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_request_with_state(state.clone(), my_middleware)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` pub fn map_request_with_state(state: S, f: F) -> MapRequestLayer { MapRequestLayer { f, state, _extractor: PhantomData, } } /// A [`tower::Layer`] from an async function that transforms a request. /// /// Created with [`map_request`]. See that function for more details. #[must_use] pub struct MapRequestLayer { f: F, state: S, _extractor: PhantomData T>, } impl Clone for MapRequestLayer where F: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), state: self.state.clone(), _extractor: self._extractor, } } } impl Layer for MapRequestLayer where F: Clone, S: Clone, { type Service = MapRequest; fn layer(&self, inner: I) -> Self::Service { MapRequest { f: self.f.clone(), state: self.state.clone(), inner, _extractor: PhantomData, } } } impl fmt::Debug for MapRequestLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequestLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) .field("state", &self.state) .finish() } } /// A middleware created from an async function that transforms a request. /// /// Created with [`map_request`]. See that function for more details. pub struct MapRequest { f: F, inner: I, state: S, _extractor: PhantomData T>, } impl Clone for MapRequest where F: Clone, I: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), state: self.state.clone(), _extractor: self._extractor, } } } macro_rules! impl_service { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] impl Service> for MapRequest where F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static, $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, Fut: Future + Send + 'static, Fut::Output: IntoMapRequestResult + Send + 'static, I: Service, Error = Infallible> + Clone + Send + 'static, I::Response: IntoResponse, I::Future: Send + 'static, B: Send + 'static, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); let state = self.state.clone(); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); let $last = match $last::from_request(req, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; match f($($ty,)* $last).await.into_map_request_result() { Ok(req) => { ready_inner.call(req).await.into_response() } Err(res) => { res } } }); ResponseFuture { inner: future } } } }; } all_the_tuples!(impl_service); impl fmt::Debug for MapRequest where S: fmt::Debug, I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequest") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) .field("state", &self.state) .finish() } } /// Response future for [`MapRequest`]. pub struct ResponseFuture { inner: BoxFuture<'static, Response>, } impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.as_mut().poll(cx).map(Ok) } } impl fmt::Debug for ResponseFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseFuture").finish() } } mod private { use crate::{http::Request, response::IntoResponse}; pub trait Sealed {} impl Sealed for Result, E> where E: IntoResponse {} impl Sealed for Request {} } /// Trait implemented by types that can be returned from [`map_request`], /// [`map_request_with_state`]. /// /// This trait is sealed such that it cannot be implemented outside this crate. pub trait IntoMapRequestResult: private::Sealed { /// Perform the conversion. fn into_map_request_result(self) -> Result, Response>; } impl IntoMapRequestResult for Result, E> where E: IntoResponse, { fn into_map_request_result(self) -> Result, Response> { self.map_err(IntoResponse::into_response) } } impl IntoMapRequestResult for Request { fn into_map_request_result(self) -> Result, Response> { Ok(self) } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router}; use http::{HeaderMap, StatusCode}; #[crate::test] async fn works() { async fn add_header(mut req: Request) -> Request { req.headers_mut().insert("x-foo", "foo".parse().unwrap()); req } async fn handler(headers: HeaderMap) -> Response { headers["x-foo"] .to_str() .unwrap() .to_owned() .into_response() } let app = Router::new() .route("/", get(handler)) .layer(map_request(add_header)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } #[crate::test] async fn works_for_short_circutting() { async fn add_header(_req: Request) -> Result, (StatusCode, &'static str)> { Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong")) } async fn handler(_headers: HeaderMap) -> Response { unreachable!() } let app = Router::new() .route("/", get(handler)) .layer(map_request(add_header)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.text().await, "something went wrong"); } } axum-axum-v0.6.20/axum/src/middleware/map_response.rs000066400000000000000000000237051446276533600226550ustar00rootroot00000000000000use crate::response::{IntoResponse, Response}; use axum_core::extract::FromRequestParts; use futures_util::future::BoxFuture; use http::Request; use std::{ any::type_name, convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an async function that transforms a response. /// /// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific /// extractors. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_response, /// response::Response, /// }; /// /// async fn set_header(mut response: Response) -> Response { /// response.headers_mut().insert("x-foo", "foo".parse().unwrap()); /// response /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(map_response(set_header)); /// # let _: Router = app; /// ``` /// /// # Running extractors /// /// It is also possible to run extractors that implement [`FromRequestParts`]. These will be run /// before calling the handler. /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_response, /// extract::Path, /// response::Response, /// }; /// use std::collections::HashMap; /// /// async fn log_path_params( /// Path(path_params): Path>, /// response: Response, /// ) -> Response { /// tracing::debug!(?path_params); /// response /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(map_response(log_path_params)); /// # let _: Router = app; /// ``` /// /// Note that to access state you must use either [`map_response_with_state`]. /// /// # Returning any `impl IntoResponse` /// /// It is also possible to return anything that implements [`IntoResponse`] /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_response, /// response::{Response, IntoResponse}, /// }; /// use std::collections::HashMap; /// /// async fn set_header(response: Response) -> impl IntoResponse { /// ( /// [("x-foo", "foo")], /// response, /// ) /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(map_response(set_header)); /// # let _: Router = app; /// ``` pub fn map_response(f: F) -> MapResponseLayer { map_response_with_state((), f) } /// Create a middleware from an async function that transforms a response, with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::StatusCode, /// routing::get, /// response::Response, /// middleware::map_response_with_state, /// extract::State, /// }; /// /// #[derive(Clone)] /// struct AppState { /* ... */ } /// /// async fn my_middleware( /// State(state): State, /// // you can add more extractors here but they must /// // all implement `FromRequestParts` /// // `FromRequest` is not allowed /// response: Response, /// ) -> Response { /// // do something with `state` and `response`... /// response /// } /// /// let state = AppState { /* ... */ }; /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_response_with_state(state.clone(), my_middleware)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` pub fn map_response_with_state(state: S, f: F) -> MapResponseLayer { MapResponseLayer { f, state, _extractor: PhantomData, } } /// A [`tower::Layer`] from an async function that transforms a response. /// /// Created with [`map_response`]. See that function for more details. #[must_use] pub struct MapResponseLayer { f: F, state: S, _extractor: PhantomData T>, } impl Clone for MapResponseLayer where F: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), state: self.state.clone(), _extractor: self._extractor, } } } impl Layer for MapResponseLayer where F: Clone, S: Clone, { type Service = MapResponse; fn layer(&self, inner: I) -> Self::Service { MapResponse { f: self.f.clone(), state: self.state.clone(), inner, _extractor: PhantomData, } } } impl fmt::Debug for MapResponseLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapResponseLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) .field("state", &self.state) .finish() } } /// A middleware created from an async function that transforms a response. /// /// Created with [`map_response`]. See that function for more details. pub struct MapResponse { f: F, inner: I, state: S, _extractor: PhantomData T>, } impl Clone for MapResponse where F: Clone, I: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), state: self.state.clone(), _extractor: self._extractor, } } } macro_rules! impl_service { ( $($ty:ident),* ) => { #[allow(non_snake_case, unused_mut)] impl Service> for MapResponse where F: FnMut($($ty,)* Response) -> Fut + Clone + Send + 'static, $( $ty: FromRequestParts + Send, )* Fut: Future + Send + 'static, Fut::Output: IntoResponse + Send + 'static, I: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, I::Future: Send + 'static, B: Send + 'static, ResBody: Send + 'static, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); let _state = self.state.clone(); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &_state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); match ready_inner.call(req).await { Ok(res) => { f($($ty,)* res).await.into_response() } Err(err) => match err {} } }); ResponseFuture { inner: future } } } }; } impl_service!(); impl_service!(T1); impl_service!(T1, T2); impl_service!(T1, T2, T3); impl_service!(T1, T2, T3, T4); impl_service!(T1, T2, T3, T4, T5); impl_service!(T1, T2, T3, T4, T5, T6); impl_service!(T1, T2, T3, T4, T5, T6, T7); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); impl fmt::Debug for MapResponse where S: fmt::Debug, I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapResponse") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) .field("state", &self.state) .finish() } } /// Response future for [`MapResponse`]. pub struct ResponseFuture { inner: BoxFuture<'static, Response>, } impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.as_mut().poll(cx).map(Ok) } } impl fmt::Debug for ResponseFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseFuture").finish() } } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use crate::{test_helpers::TestClient, Router}; #[crate::test] async fn works() { async fn add_header(mut res: Response) -> Response { res.headers_mut().insert("x-foo", "foo".parse().unwrap()); res } let app = Router::new().layer(map_response(add_header)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.headers()["x-foo"], "foo"); } } axum-axum-v0.6.20/axum/src/middleware/mod.rs000066400000000000000000000017171446276533600207400ustar00rootroot00000000000000//! Utilities for writing middleware //! #![doc = include_str!("../docs/middleware.md")] mod from_extractor; mod from_fn; mod map_request; mod map_response; pub use self::from_extractor::{ from_extractor, from_extractor_with_state, FromExtractor, FromExtractorLayer, }; pub use self::from_fn::{from_fn, from_fn_with_state, FromFn, FromFnLayer, Next}; pub use self::map_request::{ map_request, map_request_with_state, IntoMapRequestResult, MapRequest, MapRequestLayer, }; pub use self::map_response::{ map_response, map_response_with_state, MapResponse, MapResponseLayer, }; pub use crate::extension::AddExtension; pub mod future { //! Future types. pub use super::from_extractor::ResponseFuture as FromExtractorResponseFuture; pub use super::from_fn::ResponseFuture as FromFnResponseFuture; pub use super::map_request::ResponseFuture as MapRequestResponseFuture; pub use super::map_response::ResponseFuture as MapResponseResponseFuture; } axum-axum-v0.6.20/axum/src/response/000077500000000000000000000000001446276533600173265ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/response/mod.rs000066400000000000000000000143661446276533600204650ustar00rootroot00000000000000#![doc = include_str!("../docs/response.md")] use crate::body::{Bytes, Full}; use http::{header, HeaderValue}; mod redirect; #[cfg(feature = "tokio")] pub mod sse; #[doc(no_inline)] #[cfg(feature = "json")] pub use crate::Json; #[doc(no_inline)] #[cfg(feature = "headers")] pub use crate::TypedHeader; #[cfg(feature = "form")] #[doc(no_inline)] pub use crate::form::Form; #[doc(no_inline)] pub use crate::Extension; #[doc(inline)] pub use axum_core::response::{ AppendHeaders, ErrorResponse, IntoResponse, IntoResponseParts, Response, ResponseParts, Result, }; #[doc(inline)] pub use self::redirect::Redirect; #[doc(inline)] #[cfg(feature = "tokio")] pub use sse::Sse; /// An HTML response. /// /// Will automatically get `Content-Type: text/html`. #[derive(Clone, Copy, Debug)] #[must_use] pub struct Html(pub T); impl IntoResponse for Html where T: Into>, { fn into_response(self) -> Response { ( [( header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), )], self.0.into(), ) .into_response() } } impl From for Html { fn from(inner: T) -> Self { Self(inner) } } #[cfg(test)] mod tests { use crate::extract::Extension; use crate::{body::Body, routing::get, Router}; use axum_core::response::IntoResponse; use http::HeaderMap; use http::{StatusCode, Uri}; // just needs to compile #[allow(dead_code)] fn impl_trait_result_works() { async fn impl_trait_ok() -> Result { Ok(()) } async fn impl_trait_err() -> Result<(), impl IntoResponse> { Err(()) } async fn impl_trait_both(uri: Uri) -> Result { if uri.path() == "/" { Ok(()) } else { Err(()) } } async fn impl_trait(uri: Uri) -> impl IntoResponse { if uri.path() == "/" { Ok(()) } else { Err(()) } } _ = Router::<(), Body>::new() .route("/", get(impl_trait_ok)) .route("/", get(impl_trait_err)) .route("/", get(impl_trait_both)) .route("/", get(impl_trait)); } // just needs to compile #[allow(dead_code)] fn tuple_responses() { async fn status() -> impl IntoResponse { StatusCode::OK } async fn status_headermap() -> impl IntoResponse { (StatusCode::OK, HeaderMap::new()) } async fn status_header_array() -> impl IntoResponse { (StatusCode::OK, [("content-type", "text/plain")]) } async fn status_headermap_body() -> impl IntoResponse { (StatusCode::OK, HeaderMap::new(), String::new()) } async fn status_header_array_body() -> impl IntoResponse { ( StatusCode::OK, [("content-type", "text/plain")], String::new(), ) } async fn status_headermap_impl_into_response() -> impl IntoResponse { (StatusCode::OK, HeaderMap::new(), impl_into_response()) } async fn status_header_array_impl_into_response() -> impl IntoResponse { ( StatusCode::OK, [("content-type", "text/plain")], impl_into_response(), ) } fn impl_into_response() -> impl IntoResponse {} async fn status_header_array_extension_body() -> impl IntoResponse { ( StatusCode::OK, [("content-type", "text/plain")], Extension(1), String::new(), ) } async fn status_header_array_extension_mixed_body() -> impl IntoResponse { ( StatusCode::OK, [("content-type", "text/plain")], Extension(1), HeaderMap::new(), String::new(), ) } // async fn headermap() -> impl IntoResponse { HeaderMap::new() } async fn header_array() -> impl IntoResponse { [("content-type", "text/plain")] } async fn headermap_body() -> impl IntoResponse { (HeaderMap::new(), String::new()) } async fn header_array_body() -> impl IntoResponse { ([("content-type", "text/plain")], String::new()) } async fn headermap_impl_into_response() -> impl IntoResponse { (HeaderMap::new(), impl_into_response()) } async fn header_array_impl_into_response() -> impl IntoResponse { ([("content-type", "text/plain")], impl_into_response()) } async fn header_array_extension_body() -> impl IntoResponse { ( [("content-type", "text/plain")], Extension(1), String::new(), ) } async fn header_array_extension_mixed_body() -> impl IntoResponse { ( [("content-type", "text/plain")], Extension(1), HeaderMap::new(), String::new(), ) } _ = Router::<(), Body>::new() .route("/", get(status)) .route("/", get(status_headermap)) .route("/", get(status_header_array)) .route("/", get(status_headermap_body)) .route("/", get(status_header_array_body)) .route("/", get(status_headermap_impl_into_response)) .route("/", get(status_header_array_impl_into_response)) .route("/", get(status_header_array_extension_body)) .route("/", get(status_header_array_extension_mixed_body)) .route("/", get(headermap)) .route("/", get(header_array)) .route("/", get(headermap_body)) .route("/", get(header_array_body)) .route("/", get(headermap_impl_into_response)) .route("/", get(header_array_impl_into_response)) .route("/", get(header_array_extension_body)) .route("/", get(header_array_extension_mixed_body)); } } axum-axum-v0.6.20/axum/src/response/redirect.rs000066400000000000000000000061531446276533600215020ustar00rootroot00000000000000use axum_core::response::{IntoResponse, Response}; use http::{header::LOCATION, HeaderValue, StatusCode}; /// Response that redirects the request to another location. /// /// # Example /// /// ```rust /// use axum::{ /// routing::get, /// response::Redirect, /// Router, /// }; /// /// let app = Router::new() /// .route("/old", get(|| async { Redirect::permanent("/new") })) /// .route("/new", get(|| async { "Hello!" })); /// # async { /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[must_use = "needs to be returned from a handler or otherwise turned into a Response to be useful"] #[derive(Debug, Clone)] pub struct Redirect { status_code: StatusCode, location: HeaderValue, } impl Redirect { /// Create a new [`Redirect`] that uses a [`303 See Other`][mdn] status code. /// /// This redirect instructs the client to change the method to GET for the subsequent request /// to the given `uri`, which is useful after successful form submission, file upload or when /// you generally don't want the redirected-to page to observe the original request method and /// body (if non-empty). If you want to preserve the request method and body, /// [`Redirect::temporary`] should be used instead. /// /// # Panics /// /// If `uri` isn't a valid [`HeaderValue`]. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 pub fn to(uri: &str) -> Self { Self::with_status_code(StatusCode::SEE_OTHER, uri) } /// Create a new [`Redirect`] that uses a [`307 Temporary Redirect`][mdn] status code. /// /// This has the same behavior as [`Redirect::to`], except it will preserve the original HTTP /// method and body. /// /// # Panics /// /// If `uri` isn't a valid [`HeaderValue`]. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/307 pub fn temporary(uri: &str) -> Self { Self::with_status_code(StatusCode::TEMPORARY_REDIRECT, uri) } /// Create a new [`Redirect`] that uses a [`308 Permanent Redirect`][mdn] status code. /// /// # Panics /// /// If `uri` isn't a valid [`HeaderValue`]. /// /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308 pub fn permanent(uri: &str) -> Self { Self::with_status_code(StatusCode::PERMANENT_REDIRECT, uri) } // This is intentionally not public since other kinds of redirects might not // use the `Location` header, namely `304 Not Modified`. // // We're open to adding more constructors upon request, if they make sense :) fn with_status_code(status_code: StatusCode, uri: &str) -> Self { assert!( status_code.is_redirection(), "not a redirection status code" ); Self { status_code, location: HeaderValue::try_from(uri).expect("URI isn't a valid header value"), } } } impl IntoResponse for Redirect { fn into_response(self) -> Response { (self.status_code, [(LOCATION, self.location)]).into_response() } } axum-axum-v0.6.20/axum/src/response/sse.rs000066400000000000000000000504701446276533600204740ustar00rootroot00000000000000//! Server-Sent Events (SSE) responses. //! //! # Example //! //! ``` //! use axum::{ //! Router, //! routing::get, //! response::sse::{Event, KeepAlive, Sse}, //! }; //! use std::{time::Duration, convert::Infallible}; //! use tokio_stream::StreamExt as _ ; //! use futures_util::stream::{self, Stream}; //! //! let app = Router::new().route("/sse", get(sse_handler)); //! //! async fn sse_handler() -> Sse>> { //! // A `Stream` that repeats an event every second //! let stream = stream::repeat_with(|| Event::default().data("hi!")) //! .map(Ok) //! .throttle(Duration::from_secs(1)); //! //! Sse::new(stream).keep_alive(KeepAlive::default()) //! } //! # async { //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` use crate::{ body::{Bytes, HttpBody}, BoxError, }; use axum_core::{ body, response::{IntoResponse, Response}, }; use bytes::{BufMut, BytesMut}; use futures_util::{ ready, stream::{Stream, TryStream}, }; use pin_project_lite::pin_project; use std::{ fmt, future::Future, pin::Pin, task::{Context, Poll}, time::Duration, }; use sync_wrapper::SyncWrapper; use tokio::time::Sleep; /// An SSE response #[derive(Clone)] #[must_use] pub struct Sse { stream: S, keep_alive: Option, } impl Sse { /// Create a new [`Sse`] response that will respond with the given stream of /// [`Event`]s. /// /// See the [module docs](self) for more details. pub fn new(stream: S) -> Self where S: TryStream + Send + 'static, S::Error: Into, { Sse { stream, keep_alive: None, } } /// Configure the interval between keep-alive messages. /// /// Defaults to no keep-alive messages. pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { self.keep_alive = Some(keep_alive); self } } impl fmt::Debug for Sse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Sse") .field("stream", &format_args!("{}", std::any::type_name::())) .field("keep_alive", &self.keep_alive) .finish() } } impl IntoResponse for Sse where S: Stream> + Send + 'static, E: Into, { fn into_response(self) -> Response { ( [ (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()), (http::header::CACHE_CONTROL, "no-cache"), ], body::boxed(Body { event_stream: SyncWrapper::new(self.stream), keep_alive: self.keep_alive.map(KeepAliveStream::new), }), ) .into_response() } } pin_project! { struct Body { #[pin] event_stream: SyncWrapper, #[pin] keep_alive: Option, } } impl HttpBody for Body where S: Stream>, { type Data = Bytes; type Error = E; fn poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let this = self.project(); match this.event_stream.get_pin_mut().poll_next(cx) { Poll::Pending => { if let Some(keep_alive) = this.keep_alive.as_pin_mut() { keep_alive.poll_event(cx).map(|e| Some(Ok(e))) } else { Poll::Pending } } Poll::Ready(Some(Ok(event))) => { if let Some(keep_alive) = this.keep_alive.as_pin_mut() { keep_alive.reset(); } Poll::Ready(Some(Ok(event.finalize()))) } Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))), Poll::Ready(None) => Poll::Ready(None), } } fn poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll, Self::Error>> { Poll::Ready(Ok(None)) } } /// Server-sent event #[derive(Debug, Default, Clone)] #[must_use] pub struct Event { buffer: BytesMut, flags: EventFlags, } impl Event { /// Set the event's data data field(s) (`data:`) /// /// Newlines in `data` will automatically be broken across `data:` fields. /// /// This corresponds to [`MessageEvent`'s data field]. /// /// Note that events with an empty data field will be ignored by the browser. /// /// # Panics /// /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE. /// - Panics if `data` or `json_data` have already been called. /// /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data pub fn data(mut self, data: T) -> Event where T: AsRef, { if self.flags.contains(EventFlags::HAS_DATA) { panic!("Called `EventBuilder::data` multiple times"); } for line in memchr_split(b'\n', data.as_ref().as_bytes()) { self.field("data", line); } self.flags.insert(EventFlags::HAS_DATA); self } /// Set the event's data field to a value serialized as unformatted JSON (`data:`). /// /// This corresponds to [`MessageEvent`'s data field]. /// /// # Panics /// /// Panics if `data` or `json_data` have already been called. /// /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data #[cfg(feature = "json")] pub fn json_data(mut self, data: T) -> serde_json::Result where T: serde::Serialize, { if self.flags.contains(EventFlags::HAS_DATA) { panic!("Called `EventBuilder::json_data` multiple times"); } self.buffer.extend_from_slice(b"data:"); serde_json::to_writer((&mut self.buffer).writer(), &data)?; self.buffer.put_u8(b'\n'); self.flags.insert(EventFlags::HAS_DATA); Ok(self) } /// Set the event's comment field (`:`). /// /// This field will be ignored by most SSE clients. /// /// Unlike other functions, this function can be called multiple times to add many comments. /// /// # Panics /// /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in /// comments. pub fn comment(mut self, comment: T) -> Event where T: AsRef, { self.field("", comment.as_ref()); self } /// Set the event's name field (`event:`). /// /// This corresponds to the `type` parameter given when calling `addEventListener` on an /// [`EventSource`]. For example, `.event("update")` should correspond to /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a /// [`message` event] instead. /// /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event /// /// # Panics /// /// - Panics if `event` contains any newlines or carriage returns. /// - Panics if this function has already been called on this event. pub fn event(mut self, event: T) -> Event where T: AsRef, { if self.flags.contains(EventFlags::HAS_EVENT) { panic!("Called `EventBuilder::event` multiple times"); } self.flags.insert(EventFlags::HAS_EVENT); self.field("event", event.as_ref()); self } /// Set the event's retry timeout field (`retry:`). /// /// This sets how long clients will wait before reconnecting if they are disconnected from the /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they /// wish, such as if they implement exponential backoff. /// /// # Panics /// /// Panics if this function has already been called on this event. pub fn retry(mut self, duration: Duration) -> Event { if self.flags.contains(EventFlags::HAS_RETRY) { panic!("Called `EventBuilder::retry` multiple times"); } self.flags.insert(EventFlags::HAS_RETRY); self.buffer.extend_from_slice(b"retry:"); let secs = duration.as_secs(); let millis = duration.subsec_millis(); if secs > 0 { // format seconds self.buffer .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes()); // pad milliseconds if millis < 10 { self.buffer.extend_from_slice(b"00"); } else if millis < 100 { self.buffer.extend_from_slice(b"0"); } } // format milliseconds self.buffer .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes()); self.buffer.put_u8(b'\n'); self } /// Set the event's identifier field (`id:`). /// /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself, /// the browser will set that field to the last known message ID, starting with the empty /// string. /// /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId /// /// # Panics /// /// - Panics if `id` contains any newlines, carriage returns or null characters. /// - Panics if this function has already been called on this event. pub fn id(mut self, id: T) -> Event where T: AsRef, { if self.flags.contains(EventFlags::HAS_ID) { panic!("Called `EventBuilder::id` multiple times"); } self.flags.insert(EventFlags::HAS_ID); let id = id.as_ref().as_bytes(); assert_eq!( memchr::memchr(b'\0', id), None, "Event ID cannot contain null characters", ); self.field("id", id); self } fn field(&mut self, name: &str, value: impl AsRef<[u8]>) { let value = value.as_ref(); assert_eq!( memchr::memchr2(b'\r', b'\n', value), None, "SSE field value cannot contain newlines or carriage returns", ); self.buffer.extend_from_slice(name.as_bytes()); self.buffer.put_u8(b':'); // Prevent values that start with spaces having that space stripped if value.starts_with(b" ") { self.buffer.put_u8(b' '); } self.buffer.extend_from_slice(value); self.buffer.put_u8(b'\n'); } fn finalize(mut self) -> Bytes { self.buffer.put_u8(b'\n'); self.buffer.freeze() } } bitflags::bitflags! { #[derive(Default)] struct EventFlags: u8 { const HAS_DATA = 0b0001; const HAS_EVENT = 0b0010; const HAS_RETRY = 0b0100; const HAS_ID = 0b1000; } } /// Configure the interval between keep-alive messages, the content /// of each message, and the associated stream. #[derive(Debug, Clone)] #[must_use] pub struct KeepAlive { event: Bytes, max_interval: Duration, } impl KeepAlive { /// Create a new `KeepAlive`. pub fn new() -> Self { Self { event: Bytes::from_static(b":\n\n"), max_interval: Duration::from_secs(15), } } /// Customize the interval between keep-alive messages. /// /// Default is 15 seconds. pub fn interval(mut self, time: Duration) -> Self { self.max_interval = time; self } /// Customize the text of the keep-alive message. /// /// Default is an empty comment. /// /// # Panics /// /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE /// comments. pub fn text(self, text: I) -> Self where I: AsRef, { self.event(Event::default().comment(text)) } /// Customize the event of the keep-alive message. /// /// Default is an empty comment. /// /// # Panics /// /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE /// comments. pub fn event(mut self, event: Event) -> Self { self.event = event.finalize(); self } } impl Default for KeepAlive { fn default() -> Self { Self::new() } } pin_project! { #[derive(Debug)] struct KeepAliveStream { keep_alive: KeepAlive, #[pin] alive_timer: Sleep, } } impl KeepAliveStream { fn new(keep_alive: KeepAlive) -> Self { Self { alive_timer: tokio::time::sleep(keep_alive.max_interval), keep_alive, } } fn reset(self: Pin<&mut Self>) { let this = self.project(); this.alive_timer .reset(tokio::time::Instant::now() + this.keep_alive.max_interval); } fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut().project(); ready!(this.alive_timer.poll(cx)); let event = this.keep_alive.event.clone(); self.reset(); Poll::Ready(event) } } fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> { MemchrSplit { needle, haystack: Some(haystack), } } struct MemchrSplit<'a> { needle: u8, haystack: Option<&'a [u8]>, } impl<'a> Iterator for MemchrSplit<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { let haystack = self.haystack?; if let Some(pos) = memchr::memchr(self.needle, haystack) { let (front, back) = haystack.split_at(pos); self.haystack = Some(&back[1..]); Some(front) } else { self.haystack.take() } } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use futures_util::stream; use std::{collections::HashMap, convert::Infallible}; use tokio_stream::StreamExt as _; #[test] fn leading_space_is_not_stripped() { let no_leading_space = Event::default().data("\tfoobar"); assert_eq!(&*no_leading_space.finalize(), b"data:\tfoobar\n\n"); let leading_space = Event::default().data(" foobar"); assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n"); } #[crate::test] async fn basic() { let app = Router::new().route( "/", get(|| async { let stream = stream::iter(vec![ Event::default().data("one").comment("this is a comment"), Event::default() .json_data(serde_json::json!({ "foo": "bar" })) .unwrap(), Event::default() .event("three") .retry(Duration::from_secs(30)) .id("unique-id"), ]) .map(Ok::<_, Infallible>); Sse::new(stream) }), ); let client = TestClient::new(app); let mut stream = client.get("/").send().await; assert_eq!(stream.headers()["content-type"], "text/event-stream"); assert_eq!(stream.headers()["cache-control"], "no-cache"); let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "one"); assert_eq!(event_fields.get("comment").unwrap(), "this is a comment"); let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}"); assert!(event_fields.get("comment").is_none()); let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("event").unwrap(), "three"); assert_eq!(event_fields.get("retry").unwrap(), "30000"); assert_eq!(event_fields.get("id").unwrap(), "unique-id"); assert!(event_fields.get("comment").is_none()); assert!(stream.chunk_text().await.is_none()); } #[tokio::test(start_paused = true)] async fn keep_alive() { const DELAY: Duration = Duration::from_secs(5); let app = Router::new().route( "/", get(|| async { let stream = stream::repeat_with(|| Event::default().data("msg")) .map(Ok::<_, Infallible>) .throttle(DELAY); Sse::new(stream).keep_alive( KeepAlive::new() .interval(Duration::from_secs(1)) .text("keep-alive-text"), ) }), ); let client = TestClient::new(app); let mut stream = client.get("/").send().await; for _ in 0..5 { // first message should be an event let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "msg"); // then 4 seconds of keep-alive messages for _ in 0..4 { tokio::time::sleep(Duration::from_secs(1)).await; let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); } } } #[tokio::test(start_paused = true)] async fn keep_alive_ends_when_the_stream_ends() { const DELAY: Duration = Duration::from_secs(5); let app = Router::new().route( "/", get(|| async { let stream = stream::repeat_with(|| Event::default().data("msg")) .map(Ok::<_, Infallible>) .throttle(DELAY) .take(2); Sse::new(stream).keep_alive( KeepAlive::new() .interval(Duration::from_secs(1)) .text("keep-alive-text"), ) }), ); let client = TestClient::new(app); let mut stream = client.get("/").send().await; // first message should be an event let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "msg"); // then 4 seconds of keep-alive messages for _ in 0..4 { tokio::time::sleep(Duration::from_secs(1)).await; let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); } // then the last event let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "msg"); // then no more events or keep-alive messages assert!(stream.chunk_text().await.is_none()); } fn parse_event(payload: &str) -> HashMap { let mut fields = HashMap::new(); let mut lines = payload.lines().peekable(); while let Some(line) = lines.next() { if line.is_empty() { assert!(lines.next().is_none()); break; } let (mut key, value) = line.split_once(':').unwrap(); let value = value.trim(); if key.is_empty() { key = "comment"; } fields.insert(key.to_owned(), value.to_owned()); } fields } #[test] fn memchr_spliting() { assert_eq!( memchr_split(2, &[]).collect::>(), [&[]] as [&[u8]; 1] ); assert_eq!( memchr_split(2, &[2]).collect::>(), [&[], &[]] as [&[u8]; 2] ); assert_eq!( memchr_split(2, &[1]).collect::>(), [&[1]] as [&[u8]; 1] ); assert_eq!( memchr_split(2, &[1, 2]).collect::>(), [&[1], &[]] as [&[u8]; 2] ); assert_eq!( memchr_split(2, &[2, 1]).collect::>(), [&[], &[1]] as [&[u8]; 2] ); assert_eq!( memchr_split(2, &[1, 2, 2, 1]).collect::>(), [&[1], &[], &[1]] as [&[u8]; 3] ); } } axum-axum-v0.6.20/axum/src/routing/000077500000000000000000000000001446276533600171575ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/routing/future.rs000066400000000000000000000002061446276533600210350ustar00rootroot00000000000000//! Future types. pub use super::{ into_make_service::IntoMakeServiceFuture, route::{InfallibleRouteFuture, RouteFuture}, }; axum-axum-v0.6.20/axum/src/routing/into_make_service.rs000066400000000000000000000022241446276533600232130ustar00rootroot00000000000000use std::{ convert::Infallible, future::ready, task::{Context, Poll}, }; use tower_service::Service; /// A [`MakeService`] that produces axum router services. /// /// [`MakeService`]: tower::make::MakeService #[derive(Debug, Clone)] pub struct IntoMakeService { svc: S, } impl IntoMakeService { pub(crate) fn new(svc: S) -> Self { Self { svc } } } impl Service for IntoMakeService where S: Clone, { type Response = S; type Error = Infallible; type Future = IntoMakeServiceFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _target: T) -> Self::Future { IntoMakeServiceFuture::new(ready(Ok(self.svc.clone()))) } } opaque_future! { /// Response future for [`IntoMakeService`]. pub type IntoMakeServiceFuture = std::future::Ready>; } #[cfg(test)] mod tests { use super::*; use crate::body::Body; #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); } } axum-axum-v0.6.20/axum/src/routing/method_filter.rs000066400000000000000000000062771446276533600223660ustar00rootroot00000000000000use bitflags::bitflags; use http::Method; use std::{ fmt, fmt::{Debug, Formatter}, }; bitflags! { /// A filter that matches one or more HTTP methods. pub struct MethodFilter: u16 { /// Match `DELETE` requests. const DELETE = 0b000000010; /// Match `GET` requests. const GET = 0b000000100; /// Match `HEAD` requests. const HEAD = 0b000001000; /// Match `OPTIONS` requests. const OPTIONS = 0b000010000; /// Match `PATCH` requests. const PATCH = 0b000100000; /// Match `POST` requests. const POST = 0b001000000; /// Match `PUT` requests. const PUT = 0b010000000; /// Match `TRACE` requests. const TRACE = 0b100000000; } } /// Error type used when converting a [`Method`] to a [`MethodFilter`] fails. #[derive(Debug)] pub struct NoMatchingMethodFilter { method: Method, } impl NoMatchingMethodFilter { /// Get the [`Method`] that couldn't be converted to a [`MethodFilter`]. pub fn method(&self) -> &Method { &self.method } } impl fmt::Display for NoMatchingMethodFilter { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "no `MethodFilter` for `{}`", self.method.as_str()) } } impl std::error::Error for NoMatchingMethodFilter {} impl TryFrom for MethodFilter { type Error = NoMatchingMethodFilter; fn try_from(m: Method) -> Result { match m { Method::DELETE => Ok(MethodFilter::DELETE), Method::GET => Ok(MethodFilter::GET), Method::HEAD => Ok(MethodFilter::HEAD), Method::OPTIONS => Ok(MethodFilter::OPTIONS), Method::PATCH => Ok(MethodFilter::PATCH), Method::POST => Ok(MethodFilter::POST), Method::PUT => Ok(MethodFilter::PUT), Method::TRACE => Ok(MethodFilter::TRACE), other => Err(NoMatchingMethodFilter { method: other }), } } } #[cfg(test)] mod tests { use super::*; #[test] fn from_http_method() { assert_eq!( MethodFilter::try_from(Method::DELETE).unwrap(), MethodFilter::DELETE ); assert_eq!( MethodFilter::try_from(Method::GET).unwrap(), MethodFilter::GET ); assert_eq!( MethodFilter::try_from(Method::HEAD).unwrap(), MethodFilter::HEAD ); assert_eq!( MethodFilter::try_from(Method::OPTIONS).unwrap(), MethodFilter::OPTIONS ); assert_eq!( MethodFilter::try_from(Method::PATCH).unwrap(), MethodFilter::PATCH ); assert_eq!( MethodFilter::try_from(Method::POST).unwrap(), MethodFilter::POST ); assert_eq!( MethodFilter::try_from(Method::PUT).unwrap(), MethodFilter::PUT ); assert_eq!( MethodFilter::try_from(Method::TRACE).unwrap(), MethodFilter::TRACE ); assert!(MethodFilter::try_from(http::Method::CONNECT) .unwrap_err() .to_string() .contains("CONNECT")); } } axum-axum-v0.6.20/axum/src/routing/method_routing.rs000066400000000000000000001414641446276533600225660ustar00rootroot00000000000000//! Route to services and handlers based on HTTP methods. use super::{future::InfallibleRouteFuture, IntoMakeService}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, Bytes, HttpBody}, boxed::BoxedIntoRoute, error_handling::{HandleError, HandleErrorLayer}, handler::Handler, http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, }; use axum_core::response::IntoResponse; use bytes::BytesMut; use std::{ convert::Infallible, fmt, task::{Context, Poll}, }; use tower::{service_fn, util::MapResponseLayer}; use tower_layer::Layer; use tower_service::Service; macro_rules! top_level_service_fn { ( $name:ident, GET ) => { top_level_service_fn!( /// Route `GET` requests to the given service. /// /// # Example /// /// ```rust /// use axum::{ /// http::Request, /// Router, /// routing::get_service, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// // Requests to `GET /` will go to `service`. /// let app = Router::new().route("/", get_service(service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. $name, GET ); }; ( $name:ident, $method:ident ) => { top_level_service_fn!( #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")] /// /// See [`get_service`] for an example. $name, $method ); }; ( $(#[$m:meta])+ $name:ident, $method:ident ) => { $(#[$m])+ pub fn $name(svc: T) -> MethodRouter where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone, { on_service(MethodFilter::$method, svc) } }; } macro_rules! top_level_handler_fn { ( $name:ident, GET ) => { top_level_handler_fn!( /// Route `GET` requests to the given handler. /// /// # Example /// /// ```rust /// use axum::{ /// routing::get, /// Router, /// }; /// /// async fn handler() {} /// /// // Requests to `GET /` will go to `handler`. /// let app = Router::new().route("/", get(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. $name, GET ); }; ( $name:ident, $method:ident ) => { top_level_handler_fn!( #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")] /// /// See [`get`] for an example. $name, $method ); }; ( $(#[$m:meta])+ $name:ident, $method:ident ) => { $(#[$m])+ pub fn $name(handler: H) -> MethodRouter where H: Handler, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { on(MethodFilter::$method, handler) } }; } macro_rules! chained_service_fn { ( $name:ident, GET ) => { chained_service_fn!( /// Chain an additional service that will only accept `GET` requests. /// /// # Example /// /// ```rust /// use axum::{ /// http::Request, /// Router, /// routing::post_service, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// let other_service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// // Requests to `POST /` will go to `service` and `GET /` will go to /// // `other_service`. /// let app = Router::new().route("/", post_service(service).get_service(other_service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. $name, GET ); }; ( $name:ident, $method:ident ) => { chained_service_fn!( #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")] /// /// See [`MethodRouter::get_service`] for an example. $name, $method ); }; ( $(#[$m:meta])+ $name:ident, $method:ident ) => { $(#[$m])+ #[track_caller] pub fn $name(self, svc: T) -> Self where T: Service, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { self.on_service(MethodFilter::$method, svc) } }; } macro_rules! chained_handler_fn { ( $name:ident, GET ) => { chained_handler_fn!( /// Chain an additional handler that will only accept `GET` requests. /// /// # Example /// /// ```rust /// use axum::{routing::post, Router}; /// /// async fn handler() {} /// /// async fn other_handler() {} /// /// // Requests to `POST /` will go to `handler` and `GET /` will go to /// // `other_handler`. /// let app = Router::new().route("/", post(handler).get(other_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. $name, GET ); }; ( $name:ident, $method:ident ) => { chained_handler_fn!( #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")] /// /// See [`MethodRouter::get`] for an example. $name, $method ); }; ( $(#[$m:meta])+ $name:ident, $method:ident ) => { $(#[$m])+ #[track_caller] pub fn $name(self, handler: H) -> Self where H: Handler, T: 'static, S: Send + Sync + 'static, { self.on(MethodFilter::$method, handler) } }; } top_level_service_fn!(delete_service, DELETE); top_level_service_fn!(get_service, GET); top_level_service_fn!(head_service, HEAD); top_level_service_fn!(options_service, OPTIONS); top_level_service_fn!(patch_service, PATCH); top_level_service_fn!(post_service, POST); top_level_service_fn!(put_service, PUT); top_level_service_fn!(trace_service, TRACE); /// Route requests with the given method to the service. /// /// # Example /// /// ```rust /// use axum::{ /// http::Request, /// routing::on, /// Router, /// routing::{MethodFilter, on_service}, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// // Requests to `POST /` will go to `service`. /// let app = Router::new().route("/", on_service(MethodFilter::POST, service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone, { MethodRouter::new().on_service(filter, svc) } /// Route requests to the given service regardless of its method. /// /// # Example /// /// ```rust /// use axum::{ /// http::Request, /// Router, /// routing::any_service, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// // All requests to `/` will go to `service`. /// let app = Router::new().route("/", any_service(service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Additional methods can still be chained: /// /// ```rust /// use axum::{ /// http::Request, /// Router, /// routing::any_service, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// # Ok::<_, Infallible>(Response::new(Body::empty())) /// // ... /// }); /// /// let other_service = tower::service_fn(|request: Request| async { /// # Ok::<_, Infallible>(Response::new(Body::empty())) /// // ... /// }); /// /// // `POST /` goes to `other_service`. All other requests go to `service` /// let app = Router::new().route("/", any_service(service).post_service(other_service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub fn any_service(svc: T) -> MethodRouter where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone, { MethodRouter::new() .fallback_service(svc) .skip_allow_header() } top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(get, GET); top_level_handler_fn!(head, HEAD); top_level_handler_fn!(options, OPTIONS); top_level_handler_fn!(patch, PATCH); top_level_handler_fn!(post, POST); top_level_handler_fn!(put, PUT); top_level_handler_fn!(trace, TRACE); /// Route requests with the given method to the handler. /// /// # Example /// /// ```rust /// use axum::{ /// routing::on, /// Router, /// routing::MethodFilter, /// }; /// /// async fn handler() {} /// /// // Requests to `POST /` will go to `handler`. /// let app = Router::new().route("/", on(MethodFilter::POST, handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub fn on(filter: MethodFilter, handler: H) -> MethodRouter where H: Handler, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { MethodRouter::new().on(filter, handler) } /// Route requests with the given handler regardless of the method. /// /// # Example /// /// ```rust /// use axum::{ /// routing::any, /// Router, /// }; /// /// async fn handler() {} /// /// // All requests to `/` will go to `handler`. /// let app = Router::new().route("/", any(handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Additional methods can still be chained: /// /// ```rust /// use axum::{ /// routing::any, /// Router, /// }; /// /// async fn handler() {} /// /// async fn other_handler() {} /// /// // `POST /` goes to `other_handler`. All other requests go to `handler` /// let app = Router::new().route("/", any(handler).post(other_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` pub fn any(handler: H) -> MethodRouter where H: Handler, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static, { MethodRouter::new().fallback(handler).skip_allow_header() } /// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers and services. /// /// # When does `MethodRouter` implement [`Service`]? /// /// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires. /// /// ``` /// use tower::Service; /// use axum::{routing::get, extract::State, body::Body, http::Request}; /// /// // this `MethodRouter` doesn't require any state, i.e. the state is `()`, /// let method_router = get(|| async {}); /// // and thus it implements `Service` /// assert_service(method_router); /// /// // this requires a `String` and doesn't implement `Service` /// let method_router = get(|_: State| async {}); /// // until you provide the `String` with `.with_state(...)` /// let method_router_with_state = method_router.with_state(String::new()); /// // and then it implements `Service` /// assert_service(method_router_with_state); /// /// // helper to check that a value implements `Service` /// fn assert_service(service: S) /// where /// S: Service>, /// {} /// ``` #[must_use] pub struct MethodRouter { get: MethodEndpoint, head: MethodEndpoint, delete: MethodEndpoint, options: MethodEndpoint, patch: MethodEndpoint, post: MethodEndpoint, put: MethodEndpoint, trace: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } #[derive(Clone, Debug)] enum AllowHeader { /// No `Allow` header value has been built-up yet. This is the default state None, /// Don't set an `Allow` header. This is used when `any` or `any_service` are called. Skip, /// The current value of the `Allow` header. Bytes(BytesMut), } impl AllowHeader { fn merge(self, other: Self) -> Self { match (self, other) { (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip, (AllowHeader::None, AllowHeader::None) => AllowHeader::None, (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick), (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick), (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => { a.extend_from_slice(b","); a.extend_from_slice(&b); AllowHeader::Bytes(a) } } } } impl fmt::Debug for MethodRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MethodRouter") .field("get", &self.get) .field("head", &self.head) .field("delete", &self.delete) .field("options", &self.options) .field("patch", &self.patch) .field("post", &self.post) .field("put", &self.put) .field("trace", &self.trace) .field("fallback", &self.fallback) .field("allow_header", &self.allow_header) .finish() } } impl MethodRouter where B: HttpBody + Send + 'static, S: Clone, { /// Chain an additional handler that will accept requests matching the given /// `MethodFilter`. /// /// # Example /// /// ```rust /// use axum::{ /// routing::get, /// Router, /// routing::MethodFilter /// }; /// /// async fn handler() {} /// /// async fn other_handler() {} /// /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to /// // `other_handler` /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[track_caller] pub fn on(self, filter: MethodFilter, handler: H) -> Self where H: Handler, T: 'static, S: Send + Sync + 'static, { self.on_endpoint( filter, MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)), ) } chained_handler_fn!(delete, DELETE); chained_handler_fn!(get, GET); chained_handler_fn!(head, HEAD); chained_handler_fn!(options, OPTIONS); chained_handler_fn!(patch, PATCH); chained_handler_fn!(post, POST); chained_handler_fn!(put, PUT); chained_handler_fn!(trace, TRACE); /// Add a fallback [`Handler`] to the router. pub fn fallback(mut self, handler: H) -> Self where H: Handler, T: 'static, S: Send + Sync + 'static, { self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } } impl MethodRouter<(), B, Infallible> where B: HttpBody + Send + 'static, { /// Convert the handler into a [`MakeService`]. /// /// This allows you to serve a single handler if you don't need any routing: /// /// ```rust /// use axum::{ /// Server, /// handler::Handler, /// http::{Uri, Method}, /// response::IntoResponse, /// routing::get, /// }; /// use std::net::SocketAddr; /// /// async fn handler(method: Method, uri: Uri, body: String) -> String { /// format!("received `{} {}` with body `{:?}`", method, uri, body) /// } /// /// let router = get(handler).post(handler); /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) /// .serve(router.into_make_service()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService pub fn into_make_service(self) -> IntoMakeService { IntoMakeService::new(self.with_state(())) } /// Convert the router into a [`MakeService`] which stores information /// about the incoming connection. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// ```rust /// use axum::{ /// Server, /// handler::Handler, /// response::IntoResponse, /// extract::ConnectInfo, /// routing::get, /// }; /// use std::net::SocketAddr; /// /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { /// format!("Hello {}", addr) /// } /// /// let router = get(handler).post(handler); /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) /// .serve(router.into_make_service_with_connect_info::()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[cfg(feature = "tokio")] pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self.with_state(())) } } impl MethodRouter where B: HttpBody + Send + 'static, S: Clone, { /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all /// requests. pub fn new() -> Self { let fallback = Route::new(service_fn(|_: Request| async { Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()) })); Self { get: MethodEndpoint::None, head: MethodEndpoint::None, delete: MethodEndpoint::None, options: MethodEndpoint::None, patch: MethodEndpoint::None, post: MethodEndpoint::None, put: MethodEndpoint::None, trace: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } } /// Provide the state for the router. pub fn with_state(self, state: S) -> MethodRouter { MethodRouter { get: self.get.with_state(&state), head: self.head.with_state(&state), delete: self.delete.with_state(&state), options: self.options.with_state(&state), patch: self.patch.with_state(&state), post: self.post.with_state(&state), put: self.put.with_state(&state), trace: self.trace.with_state(&state), allow_header: self.allow_header, fallback: self.fallback.with_state(state), } } /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// /// # Example /// /// ```rust /// use axum::{ /// http::Request, /// Router, /// routing::{MethodFilter, on_service}, /// }; /// use http::Response; /// use std::convert::Infallible; /// use hyper::Body; /// /// let service = tower::service_fn(|request: Request| async { /// Ok::<_, Infallible>(Response::new(Body::empty())) /// }); /// /// // Requests to `DELETE /` will go to `service` /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` #[track_caller] pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where T: Service, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc))) } #[track_caller] fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint) -> Self { // written as a separate function to generate less IR #[track_caller] fn set_endpoint( method_name: &str, out: &mut MethodEndpoint, endpoint: &MethodEndpoint, endpoint_filter: MethodFilter, filter: MethodFilter, allow_header: &mut AllowHeader, methods: &[&'static str], ) where MethodEndpoint: Clone, S: Clone, { if endpoint_filter.contains(filter) { if out.is_some() { panic!( "Overlapping method route. Cannot add two method routes that both handle \ `{method_name}`", ) } *out = endpoint.clone(); for method in methods { append_allow_header(allow_header, method); } } } set_endpoint( "GET", &mut self.get, &endpoint, filter, MethodFilter::GET, &mut self.allow_header, &["GET", "HEAD"], ); set_endpoint( "HEAD", &mut self.head, &endpoint, filter, MethodFilter::HEAD, &mut self.allow_header, &["HEAD"], ); set_endpoint( "TRACE", &mut self.trace, &endpoint, filter, MethodFilter::TRACE, &mut self.allow_header, &["TRACE"], ); set_endpoint( "PUT", &mut self.put, &endpoint, filter, MethodFilter::PUT, &mut self.allow_header, &["PUT"], ); set_endpoint( "POST", &mut self.post, &endpoint, filter, MethodFilter::POST, &mut self.allow_header, &["POST"], ); set_endpoint( "PATCH", &mut self.patch, &endpoint, filter, MethodFilter::PATCH, &mut self.allow_header, &["PATCH"], ); set_endpoint( "OPTIONS", &mut self.options, &endpoint, filter, MethodFilter::OPTIONS, &mut self.allow_header, &["OPTIONS"], ); set_endpoint( "DELETE", &mut self.delete, &endpoint, filter, MethodFilter::DELETE, &mut self.allow_header, &["DELETE"], ); self } chained_service_fn!(delete_service, DELETE); chained_service_fn!(get_service, GET); chained_service_fn!(head_service, HEAD); chained_service_fn!(options_service, OPTIONS); chained_service_fn!(patch_service, PATCH); chained_service_fn!(post_service, POST); chained_service_fn!(put_service, PUT); chained_service_fn!(trace_service, TRACE); #[doc = include_str!("../docs/method_routing/fallback.md")] pub fn fallback_service(mut self, svc: T) -> Self where T: Service, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { self.fallback = Fallback::Service(Route::new(svc)); self } #[doc = include_str!("../docs/method_routing/layer.md")] pub fn layer(self, layer: L) -> MethodRouter where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, E: 'static, S: 'static, NewReqBody: HttpBody + 'static, NewError: 'static, { let layer_fn = move |route: Route| route.layer(layer.clone()); MethodRouter { get: self.get.map(layer_fn.clone()), head: self.head.map(layer_fn.clone()), delete: self.delete.map(layer_fn.clone()), options: self.options.map(layer_fn.clone()), patch: self.patch.map(layer_fn.clone()), post: self.post.map(layer_fn.clone()), put: self.put.map(layer_fn.clone()), trace: self.trace.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, } } #[doc = include_str!("../docs/method_routing/route_layer.md")] #[track_caller] pub fn route_layer(mut self, layer: L) -> MethodRouter where L: Layer> + Clone + Send + 'static, L::Service: Service, Error = E> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Future: Send + 'static, E: 'static, S: 'static, { if self.get.is_none() && self.head.is_none() && self.delete.is_none() && self.options.is_none() && self.patch.is_none() && self.post.is_none() && self.put.is_none() && self.trace.is_none() { panic!( "Adding a route_layer before any routes is a no-op. \ Add the routes you want the layer to apply to first." ); } let layer_fn = move |svc| { let svc = layer.layer(svc); let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); Route::new(svc) }; self.get = self.get.map(layer_fn.clone()); self.head = self.head.map(layer_fn.clone()); self.delete = self.delete.map(layer_fn.clone()); self.options = self.options.map(layer_fn.clone()); self.patch = self.patch.map(layer_fn.clone()); self.post = self.post.map(layer_fn.clone()); self.put = self.put.map(layer_fn.clone()); self.trace = self.trace.map(layer_fn); self } #[track_caller] pub(crate) fn merge_for_path( mut self, path: Option<&str>, other: MethodRouter, ) -> Self { // written using inner functions to generate less IR #[track_caller] fn merge_inner( path: Option<&str>, name: &str, first: MethodEndpoint, second: MethodEndpoint, ) -> MethodEndpoint { match (first, second) { (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None, (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick, _ => { if let Some(path) = path { panic!( "Overlapping method route. Handler for `{name} {path}` already exists" ); } else { panic!( "Overlapping method route. Cannot merge two method routes that both \ define `{name}`" ); } } } } self.get = merge_inner(path, "GET", self.get, other.get); self.head = merge_inner(path, "HEAD", self.head, other.head); self.delete = merge_inner(path, "DELETE", self.delete, other.delete); self.options = merge_inner(path, "OPTIONS", self.options, other.options); self.patch = merge_inner(path, "PATCH", self.patch, other.patch); self.post = merge_inner(path, "POST", self.post, other.post); self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); self.fallback = self .fallback .merge(other.fallback) .expect("Cannot merge two `MethodRouter`s that both have a fallback"); self.allow_header = self.allow_header.merge(other.allow_header); self } #[doc = include_str!("../docs/method_routing/merge.md")] #[track_caller] pub fn merge(self, other: MethodRouter) -> Self { self.merge_for_path(None, other) } /// Apply a [`HandleErrorLayer`]. /// /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. pub fn handle_error(self, f: F) -> MethodRouter where F: Clone + Send + Sync + 'static, HandleError, F, T>: Service, Error = Infallible>, , F, T> as Service>>::Future: Send, , F, T> as Service>>::Response: IntoResponse + Send, T: 'static, E: 'static, B: 'static, S: 'static, { self.layer(HandleErrorLayer::new(f)) } fn skip_allow_header(mut self) -> Self { self.allow_header = AllowHeader::Skip; self } pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { macro_rules! call { ( $req:expr, $method:expr, $method_variant:ident, $svc:expr ) => { if $method == Method::$method_variant { match $svc { MethodEndpoint::None => {} MethodEndpoint::Route(route) => { return RouteFuture::from_future(route.oneshot_inner($req)) .strip_body($method == Method::HEAD); } MethodEndpoint::BoxedHandler(handler) => { let mut route = handler.clone().into_route(state); return RouteFuture::from_future(route.oneshot_inner($req)) .strip_body($method == Method::HEAD); } } } }; } let method = req.method().clone(); // written with a pattern match like this to ensure we call all routes let Self { get, head, delete, options, patch, post, put, trace, fallback, allow_header, } = self; call!(req, method, HEAD, head); call!(req, method, HEAD, get); call!(req, method, GET, get); call!(req, method, POST, post); call!(req, method, OPTIONS, options); call!(req, method, PATCH, patch); call!(req, method, PUT, put); call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); let future = fallback.call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), AllowHeader::Skip => future, AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()), } } } fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { match allow_header { AllowHeader::None => { *allow_header = AllowHeader::Bytes(BytesMut::from(method)); } AllowHeader::Skip => {} AllowHeader::Bytes(allow_header) => { if let Ok(s) = std::str::from_utf8(allow_header) { if !s.contains(method) { allow_header.extend_from_slice(b","); allow_header.extend_from_slice(method.as_bytes()); } } else { #[cfg(debug_assertions)] panic!("`allow_header` contained invalid uft-8. This should never happen") } } } } impl Clone for MethodRouter { fn clone(&self) -> Self { Self { get: self.get.clone(), head: self.head.clone(), delete: self.delete.clone(), options: self.options.clone(), patch: self.patch.clone(), post: self.post.clone(), put: self.put.clone(), trace: self.trace.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), } } } impl Default for MethodRouter where B: HttpBody + Send + 'static, S: Clone, { fn default() -> Self { Self::new() } } enum MethodEndpoint { None, Route(Route), BoxedHandler(BoxedIntoRoute), } impl MethodEndpoint where S: Clone, { fn is_some(&self) -> bool { matches!(self, Self::Route(_) | Self::BoxedHandler(_)) } fn is_none(&self) -> bool { matches!(self, Self::None) } fn map(self, f: F) -> MethodEndpoint where S: 'static, B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static, { match self { Self::None => MethodEndpoint::None, Self::Route(route) => MethodEndpoint::Route(f(route)), Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)), } } fn with_state(self, state: &S) -> MethodEndpoint { match self { MethodEndpoint::None => MethodEndpoint::None, MethodEndpoint::Route(route) => MethodEndpoint::Route(route), MethodEndpoint::BoxedHandler(handler) => { MethodEndpoint::Route(handler.into_route(state.clone())) } } } } impl Clone for MethodEndpoint { fn clone(&self) -> Self { match self { Self::None => Self::None, Self::Route(inner) => Self::Route(inner.clone()), Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()), } } } impl fmt::Debug for MethodEndpoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::None => f.debug_tuple("None").finish(), Self::Route(inner) => inner.fmt(f), Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(), } } } impl Service> for MethodRouter<(), B, E> where B: HttpBody + Send + 'static, { type Response = Response; type Error = E; type Future = RouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[inline] fn call(&mut self, req: Request) -> Self::Future { self.call_with_state(req, ()) } } impl Handler<(), S, B> for MethodRouter where S: Clone + 'static, B: HttpBody + Send + 'static, { type Future = InfallibleRouteFuture; fn call(mut self, req: Request, state: S) -> Self::Future { InfallibleRouteFuture::new(self.call_with_state(req, state)) } } #[cfg(test)] mod tests { use super::*; use crate::{ body::Body, error_handling::HandleErrorLayer, extract::State, handler::HandlerWithoutStateExt, }; use axum_core::response::IntoResponse; use http::{header::ALLOW, HeaderMap}; use std::time::Duration; use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt}; use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer}; #[crate::test] async fn method_not_allowed_by_default() { let mut svc = MethodRouter::new(); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert!(body.is_empty()); } #[crate::test] async fn get_service_fn() { async fn handle(_req: Request) -> Result, Infallible> { Ok(Response::new(Body::from("ok"))) } let mut svc = get_service(service_fn(handle)); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(body, "ok"); } #[crate::test] async fn get_handler() { let mut svc = MethodRouter::new().get(ok); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(body, "ok"); } #[crate::test] async fn get_accepts_head() { let mut svc = MethodRouter::new().get(ok); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(body.is_empty()); } #[crate::test] async fn head_takes_precedence_over_get() { let mut svc = MethodRouter::new().head(created).get(ok); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::CREATED); assert!(body.is_empty()); } #[crate::test] async fn merge() { let mut svc = get(ok).merge(post(ok)); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); let (status, _, _) = call(Method::POST, &mut svc).await; assert_eq!(status, StatusCode::OK); } #[crate::test] async fn layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) .layer(ValidateRequestHeaderLayer::bearer("password")); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); // method without route let (status, _, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); } #[crate::test] async fn route_layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) .route_layer(ValidateRequestHeaderLayer::bearer("password")); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); // method without route let (status, _, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); } #[allow(dead_code)] fn buiding_complex_router() { let app = crate::Router::new().route( "/", // use the all the things ๐Ÿ’ฃ๏ธ get(ok) .post(ok) .route_layer(ValidateRequestHeaderLayer::bearer("password")) .merge(delete_service(ServeDir::new("."))) .fallback(|| async { StatusCode::NOT_FOUND }) .put(ok) .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_| async { StatusCode::REQUEST_TIMEOUT })) .layer(TimeoutLayer::new(Duration::from_secs(10))), ), ); crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service()); } #[crate::test] async fn sets_allow_header() { let mut svc = MethodRouter::new().put(ok).patch(ok); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "PUT,PATCH"); } #[crate::test] async fn sets_allow_header_get_head() { let mut svc = MethodRouter::new().get(ok).head(ok); let (status, headers, _) = call(Method::PUT, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,HEAD"); } #[crate::test] async fn empty_allow_header_by_default() { let mut svc = MethodRouter::new(); let (status, headers, _) = call(Method::PATCH, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], ""); } #[crate::test] async fn allow_header_when_merging() { let a = put(ok).patch(ok); let b = get(ok).head(ok); let mut svc = a.merge(b); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD"); } #[crate::test] async fn allow_header_any() { let mut svc = any(ok); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(!headers.contains_key(ALLOW)); } #[crate::test] async fn allow_header_with_fallback() { let mut svc = MethodRouter::new() .get(ok) .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,HEAD"); } #[crate::test] async fn allow_header_with_fallback_that_sets_allow() { async fn fallback(method: Method) -> Response { if method == Method::POST { "OK".into_response() } else { ( StatusCode::METHOD_NOT_ALLOWED, [(ALLOW, "GET,POST")], "Method not allowed", ) .into_response() } } let mut svc = MethodRouter::new().get(ok).fallback(fallback); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); let (status, _, _) = call(Method::POST, &mut svc).await; assert_eq!(status, StatusCode::OK); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,POST"); } #[crate::test] async fn allow_header_noop_middleware() { let mut svc = MethodRouter::new() .get(ok) .layer(tower::layer::util::Identity::new()); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,HEAD"); } #[crate::test] #[should_panic( expected = "Overlapping method route. Cannot add two method routes that both handle `GET`" )] async fn handler_overlaps() { let _: MethodRouter<()> = get(ok).get(ok); } #[crate::test] #[should_panic( expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" )] async fn service_overlaps() { let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service()); } #[crate::test] async fn get_head_does_not_overlap() { let _: MethodRouter<()> = get(ok).head(ok); } #[crate::test] async fn head_get_does_not_overlap() { let _: MethodRouter<()> = head(ok).get(ok); } #[crate::test] async fn accessing_state() { let mut svc = MethodRouter::new() .get(|State(state): State<&'static str>| async move { state }) .with_state("state"); let (status, _, text) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(text, "state"); } #[crate::test] async fn fallback_accessing_state() { let mut svc = MethodRouter::new() .fallback(|State(state): State<&'static str>| async move { state }) .with_state("state"); let (status, _, text) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(text, "state"); } #[crate::test] async fn merge_accessing_state() { let one = get(|State(state): State<&'static str>| async move { state }); let two = post(|State(state): State<&'static str>| async move { state }); let mut svc = one.merge(two).with_state("state"); let (status, _, text) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(text, "state"); let (status, _, _) = call(Method::POST, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(text, "state"); } async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service, Error = Infallible>, S::Response: IntoResponse, { let request = Request::builder() .uri("/") .method(method) .body(Body::empty()) .unwrap(); let response = svc .ready() .await .unwrap() .call(request) .await .unwrap() .into_response(); let (parts, body) = response.into_parts(); let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap(); (parts.status, parts.headers, body) } async fn ok() -> (StatusCode, &'static str) { (StatusCode::OK, "ok") } async fn created() -> (StatusCode, &'static str) { (StatusCode::CREATED, "created") } } axum-axum-v0.6.20/axum/src/routing/mod.rs000066400000000000000000000431651446276533600203150ustar00rootroot00000000000000//! Routing between [`Service`]s and handlers. use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, HttpBody}, boxed::BoxedIntoRoute, handler::Handler, util::try_downcast, }; use axum_core::response::{IntoResponse, Response}; use http::Request; use std::{ convert::Infallible, fmt, task::{Context, Poll}, }; use sync_wrapper::SyncWrapper; use tower_layer::Layer; use tower_service::Service; pub mod future; pub mod method_routing; mod into_make_service; mod method_filter; mod not_found; pub(crate) mod path_router; mod route; mod strip_prefix; pub(crate) mod url_params; #[cfg(test)] mod tests; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, options, options_service, patch, patch_service, post, post_service, put, put_service, trace, trace_service, MethodRouter, }; macro_rules! panic_on_err { ($expr:expr) => { match $expr { Ok(x) => x, Err(err) => panic!("{err}"), } }; } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) struct RouteId(u32); /// The router type for composing handlers and services. #[must_use] pub struct Router { path_router: PathRouter, fallback_router: PathRouter, default_fallback: bool, catch_all_fallback: Fallback, } impl Clone for Router { fn clone(&self) -> Self { Self { path_router: self.path_router.clone(), fallback_router: self.fallback_router.clone(), default_fallback: self.default_fallback, catch_all_fallback: self.catch_all_fallback.clone(), } } } impl Default for Router where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { fn default() -> Self { Self::new() } } impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") .field("path_router", &self.path_router) .field("fallback_router", &self.fallback_router) .field("default_fallback", &self.default_fallback) .field("catch_all_fallback", &self.catch_all_fallback) .finish() } } pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback"; impl Router where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { /// Create a new `Router`. /// /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { Self { path_router: Default::default(), fallback_router: PathRouter::new_fallback(), default_fallback: true, catch_all_fallback: Fallback::Default(Route::new(NotFound)), } } #[doc = include_str!("../docs/routing/route.md")] #[track_caller] pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { panic_on_err!(self.path_router.route(path, method_router)); self } #[doc = include_str!("../docs/routing/route_service.md")] pub fn route_service(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let service = match try_downcast::, _>(service) { Ok(_) => { panic!( "Invalid route: `Router::route_service` cannot be used with `Router`s. \ Use `Router::nest` instead" ); } Err(service) => service, }; panic_on_err!(self.path_router.route_service(path, service)); self } #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] pub fn nest(mut self, path: &str, router: Router) -> Self { let Router { path_router, fallback_router, default_fallback, // we don't need to inherit the catch-all fallback. It is only used for CONNECT // requests with an empty path. If we were to inherit the catch-all fallback // it would end up matching `/{path}/*` which doesn't match empty paths. catch_all_fallback: _, } = router; panic_on_err!(self.path_router.nest(path, path_router)); if !default_fallback { panic_on_err!(self.fallback_router.nest(path, fallback_router)); } self } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] pub fn nest_service(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { panic_on_err!(self.path_router.nest_service(path, service)); self } #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] pub fn merge(mut self, other: R) -> Self where R: Into>, { const PANIC_MSG: &str = "Failed to merge fallbacks. This is a bug in axum. Please file an issue"; let Router { path_router, fallback_router: mut other_fallback, default_fallback, catch_all_fallback, } = other.into(); panic_on_err!(self.path_router.merge(path_router)); match (self.default_fallback, default_fallback) { // both have the default fallback // use the one from other (true, true) => { self.fallback_router.merge(other_fallback).expect(PANIC_MSG); } // self has default fallback, other has a custom fallback (true, false) => { self.fallback_router.merge(other_fallback).expect(PANIC_MSG); self.default_fallback = false; } // self has a custom fallback, other has a default (false, true) => { let fallback_router = std::mem::take(&mut self.fallback_router); other_fallback.merge(fallback_router).expect(PANIC_MSG); self.fallback_router = other_fallback; } // both have a custom fallback, not allowed (false, false) => { panic!("Cannot merge two `Router`s that both have a fallback") } }; self.catch_all_fallback = self .catch_all_fallback .merge(catch_all_fallback) .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); self } #[doc = include_str!("../docs/routing/layer.md")] pub fn layer(self, layer: L) -> Router where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, NewReqBody: HttpBody + 'static, { Router { path_router: self.path_router.layer(layer.clone()), fallback_router: self.fallback_router.layer(layer.clone()), default_fallback: self.default_fallback, catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), } } #[doc = include_str!("../docs/routing/route_layer.md")] #[track_caller] pub fn route_layer(self, layer: L) -> Self where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, { Router { path_router: self.path_router.route_layer(layer), fallback_router: self.fallback_router, default_fallback: self.default_fallback, catch_all_fallback: self.catch_all_fallback, } } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] pub fn fallback(mut self, handler: H) -> Self where H: Handler, T: 'static, { self.catch_all_fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. pub fn fallback_service(mut self, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let route = Route::new(service); self.catch_all_fallback = Fallback::Service(route.clone()); self.fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { self.fallback_router.set_fallback(endpoint); self.default_fallback = false; self } #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { Router { path_router: self.path_router.with_state(state.clone()), fallback_router: self.fallback_router.with_state(state.clone()), default_fallback: self.default_fallback, catch_all_fallback: self.catch_all_fallback.with_state(state), } } pub(crate) fn call_with_state( &mut self, mut req: Request, state: S, ) -> RouteFuture { // required for opaque routers to still inherit the fallback // TODO(david): remove this feature in 0.7 if !self.default_fallback { req.extensions_mut().insert(SuperFallback(SyncWrapper::new( self.fallback_router.clone(), ))); } match self.path_router.call_with_state(req, state) { Ok(future) => future, Err((mut req, state)) => { let super_fallback = req .extensions_mut() .remove::>() .map(|SuperFallback(path_router)| path_router.into_inner()); if let Some(mut super_fallback) = super_fallback { match super_fallback.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => { return self.catch_all_fallback.call_with_state(req, state); } } } match self.fallback_router.call_with_state(req, state) { Ok(future) => future, Err((req, state)) => self.catch_all_fallback.call_with_state(req, state), } } } } } impl Router<(), B> where B: HttpBody + Send + 'static, { /// Convert this router into a [`MakeService`], that is a [`Service`] whose /// response is another service. /// /// This is useful when running your application with hyper's /// [`Server`](hyper::server::Server): /// /// ``` /// use axum::{ /// routing::get, /// Router, /// }; /// /// let app = Router::new().route("/", get(|| async { "Hi!" })); /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(app.into_make_service()) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService pub fn into_make_service(self) -> IntoMakeService { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request IntoMakeService::new(self.with_state(())) } #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")] #[cfg(feature = "tokio")] pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request IntoMakeServiceWithConnectInfo::new(self.with_state(())) } } impl Service> for Router<(), B> where B: HttpBody + Send + 'static, { type Response = Response; type Error = Infallible; type Future = RouteFuture; #[inline] fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[inline] fn call(&mut self, req: Request) -> Self::Future { self.call_with_state(req, ()) } } enum Fallback { Default(Route), Service(Route), BoxedHandler(BoxedIntoRoute), } impl Fallback where S: Clone, { fn merge(self, other: Self) -> Option { match (self, other) { (Self::Default(_), pick @ Self::Default(_)) => Some(pick), (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick), _ => None, } } fn map(self, f: F) -> Fallback where S: 'static, B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static, { match self { Self::Default(route) => Fallback::Default(f(route)), Self::Service(route) => Fallback::Service(f(route)), Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)), } } fn with_state(self, state: S) -> Fallback { match self { Fallback::Default(route) => Fallback::Default(route), Fallback::Service(route) => Fallback::Service(route), Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), } } fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { match self { Fallback::Default(route) | Fallback::Service(route) => { RouteFuture::from_future(route.oneshot_inner(req)) } Fallback::BoxedHandler(handler) => { let mut route = handler.clone().into_route(state); RouteFuture::from_future(route.oneshot_inner(req)) } } } } impl Clone for Fallback { fn clone(&self) -> Self { match self { Self::Default(inner) => Self::Default(inner.clone()), Self::Service(inner) => Self::Service(inner.clone()), Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()), } } } impl fmt::Debug for Fallback { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(), Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(), } } } #[allow(clippy::large_enum_variant)] enum Endpoint { MethodRouter(MethodRouter), Route(Route), } impl Endpoint where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { fn layer(self, layer: L) -> Endpoint where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, NewReqBody: HttpBody + 'static, { match self { Endpoint::MethodRouter(method_router) => { Endpoint::MethodRouter(method_router.layer(layer)) } Endpoint::Route(route) => Endpoint::Route(route.layer(layer)), } } } impl Clone for Endpoint { fn clone(&self) -> Self { match self { Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()), Self::Route(inner) => Self::Route(inner.clone()), } } } impl fmt::Debug for Endpoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MethodRouter(method_router) => { f.debug_tuple("MethodRouter").field(method_router).finish() } Self::Route(route) => f.debug_tuple("Route").field(route).finish(), } } } struct SuperFallback(SyncWrapper>); #[test] #[allow(warnings)] fn traits() { use crate::test_helpers::*; assert_send::>(); } axum-axum-v0.6.20/axum/src/routing/not_found.rs000066400000000000000000000016211446276533600215200ustar00rootroot00000000000000use crate::response::Response; use axum_core::response::IntoResponse; use http::{Request, StatusCode}; use std::{ convert::Infallible, future::ready, task::{Context, Poll}, }; use tower_service::Service; /// A [`Service`] that responds with `404 Not Found` to all requests. /// /// This is used as the bottom service in a method router. You shouldn't have to /// use it manually. #[derive(Clone, Copy, Debug)] pub(super) struct NotFound; impl Service> for NotFound where B: Send + 'static, { type Response = Response; type Error = Infallible; type Future = std::future::Ready>; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: Request) -> Self::Future { ready(Ok(StatusCode::NOT_FOUND.into_response())) } } axum-axum-v0.6.20/axum/src/routing/path_router.rs000066400000000000000000000362041446276533600220660ustar00rootroot00000000000000use crate::body::HttpBody; use axum_core::response::IntoResponse; use http::Request; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; use tower_layer::Layer; use tower_service::Service; use super::{ future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM, }; pub(super) struct PathRouter { routes: HashMap>, node: Arc, prev_route_id: RouteId, } impl PathRouter where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { pub(super) fn new_fallback() -> Self { let mut this = Self::default(); this.set_fallback(Endpoint::Route(Route::new(NotFound))); this } pub(super) fn set_fallback(&mut self, endpoint: Endpoint) { self.replace_endpoint("/", endpoint.clone()); self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint); } } impl PathRouter where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { pub(super) fn route( &mut self, path: &str, method_router: MethodRouter, ) -> Result<(), Cow<'static, str>> { fn validate_path(path: &str) -> Result<(), &'static str> { if path.is_empty() { return Err("Paths must start with a `/`. Use \"/\" for root routes"); } else if !path.starts_with('/') { return Err("Paths must start with a `/`"); } Ok(()) } validate_path(path)?; let id = self.next_route_id(); let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node .path_to_route_id .get(path) .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) { // if we're adding a new `MethodRouter` to a route that already has one just // merge them. This makes `.route("/", get(_)).route("/", post(_))` work let service = Endpoint::MethodRouter( prev_method_router .clone() .merge_for_path(Some(path), method_router), ); self.routes.insert(route_id, service); return Ok(()); } else { Endpoint::MethodRouter(method_router) }; self.set_node(path, id)?; self.routes.insert(id, endpoint); Ok(()) } pub(super) fn route_service( &mut self, path: &str, service: T, ) -> Result<(), Cow<'static, str>> where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { self.route_endpoint(path, Endpoint::Route(Route::new(service))) } pub(super) fn route_endpoint( &mut self, path: &str, endpoint: Endpoint, ) -> Result<(), Cow<'static, str>> { if path.is_empty() { return Err("Paths must start with a `/`. Use \"/\" for root routes".into()); } else if !path.starts_with('/') { return Err("Paths must start with a `/`".into()); } let id = self.next_route_id(); self.set_node(path, id)?; self.routes.insert(id, endpoint); Ok(()) } fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> { let mut node = Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); if let Err(err) = node.insert(path, id) { return Err(format!("Invalid route {path:?}: {err}")); } self.node = Arc::new(node); Ok(()) } pub(super) fn merge( &mut self, other: PathRouter, ) -> Result<(), Cow<'static, str>> { let PathRouter { routes, node, prev_route_id: _, } = other; for (id, route) in routes { let path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) { // when merging two routers it doesn't matter if you do `a.merge(b)` or // `b.merge(a)`. This must also be true for fallbacks. // // However all fallback routers will have routes for `/` and `/*` so when merging // we have to ignore the top level fallbacks on one side otherwise we get // conflicts. // // `Router::merge` makes sure that when merging fallbacks `other` always has the // fallback we want to keep. It panics if both routers have a custom fallback. Thus // it is always okay to ignore one fallback and `Router::merge` also makes sure the // one we can ignore is that of `self`. self.replace_endpoint(path, route); } else { match route { Endpoint::MethodRouter(method_router) => self.route(path, method_router)?, Endpoint::Route(route) => self.route_service(path, route)?, } } } Ok(()) } pub(super) fn nest( &mut self, path: &str, router: PathRouter, ) -> Result<(), Cow<'static, str>> { let prefix = validate_nest_path(path); let PathRouter { routes, node, prev_route_id: _, } = router; for (id, endpoint) in routes { let inner_path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); let path = path_for_nested_route(prefix, inner_path); match endpoint.layer(StripPrefix::layer(prefix)) { Endpoint::MethodRouter(method_router) => { self.route(&path, method_router)?; } Endpoint::Route(route) => { self.route_endpoint(&path, Endpoint::Route(route))?; } } } Ok(()) } pub(super) fn nest_service(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let path = validate_nest_path(path); let prefix = path; let path = if path.ends_with('/') { format!("{path}*{NEST_TAIL_PARAM}") } else { format!("{path}/*{NEST_TAIL_PARAM}") }; let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))); self.route_endpoint(&path, endpoint.clone())?; // `/*rest` is not matched by `/` so we need to also register a router at the // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself // wouldn't match, which it should self.route_endpoint(prefix, endpoint.clone())?; if !prefix.ends_with('/') { // same goes for `/foo/`, that should also match self.route_endpoint(&format!("{prefix}/"), endpoint)?; } Ok(()) } pub(super) fn layer(self, layer: L) -> PathRouter where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, NewReqBody: HttpBody + 'static, { let routes = self .routes .into_iter() .map(|(id, endpoint)| { let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } #[track_caller] pub(super) fn route_layer(self, layer: L) -> Self where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, { if self.routes.is_empty() { panic!( "Adding a route_layer before any routes is a no-op. \ Add the routes you want the layer to apply to first." ); } let routes = self .routes .into_iter() .map(|(id, endpoint)| { let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } pub(super) fn with_state(self, state: S) -> PathRouter { let routes = self .routes .into_iter() .map(|(id, endpoint)| { let endpoint: Endpoint = match endpoint { Endpoint::MethodRouter(method_router) => { Endpoint::MethodRouter(method_router.with_state(state.clone())) } Endpoint::Route(route) => Endpoint::Route(route), }; (id, endpoint) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } pub(super) fn call_with_state( &mut self, mut req: Request, state: S, ) -> Result, (Request, S)> { #[cfg(feature = "original-uri")] { use crate::extract::OriginalUri; if req.extensions().get::().is_none() { let original_uri = OriginalUri(req.uri().clone()); req.extensions_mut().insert(original_uri); } } let path = req.uri().path().to_owned(); match self.node.at(&path) { Ok(match_) => { let id = *match_.value; if !IS_FALLBACK { #[cfg(feature = "matched-path")] crate::extract::matched_path::set_matched_path_for_request( id, &self.node.route_id_to_path, req.extensions_mut(), ); } url_params::insert_url_params(req.extensions_mut(), match_.params); let endpont = self .routes .get_mut(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); match endpont { Endpoint::MethodRouter(method_router) => { Ok(method_router.call_with_state(req, state)) } Endpoint::Route(route) => Ok(route.clone().call(req)), } } // explicitly handle all variants in case matchit adds // new ones we need to handle differently Err( MatchError::NotFound | MatchError::ExtraTrailingSlash | MatchError::MissingTrailingSlash, ) => Err((req, state)), } } pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint) { match self.node.at(path) { Ok(match_) => { let id = *match_.value; self.routes.insert(id, endpoint); } Err(_) => self .route_endpoint(path, endpoint) .expect("path wasn't matched so endpoint shouldn't exist"), } } fn next_route_id(&mut self) -> RouteId { let next_id = self .prev_route_id .0 .checked_add(1) .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); self.prev_route_id = RouteId(next_id); self.prev_route_id } } impl Default for PathRouter { fn default() -> Self { Self { routes: Default::default(), node: Default::default(), prev_route_id: RouteId(0), } } } impl fmt::Debug for PathRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PathRouter") .field("routes", &self.routes) .field("node", &self.node) .finish() } } impl Clone for PathRouter { fn clone(&self) -> Self { Self { routes: self.routes.clone(), node: self.node.clone(), prev_route_id: self.prev_route_id, } } } /// Wrapper around `matchit::Router` that supports merging two `Router`s. #[derive(Clone, Default)] struct Node { inner: matchit::Router, route_id_to_path: HashMap>, path_to_route_id: HashMap, RouteId>, } impl Node { fn insert( &mut self, path: impl Into, val: RouteId, ) -> Result<(), matchit::InsertError> { let path = path.into(); self.inner.insert(&path, val)?; let shared_path: Arc = path.into(); self.route_id_to_path.insert(val, shared_path.clone()); self.path_to_route_id.insert(shared_path, val); Ok(()) } fn at<'n, 'p>( &'n self, path: &'p str, ) -> Result, MatchError> { self.inner.at(path) } } impl fmt::Debug for Node { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Node") .field("paths", &self.route_id_to_path) .finish() } } #[track_caller] fn validate_nest_path(path: &str) -> &str { if path.is_empty() { // nesting at `""` and `"/"` should mean the same thing return "/"; } if path.contains('*') { panic!("Invalid route: nested routes cannot contain wildcards (*)"); } path } pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> { debug_assert!(prefix.starts_with('/')); debug_assert!(path.starts_with('/')); if prefix.ends_with('/') { format!("{prefix}{}", path.trim_start_matches('/')).into() } else if path == "/" { prefix.into() } else { format!("{prefix}{path}").into() } } axum-axum-v0.6.20/axum/src/routing/route.rs000066400000000000000000000152571446276533600206750ustar00rootroot00000000000000use crate::{ body::{boxed, Body, Empty, HttpBody}, response::Response, }; use axum_core::response::IntoResponse; use bytes::Bytes; use http::{ header::{self, CONTENT_LENGTH}, HeaderMap, HeaderValue, Request, }; use pin_project_lite::pin_project; use std::{ convert::Infallible, fmt, future::Future, pin::Pin, task::{Context, Poll}, }; use tower::{ util::{BoxCloneService, MapResponseLayer, Oneshot}, ServiceBuilder, ServiceExt, }; use tower_layer::Layer; use tower_service::Service; /// How routes are stored inside a [`Router`](super::Router). /// /// You normally shouldn't need to care about this type. It's used in /// [`Router::layer`](super::Router::layer). pub struct Route(BoxCloneService, Response, E>); impl Route { pub(crate) fn new(svc: T) -> Self where T: Service, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { Self(BoxCloneService::new( svc.map_response(IntoResponse::into_response), )) } pub(crate) fn oneshot_inner( &mut self, req: Request, ) -> Oneshot, Response, E>, Request> { self.0.clone().oneshot(req) } pub(crate) fn layer(self, layer: L) -> Route where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, NewReqBody: 'static, NewError: 'static, { let layer = ServiceBuilder::new() .map_err(Into::into) .layer(MapResponseLayer::new(IntoResponse::into_response)) .layer(layer) .into_inner(); Route::new(layer.layer(self)) } } impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) } } impl fmt::Debug for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Route").finish() } } impl Service> for Route where B: HttpBody, { type Response = Response; type Error = E; type Future = RouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[inline] fn call(&mut self, req: Request) -> Self::Future { RouteFuture::from_future(self.oneshot_inner(req)) } } pin_project! { /// Response future for [`Route`]. pub struct RouteFuture { #[pin] kind: RouteFutureKind, strip_body: bool, allow_header: Option, } } pin_project! { #[project = RouteFutureKindProj] enum RouteFutureKind { Future { #[pin] future: Oneshot< BoxCloneService, Response, E>, Request, >, }, Response { response: Option, } } } impl RouteFuture { pub(crate) fn from_future( future: Oneshot, Response, E>, Request>, ) -> Self { Self { kind: RouteFutureKind::Future { future }, strip_body: false, allow_header: None, } } pub(crate) fn strip_body(mut self, strip_body: bool) -> Self { self.strip_body = strip_body; self } pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { self.allow_header = Some(allow_header); self } } impl Future for RouteFuture where B: HttpBody, { type Output = Result; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = match this.kind.project() { RouteFutureKindProj::Future { future } => match future.poll(cx) { Poll::Ready(Ok(res)) => res, Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, }, RouteFutureKindProj::Response { response } => { response.take().expect("future polled after completion") } }; set_allow_header(res.headers_mut(), this.allow_header); // make sure to set content-length before removing the body set_content_length(res.size_hint(), res.headers_mut()); let res = if *this.strip_body { res.map(|_| boxed(Empty::new())) } else { res }; Poll::Ready(Ok(res)) } } fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option) { match allow_header.take() { Some(allow_header) if !headers.contains_key(header::ALLOW) => { headers.insert( header::ALLOW, HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"), ); } _ => {} } } fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) { if headers.contains_key(CONTENT_LENGTH) { return; } if let Some(size) = size_hint.exact() { let header_value = if size == 0 { #[allow(clippy::declare_interior_mutable_const)] const ZERO: HeaderValue = HeaderValue::from_static("0"); ZERO } else { let mut buffer = itoa::Buffer::new(); HeaderValue::from_str(buffer.format(size)).unwrap() }; headers.insert(CONTENT_LENGTH, header_value); } } pin_project! { /// A [`RouteFuture`] that always yields a [`Response`]. pub struct InfallibleRouteFuture { #[pin] future: RouteFuture, } } impl InfallibleRouteFuture { pub(crate) fn new(future: RouteFuture) -> Self { Self { future } } } impl Future for InfallibleRouteFuture where B: HttpBody, { type Output = Response; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match futures_util::ready!(self.project().future.poll(cx)) { Ok(response) => Poll::Ready(response), Err(err) => match err {}, } } } #[cfg(test)] mod tests { use super::*; #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); } } axum-axum-v0.6.20/axum/src/routing/strip_prefix.rs000066400000000000000000000265511446276533600222540ustar00rootroot00000000000000use http::{Request, Uri}; use std::{ sync::Arc, task::{Context, Poll}, }; use tower::Layer; use tower_layer::layer_fn; use tower_service::Service; #[derive(Clone)] pub(super) struct StripPrefix { inner: S, prefix: Arc, } impl StripPrefix { pub(super) fn new(inner: S, prefix: &str) -> Self { Self { inner, prefix: prefix.into(), } } pub(super) fn layer(prefix: &str) -> impl Layer + Clone { let prefix = Arc::from(prefix); layer_fn(move |inner| Self { inner, prefix: Arc::clone(&prefix), }) } } impl Service> for StripPrefix where S: Service>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) { *req.uri_mut() = new_uri; } self.inner.call(req) } } fn strip_prefix(uri: &Uri, prefix: &str) -> Option { let path_and_query = uri.path_and_query()?; // Check whether the prefix matches the path and if so how long the matching prefix is. // // For example: // // prefix = /api // path = /api/users // ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4 // characters we get the remainder // // prefix = /api/:version // path = /api/v0/users // ^^^^^^^ this much is matched and the length is 7. let mut matching_prefix_length = Some(0); for item in zip_longest(segments(path_and_query.path()), segments(prefix)) { // count the `/` *matching_prefix_length.as_mut().unwrap() += 1; match item { Item::Both(path_segment, prefix_segment) => { if prefix_segment.starts_with(':') || path_segment == prefix_segment { // the prefix segment is either a param, which matches anything, or // it actually matches the path segment *matching_prefix_length.as_mut().unwrap() += path_segment.len(); } else if prefix_segment.is_empty() { // the prefix ended in a `/` so we got a match. // // For example: // // prefix = /foo/ // path = /foo/bar // // The prefix matches and the new path should be `/bar` break; } else { // the prefix segment didn't match so there is no match matching_prefix_length = None; break; } } // the path had more segments than the prefix but we got a match. // // For example: // // prefix = /foo // path = /foo/bar Item::First(_) => { break; } // the prefix had more segments than the path so there is no match Item::Second(_) => { matching_prefix_length = None; break; } } } // if the prefix matches it will always do so up until a `/`, it cannot match only // part of a segment. Therefore this will always be at a char boundary and `split_at` wont // panic let after_prefix = uri.path().split_at(matching_prefix_length?).1; let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) { (true, None) => after_prefix.parse().unwrap(), (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(), (false, None) => format!("/{after_prefix}").parse().unwrap(), (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(), }; let mut parts = uri.clone().into_parts(); parts.path_and_query = Some(new_path_and_query); Some(Uri::from_parts(parts).unwrap()) } fn segments(s: &str) -> impl Iterator { assert!( s.starts_with('/'), "path didn't start with '/'. axum should have caught this higher up." ); s.split('/') // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"] // otherwise .skip(1) } fn zip_longest(a: I, b: I2) -> impl Iterator> where I: Iterator, I2: Iterator, { let a = a.map(Some).chain(std::iter::repeat_with(|| None)); let b = b.map(Some).chain(std::iter::repeat_with(|| None)); a.zip(b).map_while(|(a, b)| match (a, b) { (Some(a), Some(b)) => Some(Item::Both(a, b)), (Some(a), None) => Some(Item::First(a)), (None, Some(b)) => Some(Item::Second(b)), (None, None) => None, }) } #[derive(Debug)] enum Item { Both(T, T), First(T), Second(T), } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use quickcheck::Arbitrary; use quickcheck_macros::quickcheck; macro_rules! test { ( $name:ident, uri = $uri:literal, prefix = $prefix:literal, expected = $expected:expr, ) => { #[test] fn $name() { let uri = $uri.parse().unwrap(); let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string()); assert_eq!(new_uri.as_deref(), $expected); } }; } test!(empty, uri = "/", prefix = "/", expected = Some("/"),); test!( single_segment, uri = "/a", prefix = "/a", expected = Some("/"), ); test!( single_segment_root_uri, uri = "/", prefix = "/a", expected = None, ); // the prefix is empty, so removing it should have no effect test!( single_segment_root_prefix, uri = "/a", prefix = "/", expected = Some("/a"), ); test!( single_segment_no_match, uri = "/a", prefix = "/b", expected = None, ); test!( single_segment_trailing_slash, uri = "/a/", prefix = "/a/", expected = Some("/"), ); test!( single_segment_trailing_slash_2, uri = "/a", prefix = "/a/", expected = None, ); test!( single_segment_trailing_slash_3, uri = "/a/", prefix = "/a", expected = Some("/"), ); test!( multi_segment, uri = "/a/b", prefix = "/a", expected = Some("/b"), ); test!( multi_segment_2, uri = "/b/a", prefix = "/a", expected = None, ); test!( multi_segment_3, uri = "/a", prefix = "/a/b", expected = None, ); test!( multi_segment_4, uri = "/a/b", prefix = "/b", expected = None, ); test!( multi_segment_trailing_slash, uri = "/a/b/", prefix = "/a/b/", expected = Some("/"), ); test!( multi_segment_trailing_slash_2, uri = "/a/b", prefix = "/a/b/", expected = None, ); test!( multi_segment_trailing_slash_3, uri = "/a/b/", prefix = "/a/b", expected = Some("/"), ); test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),); test!( param_1, uri = "/a", prefix = "/:param", expected = Some("/"), ); test!( param_2, uri = "/a/b", prefix = "/:param", expected = Some("/b"), ); test!( param_3, uri = "/b/a", prefix = "/:param", expected = Some("/a"), ); test!( param_4, uri = "/a/b", prefix = "/a/:param", expected = Some("/"), ); test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,); test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,); test!( param_7, uri = "/b/a", prefix = "/:param/a", expected = Some("/"), ); test!( param_8, uri = "/a/b/c", prefix = "/a/:param/c", expected = Some("/"), ); test!( param_9, uri = "/c/b/a", prefix = "/a/:param/c", expected = None, ); test!( param_10, uri = "/a/", prefix = "/:param", expected = Some("/"), ); test!(param_11, uri = "/a", prefix = "/:param/", expected = None,); test!( param_12, uri = "/a/", prefix = "/:param/", expected = Some("/"), ); test!( param_13, uri = "/a/a", prefix = "/a/", expected = Some("/a"), ); #[quickcheck] fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool { let UriAndPrefix { uri, prefix } = uri_and_prefix; strip_prefix(&uri, &prefix); true } #[derive(Clone, Debug)] struct UriAndPrefix { uri: Uri, prefix: String, } impl Arbitrary for UriAndPrefix { fn arbitrary(g: &mut quickcheck::Gen) -> Self { let mut uri = String::new(); let mut prefix = String::new(); let size = u8_between(1, 20, g); for _ in 0..size { let segment = ascii_alphanumeric(g); uri.push('/'); uri.push_str(&segment); prefix.push('/'); let make_matching_segment = bool::arbitrary(g); let make_capture = bool::arbitrary(g); match (make_matching_segment, make_capture) { (_, true) => { prefix.push_str(":a"); } (true, false) => { prefix.push_str(&segment); } (false, false) => { prefix.push_str(&ascii_alphanumeric(g)); } } } if bool::arbitrary(g) { uri.push('/'); } if bool::arbitrary(g) { prefix.push('/'); } Self { uri: uri.parse().unwrap(), prefix, } } } fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String { #[derive(Clone)] struct AsciiAlphanumeric(String); impl Arbitrary for AsciiAlphanumeric { fn arbitrary(g: &mut quickcheck::Gen) -> Self { let mut out = String::new(); let size = u8_between(1, 20, g) as usize; while out.len() < size { let c = char::arbitrary(g); if c.is_ascii_alphanumeric() { out.push(c); } } Self(out) } } let out = AsciiAlphanumeric::arbitrary(g).0; assert!(!out.is_empty()); out } fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 { loop { let size = u8::arbitrary(g); if size > lower && size <= upper { break size; } } } } axum-axum-v0.6.20/axum/src/routing/tests/000077500000000000000000000000001446276533600203215ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/routing/tests/fallback.rs000066400000000000000000000253511446276533600224340ustar00rootroot00000000000000use tower::ServiceExt; use super::*; use crate::middleware::{map_request, map_response}; #[crate::test] async fn basic() { let app = Router::new() .route("/foo", get(|| async {})) .fallback(|| async { "fallback" }); let client = TestClient::new(app); assert_eq!(client.get("/foo").send().await.status(), StatusCode::OK); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } #[crate::test] async fn nest() { let app = Router::new() .nest("/foo", Router::new().route("/bar", get(|| async {}))) .fallback(|| async { "fallback" }); let client = TestClient::new(app); assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } #[crate::test] async fn or() { let one = Router::new().route("/one", get(|| async {})); let two = Router::new().route("/two", get(|| async {})); let app = one.merge(two).fallback(|| async { "fallback" }); let client = TestClient::new(app); assert_eq!(client.get("/one").send().await.status(), StatusCode::OK); assert_eq!(client.get("/two").send().await.status(), StatusCode::OK); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } #[crate::test] async fn fallback_accessing_state() { let app = Router::new() .fallback(|State(state): State<&'static str>| async move { state }) .with_state("state"); let client = TestClient::new(app); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "state"); } async fn inner_fallback() -> impl IntoResponse { (StatusCode::NOT_FOUND, "inner") } async fn outer_fallback() -> impl IntoResponse { (StatusCode::NOT_FOUND, "outer") } #[crate::test] async fn nested_router_inherits_fallback() { let inner = Router::new(); let app = Router::new().nest("/foo", inner).fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn doesnt_inherit_fallback_if_overriden() { let inner = Router::new().fallback(inner_fallback); let app = Router::new().nest("/foo", inner).fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn deeply_nested_inherit_from_top() { let app = Router::new() .nest("/foo", Router::new().nest("/bar", Router::new())) .fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn deeply_nested_inherit_from_middle() { let app = Router::new().nest( "/foo", Router::new() .nest("/bar", Router::new()) .fallback(outer_fallback), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn with_middleware_on_inner_fallback() { async fn never_called(_: Request) -> Request { panic!("should never be called") } let inner = Router::new().layer(map_request(never_called)); let app = Router::new().nest("/foo", inner).fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn also_inherits_default_layered_fallback() { async fn set_header(mut res: Response) -> Response { res.headers_mut() .insert("x-from-fallback", "1".parse().unwrap()); res } let inner = Router::new(); let app = Router::new() .nest("/foo", inner) .fallback(outer_fallback) .layer(map_response(set_header)); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.headers()["x-from-fallback"], "1"); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn fallback_inherited_into_nested_router_service() { let inner = Router::new() .route( "/bar", get(|State(state): State<&'static str>| async move { state }), ) .with_state("inner"); // with a different state let app = Router::<()>::new() .nest_service("/foo", inner) .fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/not-found").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn fallback_inherited_into_nested_opaque_service() { let inner = Router::new() .route( "/bar", get(|State(state): State<&'static str>| async move { state }), ) .with_state("inner") // even if the service is made more opaque it should still inherit the fallback .boxed_clone(); // with a different state let app = Router::<()>::new() .nest_service("/foo", inner) .fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/foo/not-found").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn nest_fallback_on_inner() { let app = Router::new() .nest( "/foo", Router::new() .route("/", get(|| async {})) .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), ) .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); let client = TestClient::new(app); let res = client.get("/foo/not-found").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner fallback"); } // https://github.com/tokio-rs/axum/issues/1931 #[crate::test] async fn doesnt_panic_if_used_with_nested_router() { async fn handler() {} let routes_static = Router::new().nest_service("/", crate::routing::get_service(handler.into_service())); let routes_all = Router::new().fallback_service(routes_static); let client = TestClient::new(routes_all); let res = client.get("/foobar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn issue_2072() { let nested_routes = Router::new().fallback(inner_fallback); let app = Router::new() .nest("/nested", nested_routes) .merge(Router::new()); let client = TestClient::new(app); let res = client.get("/nested/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, ""); } #[crate::test] async fn issue_2072_outer_fallback_before_merge() { let nested_routes = Router::new().fallback(inner_fallback); let app = Router::new() .nest("/nested", nested_routes) .fallback(outer_fallback) .merge(Router::new()); let client = TestClient::new(app); let res = client.get("/nested/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn issue_2072_outer_fallback_after_merge() { let nested_routes = Router::new().fallback(inner_fallback); let app = Router::new() .nest("/nested", nested_routes) .merge(Router::new()) .fallback(outer_fallback); let client = TestClient::new(app); let res = client.get("/nested/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn merge_router_with_fallback_into_nested_router_with_fallback() { let nested_routes = Router::new().fallback(inner_fallback); let app = Router::new() .nest("/nested", nested_routes) .merge(Router::new().fallback(outer_fallback)); let client = TestClient::new(app); let res = client.get("/nested/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn merging_nested_router_with_fallback_into_router_with_fallback() { let nested_routes = Router::new().fallback(inner_fallback); let app = Router::new() .fallback(outer_fallback) .merge(Router::new().nest("/nested", nested_routes)); let client = TestClient::new(app); let res = client.get("/nested/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn merge_empty_into_router_with_fallback() { let app = Router::new().fallback(outer_fallback).merge(Router::new()); let client = TestClient::new(app); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } #[crate::test] async fn merge_router_with_fallback_into_empty() { let app = Router::new().merge(Router::new().fallback(outer_fallback)); let client = TestClient::new(app); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } axum-axum-v0.6.20/axum/src/routing/tests/get_to_head.rs000066400000000000000000000040371446276533600231350ustar00rootroot00000000000000use super::*; use http::Method; use tower::ServiceExt; mod for_handlers { use super::*; use http::HeaderMap; #[crate::test] async fn get_handles_head() { let app = Router::new().route( "/", get(|| async { let mut headers = HeaderMap::new(); headers.insert("x-some-header", "foobar".parse().unwrap()); (headers, "you shouldn't see this") }), ); // don't use reqwest because it always strips bodies from HEAD responses let res = app .oneshot( Request::builder() .uri("/") .method(Method::HEAD) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["x-some-header"], "foobar"); let body = hyper::body::to_bytes(res.into_body()).await.unwrap(); assert_eq!(body.len(), 0); } } mod for_services { use super::*; use crate::routing::get_service; #[crate::test] async fn get_handles_head() { let app = Router::new().route( "/", get_service(service_fn(|_req: Request| async move { Ok::<_, Infallible>( ([("x-some-header", "foobar")], "you shouldn't see this").into_response(), ) })), ); // don't use reqwest because it always strips bodies from HEAD responses let res = app .oneshot( Request::builder() .uri("/") .method(Method::HEAD) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.headers()["x-some-header"], "foobar"); let body = hyper::body::to_bytes(res.into_body()).await.unwrap(); assert_eq!(body.len(), 0); } } axum-axum-v0.6.20/axum/src/routing/tests/handle_error.rs000066400000000000000000000054641446276533600233440ustar00rootroot00000000000000use super::*; use std::future::{pending, ready}; use tower::{timeout::TimeoutLayer, ServiceBuilder}; async fn unit() {} async fn forever() { pending().await } fn timeout() -> TimeoutLayer { TimeoutLayer::new(Duration::from_millis(10)) } #[derive(Clone)] struct Svc; impl Service for Svc { type Response = Response; type Error = hyper::Error; type Future = Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: R) -> Self::Future { ready(Ok(Response::new(Body::empty()))) } } #[crate::test] async fn handler() { let app = Router::new().route( "/", get(forever.layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) .layer(timeout()), )), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } #[crate::test] async fn handler_multiple_methods_first() { let app = Router::new().route( "/", get(forever.layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) .layer(timeout()), )) .post(unit), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } #[crate::test] async fn handler_multiple_methods_middle() { let app = Router::new().route( "/", delete(unit) .get( forever.layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) .layer(timeout()), ), ) .post(unit), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } #[crate::test] async fn handler_multiple_methods_last() { let app = Router::new().route( "/", delete(unit).get( forever.layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) .layer(timeout()), ), ), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } axum-axum-v0.6.20/axum/src/routing/tests/merge.rs000066400000000000000000000266701446276533600220010ustar00rootroot00000000000000use super::*; use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json}; use serde_json::{json, Value}; use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; #[crate::test] async fn basic() { let one = Router::new() .route("/foo", get(|| async {})) .route("/bar", get(|| async {})); let two = Router::new().route("/baz", get(|| async {})); let app = one.merge(two); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/baz").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/qux").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn multiple_ors_balanced_differently() { let one = Router::new().route("/one", get(|| async { "one" })); let two = Router::new().route("/two", get(|| async { "two" })); let three = Router::new().route("/three", get(|| async { "three" })); let four = Router::new().route("/four", get(|| async { "four" })); test( "one", one.clone() .merge(two.clone()) .merge(three.clone()) .merge(four.clone()), ) .await; test( "two", one.clone() .merge(two.clone()) .merge(three.clone().merge(four.clone())), ) .await; test( "three", one.clone() .merge(two.clone().merge(three.clone()).merge(four.clone())), ) .await; test("four", one.merge(two.merge(three.merge(four)))).await; async fn test(name: &str, app: Router) { let client = TestClient::new(app); for n in ["one", "two", "three", "four"].iter() { println!("running: {name} / {n}"); let res = client.get(&format!("/{n}")).send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, *n); } } } #[crate::test] async fn nested_or() { let bar = Router::new().route("/bar", get(|| async { "bar" })); let baz = Router::new().route("/baz", get(|| async { "baz" })); let bar_or_baz = bar.merge(baz); let client = TestClient::new(bar_or_baz.clone()); assert_eq!(client.get("/bar").send().await.text().await, "bar"); assert_eq!(client.get("/baz").send().await.text().await, "baz"); let client = TestClient::new(Router::new().nest("/foo", bar_or_baz)); assert_eq!(client.get("/foo/bar").send().await.text().await, "bar"); assert_eq!(client.get("/foo/baz").send().await.text().await, "baz"); } #[crate::test] async fn or_with_route_following() { let one = Router::new().route("/one", get(|| async { "one" })); let two = Router::new().route("/two", get(|| async { "two" })); let app = one.merge(two).route("/three", get(|| async { "three" })); let client = TestClient::new(app); let res = client.get("/one").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/two").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/three").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn layer() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new() .route("/bar", get(|| async {})) .layer(ConcurrencyLimitLayer::new(10)); let app = one.merge(two); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn layer_and_handle_error() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new() .route("/timeout", get(std::future::pending::<()>)) .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_| async { StatusCode::REQUEST_TIMEOUT })) .layer(TimeoutLayer::new(Duration::from_millis(10))), ); let app = one.merge(two); let client = TestClient::new(app); let res = client.get("/timeout").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } #[crate::test] async fn nesting() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {}))); let app = one.merge(two); let client = TestClient::new(app); let res = client.get("/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn boxed() { let one = Router::new().route("/foo", get(|| async {})); let two = Router::new().route("/bar", get(|| async {})); let app = one.merge(two); let client = TestClient::new(app); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn many_ors() { let app = Router::new() .route("/r1", get(|| async {})) .merge(Router::new().route("/r2", get(|| async {}))) .merge(Router::new().route("/r3", get(|| async {}))) .merge(Router::new().route("/r4", get(|| async {}))) .merge(Router::new().route("/r5", get(|| async {}))) .merge(Router::new().route("/r6", get(|| async {}))) .merge(Router::new().route("/r7", get(|| async {}))); let client = TestClient::new(app); for n in 1..=7 { let res = client.get(&format!("/r{n}")).send().await; assert_eq!(res.status(), StatusCode::OK); } let res = client.get("/r8").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn services() { use crate::routing::get_service; let app = Router::new() .route( "/foo", get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::empty())) })), ) .merge(Router::new().route( "/bar", get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::empty())) })), )); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } async fn all_the_uris( uri: Uri, OriginalUri(original_uri): OriginalUri, req: Request, ) -> impl IntoResponse { Json(json!({ "uri": uri.to_string(), "request_uri": req.uri().to_string(), "original_uri": original_uri.to_string(), })) } #[crate::test] async fn nesting_and_seeing_the_right_uri() { let one = Router::new().nest("/foo/", Router::new().route("/bar", get(all_the_uris))); let two = Router::new().route("/foo", get(all_the_uris)); let client = TestClient::new(one.merge(two)); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/bar", "request_uri": "/bar", "original_uri": "/foo/bar", }) ); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/foo", "request_uri": "/foo", "original_uri": "/foo", }) ); } #[crate::test] async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { let one = Router::new().nest( "/foo/", Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), ); let two = Router::new().route("/foo", get(all_the_uris)); let client = TestClient::new(one.merge(two)); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/baz", "request_uri": "/baz", "original_uri": "/foo/bar/baz", }) ); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/foo", "request_uri": "/foo", "original_uri": "/foo", }) ); } #[crate::test] async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { let one = Router::new().nest( "/one", Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), ); let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris))); let three = Router::new().route("/three", get(all_the_uris)); let client = TestClient::new(one.merge(two).merge(three)); let res = client.get("/one/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/baz", "request_uri": "/baz", "original_uri": "/one/bar/baz", }) ); let res = client.get("/two/qux").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/qux", "request_uri": "/qux", "original_uri": "/two/qux", }) ); let res = client.get("/three").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/three", "request_uri": "/three", "original_uri": "/three", }) ); } #[crate::test] async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { let one = Router::new().nest( "/one", Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))), ); let two = Router::new().route("/two/foo", get(all_the_uris)); let client = TestClient::new(one.merge(two)); let res = client.get("/one/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/bar", "request_uri": "/bar", "original_uri": "/one/foo/bar", }) ); let res = client.get("/two/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, json!({ "uri": "/two/foo", "request_uri": "/two/foo", "original_uri": "/two/foo", }) ); } #[crate::test] async fn middleware_that_return_early() { let private = Router::new() .route("/", get(|| async {})) .layer(ValidateRequestHeaderLayer::bearer("password")); let public = Router::new().route("/public", get(|| async {})); let client = TestClient::new(private.merge(public)); assert_eq!( client.get("/").send().await.status(), StatusCode::UNAUTHORIZED ); assert_eq!( client .get("/") .header("authorization", "Bearer password") .send() .await .status(), StatusCode::OK ); assert_eq!( client.get("/doesnt-exist").send().await.status(), StatusCode::NOT_FOUND ); assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } axum-axum-v0.6.20/axum/src/routing/tests/mod.rs000066400000000000000000000740341446276533600214560ustar00rootroot00000000000000use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, extract::{self, DefaultBodyLimit, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, routing::{ delete, get, get_service, on, on_service, patch, patch_service, path_router::path_for_nested_route, post, MethodFilter, }, test_helpers::{ tracing_helpers::{capture_tracing, TracingEvent}, *, }, BoxError, Extension, Json, Router, }; use futures_util::stream::StreamExt; use http::{ header::CONTENT_LENGTH, header::{ALLOW, HOST}, HeaderMap, Method, Request, Response, StatusCode, Uri, }; use hyper::Body; use serde::Deserialize; use serde_json::json; use std::{ convert::Infallible, future::{ready, Ready}, sync::atomic::{AtomicBool, AtomicUsize, Ordering}, task::{Context, Poll}, time::Duration, }; use tower::{ service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt, }; use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer}; use tower_service::Service; mod fallback; mod get_to_head; mod handle_error; mod merge; mod nest; #[crate::test] async fn hello_world() { async fn root(_: Request) -> &'static str { "Hello, World!" } async fn foo(_: Request) -> &'static str { "foo" } async fn users_create(_: Request) -> &'static str { "users#create" } let app = Router::new() .route("/", get(root).post(foo)) .route("/users", post(users_create)); let client = TestClient::new(app); let res = client.get("/").send().await; let body = res.text().await; assert_eq!(body, "Hello, World!"); let res = client.post("/").send().await; let body = res.text().await; assert_eq!(body, "foo"); let res = client.post("/users").send().await; let body = res.text().await; assert_eq!(body, "users#create"); } #[crate::test] async fn routing() { let app = Router::new() .route( "/users", get(|_: Request| async { "users#index" }) .post(|_: Request| async { "users#create" }), ) .route("/users/:id", get(|_: Request| async { "users#show" })) .route( "/users/:id/action", get(|_: Request| async { "users#action" }), ); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/users").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#index"); let res = client.post("/users").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#create"); let res = client.get("/users/1").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#show"); let res = client.get("/users/1/action").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#action"); } #[crate::test] async fn router_type_doesnt_change() { let app: Router = Router::new() .route( "/", on(MethodFilter::GET, |_: Request| async { "hi from GET" }) .on(MethodFilter::POST, |_: Request| async { "hi from POST" }), ) .layer(tower_http::compression::CompressionLayer::new()); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi from GET"); let res = client.post("/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi from POST"); } #[crate::test] async fn routing_between_services() { use std::convert::Infallible; use tower::service_fn; async fn handle(_: Request) -> &'static str { "handler" } let app = Router::new() .route( "/one", get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one get"))) })) .post_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one post"))) })) .on_service( MethodFilter::PUT, service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one put"))) }), ), ) .route("/two", on_service(MethodFilter::GET, handle.into_service())); let client = TestClient::new(app); let res = client.get("/one").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one get"); let res = client.post("/one").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one post"); let res = client.put("/one").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one put"); let res = client.get("/two").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "handler"); } #[crate::test] async fn middleware_on_single_route() { use tower::ServiceBuilder; use tower_http::{compression::CompressionLayer, trace::TraceLayer}; async fn handle(_: Request) -> &'static str { "Hello, World!" } let app = Router::new().route( "/", get(handle.layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CompressionLayer::new()) .into_inner(), )), ); let client = TestClient::new(app); let res = client.get("/").send().await; let body = res.text().await; assert_eq!(body, "Hello, World!"); } #[crate::test] async fn service_in_bottom() { async fn handler(_req: Request) -> Result, Infallible> { Ok(Response::new(hyper::Body::empty())) } let app = Router::new().route("/", get_service(service_fn(handler))); TestClient::new(app); } #[crate::test] async fn wrong_method_handler() { let app = Router::new() .route("/", get(|| async {}).post(|| async {})) .route("/foo", patch(|| async {})); let client = TestClient::new(app); let res = client.patch("/").send().await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST"); let res = client.patch("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/foo").send().await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "PATCH"); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn wrong_method_service() { #[derive(Clone)] struct Svc; impl Service for Svc { type Response = Response>; type Error = Infallible; type Future = Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: R) -> Self::Future { ready(Ok(Response::new(Empty::new()))) } } let app = Router::new() .route("/", get_service(Svc).post_service(Svc)) .route("/foo", patch_service(Svc)); let client = TestClient::new(app); let res = client.patch("/").send().await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST"); let res = client.patch("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/foo").send().await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "PATCH"); let res = client.get("/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn multiple_methods_for_one_handler() { async fn root(_: Request) -> &'static str { "Hello, World!" } let app = Router::new().route("/", on(MethodFilter::GET | MethodFilter::POST, root)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn wildcard_sees_whole_url() { let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() })); let client = TestClient::new(app); let res = client.get("/api/foo/bar").send().await; assert_eq!(res.text().await, "/api/foo/bar"); } #[crate::test] async fn middleware_applies_to_routes_above() { let app = Router::new() .route("/one", get(std::future::pending::<()>)) .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async move { StatusCode::REQUEST_TIMEOUT })) .layer(TimeoutLayer::new(Duration::new(0, 0))), ) .route("/two", get(|| async {})); let client = TestClient::new(app); let res = client.get("/one").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); let res = client.get("/two").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn not_found_for_extra_trailing_slash() { let app = Router::new().route("/foo", get(|| async {})); let client = TestClient::new(app); let res = client.get("/foo/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn not_found_for_missing_trailing_slash() { let app = Router::new().route("/foo/", get(|| async {})); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn with_and_without_trailing_slash() { let app = Router::new() .route("/foo", get(|| async { "without tsr" })) .route("/foo/", get(|| async { "with tsr" })); let client = TestClient::new(app); let res = client.get("/foo/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "with tsr"); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "without tsr"); } // for https://github.com/tokio-rs/axum/issues/420 #[crate::test] async fn wildcard_doesnt_match_just_trailing_slash() { let app = Router::new().route( "/x/*path", get(|Path(path): Path| async move { path }), ); let client = TestClient::new(app); let res = client.get("/x").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/x/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/x/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo/bar"); } #[crate::test] async fn what_matches_wildcard() { let app = Router::new() .route("/*key", get(|| async { "root" })) .route("/x/*key", get(|| async { "x" })) .fallback(|| async { "fallback" }); let client = TestClient::new(app); let get = |path| { let f = client.get(path).send(); async move { f.await.text().await } }; assert_eq!(get("/").await, "fallback"); assert_eq!(get("/a").await, "root"); assert_eq!(get("/a/").await, "root"); assert_eq!(get("/a/b").await, "root"); assert_eq!(get("/a/b/").await, "root"); assert_eq!(get("/x").await, "root"); assert_eq!(get("/x/").await, "root"); assert_eq!(get("/x/a").await, "x"); assert_eq!(get("/x/a/").await, "x"); assert_eq!(get("/x/a/b").await, "x"); assert_eq!(get("/x/a/b/").await, "x"); } #[crate::test] async fn static_and_dynamic_paths() { let app = Router::new() .route( "/:key", get(|Path(key): Path| async move { format!("dynamic: {key}") }), ) .route("/foo", get(|| async { "static" })); let client = TestClient::new(app); let res = client.get("/bar").send().await; assert_eq!(res.text().await, "dynamic: bar"); let res = client.get("/foo").send().await; assert_eq!(res.text().await, "static"); } #[crate::test] #[should_panic(expected = "Paths must start with a `/`. Use \"/\" for root routes")] async fn empty_route() { let app = Router::new().route("", get(|| async {})); TestClient::new(app); } #[crate::test] async fn middleware_still_run_for_unmatched_requests() { #[derive(Clone)] struct CountMiddleware(S); static COUNT: AtomicUsize = AtomicUsize::new(0); impl Service for CountMiddleware where S: Service, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: R) -> Self::Future { COUNT.fetch_add(1, Ordering::SeqCst); self.0.call(req) } } let app = Router::new() .route("/", get(|| async {})) .layer(tower::layer::layer_fn(CountMiddleware)); let client = TestClient::new(app); assert_eq!(COUNT.load(Ordering::SeqCst), 0); client.get("/").send().await; assert_eq!(COUNT.load(Ordering::SeqCst), 1); client.get("/not-found").send().await; assert_eq!(COUNT.load(Ordering::SeqCst), 2); } #[crate::test] #[should_panic(expected = "\ Invalid route: `Router::route_service` cannot be used with `Router`s. \ Use `Router::nest` instead\ ")] async fn routing_to_router_panics() { TestClient::new(Router::new().route_service("/", Router::new())); } #[crate::test] async fn route_layer() { let app = Router::new() .route("/foo", get(|| async {})) .route_layer(ValidateRequestHeaderLayer::bearer("password")); let client = TestClient::new(app); let res = client .get("/foo") .header("authorization", "Bearer password") .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); let res = client.get("/not-found").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); // it would be nice if this would return `405 Method Not Allowed` // but that requires knowing more about which method route we're calling, which we // don't know currently since its just a generic `Service` let res = client.post("/foo").send().await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } #[crate::test] async fn different_methods_added_in_different_routes() { let app = Router::new() .route("/", get(|| async { "GET" })) .route("/", post(|| async { "POST" })); let client = TestClient::new(app); let res = client.get("/").send().await; let body = res.text().await; assert_eq!(body, "GET"); let res = client.post("/").send().await; let body = res.text().await; assert_eq!(body, "POST"); } #[crate::test] #[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")] async fn merging_routers_with_fallbacks_panics() { async fn fallback() {} let one = Router::new().fallback(fallback); let two = Router::new().fallback(fallback); TestClient::new(one.merge(two)); } #[test] #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")] fn routes_with_overlapping_method_routes() { async fn handler() {} let _: Router = Router::new() .route("/foo/bar", get(handler)) .route("/foo/bar", get(handler)); } #[test] #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")] fn merging_with_overlapping_method_routes() { async fn handler() {} let app: Router = Router::new().route("/foo/bar", get(handler)); _ = app.clone().merge(app); } #[crate::test] async fn merging_routers_with_same_paths_but_different_methods() { let one = Router::new().route("/", get(|| async { "GET" })); let two = Router::new().route("/", post(|| async { "POST" })); let client = TestClient::new(one.merge(two)); let res = client.get("/").send().await; let body = res.text().await; assert_eq!(body, "GET"); let res = client.post("/").send().await; let body = res.text().await; assert_eq!(body, "POST"); } #[crate::test] async fn head_content_length_through_hyper_server() { let app = Router::new() .route("/", get(|| async { "foo" })) .route("/json", get(|| async { Json(json!({ "foo": 1 })) })); let client = TestClient::new(app); let res = client.head("/").send().await; assert_eq!(res.headers()["content-length"], "3"); assert!(res.text().await.is_empty()); let res = client.head("/json").send().await; assert_eq!(res.headers()["content-length"], "9"); assert!(res.text().await.is_empty()); } #[crate::test] async fn head_content_length_through_hyper_server_that_hits_fallback() { let app = Router::new().fallback(|| async { "foo" }); let client = TestClient::new(app); let res = client.head("/").send().await; assert_eq!(res.headers()["content-length"], "3"); } #[crate::test] async fn head_with_middleware_applied() { use tower_http::compression::{predicate::SizeAbove, CompressionLayer}; let app = Router::new() .nest( "/", Router::new().route("/", get(|| async { "Hello, World!" })), ) .layer(CompressionLayer::new().compress_when(SizeAbove::new(0))); let client = TestClient::new(app); // send GET request let res = client .get("/") .header("accept-encoding", "gzip") .send() .await; assert_eq!(res.headers()["transfer-encoding"], "chunked"); // cannot have `transfer-encoding: chunked` and `content-length` assert!(!res.headers().contains_key("content-length")); // send HEAD request let res = client .head("/") .header("accept-encoding", "gzip") .send() .await; // no response body so no `transfer-encoding` assert!(!res.headers().contains_key("transfer-encoding")); // no content-length since we cannot know it since the response // is compressed assert!(!res.headers().contains_key("content-length")); } #[crate::test] #[should_panic(expected = "Paths must start with a `/`")] async fn routes_must_start_with_slash() { let app = Router::new().route(":foo", get(|| async {})); TestClient::new(app); } #[crate::test] async fn body_limited_by_default() { let app = Router::new() .route("/bytes", post(|_: Bytes| async {})) .route("/string", post(|_: String| async {})) .route("/json", post(|_: Json| async {})); let client = TestClient::new(app); for uri in ["/bytes", "/string", "/json"] { println!("calling {uri}"); let stream = futures_util::stream::repeat("a".repeat(1000)).map(Ok::<_, hyper::Error>); let body = Body::wrap_stream(stream); let res_future = client .post(uri) .header("content-type", "application/json") .body(body) .send(); let res = tokio::time::timeout(Duration::from_secs(3), res_future) .await .expect("never got response"); assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } } #[crate::test] async fn disabling_the_default_limit() { let app = Router::new() .route("/", post(|_: Bytes| async {})) .layer(DefaultBodyLimit::disable()); let client = TestClient::new(app); // `DEFAULT_LIMIT` is 2mb so make a body larger than that let body = Body::from("a".repeat(3_000_000)); let res = client.post("/").body(body).send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn limited_body_with_content_length() { const LIMIT: usize = 3; let app = Router::new() .route( "/", post(|headers: HeaderMap, _body: Bytes| async move { assert!(headers.get(CONTENT_LENGTH).is_some()); }), ) .layer(RequestBodyLimitLayer::new(LIMIT)); let client = TestClient::new(app); let res = client.post("/").body("a".repeat(LIMIT)).send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } #[crate::test] async fn changing_the_default_limit() { let new_limit = 2; let app = Router::new() .route("/", post(|_: Bytes| async {})) .layer(DefaultBodyLimit::max(new_limit)); let client = TestClient::new(app); let res = client .post("/") .body(Body::from("a".repeat(new_limit))) .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client .post("/") .body(Body::from("a".repeat(new_limit + 1))) .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } #[crate::test] async fn limited_body_with_streaming_body() { const LIMIT: usize = 3; let app = Router::new() .route( "/", post(|headers: HeaderMap, _body: Bytes| async move { assert!(headers.get(CONTENT_LENGTH).is_none()); }), ) .layer(RequestBodyLimitLayer::new(LIMIT)); let client = TestClient::new(app); let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT))]); let res = client .post("/") .body(Body::wrap_stream(stream)) .send() .await; assert_eq!(res.status(), StatusCode::OK); let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT * 2))]); let res = client .post("/") .body(Body::wrap_stream(stream)) .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } #[crate::test] async fn extract_state() { #[derive(Clone)] struct AppState { value: i32, inner: InnerState, } #[derive(Clone)] struct InnerState { value: i32, } impl FromRef for InnerState { fn from_ref(state: &AppState) -> Self { state.inner.clone() } } async fn handler(State(outer): State, State(inner): State) { assert_eq!(outer.value, 1); assert_eq!(inner.value, 2); } let state = AppState { value: 1, inner: InnerState { value: 2 }, }; let app = Router::new().route("/", get(handler)).with_state(state); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn explicitly_set_state() { let app = Router::new() .route_service( "/", get(|State(state): State<&'static str>| async move { state }).with_state("foo"), ) .with_state("..."); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } #[crate::test] async fn layer_response_into_response() { fn map_response(_res: Response) -> Result, impl IntoResponse> { let headers = [("x-foo", "bar")]; let status = StatusCode::IM_A_TEAPOT; Err((headers, status)) } let app = Router::new() .route("/", get(|| async {})) .layer(MapResponseLayer::new(map_response)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.headers()["x-foo"], "bar"); assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); } #[allow(dead_code)] fn method_router_fallback_with_state() { async fn fallback(_: State<&'static str>) {} async fn not_found(_: State<&'static str>) {} let state = "foo"; let _: Router = Router::new() .fallback(get(fallback).fallback(not_found)) .with_state(state); } #[test] fn test_path_for_nested_route() { assert_eq!(path_for_nested_route("/", "/"), "/"); assert_eq!(path_for_nested_route("/a", "/"), "/a"); assert_eq!(path_for_nested_route("/", "/b"), "/b"); assert_eq!(path_for_nested_route("/a/", "/"), "/a/"); assert_eq!(path_for_nested_route("/", "/b/"), "/b/"); assert_eq!(path_for_nested_route("/a", "/b"), "/a/b"); assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b"); assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/"); assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/"); } #[crate::test] async fn state_isnt_cloned_too_much() { static SETUP_DONE: AtomicBool = AtomicBool::new(false); static COUNT: AtomicUsize = AtomicUsize::new(0); struct AppState; impl Clone for AppState { fn clone(&self) -> Self { #[rustversion::since(1.65)] #[track_caller] fn count() { if SETUP_DONE.load(Ordering::SeqCst) { let bt = std::backtrace::Backtrace::force_capture(); let bt = bt .to_string() .lines() .filter(|line| line.contains("axum") || line.contains("./src")) .collect::>() .join("\n"); println!("AppState::Clone:\n===============\n{}\n", bt); COUNT.fetch_add(1, Ordering::SeqCst); } } #[rustversion::not(since(1.65))] fn count() { if SETUP_DONE.load(Ordering::SeqCst) { COUNT.fetch_add(1, Ordering::SeqCst); } } count(); Self } } let app = Router::new() .route("/", get(|_: State| async {})) .with_state(AppState); let client = TestClient::new(app); // ignore clones made during setup SETUP_DONE.store(true, Ordering::SeqCst); client.get("/").send().await; assert_eq!(COUNT.load(Ordering::SeqCst), 4); } #[crate::test] async fn logging_rejections() { #[derive(Deserialize, Eq, PartialEq, Debug)] #[serde(deny_unknown_fields)] struct RejectionEvent { message: String, status: u16, body: String, rejection_type: String, } let events = capture_tracing::(|| async { let app = Router::new() .route("/extension", get(|_: Extension| async {})) .route("/string", post(|_: String| async {})); let client = TestClient::new(app); assert_eq!( client.get("/extension").send().await.status(), StatusCode::INTERNAL_SERVER_ERROR ); assert_eq!( client .post("/string") .body(Vec::from([0, 159, 146, 150])) .send() .await .status(), StatusCode::BAD_REQUEST, ); }) .await; assert_eq!( dbg!(events), Vec::from([ TracingEvent { fields: RejectionEvent { message: "rejecting request".to_owned(), status: 500, body: "Missing request extension: Extension of \ type `core::convert::Infallible` was not found. \ Perhaps you forgot to add it? See `axum::Extension`." .to_owned(), rejection_type: "axum::extract::rejection::MissingExtension".to_owned(), }, target: "axum::rejection".to_owned(), level: "TRACE".to_owned(), }, TracingEvent { fields: RejectionEvent { message: "rejecting request".to_owned(), status: 400, body: "Request body didn't contain valid UTF-8: \ invalid utf-8 sequence of 1 bytes from index 1" .to_owned(), rejection_type: "axum_core::extract::rejection::InvalidUtf8".to_owned(), }, target: "axum::rejection".to_owned(), level: "TRACE".to_owned(), }, ]) ) } // https://github.com/tokio-rs/axum/issues/1955 #[crate::test] async fn connect_going_to_custom_fallback() { let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") }); let req = Request::builder() .uri("example.com:443") .method(Method::CONNECT) .header(HOST, "example.com:443") .body(Body::empty()) .unwrap(); let res = app.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap(); assert_eq!(text, "custom fallback"); } // https://github.com/tokio-rs/axum/issues/1955 #[crate::test] async fn connect_going_to_default_fallback() { let app = Router::new(); let req = Request::builder() .uri("example.com:443") .method(Method::CONNECT) .header(HOST, "example.com:443") .body(Body::empty()) .unwrap(); let res = app.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_FOUND); let body = hyper::body::to_bytes(res).await.unwrap(); assert!(body.is_empty()); } #[crate::test] async fn impl_handler_for_into_response() { let app = Router::new().route("/things", post((StatusCode::CREATED, "thing created"))); let client = TestClient::new(app); let res = client.post("/things").send().await; assert_eq!(res.status(), StatusCode::CREATED); assert_eq!(res.text().await, "thing created"); } axum-axum-v0.6.20/axum/src/routing/tests/nest.rs000066400000000000000000000330231446276533600216410ustar00rootroot00000000000000use super::*; use crate::{body::boxed, extract::Extension}; use std::collections::HashMap; use tower_http::services::ServeDir; #[crate::test] async fn nesting_apps() { let api_routes = Router::new() .route( "/users", get(|| async { "users#index" }).post(|| async { "users#create" }), ) .route( "/users/:id", get( |params: extract::Path>| async move { format!( "{}: users#show ({})", params.get("version").unwrap(), params.get("id").unwrap() ) }, ), ) .route( "/games/:id", get( |params: extract::Path>| async move { format!( "{}: games#show ({})", params.get("version").unwrap(), params.get("id").unwrap() ) }, ), ); let app = Router::new() .route("/", get(|| async { "hi" })) .nest("/:version/api", api_routes); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi"); let res = client.get("/v0/api/users").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#index"); let res = client.get("/v0/api/users/123").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "v0: users#show (123)"); let res = client.get("/v0/api/games/123").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "v0: games#show (123)"); } #[crate::test] async fn wrong_method_nest() { let nested_app = Router::new().route("/", get(|| async {})); let app = Router::new().nest("/", nested_app); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.post("/").send().await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD"); let res = client.patch("/foo").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn nesting_router_at_root() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); let app = Router::new().nest("/", nested); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo"); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn nesting_router_at_empty_path() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); let app = Router::new().nest("", nested); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo"); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[crate::test] async fn nesting_handler_at_root() { let app = Router::new().nest_service("/", get(|uri: Uri| async move { uri.to_string() })); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/"); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo"); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo/bar"); } #[crate::test] async fn nested_url_extractor() { let app = Router::new().nest( "/foo", Router::new().nest( "/bar", Router::new() .route("/baz", get(|uri: Uri| async move { uri.to_string() })) .route( "/qux", get(|req: Request| async move { req.uri().to_string() }), ), ), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/baz"); let res = client.get("/foo/bar/qux").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/qux"); } #[crate::test] async fn nested_url_original_extractor() { let app = Router::new().nest( "/foo", Router::new().nest( "/bar", Router::new().route( "/baz", get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), ), ), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo/bar/baz"); } #[crate::test] async fn nested_service_sees_stripped_uri() { let app = Router::new().nest( "/foo", Router::new().nest( "/bar", Router::new().route_service( "/baz", service_fn(|req: Request| async move { let body = boxed(Body::from(req.uri().to_string())); Ok::<_, Infallible>(Response::new(body)) }), ), ), ); let client = TestClient::new(app); let res = client.get("/foo/bar/baz").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/baz"); } #[crate::test] async fn nest_static_file_server() { let app = Router::new().nest_service("/static", ServeDir::new(".")); let client = TestClient::new(app); let res = client.get("/static/README.md").send().await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn nested_multiple_routes() { let app = Router::new() .nest( "/api", Router::new() .route("/users", get(|| async { "users" })) .route("/teams", get(|| async { "teams" })), ) .route("/", get(|| async { "root" })); let client = TestClient::new(app); assert_eq!(client.get("/").send().await.text().await, "root"); assert_eq!(client.get("/api/users").send().await.text().await, "users"); assert_eq!(client.get("/api/teams").send().await.text().await, "teams"); } #[test] #[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] fn nested_service_at_root_with_other_routes() { let _: Router = Router::new() .nest_service("/", Router::new().route("/users", get(|| async {}))) .route("/", get(|| async {})); } #[test] fn nested_at_root_with_other_routes() { let _: Router = Router::new() .nest("/", Router::new().route("/users", get(|| async {}))) .route("/", get(|| async {})); } #[crate::test] async fn multiple_top_level_nests() { let app = Router::new() .nest( "/one", Router::new().route("/route", get(|| async { "one" })), ) .nest( "/two", Router::new().route("/route", get(|| async { "two" })), ); let client = TestClient::new(app); assert_eq!(client.get("/one/route").send().await.text().await, "one"); assert_eq!(client.get("/two/route").send().await.text().await, "two"); } #[crate::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { _ = Router::<(), Body>::new().nest("/one/*rest", Router::new()); } #[crate::test] async fn outer_middleware_still_see_whole_url() { #[derive(Clone)] struct SetUriExtension(S); #[derive(Clone)] struct Uri(http::Uri); impl Service> for SetUriExtension where S: Service>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let uri = Uri(req.uri().clone()); req.extensions_mut().insert(uri); self.0.call(req) } } async fn handler(Extension(Uri(middleware_uri)): Extension) -> impl IntoResponse { middleware_uri.to_string() } let app = Router::new() .route("/", get(handler)) .route("/foo", get(handler)) .route("/foo/bar", get(handler)) .nest("/one", Router::new().route("/two", get(handler))) .fallback(handler) .layer(tower::layer::layer_fn(SetUriExtension)); let client = TestClient::new(app); assert_eq!(client.get("/").send().await.text().await, "/"); assert_eq!(client.get("/foo").send().await.text().await, "/foo"); assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar"); assert_eq!( client.get("/not-found").send().await.text().await, "/not-found" ); assert_eq!(client.get("/one/two").send().await.text().await, "/one/two"); } #[crate::test] async fn nest_at_capture() { let api_routes = Router::new().route( "/:b", get(|Path((a, b)): Path<(String, String)>| async move { format!("a={a} b={b}") }), ); let app = Router::new().nest("/:a", api_routes); let client = TestClient::new(app); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "a=foo b=bar"); } #[crate::test] async fn nest_with_and_without_trailing() { let app = Router::new().nest_service("/foo", get(|| async {})); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/foo/").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } #[tokio::test] async fn nesting_with_root_inner_router() { let app = Router::new() .nest_service("/service", Router::new().route("/", get(|| async {}))) .nest("/router", Router::new().route("/", get(|| async {}))) .nest("/router-slash/", Router::new().route("/", get(|| async {}))); let client = TestClient::new(app); // `/service/` does match the `/service` prefix and the remaining path is technically // empty, which is the same as `/` which matches `.route("/", _)` let res = client.get("/service").send().await; assert_eq!(res.status(), StatusCode::OK); // `/service/` does match the `/service` prefix and the remaining path is `/` // which matches `.route("/", _)` // // this is perhaps a little surprising but don't think there is much we can do let res = client.get("/service/").send().await; assert_eq!(res.status(), StatusCode::OK); // at least it does work like you'd expect when using `nest` let res = client.get("/router").send().await; assert_eq!(res.status(), StatusCode::OK); let res = client.get("/router/").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/router-slash").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); let res = client.get("/router-slash/").send().await; assert_eq!(res.status(), StatusCode::OK); } macro_rules! nested_route_test { ( $name:ident, // the path we nest the inner router at nest = $nested_path:literal, // the route the inner router accepts route = $route_path:literal, // the route we expect to be able to call expected = $expected_path:literal $(,)? ) => { #[crate::test] async fn $name() { let inner = Router::new().route($route_path, get(|| async {})); let app = Router::new().nest($nested_path, inner); let client = TestClient::new(app); let res = client.get($expected_path).send().await; let status = res.status(); assert_eq!(status, StatusCode::OK, "Router"); } }; } // test cases taken from https://github.com/tokio-rs/axum/issues/714#issuecomment-1058144460 nested_route_test!(nest_1, nest = "", route = "/", expected = "/"); nested_route_test!(nest_2, nest = "", route = "/a", expected = "/a"); nested_route_test!(nest_3, nest = "", route = "/a/", expected = "/a/"); nested_route_test!(nest_4, nest = "/", route = "/", expected = "/"); nested_route_test!(nest_5, nest = "/", route = "/a", expected = "/a"); nested_route_test!(nest_6, nest = "/", route = "/a/", expected = "/a/"); nested_route_test!(nest_7, nest = "/a", route = "/", expected = "/a"); nested_route_test!(nest_8, nest = "/a", route = "/a", expected = "/a/a"); nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); axum-axum-v0.6.20/axum/src/routing/url_params.rs000066400000000000000000000027071446276533600217000ustar00rootroot00000000000000use crate::util::PercentDecodedStr; use http::Extensions; use matchit::Params; use std::sync::Arc; pub(crate) enum UrlParams { Params(Vec<(Arc, PercentDecodedStr)>), InvalidUtf8InPathParam { key: Arc }, } pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { let current_params = extensions.get_mut(); if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params { // nothing to do here since an error was stored earlier return; } let params = params .iter() .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) .filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM)) .map(|(k, v)| { if let Some(decoded) = PercentDecodedStr::new(v) { Ok((Arc::from(k), decoded)) } else { Err(Arc::from(k)) } }) .collect::, _>>(); match (current_params, params) { (Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => { unreachable!("we check for this state earlier in this method") } (_, Err(invalid_key)) => { extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key }); } (Some(UrlParams::Params(current)), Ok(params)) => { current.extend(params); } (None, Ok(params)) => { extensions.insert(UrlParams::Params(params)); } } } axum-axum-v0.6.20/axum/src/service_ext.rs000066400000000000000000000036541446276533600203660ustar00rootroot00000000000000#[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::routing::IntoMakeService; use tower_service::Service; /// Extension trait that adds additional methods to any [`Service`]. pub trait ServiceExt: Service + Sized { /// Convert this service into a [`MakeService`], that is a [`Service`] whose /// response is another service. /// /// This is commonly used when applying middleware around an entire [`Router`]. See ["Rewriting /// request URI in middleware"] for more details. /// /// [`MakeService`]: tower::make::MakeService /// ["Rewriting request URI in middleware"]: crate::middleware#rewriting-request-uri-in-middleware /// [`Router`]: crate::Router fn into_make_service(self) -> IntoMakeService; /// Convert this service into a [`MakeService`], that will store `C`'s /// associated `ConnectInfo` in a request extension such that [`ConnectInfo`] /// can extract it. /// /// This enables extracting things like the client's remote address. /// This is commonly used when applying middleware around an entire [`Router`]. See ["Rewriting /// request URI in middleware"] for more details. /// /// [`MakeService`]: tower::make::MakeService /// ["Rewriting request URI in middleware"]: crate::middleware#rewriting-request-uri-in-middleware /// [`Router`]: crate::Router /// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo #[cfg(feature = "tokio")] fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo; } impl ServiceExt for S where S: Service + Sized, { fn into_make_service(self) -> IntoMakeService { IntoMakeService::new(self) } #[cfg(feature = "tokio")] fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self) } } axum-axum-v0.6.20/axum/src/test_helpers/000077500000000000000000000000001446276533600201715ustar00rootroot00000000000000axum-axum-v0.6.20/axum/src/test_helpers/mod.rs000066400000000000000000000005121446276533600213140ustar00rootroot00000000000000#![allow(clippy::disallowed_names)] use crate::{body::HttpBody, BoxError}; mod test_client; pub(crate) use self::test_client::*; pub(crate) mod tracing_helpers; pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {} pub(crate) struct NotSendSync(*const ()); axum-axum-v0.6.20/axum/src/test_helpers/test_client.rs000066400000000000000000000102521446276533600230540ustar00rootroot00000000000000use super::{BoxError, HttpBody}; use bytes::Bytes; use http::{ header::{HeaderName, HeaderValue}, Request, StatusCode, }; use hyper::{Body, Server}; use std::net::{SocketAddr, TcpListener}; use tower::make::Shared; use tower_service::Service; pub(crate) struct TestClient { client: reqwest::Client, addr: SocketAddr, } impl TestClient { pub(crate) fn new(svc: S) -> Self where S: Service, Response = http::Response> + Clone + Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Data: Send, ResBody::Error: Into, S::Future: Send, S::Error: Into, { let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket"); let addr = listener.local_addr().unwrap(); println!("Listening on {addr}"); tokio::spawn(async move { let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc)); server.await.expect("server error"); }); let client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap(); TestClient { client, addr } } pub(crate) fn get(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.get(format!("http://{}{}", self.addr, url)), } } pub(crate) fn head(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.head(format!("http://{}{}", self.addr, url)), } } pub(crate) fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.post(format!("http://{}{}", self.addr, url)), } } #[allow(dead_code)] pub(crate) fn put(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.put(format!("http://{}{}", self.addr, url)), } } #[allow(dead_code)] pub(crate) fn patch(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.patch(format!("http://{}{}", self.addr, url)), } } } pub(crate) struct RequestBuilder { builder: reqwest::RequestBuilder, } impl RequestBuilder { pub(crate) async fn send(self) -> TestResponse { TestResponse { response: self.builder.send().await.unwrap(), } } pub(crate) fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self } pub(crate) fn json(mut self, json: &T) -> Self where T: serde::Serialize, { self.builder = self.builder.json(json); self } pub(crate) fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { self.builder = self.builder.header(key, value); self } #[allow(dead_code)] pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self { self.builder = self.builder.multipart(form); self } } #[derive(Debug)] pub(crate) struct TestResponse { response: reqwest::Response, } impl TestResponse { #[allow(dead_code)] pub(crate) async fn bytes(self) -> Bytes { self.response.bytes().await.unwrap() } pub(crate) async fn text(self) -> String { self.response.text().await.unwrap() } #[allow(dead_code)] pub(crate) async fn json(self) -> T where T: serde::de::DeserializeOwned, { self.response.json().await.unwrap() } pub(crate) fn status(&self) -> StatusCode { self.response.status() } pub(crate) fn headers(&self) -> &http::HeaderMap { self.response.headers() } pub(crate) async fn chunk(&mut self) -> Option { self.response.chunk().await.unwrap() } pub(crate) async fn chunk_text(&mut self) -> Option { let chunk = self.chunk().await?; Some(String::from_utf8(chunk.to_vec()).unwrap()) } } axum-axum-v0.6.20/axum/src/test_helpers/tracing_helpers.rs000066400000000000000000000046761446276533600237250ustar00rootroot00000000000000use std::{ future::Future, io, sync::{Arc, Mutex}, }; use serde::{de::DeserializeOwned, Deserialize}; use tracing_subscriber::prelude::*; use tracing_subscriber::{filter::Targets, fmt::MakeWriter}; #[derive(Deserialize, Eq, PartialEq, Debug)] #[serde(deny_unknown_fields)] pub(crate) struct TracingEvent { pub(crate) fields: T, pub(crate) target: String, pub(crate) level: String, } /// Run an async closure and capture the tracing output it produces. pub(crate) async fn capture_tracing(f: F) -> Vec> where F: Fn() -> Fut, Fut: Future, T: DeserializeOwned, { let (make_writer, handle) = TestMakeWriter::new(); let subscriber = tracing_subscriber::registry().with( tracing_subscriber::fmt::layer() .with_writer(make_writer) .with_target(true) .without_time() .with_ansi(false) .json() .flatten_event(false) .with_filter("axum=trace".parse::().unwrap()), ); let guard = tracing::subscriber::set_default(subscriber); f().await; drop(guard); handle .take() .lines() .map(|line| serde_json::from_str(line).unwrap()) .collect() } struct TestMakeWriter { write: Arc>>>, } impl TestMakeWriter { fn new() -> (Self, Handle) { let write = Arc::new(Mutex::new(Some(Vec::::new()))); ( Self { write: write.clone(), }, Handle { write }, ) } } impl<'a> MakeWriter<'a> for TestMakeWriter { type Writer = Writer<'a>; fn make_writer(&'a self) -> Self::Writer { Writer(self) } } struct Writer<'a>(&'a TestMakeWriter); impl<'a> io::Write for Writer<'a> { fn write(&mut self, buf: &[u8]) -> io::Result { match &mut *self.0.write.lock().unwrap() { Some(vec) => { let len = buf.len(); vec.extend(buf); Ok(len) } None => Err(io::Error::new( io::ErrorKind::Other, "inner writer has been taken", )), } } fn flush(&mut self) -> io::Result<()> { Ok(()) } } struct Handle { write: Arc>>>, } impl Handle { fn take(self) -> String { let vec = self.write.lock().unwrap().take().unwrap(); String::from_utf8(vec).unwrap() } } axum-axum-v0.6.20/axum/src/typed_header.rs000066400000000000000000000134571446276533600205050ustar00rootroot00000000000000use crate::extract::FromRequestParts; use async_trait::async_trait; use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use headers::HeaderMapExt; use http::request::Parts; use std::convert::Infallible; /// Extractor and response that works with typed header values from [`headers`]. /// /// # As extractor /// /// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than /// removing all headers with the `HeaderMap` extractor. /// /// ```rust,no_run /// use axum::{ /// TypedHeader, /// headers::UserAgent, /// routing::get, /// Router, /// }; /// /// async fn users_teams_show( /// TypedHeader(user_agent): TypedHeader, /// ) { /// // ... /// } /// /// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # As response /// /// ```rust /// use axum::{ /// TypedHeader, /// response::IntoResponse, /// headers::ContentType, /// }; /// /// async fn handler() -> (TypedHeader, &'static str) { /// ( /// TypedHeader(ContentType::text_utf8()), /// "Hello, World!", /// ) /// } /// ``` #[cfg(feature = "headers")] #[derive(Debug, Clone, Copy)] #[must_use] pub struct TypedHeader(pub T); #[async_trait] impl FromRequestParts for TypedHeader where T: headers::Header, S: Send + Sync, { type Rejection = TypedHeaderRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let mut values = parts.headers.get_all(T::name()).iter(); let is_missing = values.size_hint() == (0, Some(0)); T::decode(&mut values) .map(Self) .map_err(|err| TypedHeaderRejection { name: T::name(), reason: if is_missing { // Report a more precise rejection for the missing header case. TypedHeaderRejectionReason::Missing } else { TypedHeaderRejectionReason::Error(err) }, }) } } axum_core::__impl_deref!(TypedHeader); impl IntoResponseParts for TypedHeader where T: headers::Header, { type Error = Infallible; fn into_response_parts(self, mut res: ResponseParts) -> Result { res.headers_mut().typed_insert(self.0); Ok(res) } } impl IntoResponse for TypedHeader where T: headers::Header, { fn into_response(self) -> Response { let mut res = ().into_response(); res.headers_mut().typed_insert(self.0); res } } /// Rejection used for [`TypedHeader`](super::TypedHeader). #[cfg(feature = "headers")] #[derive(Debug)] pub struct TypedHeaderRejection { name: &'static http::header::HeaderName, reason: TypedHeaderRejectionReason, } impl TypedHeaderRejection { /// Name of the header that caused the rejection pub fn name(&self) -> &http::header::HeaderName { self.name } /// Reason why the header extraction has failed pub fn reason(&self) -> &TypedHeaderRejectionReason { &self.reason } } /// Additional information regarding a [`TypedHeaderRejection`] #[cfg(feature = "headers")] #[derive(Debug)] #[non_exhaustive] pub enum TypedHeaderRejectionReason { /// The header was missing from the HTTP request Missing, /// An error occured when parsing the header from the HTTP request Error(headers::Error), } impl IntoResponse for TypedHeaderRejection { fn into_response(self) -> Response { (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() } } impl std::fmt::Display for TypedHeaderRejection { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.reason { TypedHeaderRejectionReason::Missing => { write!(f, "Header of type `{}` was missing", self.name) } TypedHeaderRejectionReason::Error(err) => { write!(f, "{} ({})", err, self.name) } } } } impl std::error::Error for TypedHeaderRejection { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match &self.reason { TypedHeaderRejectionReason::Error(err) => Some(err), TypedHeaderRejectionReason::Missing => None, } } } #[cfg(test)] mod tests { use super::*; use crate::{response::IntoResponse, routing::get, test_helpers::*, Router}; #[crate::test] async fn typed_header() { async fn handle( TypedHeader(user_agent): TypedHeader, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let user_agent = user_agent.as_str(); let cookies = cookies.iter().collect::>(); format!("User-Agent={user_agent:?}, Cookie={cookies:?}") } let app = Router::new().route("/", get(handle)); let client = TestClient::new(app); let res = client .get("/") .header("user-agent", "foobar") .header("cookie", "a=1; b=2") .header("cookie", "c=3") .send() .await; let body = res.text().await; assert_eq!( body, r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"# ); let res = client.get("/").header("user-agent", "foobar").send().await; let body = res.text().await; assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#); let res = client.get("/").header("cookie", "a=1").send().await; let body = res.text().await; assert_eq!(body, "Header of type `user-agent` was missing"); } } axum-axum-v0.6.20/axum/src/util.rs000066400000000000000000000024371446276533600170210ustar00rootroot00000000000000use pin_project_lite::pin_project; use std::{ops::Deref, sync::Arc}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct PercentDecodedStr(Arc); impl PercentDecodedStr { pub(crate) fn new(s: S) -> Option where S: AsRef, { percent_encoding::percent_decode(s.as_ref().as_bytes()) .decode_utf8() .ok() .map(|decoded| Self(decoded.as_ref().into())) } pub(crate) fn as_str(&self) -> &str { &self.0 } pub(crate) fn into_inner(self) -> Arc { self.0 } } impl Deref for PercentDecodedStr { type Target = str; #[inline] fn deref(&self) -> &Self::Target { self.as_str() } } pin_project! { #[project = EitherProj] pub(crate) enum Either { A { #[pin] inner: A }, B { #[pin] inner: B }, } } pub(crate) fn try_downcast(k: K) -> Result where T: 'static, K: Send + 'static, { let mut k = Some(k); if let Some(k) = ::downcast_mut::>(&mut k) { Ok(k.take().unwrap()) } else { Err(k.unwrap()) } } #[test] fn test_try_downcast() { assert_eq!(try_downcast::(5_u32), Err(5_u32)); assert_eq!(try_downcast::(5_i32), Ok(5_i32)); } axum-axum-v0.6.20/deny.toml000066400000000000000000000017471446276533600155740ustar00rootroot00000000000000[advisories] vulnerability = "deny" unmaintained = "warn" notice = "warn" ignore = [] [licenses] unlicensed = "warn" allow = [] deny = [] copyleft = "warn" allow-osi-fsf-free = "either" confidence-threshold = 0.8 [bans] multiple-versions = "deny" highlight = "all" skip-tree = [ # currently duplicated through header, reqwest, tower-http and cookie # C.f. https://github.com/tokio-rs/axum/pull/1641 { name = "base64" }, # parking_lot pulls in old versions of windows-sys { name = "windows-sys" }, # old version pulled in by rustls via ring { name = "spin" }, # lots still pulls in syn 1.x { name = "syn" }, # pulled in by tracing-subscriber { name = "regex-syntax" }, # pulled in by tracing-subscriber { name = "regex-automata" }, # pulled in by h2 { name = "indexmap" }, # pulled in by axum, can't update without breaking changes { name = "bitflags" }, ] [sources] unknown-registry = "warn" unknown-git = "warn" allow-git = [] axum-axum-v0.6.20/examples/000077500000000000000000000000001446276533600155455ustar00rootroot00000000000000axum-axum-v0.6.20/examples/README.md000066400000000000000000000003601446276533600170230ustar00rootroot00000000000000# Examples This folder contains numerous example showing how to use axum. Each example is setup as its own crate so its dependencies are clear. For a list of what the community built with axum, please see the list [here](../ECOSYSTEM.md).axum-axum-v0.6.20/examples/anyhow-error-response/000077500000000000000000000000001446276533600220355ustar00rootroot00000000000000axum-axum-v0.6.20/examples/anyhow-error-response/Cargo.toml000066400000000000000000000003231446276533600237630ustar00rootroot00000000000000[package] name = "example-anyhow-error-response" version = "0.1.0" edition = "2021" publish = false [dependencies] anyhow = "1.0" axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } axum-axum-v0.6.20/examples/anyhow-error-response/src/000077500000000000000000000000001446276533600226245ustar00rootroot00000000000000axum-axum-v0.6.20/examples/anyhow-error-response/src/main.rs000066400000000000000000000024761446276533600241270ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-anyhow-error-response //! ``` use axum::{ http::StatusCode, response::{IntoResponse, Response}, routing::get, Router, }; use std::net::SocketAddr; #[tokio::main] async fn main() { let app = Router::new().route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler() -> Result<(), AppError> { try_thing()?; Ok(()) } fn try_thing() -> Result<(), anyhow::Error> { anyhow::bail!("it failed!") } // Make our own error that wraps `anyhow::Error`. struct AppError(anyhow::Error); // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", self.0), ) .into_response() } } // This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into // `Result<_, AppError>`. That way you don't need to do that manually. impl From for AppError where E: Into, { fn from(err: E) -> Self { Self(err.into()) } } axum-axum-v0.6.20/examples/async-graphql/000077500000000000000000000000001446276533600203165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/async-graphql/README.md000066400000000000000000000000611446276533600215720ustar00rootroot00000000000000See . axum-axum-v0.6.20/examples/chat/000077500000000000000000000000001446276533600164645ustar00rootroot00000000000000axum-axum-v0.6.20/examples/chat/Cargo.toml000066400000000000000000000005311446276533600204130ustar00rootroot00000000000000[package] name = "example-chat" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["ws"] } futures = "0.3" tokio = { version = "1", features = ["full"] } tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/chat/chat.html000066400000000000000000000035671446276533600203040ustar00rootroot00000000000000 WebSocket Chat

WebSocket Chat Example

axum-axum-v0.6.20/examples/chat/src/000077500000000000000000000000001446276533600172535ustar00rootroot00000000000000axum-axum-v0.6.20/examples/chat/src/main.rs000066400000000000000000000117451446276533600205550ustar00rootroot00000000000000//! Example chat application. //! //! Run with //! //! ```not_rust //! cargo run -p example-chat //! ``` use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, State, }, response::{Html, IntoResponse}, routing::get, Router, }; use futures::{sink::SinkExt, stream::StreamExt}; use std::{ collections::HashSet, net::SocketAddr, sync::{Arc, Mutex}, }; use tokio::sync::broadcast; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Our shared state struct AppState { // We require unique usernames. This tracks which usernames have been taken. user_set: Mutex>, // Channel used to send messages to all connected clients. tx: broadcast::Sender, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_chat=trace".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // Set up application state for use with with_state(). let user_set = Mutex::new(HashSet::new()); let (tx, _rx) = broadcast::channel(100); let app_state = Arc::new(AppState { user_set, tx }); let app = Router::new() .route("/", get(index)) .route("/websocket", get(websocket_handler)) .with_state(app_state); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn websocket_handler( ws: WebSocketUpgrade, State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } // This function deals with a single websocket connection, i.e., a single // connected client / user, for which we will spawn two independent tasks (for // receiving / sending chat messages). async fn websocket(stream: WebSocket, state: Arc) { // By splitting, we can send and receive at the same time. let (mut sender, mut receiver) = stream.split(); // Username gets set in the receive loop, if it's valid. let mut username = String::new(); // Loop until a text message is found. while let Some(Ok(message)) = receiver.next().await { if let Message::Text(name) = message { // If username that is sent by client is not taken, fill username string. check_username(&state, &mut username, &name); // If not empty we want to quit the loop else we want to quit function. if !username.is_empty() { break; } else { // Only send our client that username is taken. let _ = sender .send(Message::Text(String::from("Username already taken."))) .await; return; } } } // We subscribe *before* sending the "joined" message, so that we will also // display it to our client. let mut rx = state.tx.subscribe(); // Now send the "joined" message to all subscribers. let msg = format!("{} joined.", username); tracing::debug!("{}", msg); let _ = state.tx.send(msg); // Spawn the first task that will receive broadcast messages and send text // messages over the websocket to our client. let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { // In any websocket error, break loop. if sender.send(Message::Text(msg)).await.is_err() { break; } } }); // Clone things we want to pass (move) to the receiving task. let tx = state.tx.clone(); let name = username.clone(); // Spawn a task that takes messages from the websocket, prepends the user // name, and sends them to all broadcast subscribers. let mut recv_task = tokio::spawn(async move { while let Some(Ok(Message::Text(text))) = receiver.next().await { // Add username before message. let _ = tx.send(format!("{}: {}", name, text)); } }); // If any one of the tasks run to completion, we abort the other. tokio::select! { _ = (&mut send_task) => recv_task.abort(), _ = (&mut recv_task) => send_task.abort(), }; // Send "user left" message (similar to "joined" above). let msg = format!("{} left.", username); tracing::debug!("{}", msg); let _ = state.tx.send(msg); // Remove username from map so new clients can take it again. state.user_set.lock().unwrap().remove(&username); } fn check_username(state: &AppState, string: &mut String, name: &str) { let mut user_set = state.user_set.lock().unwrap(); if !user_set.contains(name) { user_set.insert(name.to_owned()); string.push_str(name); } } // Include utf-8 file at **compile** time. async fn index() -> Html<&'static str> { Html(std::include_str!("../chat.html")) } axum-axum-v0.6.20/examples/consume-body-in-extractor-or-middleware/000077500000000000000000000000001446276533600253175ustar00rootroot00000000000000axum-axum-v0.6.20/examples/consume-body-in-extractor-or-middleware/Cargo.toml000066400000000000000000000006231446276533600272500ustar00rootroot00000000000000[package] name = "example-consume-body-in-extractor-or-middleware" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = "0.14" tokio = { version = "1.0", features = ["full"] } tower = "0.4" tower-http = { version = "0.4.0", features = ["map-request-body", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/consume-body-in-extractor-or-middleware/src/000077500000000000000000000000001446276533600261065ustar00rootroot00000000000000axum-axum-v0.6.20/examples/consume-body-in-extractor-or-middleware/src/main.rs000066400000000000000000000055671446276533600274150ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-consume-body-in-extractor-or-middleware //! ``` use axum::{ async_trait, body::{self, BoxBody, Bytes, Full}, extract::FromRequest, http::{Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, Router, }; use std::net::SocketAddr; use tower::ServiceBuilder; use tower_http::ServiceBuilderExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_consume_body_in_extractor_or_middleware=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new().route("/", post(handler)).layer( ServiceBuilder::new() .map_request_body(body::boxed) .layer(middleware::from_fn(print_request_body)), ); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // middleware that shows how to consume the request body upfront async fn print_request_body( request: Request, next: Next, ) -> Result { let request = buffer_request_body(request).await?; Ok(next.run(request).await) } // the trick is to take the request apart, buffer the body, do what you need to do, then put // the request back together async fn buffer_request_body(request: Request) -> Result, Response> { let (parts, body) = request.into_parts(); // this wont work if the body is an long running stream let bytes = hyper::body::to_bytes(body) .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; do_thing_with_request_body(bytes.clone()); Ok(Request::from_parts(parts, body::boxed(Full::from(bytes)))) } fn do_thing_with_request_body(bytes: Bytes) { tracing::debug!(body = ?bytes); } async fn handler(BufferRequestBody(body): BufferRequestBody) { tracing::debug!(?body, "handler received body"); } // extractor that shows how to consume the request body upfront struct BufferRequestBody(Bytes); // we must implement `FromRequest` (and not `FromRequestParts`) to consume the body #[async_trait] impl FromRequest for BufferRequestBody where S: Send + Sync, { type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let body = Bytes::from_request(req, state) .await .map_err(|err| err.into_response())?; do_thing_with_request_body(body.clone()); Ok(Self(body)) } } axum-axum-v0.6.20/examples/cors/000077500000000000000000000000001446276533600165135ustar00rootroot00000000000000axum-axum-v0.6.20/examples/cors/Cargo.toml000066400000000000000000000003531446276533600204440ustar00rootroot00000000000000[package] name = "example-cors" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["cors"] } axum-axum-v0.6.20/examples/cors/src/000077500000000000000000000000001446276533600173025ustar00rootroot00000000000000axum-axum-v0.6.20/examples/cors/src/main.rs000066400000000000000000000032271446276533600206000ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-cors //! ``` use axum::{ http::{HeaderValue, Method}, response::{Html, IntoResponse}, routing::get, Json, Router, }; use std::net::SocketAddr; use tower_http::cors::CorsLayer; #[tokio::main] async fn main() { let frontend = async { let app = Router::new().route("/", get(html)); serve(app, 3000).await; }; let backend = async { let app = Router::new().route("/json", get(json)).layer( // see https://docs.rs/tower-http/latest/tower_http/cors/index.html // for more details // // pay attention that for some request types like posting content-type: application/json // it is required to add ".allow_headers([http::header::CONTENT_TYPE])" // or see this issue https://github.com/tokio-rs/axum/issues/849 CorsLayer::new() .allow_origin("http://localhost:3000".parse::().unwrap()) .allow_methods([Method::GET]), ); serve(app, 4000).await; }; tokio::join!(frontend, backend); } async fn serve(app: Router, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn html() -> impl IntoResponse { Html( r#" "#, ) } async fn json() -> impl IntoResponse { Json(vec!["one", "two", "three"]) } axum-axum-v0.6.20/examples/customize-extractor-error/000077500000000000000000000000001446276533600227275ustar00rootroot00000000000000axum-axum-v0.6.20/examples/customize-extractor-error/Cargo.toml000066400000000000000000000007151446276533600246620ustar00rootroot00000000000000[package] name = "example-customize-extractor-error" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } axum-extra = { path = "../../axum-extra" } axum-macros = { path = "../../axum-macros" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" tokio = { version = "1.20", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/customize-extractor-error/README.md000066400000000000000000000010661446276533600242110ustar00rootroot00000000000000This example explores 3 different ways you can create custom rejections for already existing extractors - [`with_rejection`](src/with_rejection.rs): Uses `axum_extra::extract::WithRejection` to transform one rejection into another - [`derive_from_request`](src/derive_from_request.rs): Uses `axum_macros::FromRequest` to wrap another extractor and customize the rejection - [`custom_extractor`](src/custom_extractor.rs): Manual implementation of `FromRequest` that wraps another extractor Run with ```sh cargo run -p example-customize-extractor-error ``` axum-axum-v0.6.20/examples/customize-extractor-error/src/000077500000000000000000000000001446276533600235165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/customize-extractor-error/src/custom_extractor.rs000066400000000000000000000042171446276533600274750ustar00rootroot00000000000000//! Manual implementation of `FromRequest` that wraps another extractor //! //! + Powerful API: Implementing `FromRequest` grants access to `RequestParts` //! and `async/await`. This means that you can create more powerful rejections //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code use axum::{ async_trait, extract::{rejection::JsonRejection, FromRequest, MatchedPath}, http::Request, http::StatusCode, response::IntoResponse, RequestPartsExt, }; use serde_json::{json, Value}; pub async fn handler(Json(value): Json) -> impl IntoResponse { Json(dbg!(value)); } // We define our own `Json` extractor that customizes the error from `axum::Json` pub struct Json(pub T); #[async_trait] impl FromRequest for Json where axum::Json: FromRequest, S: Send + Sync, B: Send + 'static, { type Rejection = (StatusCode, axum::Json); async fn from_request(req: Request, state: &S) -> Result { let (mut parts, body) = req.into_parts(); // We can use other extractors to provide better rejection messages. // For example, here we are using `axum::extract::MatchedPath` to // provide a better error message. // // Have to run that first since `Json` extraction consumes the request. let path = parts .extract::() .await .map(|path| path.as_str().to_owned()) .ok(); let req = Request::from_parts(parts, body); match axum::Json::::from_request(req, state).await { Ok(value) => Ok(Self(value.0)), // convert the error from `axum::Json` into whatever we want Err(rejection) => { let payload = json!({ "message": rejection.body_text(), "origin": "custom_extractor", "path": path, }); Err((rejection.status(), axum::Json(payload))) } } } } axum-axum-v0.6.20/examples/customize-extractor-error/src/derive_from_request.rs000066400000000000000000000036011446276533600301350ustar00rootroot00000000000000//! Uses `axum_macros::FromRequest` to wrap another extractor and customize the //! rejection //! //! + Easy learning curve: Deriving `FromRequest` generates a `FromRequest` //! implementation for your type using another extractor. You only need //! to provide a `From` impl between the original rejection type and the //! target rejection. Crates like [`thiserror`] can provide such conversion //! using derive macros. //! - Boilerplate: Requires deriving `FromRequest` for every custom rejection //! - There are some known limitations: [FromRequest#known-limitations] //! //! [`thiserror`]: https://crates.io/crates/thiserror //! [FromRequest#known-limitations]: https://docs.rs/axum-macros/*/axum_macros/derive.FromRequest.html#known-limitations use axum::{extract::rejection::JsonRejection, http::StatusCode, response::IntoResponse}; use axum_macros::FromRequest; use serde_json::{json, Value}; pub async fn handler(Json(value): Json) -> impl IntoResponse { Json(dbg!(value)); } // create an extractor that internally uses `axum::Json` but has a custom rejection #[derive(FromRequest)] #[from_request(via(axum::Json), rejection(ApiError))] pub struct Json(T); // We create our own rejection type #[derive(Debug)] pub struct ApiError { status: StatusCode, message: String, } // We implement `From for ApiError` impl From for ApiError { fn from(rejection: JsonRejection) -> Self { Self { status: rejection.status(), message: rejection.body_text(), } } } // We implement `IntoResponse` so `ApiError` can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let payload = json!({ "message": self.message, "origin": "derive_from_request" }); (self.status, axum::Json(payload)).into_response() } } axum-axum-v0.6.20/examples/customize-extractor-error/src/main.rs000066400000000000000000000021511446276533600250070ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-customize-extractor-error //! ``` mod custom_extractor; mod derive_from_request; mod with_rejection; use axum::{routing::post, Router}; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_customize_extractor_error=trace".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // Build our application with some routes let app = Router::new() .route("/with-rejection", post(with_rejection::handler)) .route("/custom-extractor", post(custom_extractor::handler)) .route("/derive-from-request", post(derive_from_request::handler)); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } axum-axum-v0.6.20/examples/customize-extractor-error/src/with_rejection.rs000066400000000000000000000041711446276533600271040ustar00rootroot00000000000000//! Uses `axum_extra::extract::WithRejection` to transform one rejection into //! another //! //! + Easy learning curve: `WithRejection` acts as a wrapper for another //! already existing extractor. You only need to provide a `From` impl //! between the original rejection type and the target rejection. Crates like //! `thiserror` can provide such conversion using derive macros. See //! [`thiserror`] //! - Verbose types: types become much larger, which makes them difficult to //! read. Current limitations on type aliasing makes impossible to destructure //! a type alias. See [#1116] //! //! [`thiserror`]: https://crates.io/crates/thiserror //! [#1116]: https://github.com/tokio-rs/axum/issues/1116#issuecomment-1186197684 use axum::{extract::rejection::JsonRejection, response::IntoResponse, Json}; use axum_extra::extract::WithRejection; use serde_json::{json, Value}; use thiserror::Error; pub async fn handler( // `WithRejection` will extract `Json` from the request. If it fails, // `JsonRejection` will be transform into `ApiError` and returned as response // to the client. // // The second constructor argument is not meaningful and can be safely ignored WithRejection(Json(value), _): WithRejection, ApiError>, ) -> impl IntoResponse { Json(dbg!(value)) } // We derive `thiserror::Error` #[derive(Debug, Error)] pub enum ApiError { // The `#[from]` attribute generates `From for ApiError` // implementation. See `thiserror` docs for more information #[error(transparent)] JsonExtractorRejection(#[from] JsonRejection), } // We implement `IntoResponse` so ApiError can be used as a response impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let (status, message) = match self { ApiError::JsonExtractorRejection(json_rejection) => { (json_rejection.status(), json_rejection.body_text()) } }; let payload = json!({ "message": message, "origin": "with_rejection" }); (status, Json(payload)).into_response() } } axum-axum-v0.6.20/examples/customize-path-rejection/000077500000000000000000000000001446276533600225015ustar00rootroot00000000000000axum-axum-v0.6.20/examples/customize-path-rejection/Cargo.toml000066400000000000000000000005411446276533600244310ustar00rootroot00000000000000[package] name = "example-customize-path-rejection" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/customize-path-rejection/src/000077500000000000000000000000001446276533600232705ustar00rootroot00000000000000axum-axum-v0.6.20/examples/customize-path-rejection/src/main.rs000066400000000000000000000122121446276533600245600ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-customize-path-rejection //! ``` use axum::{ async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, routing::get, Router, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_customize_path_rejection=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new().route("/users/:user_id/teams/:team_id", get(handler)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler(Path(params): Path) -> impl IntoResponse { axum::Json(params) } #[derive(Debug, Deserialize, Serialize)] struct Params { user_id: u32, team_id: u32, } // We define our own `Path` extractor that customizes the error from `axum::extract::Path` struct Path(T); #[async_trait] impl FromRequestParts for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, S: Send + Sync, { type Rejection = (StatusCode, axum::Json); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { match axum::extract::Path::::from_request_parts(parts, state).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { let (status, body) = match rejection { PathRejection::FailedToDeserializePathParams(inner) => { let mut status = StatusCode::BAD_REQUEST; let kind = inner.into_kind(); let body = match &kind { ErrorKind::WrongNumberOfParameters { .. } => PathError { message: kind.to_string(), location: None, }, ErrorKind::ParseErrorAtKey { key, .. } => PathError { message: kind.to_string(), location: Some(key.clone()), }, ErrorKind::ParseErrorAtIndex { index, .. } => PathError { message: kind.to_string(), location: Some(index.to_string()), }, ErrorKind::ParseError { .. } => PathError { message: kind.to_string(), location: None, }, ErrorKind::InvalidUtf8InPathParam { key } => PathError { message: kind.to_string(), location: Some(key.clone()), }, ErrorKind::UnsupportedType { .. } => { // this error is caused by the programmer using an unsupported type // (such as nested maps) so respond with `500` instead status = StatusCode::INTERNAL_SERVER_ERROR; PathError { message: kind.to_string(), location: None, } } ErrorKind::Message(msg) => PathError { message: msg.clone(), location: None, }, _ => PathError { message: format!("Unhandled deserialization error: {}", kind), location: None, }, }; (status, body) } PathRejection::MissingPathParams(error) => ( StatusCode::INTERNAL_SERVER_ERROR, PathError { message: error.to_string(), location: None, }, ), _ => ( StatusCode::INTERNAL_SERVER_ERROR, PathError { message: format!("Unhandled path rejection: {}", rejection), location: None, }, ), }; Err((status, axum::Json(body))) } } } } #[derive(Serialize)] struct PathError { message: String, location: Option, } axum-axum-v0.6.20/examples/diesel-async-postgres/000077500000000000000000000000001446276533600217715ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-async-postgres/Cargo.toml000066400000000000000000000007451446276533600237270ustar00rootroot00000000000000[package] name = "example-diesel-async-postgres" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } axum-macros = { path = "../../axum-macros" } bb8 = "0.8" diesel = "2" diesel-async = { version = "0.3", features = ["postgres", "bb8"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/diesel-async-postgres/migrations/000077500000000000000000000000001446276533600241455ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-async-postgres/migrations/2023-03-14-180127_add_users/000077500000000000000000000000001446276533600277665ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-async-postgres/migrations/2023-03-14-180127_add_users/down.sql000066400000000000000000000001021446276533600314470ustar00rootroot00000000000000-- This file should undo anything in "up.sql" DROP TABLE "users"; axum-axum-v0.6.20/examples/diesel-async-postgres/migrations/2023-03-14-180127_add_users/up.sql000066400000000000000000000001741446276533600311350ustar00rootroot00000000000000-- Your SQL goes here CREATE TABLE "users"( "id" SERIAL PRIMARY KEY, "name" TEXT NOT NULL, "hair_color" TEXT ); axum-axum-v0.6.20/examples/diesel-async-postgres/src/000077500000000000000000000000001446276533600225605ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-async-postgres/src/main.rs000066400000000000000000000074771446276533600240710ustar00rootroot00000000000000//! Run with //! //! ```sh //! export DATABASE_URL=postgres://localhost/your_db //! diesel migration run //! cargo run -p example-diesel-async-postgres //! ``` //! //! Checkout the [diesel webpage](https://diesel.rs) for //! longer guides about diesel //! //! Checkout the [crates.io source code](https://github.com/rust-lang/crates.io/) //! for a real world application using axum and diesel use axum::{ async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, response::Json, routing::{get, post}, Router, }; use diesel::prelude::*; use diesel_async::{ pooled_connection::AsyncDieselConnectionManager, AsyncPgConnection, RunQueryDsl, }; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // normally part of your generated schema.rs file table! { users (id) { id -> Integer, name -> Text, hair_color -> Nullable, } } #[derive(serde::Serialize, Selectable, Queryable)] struct User { id: i32, name: String, hair_color: Option, } #[derive(serde::Deserialize, Insertable)] #[diesel(table_name = users)] struct NewUser { name: String, hair_color: Option, } type Pool = bb8::Pool>; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_diesel_async_postgres=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_url = std::env::var("DATABASE_URL").unwrap(); // set up connection pool let config = AsyncDieselConnectionManager::::new(db_url); let pool = bb8::Pool::builder().build(config).await.unwrap(); // build our application with some routes let app = Router::new() .route("/user/list", get(list_users)) .route("/user/create", post(create_user)) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn create_user( State(pool): State, Json(new_user): Json, ) -> Result, (StatusCode, String)> { let mut conn = pool.get().await.map_err(internal_error)?; let res = diesel::insert_into(users::table) .values(new_user) .returning(User::as_returning()) .get_result(&mut conn) .await .map_err(internal_error)?; Ok(Json(res)) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection( bb8::PooledConnection<'static, AsyncDieselConnectionManager>, ); #[async_trait] impl FromRequestParts for DatabaseConnection where S: Send + Sync, Pool: FromRef, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { let pool = Pool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn list_users( DatabaseConnection(mut conn): DatabaseConnection, ) -> Result>, (StatusCode, String)> { let res = users::table .select(User::as_select()) .load(&mut conn) .await .map_err(internal_error)?; Ok(Json(res)) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } axum-axum-v0.6.20/examples/diesel-postgres/000077500000000000000000000000001446276533600206565ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-postgres/Cargo.toml000066400000000000000000000010201446276533600225770ustar00rootroot00000000000000[package] name = "example-diesel-postgres" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } axum-macros = { path = "../../axum-macros" } deadpool-diesel = { version = "0.4.1", features = ["postgres"] } diesel = { version = "2", features = ["postgres"] } diesel_migrations = "2" serde = { version = "1.0", features = ["derive"] } serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/diesel-postgres/migrations/000077500000000000000000000000001446276533600230325ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-postgres/migrations/2023-03-14-180127_add_users/000077500000000000000000000000001446276533600266535ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-postgres/migrations/2023-03-14-180127_add_users/down.sql000066400000000000000000000001021446276533600303340ustar00rootroot00000000000000-- This file should undo anything in "up.sql" DROP TABLE "users"; axum-axum-v0.6.20/examples/diesel-postgres/migrations/2023-03-14-180127_add_users/up.sql000066400000000000000000000001741446276533600300220ustar00rootroot00000000000000-- Your SQL goes here CREATE TABLE "users"( "id" SERIAL PRIMARY KEY, "name" TEXT NOT NULL, "hair_color" TEXT ); axum-axum-v0.6.20/examples/diesel-postgres/src/000077500000000000000000000000001446276533600214455ustar00rootroot00000000000000axum-axum-v0.6.20/examples/diesel-postgres/src/main.rs000066400000000000000000000072141446276533600227430ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-diesel-postgres //! ``` //! //! Checkout the [diesel webpage](https://diesel.rs) for //! longer guides about diesel //! //! Checkout the [crates.io source code](https://github.com/rust-lang/crates.io/) //! for a real world application using axum and diesel use axum::{ extract::State, http::StatusCode, response::Json, routing::{get, post}, Router, }; use diesel::prelude::*; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // this embeddes the migrations into the application binary // the migration path is releative to the `CARGO_MANIFEST_DIR` pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/"); // normally part of your generated schema.rs file table! { users (id) { id -> Integer, name -> Text, hair_color -> Nullable, } } #[derive(serde::Serialize, Selectable, Queryable)] struct User { id: i32, name: String, hair_color: Option, } #[derive(serde::Deserialize, Insertable)] #[diesel(table_name = users)] struct NewUser { name: String, hair_color: Option, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_url = std::env::var("DATABASE_URL").unwrap(); // set up connection pool let manager = deadpool_diesel::postgres::Manager::new(db_url, deadpool_diesel::Runtime::Tokio1); let pool = deadpool_diesel::postgres::Pool::builder(manager) .build() .unwrap(); // run the migrations on server startup { let conn = pool.get().await.unwrap(); conn.interact(|conn| conn.run_pending_migrations(MIGRATIONS).map(|_| ())) .await .unwrap() .unwrap(); } // build our application with some routes let app = Router::new() .route("/user/list", get(list_users)) .route("/user/create", post(create_user)) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn create_user( State(pool): State, Json(new_user): Json, ) -> Result, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; let res = conn .interact(|conn| { diesel::insert_into(users::table) .values(new_user) .returning(User::as_returning()) .get_result(conn) }) .await .map_err(internal_error)? .map_err(internal_error)?; Ok(Json(res)) } async fn list_users( State(pool): State, ) -> Result>, (StatusCode, String)> { let conn = pool.get().await.map_err(internal_error)?; let res = conn .interact(|conn| users::table.select(User::as_select()).load(conn)) .await .map_err(internal_error)? .map_err(internal_error)?; Ok(Json(res)) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } axum-axum-v0.6.20/examples/error-handling-and-dependency-injection/000077500000000000000000000000001446276533600253145ustar00rootroot00000000000000axum-axum-v0.6.20/examples/error-handling-and-dependency-injection/Cargo.toml000066400000000000000000000007301446276533600272440ustar00rootroot00000000000000[package] name = "example-error-handling-and-dependency-injection" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.0", features = ["v4", "serde"] } axum-axum-v0.6.20/examples/error-handling-and-dependency-injection/src/000077500000000000000000000000001446276533600261035ustar00rootroot00000000000000axum-axum-v0.6.20/examples/error-handling-and-dependency-injection/src/main.rs000066400000000000000000000103061446276533600273750ustar00rootroot00000000000000//! Example showing how to convert errors into responses and how one might do //! dependency injection using trait objects. //! //! Run with //! //! ```not_rust //! cargo run -p example-error-handling-and-dependency-injection //! ``` use axum::{ async_trait, extract::{Path, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use serde_json::json; use std::{net::SocketAddr, sync::Arc}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_error_handling_and_dependency_injection=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // Inject a `UserRepo` into our handlers via a trait object. This could be // the live implementation or just a mock for testing. let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; // Build our application with some routes let app = Router::new() .route("/users/:id", get(users_show)) .route("/users", post(users_create)) .with_state(user_repo); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } /// Handler for `GET /users/:id`. /// /// Extracts the user repo from request extensions and calls it. `UserRepoError`s /// are automatically converted into `AppError` which implements `IntoResponse` /// so it can be returned from handlers directly. async fn users_show( Path(user_id): Path, State(user_repo): State, ) -> Result, AppError> { let user = user_repo.find(user_id).await?; Ok(user.into()) } /// Handler for `POST /users`. async fn users_create( State(user_repo): State, Json(params): Json, ) -> Result, AppError> { let user = user_repo.create(params).await?; Ok(user.into()) } /// Our app's top level error type. enum AppError { /// Something went wrong when calling the user repo. UserRepo(UserRepoError), } /// This makes it possible to use `?` to automatically convert a `UserRepoError` /// into an `AppError`. impl From for AppError { fn from(inner: UserRepoError) -> Self { AppError::UserRepo(inner) } } impl IntoResponse for AppError { fn into_response(self) -> Response { let (status, error_message) = match self { AppError::UserRepo(UserRepoError::NotFound) => { (StatusCode::NOT_FOUND, "User not found") } AppError::UserRepo(UserRepoError::InvalidUsername) => { (StatusCode::UNPROCESSABLE_ENTITY, "Invalid username") } }; let body = Json(json!({ "error": error_message, })); (status, body).into_response() } } /// Example implementation of `UserRepo`. struct ExampleUserRepo; #[async_trait] impl UserRepo for ExampleUserRepo { async fn find(&self, _user_id: Uuid) -> Result { unimplemented!() } async fn create(&self, _params: CreateUser) -> Result { unimplemented!() } } /// Type alias that makes it easier to extract `UserRepo` trait objects. type DynUserRepo = Arc; /// A trait that defines things a user repo might support. #[async_trait] trait UserRepo { /// Loop up a user by their id. async fn find(&self, user_id: Uuid) -> Result; /// Create a new user. async fn create(&self, params: CreateUser) -> Result; } #[derive(Debug, Serialize)] struct User { id: Uuid, username: String, } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct CreateUser { username: String, } /// Errors that can happen when using the user repo. #[derive(Debug)] enum UserRepoError { #[allow(dead_code)] NotFound, #[allow(dead_code)] InvalidUsername, } axum-axum-v0.6.20/examples/form/000077500000000000000000000000001446276533600165105ustar00rootroot00000000000000axum-axum-v0.6.20/examples/form/Cargo.toml000066400000000000000000000004721446276533600204430ustar00rootroot00000000000000[package] name = "example-form" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/form/src/000077500000000000000000000000001446276533600172775ustar00rootroot00000000000000axum-axum-v0.6.20/examples/form/src/main.rs000066400000000000000000000033221446276533600205710ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-form //! ``` use axum::{extract::Form, response::Html, routing::get, Router}; use serde::Deserialize; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_form=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = Router::new().route("/", get(show_form).post(accept_form)); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn show_form() -> Html<&'static str> { Html( r#"
"#, ) } #[derive(Deserialize, Debug)] #[allow(dead_code)] struct Input { name: String, email: String, } async fn accept_form(Form(input): Form) { dbg!(&input); } axum-axum-v0.6.20/examples/global-404-handler/000077500000000000000000000000001446276533600207255ustar00rootroot00000000000000axum-axum-v0.6.20/examples/global-404-handler/Cargo.toml000066400000000000000000000005061446276533600226560ustar00rootroot00000000000000[package] name = "example-global-404-handler" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/global-404-handler/src/000077500000000000000000000000001446276533600215145ustar00rootroot00000000000000axum-axum-v0.6.20/examples/global-404-handler/src/main.rs000066400000000000000000000022641446276533600230120ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-global-404-handler //! ``` use axum::{ http::StatusCode, response::{Html, IntoResponse}, routing::get, Router, }; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_global_404_handler=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new().route("/", get(handler)); // add a fallback service for handling routes to unknown paths let app = app.fallback(handler_404); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler() -> Html<&'static str> { Html("

Hello, World!

") } async fn handler_404() -> impl IntoResponse { (StatusCode::NOT_FOUND, "nothing to see here") } axum-axum-v0.6.20/examples/graceful-shutdown/000077500000000000000000000000001446276533600212065ustar00rootroot00000000000000axum-axum-v0.6.20/examples/graceful-shutdown/Cargo.toml000066400000000000000000000003001446276533600231270ustar00rootroot00000000000000[package] name = "example-graceful-shutdown" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } axum-axum-v0.6.20/examples/graceful-shutdown/src/000077500000000000000000000000001446276533600217755ustar00rootroot00000000000000axum-axum-v0.6.20/examples/graceful-shutdown/src/main.rs000066400000000000000000000023761446276533600232770ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-graceful-shutdown //! kill or ctrl-c //! ``` use axum::{response::Html, routing::get, Router}; use std::net::SocketAddr; use tokio::signal; #[tokio::main] async fn main() { // build our application with a route let app = Router::new().route("/", get(handler)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } async fn handler() -> Html<&'static str> { Html("

Hello, World!

") } async fn shutdown_signal() { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } println!("signal received, starting graceful shutdown"); } axum-axum-v0.6.20/examples/handle-head-request/000077500000000000000000000000001446276533600213655ustar00rootroot00000000000000axum-axum-v0.6.20/examples/handle-head-request/Cargo.toml000066400000000000000000000004711446276533600233170ustar00rootroot00000000000000[package] name = "example-handle-head-request" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } [dev-dependencies] hyper = { version = "0.14", features = ["full"] } tower = { version = "0.4", features = ["util"] } axum-axum-v0.6.20/examples/handle-head-request/src/000077500000000000000000000000001446276533600221545ustar00rootroot00000000000000axum-axum-v0.6.20/examples/handle-head-request/src/main.rs000066400000000000000000000043311446276533600234470ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-handle-head-request //! ``` use axum::response::{IntoResponse, Response}; use axum::{http, routing::get, Router}; use std::net::SocketAddr; fn app() -> Router { Router::new().route("/get-head", get(get_head_handler)) } #[tokio::main] async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); axum::Server::bind(&addr) .serve(app().into_make_service()) .await .unwrap(); } // GET routes will also be called for HEAD requests but will have the response body removed. // You can handle the HEAD method explicitly by extracting `http::Method` from the request. async fn get_head_handler(method: http::Method) -> Response { // it usually only makes sense to special-case HEAD // if computing the body has some relevant cost if method == http::Method::HEAD { return ([("x-some-header", "header from HEAD")]).into_response(); } // then do some computing task in GET do_some_computing_task(); ([("x-some-header", "header from GET")], "body from GET").into_response() } fn do_some_computing_task() { // TODO } #[cfg(test)] mod tests { use super::*; use axum::body::Body; use axum::http::{Request, StatusCode}; use tower::ServiceExt; #[tokio::test] async fn test_get() { let app = app(); let response = app .oneshot(Request::get("/get-head").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.headers()["x-some-header"], "header from GET"); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert_eq!(&body[..], b"body from GET"); } #[tokio::test] async fn test_implicit_head() { let app = app(); let response = app .oneshot(Request::head("/get-head").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!(response.headers()["x-some-header"], "header from HEAD"); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert!(body.is_empty()); } } axum-axum-v0.6.20/examples/hello-world/000077500000000000000000000000001446276533600177755ustar00rootroot00000000000000axum-axum-v0.6.20/examples/hello-world/Cargo.toml000066400000000000000000000002721446276533600217260ustar00rootroot00000000000000[package] name = "example-hello-world" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } axum-axum-v0.6.20/examples/hello-world/src/000077500000000000000000000000001446276533600205645ustar00rootroot00000000000000axum-axum-v0.6.20/examples/hello-world/src/main.rs000066400000000000000000000011071446276533600220550ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-hello-world //! ``` use axum::{response::Html, routing::get, Router}; use std::net::SocketAddr; #[tokio::main] async fn main() { // build our application with a route let app = Router::new().route("/", get(handler)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler() -> Html<&'static str> { Html("

Hello, World!

") } axum-axum-v0.6.20/examples/http-proxy/000077500000000000000000000000001446276533600177035ustar00rootroot00000000000000axum-axum-v0.6.20/examples/http-proxy/Cargo.toml000066400000000000000000000005601446276533600216340ustar00rootroot00000000000000[package] name = "example-http-proxy" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = { version = "0.14", features = ["full"] } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["make"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/http-proxy/src/000077500000000000000000000000001446276533600204725ustar00rootroot00000000000000axum-axum-v0.6.20/examples/http-proxy/src/main.rs000066400000000000000000000055461446276533600217760ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! $ cargo run -p example-http-proxy //! ``` //! //! In another terminal: //! //! ```not_rust //! $ curl -v -x "127.0.0.1:3000" https://tokio.rs //! ``` //! //! Example is based on use axum::{ body::{self, Body}, http::{Method, Request, StatusCode}, response::{IntoResponse, Response}, routing::get, Router, }; use hyper::upgrade::Upgraded; use std::net::SocketAddr; use tokio::net::TcpStream; use tower::{make::Shared, ServiceExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let router_svc = Router::new().route("/", get(|| async { "Hello, World!" })); let service = tower::service_fn(move |req: Request| { let router_svc = router_svc.clone(); async move { if req.method() == Method::CONNECT { proxy(req).await } else { router_svc.oneshot(req).await.map_err(|err| match err {}) } } }); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .http1_preserve_header_case(true) .http1_title_case_headers(true) .serve(Shared::new(service)) .await .unwrap(); } async fn proxy(req: Request) -> Result { tracing::trace!(?req); if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) { tokio::task::spawn(async move { match hyper::upgrade::on(req).await { Ok(upgraded) => { if let Err(e) = tunnel(upgraded, host_addr).await { tracing::warn!("server io error: {}", e); }; } Err(e) => tracing::warn!("upgrade error: {}", e), } }); Ok(Response::new(body::boxed(body::Empty::new()))) } else { tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri()); Ok(( StatusCode::BAD_REQUEST, "CONNECT must be to a socket address", ) .into_response()) } } async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> { let mut server = TcpStream::connect(addr).await?; let (from_client, from_server) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; tracing::debug!( "client wrote {} bytes and received {} bytes", from_client, from_server ); Ok(()) } axum-axum-v0.6.20/examples/hyper-1-0/000077500000000000000000000000001446276533600171675ustar00rootroot00000000000000axum-axum-v0.6.20/examples/hyper-1-0/Cargo.toml000066400000000000000000000010371446276533600211200ustar00rootroot00000000000000[package] name = "example-hyper-1-0" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = { version = "1.0.0-rc.4", features = ["full"] } hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "f898015" } tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4", features = ["trace"] } tower-hyper-http-body-compat = { version = "0.2", features = ["http1", "server"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/hyper-1-0/src/000077500000000000000000000000001446276533600177565ustar00rootroot00000000000000axum-axum-v0.6.20/examples/hyper-1-0/src/main.rs000066400000000000000000000036731446276533600212610ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-hyper-1-0 //! ``` use axum::{routing::get, Router}; use std::net::SocketAddr; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; use tower_hyper_http_body_compat::{ HttpBody1ToHttpBody04, TowerService03HttpServiceAsHyper1HttpService, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // this is hyper 1.0 use hyper::{body::Incoming, server::conn::http1}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_hyper_1_0=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // you have to use `HttpBody1ToHttpBody04` as the second type parameter to `Router` let app: Router<_, HttpBody1ToHttpBody04> = Router::new() .route("/", get(|| async { "Hello, World!" })) // we can still add regular tower middleware .layer(TraceLayer::new_for_http()); // `Router` implements tower-service 0.3's `Service` trait. Convert that to something // that implements hyper 1.0's `Service` trait. let service = TowerService03HttpServiceAsHyper1HttpService::new(app); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let tcp_listener = TcpListener::bind(addr).await.unwrap(); tracing::debug!("listening on {addr}"); loop { let (tcp_stream, _) = tcp_listener.accept().await.unwrap(); let tcp_stream = hyper_util::rt::TokioIo::new(tcp_stream); let service = service.clone(); tokio::task::spawn(async move { if let Err(http_err) = http1::Builder::new() .keep_alive(true) .serve_connection(tcp_stream, service) .await { eprintln!("Error while serving HTTP connection: {http_err}"); } }); } } axum-axum-v0.6.20/examples/jwt/000077500000000000000000000000001446276533600163515ustar00rootroot00000000000000axum-axum-v0.6.20/examples/jwt/Cargo.toml000066400000000000000000000006331446276533600203030ustar00rootroot00000000000000[package] name = "example-jwt" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["headers"] } headers = "0.3" jsonwebtoken = "8.0" once_cell = "1.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/jwt/src/000077500000000000000000000000001446276533600171405ustar00rootroot00000000000000axum-axum-v0.6.20/examples/jwt/src/main.rs000066400000000000000000000131621446276533600204350ustar00rootroot00000000000000//! Example JWT authorization/authentication. //! //! Run with //! //! ```not_rust //! JWT_SECRET=secret cargo run -p example-jwt //! ``` use axum::{ async_trait, extract::{FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{get, post}, Json, RequestPartsExt, Router, }; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::json; use std::{fmt::Display, net::SocketAddr}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Quick instructions // // - get an authorization token: // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -d '{"client_id":"foo","client_secret":"bar"}' \ // http://localhost:3000/authorize // // - visit the protected area using the authorized token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \ // http://localhost:3000/protected // // - try to visit the protected area using an invalid token // // curl -s \ // -w '\n' \ // -H 'Content-Type: application/json' \ // -H 'Authorization: Bearer blahblahblah' \ // http://localhost:3000/protected static KEYS: Lazy = Lazy::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::new(secret.as_bytes()) }); #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_jwt=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/protected", get(protected)) .route("/authorize", post(authorize)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn protected(claims: Claims) -> Result { // Send the protected data to the user Ok(format!( "Welcome to the protected area :)\nYour data:\n{}", claims )) } async fn authorize(Json(payload): Json) -> Result, AuthError> { // Check if the user sent the credentials if payload.client_id.is_empty() || payload.client_secret.is_empty() { return Err(AuthError::MissingCredentials); } // Here you can check the user credentials from a database if payload.client_id != "foo" || payload.client_secret != "bar" { return Err(AuthError::WrongCredentials); } let claims = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), // Mandatory expiry time as UTC timestamp exp: 2000000000, // May 2033 }; // Create the authorization token let token = encode(&Header::default(), &claims, &KEYS.encoding) .map_err(|_| AuthError::TokenCreation)?; // Send the authorized token Ok(Json(AuthBody::new(token))) } impl Display for Claims { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Email: {}\nCompany: {}", self.sub, self.company) } } impl AuthBody { fn new(access_token: String) -> Self { Self { access_token, token_type: "Bearer".to_string(), } } } #[async_trait] impl FromRequestParts for Claims where S: Send + Sync, { type Rejection = AuthError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|_| AuthError::InvalidToken)?; // Decode the user data let token_data = decode::(bearer.token(), &KEYS.decoding, &Validation::default()) .map_err(|_| AuthError::InvalidToken)?; Ok(token_data.claims) } } impl IntoResponse for AuthError { fn into_response(self) -> Response { let (status, error_message) = match self { AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), }; let body = Json(json!({ "error": error_message, })); (status, body).into_response() } } struct Keys { encoding: EncodingKey, decoding: DecodingKey, } impl Keys { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), decoding: DecodingKey::from_secret(secret), } } } #[derive(Debug, Serialize, Deserialize)] struct Claims { sub: String, company: String, exp: usize, } #[derive(Debug, Serialize)] struct AuthBody { access_token: String, token_type: String, } #[derive(Debug, Deserialize)] struct AuthPayload { client_id: String, client_secret: String, } #[derive(Debug)] enum AuthError { WrongCredentials, MissingCredentials, TokenCreation, InvalidToken, } axum-axum-v0.6.20/examples/key-value-store/000077500000000000000000000000001446276533600206015ustar00rootroot00000000000000axum-axum-v0.6.20/examples/key-value-store/Cargo.toml000066400000000000000000000010001446276533600225200ustar00rootroot00000000000000[package] name = "example-key-value-store" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] } tower-http = { version = "0.4.0", features = [ "add-extension", "auth", "compression-full", "limit", "trace", ] } tower-layer = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/key-value-store/src/000077500000000000000000000000001446276533600213705ustar00rootroot00000000000000axum-axum-v0.6.20/examples/key-value-store/src/main.rs000066400000000000000000000104101446276533600226560ustar00rootroot00000000000000//! Simple in-memory key/value store showing features of axum. //! //! Run with: //! //! ```not_rust //! cargo run -p example-key-value-store //! ``` use axum::{ body::Bytes, error_handling::HandleErrorLayer, extract::{DefaultBodyLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, routing::{delete, get}, Router, }; use std::{ borrow::Cow, collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}, time::Duration, }; use tower::{BoxError, ServiceBuilder}; use tower_http::{ compression::CompressionLayer, limit::RequestBodyLimitLayer, trace::TraceLayer, validate_request::ValidateRequestHeaderLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_key_value_store=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let shared_state = SharedState::default(); // Build our application by composing routes let app = Router::new() .route( "/:key", // Add compression to `kv_get` get(kv_get.layer(CompressionLayer::new())) // But don't compress `kv_set` .post_service( kv_set .layer(( DefaultBodyLimit::disable(), RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */), )) .with_state(Arc::clone(&shared_state)), ), ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` .nest("/admin", admin_routes()) // Add middleware to all routes .layer( ServiceBuilder::new() // Handle errors from middleware .layer(HandleErrorLayer::new(handle_error)) .load_shed() .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()), ) .with_state(Arc::clone(&shared_state)); // Run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } type SharedState = Arc>; #[derive(Default)] struct AppState { db: HashMap, } async fn kv_get( Path(key): Path, State(state): State, ) -> Result { let db = &state.read().unwrap().db; if let Some(value) = db.get(&key) { Ok(value.clone()) } else { Err(StatusCode::NOT_FOUND) } } async fn kv_set(Path(key): Path, State(state): State, bytes: Bytes) { state.write().unwrap().db.insert(key, bytes); } async fn list_keys(State(state): State) -> String { let db = &state.read().unwrap().db; db.keys() .map(|key| key.to_string()) .collect::>() .join("\n") } fn admin_routes() -> Router { async fn delete_all_keys(State(state): State) { state.write().unwrap().db.clear(); } async fn remove_key(Path(key): Path, State(state): State) { state.write().unwrap().db.remove(&key); } Router::new() .route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require bearer auth for all admin routes .layer(ValidateRequestHeaderLayer::bearer("secret-token")) } async fn handle_error(error: BoxError) -> impl IntoResponse { if error.is::() { return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); } if error.is::() { return ( StatusCode::SERVICE_UNAVAILABLE, Cow::from("service is overloaded, try again later"), ); } ( StatusCode::INTERNAL_SERVER_ERROR, Cow::from(format!("Unhandled internal error: {}", error)), ) } axum-axum-v0.6.20/examples/listen-multiple-addrs/000077500000000000000000000000001446276533600217675ustar00rootroot00000000000000axum-axum-v0.6.20/examples/listen-multiple-addrs/Cargo.toml000066400000000000000000000003541446276533600237210ustar00rootroot00000000000000[package] name = "listen-multiple-addrs" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = { version = "0.14", features = ["full"] } tokio = { version = "1", features = ["full"] } axum-axum-v0.6.20/examples/listen-multiple-addrs/src/000077500000000000000000000000001446276533600225565ustar00rootroot00000000000000axum-axum-v0.6.20/examples/listen-multiple-addrs/src/main.rs000066400000000000000000000034161446276533600240540ustar00rootroot00000000000000//! Showcases how listening on multiple addrs is possible by //! implementing Accept for a custom struct. //! //! This may be useful in cases where the platform does not //! listen on both IPv4 and IPv6 when the IPv6 catch-all listener is used (`::`), //! [like older versions of Windows.](https://docs.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets) use axum::{routing::get, Router}; use hyper::server::{accept::Accept, conn::AddrIncoming}; use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, task::{Context, Poll}, }; #[tokio::main] async fn main() { let app = Router::new().route("/", get(|| async { "Hello, World!" })); let localhost_v4 = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080); let incoming_v4 = AddrIncoming::bind(&localhost_v4).unwrap(); let localhost_v6 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 8080); let incoming_v6 = AddrIncoming::bind(&localhost_v6).unwrap(); let combined = CombinedIncoming { a: incoming_v4, b: incoming_v6, }; axum::Server::builder(combined) .serve(app.into_make_service()) .await .unwrap(); } struct CombinedIncoming { a: AddrIncoming, b: AddrIncoming, } impl Accept for CombinedIncoming { type Conn = ::Conn; type Error = ::Error; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { if let Poll::Ready(Some(value)) = Pin::new(&mut self.a).poll_accept(cx) { return Poll::Ready(Some(value)); } if let Poll::Ready(Some(value)) = Pin::new(&mut self.b).poll_accept(cx) { return Poll::Ready(Some(value)); } Poll::Pending } } axum-axum-v0.6.20/examples/low-level-openssl/000077500000000000000000000000001446276533600211345ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-openssl/Cargo.toml000066400000000000000000000007571446276533600230750ustar00rootroot00000000000000[package] name = "example-low-level-openssl" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = { version = "0.14", features = ["full"] } openssl = "0.10" tokio = { version = "1", features = ["full"] } tokio-openssl = "0.6" tower = { version = "0.4", features = ["make"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/low-level-openssl/self_signed_certs/000077500000000000000000000000001446276533600246165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-openssl/self_signed_certs/cert.pem000066400000000000000000000024261446276533600262620ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg +MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr AopvZ09uEQ== -----END CERTIFICATE----- axum-axum-v0.6.20/examples/low-level-openssl/self_signed_certs/key.pem000066400000000000000000000032501446276533600261110ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB 12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI +2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY 5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV 5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg 3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW w37pCS7ykL+7gp7V0WShYsw= -----END PRIVATE KEY----- axum-axum-v0.6.20/examples/low-level-openssl/src/000077500000000000000000000000001446276533600217235ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-openssl/src/main.rs000066400000000000000000000047541446276533600232270ustar00rootroot00000000000000use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod}; use tokio_openssl::SslStream; use axum::{extract::ConnectInfo, routing::get, Router}; use futures_util::future::poll_fn; use hyper::server::{ accept::Accept, conn::{AddrIncoming, Http}, }; use std::{net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc}; use tokio::net::TcpListener; use tower::MakeService; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_low_level_openssl=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap(); tls_builder .set_certificate_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), SslFiletype::PEM, ) .unwrap(); tls_builder .set_private_key_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), SslFiletype::PEM, ) .unwrap(); tls_builder.check_private_key().unwrap(); let acceptor = tls_builder.build(); let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); let mut listener = AddrIncoming::from_listener(listener).unwrap(); let protocol = Arc::new(Http::new()); let mut app = Router::new() .route("/", get(handler)) .into_make_service_with_connect_info::(); tracing::info!("listening on https://localhost:3000"); loop { let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)) .await .unwrap() .unwrap(); let acceptor = acceptor.clone(); let protocol = protocol.clone(); let svc = app.make_service(&stream); tokio::spawn(async move { let ssl = Ssl::new(acceptor.context()).unwrap(); let mut tls_stream = SslStream::new(ssl, stream).unwrap(); SslStream::accept(Pin::new(&mut tls_stream)).await.unwrap(); let _ = protocol .serve_connection(tls_stream, svc.await.unwrap()) .await; }); } } async fn handler(ConnectInfo(addr): ConnectInfo) -> String { addr.to_string() } axum-axum-v0.6.20/examples/low-level-rustls/000077500000000000000000000000001446276533600210055ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-rustls/Cargo.toml000066400000000000000000000007641446276533600227440ustar00rootroot00000000000000[package] name = "example-low-level-rustls" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = { version = "0.14", features = ["full"] } rustls-pemfile = "0.3" tokio = { version = "1", features = ["full"] } tokio-rustls = "0.23" tower = { version = "0.4", features = ["make"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/low-level-rustls/self_signed_certs/000077500000000000000000000000001446276533600244675ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-rustls/self_signed_certs/cert.pem000066400000000000000000000024261446276533600261330ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg +MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr AopvZ09uEQ== -----END CERTIFICATE----- axum-axum-v0.6.20/examples/low-level-rustls/self_signed_certs/key.pem000066400000000000000000000032501446276533600257620ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB 12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI +2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY 5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV 5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg 3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW w37pCS7ykL+7gp7V0WShYsw= -----END PRIVATE KEY----- axum-axum-v0.6.20/examples/low-level-rustls/src/000077500000000000000000000000001446276533600215745ustar00rootroot00000000000000axum-axum-v0.6.20/examples/low-level-rustls/src/main.rs000066400000000000000000000055471446276533600231010ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-low-level-rustls //! ``` use axum::{extract::ConnectInfo, routing::get, Router}; use futures_util::future::poll_fn; use hyper::server::{ accept::Accept, conn::{AddrIncoming, Http}, }; use rustls_pemfile::{certs, pkcs8_private_keys}; use std::{ fs::File, io::BufReader, net::SocketAddr, path::{Path, PathBuf}, pin::Pin, sync::Arc, }; use tokio::net::TcpListener; use tokio_rustls::{ rustls::{Certificate, PrivateKey, ServerConfig}, TlsAcceptor, }; use tower::MakeService; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_tls_rustls=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let rustls_config = rustls_server_config( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), ); let acceptor = TlsAcceptor::from(rustls_config); let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); let mut listener = AddrIncoming::from_listener(listener).unwrap(); let protocol = Arc::new(Http::new()); let mut app = Router::new() .route("/", get(handler)) .into_make_service_with_connect_info::(); loop { let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx)) .await .unwrap() .unwrap(); let acceptor = acceptor.clone(); let protocol = protocol.clone(); let svc = app.make_service(&stream); tokio::spawn(async move { if let Ok(stream) = acceptor.accept(stream).await { let _ = protocol.serve_connection(stream, svc.await.unwrap()).await; } }); } } async fn handler(ConnectInfo(addr): ConnectInfo) -> String { addr.to_string() } fn rustls_server_config(key: impl AsRef, cert: impl AsRef) -> Arc { let mut key_reader = BufReader::new(File::open(key).unwrap()); let mut cert_reader = BufReader::new(File::open(cert).unwrap()); let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0)); let certs = certs(&mut cert_reader) .unwrap() .into_iter() .map(Certificate) .collect(); let mut config = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, key) .expect("bad certificate/key"); config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(config) } axum-axum-v0.6.20/examples/multipart-form/000077500000000000000000000000001446276533600205275ustar00rootroot00000000000000axum-axum-v0.6.20/examples/multipart-form/Cargo.toml000066400000000000000000000005551446276533600224640ustar00rootroot00000000000000[package] name = "example-multipart-form" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["multipart"] } tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["limit", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/multipart-form/src/000077500000000000000000000000001446276533600213165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/multipart-form/src/main.rs000066400000000000000000000043651446276533600226200ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-multipart-form //! ``` use axum::{ extract::{DefaultBodyLimit, Multipart}, response::Html, routing::get, Router, }; use std::net::SocketAddr; use tower_http::limit::RequestBodyLimitLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_multipart_form=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = Router::new() .route("/", get(show_form).post(accept_form)) .layer(DefaultBodyLimit::disable()) .layer(RequestBodyLimitLayer::new( 250 * 1024 * 1024, /* 250mb */ )) .layer(tower_http::trace::TraceLayer::new_for_http()); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn show_form() -> Html<&'static str> { Html( r#"
"#, ) } async fn accept_form(mut multipart: Multipart) { while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap().to_string(); let file_name = field.file_name().unwrap().to_string(); let content_type = field.content_type().unwrap().to_string(); let data = field.bytes().await.unwrap(); println!( "Length of `{}` (`{}`: `{}`) is {} bytes", name, file_name, content_type, data.len() ); } } axum-axum-v0.6.20/examples/oauth/000077500000000000000000000000001446276533600166655ustar00rootroot00000000000000axum-axum-v0.6.20/examples/oauth/Cargo.toml000066400000000000000000000010621446276533600206140ustar00rootroot00000000000000[package] name = "example-oauth" version = "0.1.0" edition = "2021" publish = false [dependencies] async-session = "3.0.0" axum = { path = "../../axum", features = ["headers"] } headers = "0.3" http = "0.2" oauth2 = "4.1" # Use Rustls because it makes it easier to cross-compile on CI reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/oauth/src/000077500000000000000000000000001446276533600174545ustar00rootroot00000000000000axum-axum-v0.6.20/examples/oauth/src/main.rs000066400000000000000000000177341446276533600207620ustar00rootroot00000000000000//! Example OAuth (Discord) implementation. //! //! 1) Create a new application at //! 2) Visit the OAuth2 tab to get your CLIENT_ID and CLIENT_SECRET //! 3) Add a new redirect URI (for this example: `http://127.0.0.1:3000/auth/authorized`) //! 4) Run with the following (replacing values appropriately): //! ```not_rust //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! ``` use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ rejection::TypedHeaderRejectionReason, FromRef, FromRequestParts, Query, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, RequestPartsExt, Router, }; use http::{header, request::Parts}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use std::{env, net::SocketAddr}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_oauth=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); let oauth_client = oauth_client(); let app_state = AppState { store, oauth_client, }; let app = Router::new() .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) .route("/logout", get(logout)) .with_state(app_state); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Clone)] struct AppState { store: MemoryStore, oauth_client: BasicClient, } impl FromRef for MemoryStore { fn from_ref(state: &AppState) -> Self { state.store.clone() } } impl FromRef for BasicClient { fn from_ref(state: &AppState) -> Self { state.oauth_client.clone() } } fn oauth_client() -> BasicClient { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; // *"CLIENT_SECRET" "REPLACE_ME"; // "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized"; // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; // "TOKEN_URL" "https://discord.com/api/oauth2/token"; let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!"); let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!"); let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); let auth_url = env::var("AUTH_URL").unwrap_or_else(|_| { "https://discord.com/api/oauth2/authorize?response_type=code".to_string() }); let token_url = env::var("TOKEN_URL") .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); BasicClient::new( ClientId::new(client_id), Some(ClientSecret::new(client_secret)), AuthUrl::new(auth_url).unwrap(), Some(TokenUrl::new(token_url).unwrap()), ) .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) } // The user data we'll get back from Discord. // https://discord.com/developers/docs/resources/user#user-object-user-structure #[derive(Debug, Serialize, Deserialize)] struct User { id: String, avatar: Option, username: String, discriminator: String, } // Session is optional async fn index(user: Option) -> impl IntoResponse { match user { Some(u) => format!( "Hey {}! You're logged in!\nYou may now access `/protected`.\nLog out with `/logout`.", u.username ), None => "You're not logged in.\nVisit `/auth/discord` to do so.".to_string(), } } async fn discord_auth(State(client): State) -> impl IntoResponse { let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); // Redirect to Discord's oauth service Redirect::to(auth_url.as_ref()) } // Valid user session required. If there is none, redirect to the auth page async fn protected(user: User) -> impl IntoResponse { format!( "Welcome to the protected area :)\nHere's your info:\n{:?}", user ) } async fn logout( State(store): State, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); let session = match store.load_session(cookie.to_string()).await.unwrap() { Some(s) => s, // No session active, just redirect None => return Redirect::to("/"), }; store.destroy_session(session).await.unwrap(); Redirect::to("/") } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct AuthRequest { code: String, state: String, } async fn login_authorized( Query(query): Query, State(store): State, State(oauth_client): State, ) -> impl IntoResponse { // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await .unwrap(); // Fetch user data from discord let client = reqwest::Client::new(); let user_data: User = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .unwrap() .json::() .await .unwrap(); // Create a new session filled with user data let mut session = Session::new(); session.insert("user", &user_data).unwrap(); // Store session and get corresponding cookie let cookie = store.store_session(session).await.unwrap().unwrap(); // Build the cookie let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); // Set cookie let mut headers = HeaderMap::new(); headers.insert(SET_COOKIE, cookie.parse().unwrap()); (headers, Redirect::to("/")) } struct AuthRedirect; impl IntoResponse for AuthRedirect { fn into_response(self) -> Response { Redirect::temporary("/auth/discord").into_response() } } #[async_trait] impl FromRequestParts for User where MemoryStore: FromRef, S: Send + Sync, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let store = MemoryStore::from_ref(state); let cookies = parts .extract::>() .await .map_err(|e| match *e.name() { header::COOKIE => match e.reason() { TypedHeaderRejectionReason::Missing => AuthRedirect, _ => panic!("unexpected error getting Cookie header(s): {}", e), }, _ => panic!("unexpected error getting cookies: {}", e), })?; let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?; let session = store .load_session(session_cookie.to_string()) .await .unwrap() .ok_or(AuthRedirect)?; let user = session.get::("user").ok_or(AuthRedirect)?; Ok(user) } } axum-axum-v0.6.20/examples/parse-body-based-on-content-type/000077500000000000000000000000001446276533600237275ustar00rootroot00000000000000axum-axum-v0.6.20/examples/parse-body-based-on-content-type/Cargo.toml000066400000000000000000000005261446276533600256620ustar00rootroot00000000000000[package] name = "example-parse-body-based-on-content-type" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/parse-body-based-on-content-type/src/000077500000000000000000000000001446276533600245165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/parse-body-based-on-content-type/src/main.rs000066400000000000000000000044701446276533600260150ustar00rootroot00000000000000//! Provides a RESTful web server managing some Todos. //! //! Run with //! //! ```not_rust //! cargo run -p example-parse-body-based-on-content-type //! ``` use axum::{ async_trait, extract::FromRequest, http::{header::CONTENT_TYPE, Request, StatusCode}, response::{IntoResponse, Response}, routing::post, Form, Json, RequestExt, Router, }; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { "example_parse_body_based_on_content_type=debug,tower_http=debug".into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new().route("/", post(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Debug, Serialize, Deserialize)] struct Payload { foo: String, } async fn handler(JsonOrForm(payload): JsonOrForm) { dbg!(payload); } struct JsonOrForm(T); #[async_trait] impl FromRequest for JsonOrForm where B: Send + 'static, S: Send + Sync, Json: FromRequest<(), B>, Form: FromRequest<(), B>, T: 'static, { type Rejection = Response; async fn from_request(req: Request, _state: &S) -> Result { let content_type_header = req.headers().get(CONTENT_TYPE); let content_type = content_type_header.and_then(|value| value.to_str().ok()); if let Some(content_type) = content_type { if content_type.starts_with("application/json") { let Json(payload) = req.extract().await.map_err(IntoResponse::into_response)?; return Ok(Self(payload)); } if content_type.starts_with("application/x-www-form-urlencoded") { let Form(payload) = req.extract().await.map_err(IntoResponse::into_response)?; return Ok(Self(payload)); } } Err(StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response()) } } axum-axum-v0.6.20/examples/print-request-response/000077500000000000000000000000001446276533600222235ustar00rootroot00000000000000axum-axum-v0.6.20/examples/print-request-response/Cargo.toml000066400000000000000000000006061446276533600241550ustar00rootroot00000000000000[package] name = "example-print-request-response" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = { version = "0.14", features = ["full"] } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util", "filter"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/print-request-response/src/000077500000000000000000000000001446276533600230125ustar00rootroot00000000000000axum-axum-v0.6.20/examples/print-request-response/src/main.rs000066400000000000000000000041601446276533600243050ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-print-request-response //! ``` use axum::{ body::{Body, Bytes}, http::{Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, Router, }; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_print_request_response=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let app = Router::new() .route("/", post(|| async move { "Hello from `POST /`" })) .layer(middleware::from_fn(print_request_response)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn print_request_response( req: Request, next: Next, ) -> Result { let (parts, body) = req.into_parts(); let bytes = buffer_and_print("request", body).await?; let req = Request::from_parts(parts, Body::from(bytes)); let res = next.run(req).await; let (parts, body) = res.into_parts(); let bytes = buffer_and_print("response", body).await?; let res = Response::from_parts(parts, Body::from(bytes)); Ok(res) } async fn buffer_and_print(direction: &str, body: B) -> Result where B: axum::body::HttpBody, B::Error: std::fmt::Display, { let bytes = match hyper::body::to_bytes(body).await { Ok(bytes) => bytes, Err(err) => { return Err(( StatusCode::BAD_REQUEST, format!("failed to read {} body: {}", direction, err), )); } }; if let Ok(body) = std::str::from_utf8(&bytes) { tracing::debug!("{} body = {:?}", direction, body); } Ok(bytes) } axum-axum-v0.6.20/examples/prometheus-metrics/000077500000000000000000000000001446276533600214045ustar00rootroot00000000000000axum-axum-v0.6.20/examples/prometheus-metrics/Cargo.toml000066400000000000000000000005121446276533600233320ustar00rootroot00000000000000[package] name = "example-prometheus-metrics" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } metrics = "0.18" metrics-exporter-prometheus = "0.8" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/prometheus-metrics/src/000077500000000000000000000000001446276533600221735ustar00rootroot00000000000000axum-axum-v0.6.20/examples/prometheus-metrics/src/main.rs000066400000000000000000000066411446276533600234740ustar00rootroot00000000000000//! Someday tower-http will hopefully have a metrics middleware, until then you can track //! metrics like this. //! //! Run with //! //! ```not_rust //! cargo run -p example-prometheus-metrics //! ``` use axum::{ extract::MatchedPath, http::Request, middleware::{self, Next}, response::IntoResponse, routing::get, Router, }; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::{ future::ready, net::SocketAddr, time::{Duration, Instant}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; fn metrics_app() -> Router { let recorder_handle = setup_metrics_recorder(); Router::new().route("/metrics", get(move || ready(recorder_handle.render()))) } fn main_app() -> Router { Router::new() .route("/fast", get(|| async {})) .route( "/slow", get(|| async { tokio::time::sleep(Duration::from_secs(1)).await; }), ) .route_layer(middleware::from_fn(track_metrics)) } async fn start_main_server() { let app = main_app(); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap() } async fn start_metrics_server() { let app = metrics_app(); // NOTE: expose metrics enpoint on a different port let addr = SocketAddr::from(([127, 0, 0, 1], 3001)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap() } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_todos=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // The `/metrics` endpoint should not be publicly available. If behind a reverse proxy, this // can be achieved by rejecting requests to `/metrics`. In this example, a second server is // started on another port to expose `/metrics`. let (_main_server, _metrics_server) = tokio::join!(start_main_server(), start_metrics_server()); } fn setup_metrics_recorder() -> PrometheusHandle { const EXPONENTIAL_SECONDS: &[f64] = &[ 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, ]; PrometheusBuilder::new() .set_buckets_for_metric( Matcher::Full("http_requests_duration_seconds".to_string()), EXPONENTIAL_SECONDS, ) .unwrap() .install_recorder() .unwrap() } async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { let start = Instant::now(); let path = if let Some(matched_path) = req.extensions().get::() { matched_path.as_str().to_owned() } else { req.uri().path().to_owned() }; let method = req.method().clone(); let response = next.run(req).await; let latency = start.elapsed().as_secs_f64(); let status = response.status().as_u16().to_string(); let labels = [ ("method", method.to_string()), ("path", path), ("status", status), ]; metrics::increment_counter!("http_requests_total", &labels); metrics::histogram!("http_requests_duration_seconds", latency, &labels); response } axum-axum-v0.6.20/examples/query-params-with-empty-strings/000077500000000000000000000000001446276533600237675ustar00rootroot00000000000000axum-axum-v0.6.20/examples/query-params-with-empty-strings/Cargo.toml000066400000000000000000000005011446276533600257130ustar00rootroot00000000000000[package] name = "example-query-params-with-empty-strings" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = "0.14" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util"] } axum-axum-v0.6.20/examples/query-params-with-empty-strings/src/000077500000000000000000000000001446276533600245565ustar00rootroot00000000000000axum-axum-v0.6.20/examples/query-params-with-empty-strings/src/main.rs000066400000000000000000000062541446276533600260570ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-query-params-with-empty-strings //! ``` use axum::{extract::Query, routing::get, Router}; use serde::{de, Deserialize, Deserializer}; use std::{fmt, str::FromStr}; #[tokio::main] async fn main() { axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) .serve(app().into_make_service()) .await .unwrap(); } fn app() -> Router { Router::new().route("/", get(handler)) } async fn handler(Query(params): Query) -> String { format!("{:?}", params) } /// See the tests below for which combinations of `foo` and `bar` result in /// which deserializations. /// /// This example only shows one possible way to do this. [`serde_with`] provides /// another way. Use which ever method works best for you. /// /// [`serde_with`]: https://docs.rs/serde_with/1.11.0/serde_with/rust/string_empty_as_none/index.html #[derive(Debug, Deserialize)] #[allow(dead_code)] struct Params { #[serde(default, deserialize_with = "empty_string_as_none")] foo: Option, bar: Option, } /// Serde deserialization decorator to map empty Strings to None, fn empty_string_as_none<'de, D, T>(de: D) -> Result, D::Error> where D: Deserializer<'de>, T: FromStr, T::Err: fmt::Display, { let opt = Option::::deserialize(de)?; match opt.as_deref() { None | Some("") => Ok(None), Some(s) => FromStr::from_str(s).map_err(de::Error::custom).map(Some), } } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request}; use tower::ServiceExt; #[tokio::test] async fn test_something() { assert_eq!( send_request_get_body("foo=1&bar=bar").await, r#"Params { foo: Some(1), bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=&bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("foo=1").await, r#"Params { foo: Some(1), bar: None }"#, ); assert_eq!( send_request_get_body("bar=bar").await, r#"Params { foo: None, bar: Some("bar") }"#, ); assert_eq!( send_request_get_body("foo=").await, r#"Params { foo: None, bar: None }"#, ); assert_eq!( send_request_get_body("bar=").await, r#"Params { foo: None, bar: Some("") }"#, ); assert_eq!( send_request_get_body("").await, r#"Params { foo: None, bar: None }"#, ); } async fn send_request_get_body(query: &str) -> String { let body = app() .oneshot( Request::builder() .uri(format!("/?{}", query)) .body(Body::empty()) .unwrap(), ) .await .unwrap() .into_body(); let bytes = hyper::body::to_bytes(body).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } } axum-axum-v0.6.20/examples/readme/000077500000000000000000000000001446276533600170025ustar00rootroot00000000000000axum-axum-v0.6.20/examples/readme/Cargo.toml000066400000000000000000000005221446276533600207310ustar00rootroot00000000000000[package] name = "example-readme" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.68" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/readme/src/000077500000000000000000000000001446276533600175715ustar00rootroot00000000000000axum-axum-v0.6.20/examples/readme/src/main.rs000066400000000000000000000031551446276533600210670ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-readme //! ``` use axum::{ http::StatusCode, response::IntoResponse, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; #[tokio::main] async fn main() { // initialize tracing tracing_subscriber::fmt::init(); // build our application with a route let app = Router::new() // `GET /` goes to `root` .route("/", get(root)) // `POST /users` goes to `create_user` .route("/users", post(create_user)); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // basic handler that responds with a static string async fn root() -> &'static str { "Hello, World!" } async fn create_user( // this argument tells axum to parse the request body // as JSON into a `CreateUser` type Json(payload): Json, ) -> impl IntoResponse { // insert your application logic here let user = User { id: 1337, username: payload.username, }; // this will be converted into a JSON response // with a status code of `201 Created` (StatusCode::CREATED, Json(user)) } // the input to our `create_user` handler #[derive(Deserialize)] struct CreateUser { username: String, } // the output to our `create_user` handler #[derive(Serialize)] struct User { id: u64, username: String, } axum-axum-v0.6.20/examples/rest-grpc-multiplex/000077500000000000000000000000001446276533600214745ustar00rootroot00000000000000axum-axum-v0.6.20/examples/rest-grpc-multiplex/Cargo.toml000066400000000000000000000010311446276533600234170ustar00rootroot00000000000000[package] name = "example-rest-grpc-multiplex" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } futures = "0.3" hyper = { version = "0.14", features = ["full"] } prost = "0.11" tokio = { version = "1", features = ["full"] } tonic = { version = "0.9" } tonic-reflection = "0.9" tower = { version = "0.4", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [build-dependencies] tonic-build = { version = "0.9", features = ["prost"] } axum-axum-v0.6.20/examples/rest-grpc-multiplex/build.rs000066400000000000000000000004461446276533600231450ustar00rootroot00000000000000use std::{env, path::PathBuf}; fn main() { let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); tonic_build::configure() .file_descriptor_set_path(out_dir.join("helloworld_descriptor.bin")) .compile(&["proto/helloworld.proto"], &["/proto"]) .unwrap(); } axum-axum-v0.6.20/examples/rest-grpc-multiplex/proto/000077500000000000000000000000001446276533600226375ustar00rootroot00000000000000axum-axum-v0.6.20/examples/rest-grpc-multiplex/proto/helloworld.proto000066400000000000000000000021051446276533600260750ustar00rootroot00000000000000// Copyright 2015 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option java_multiple_files = true; option java_package = "io.grpc.examples.helloworld"; option java_outer_classname = "HelloWorldProto"; package helloworld; // The greeting service definition. service Greeter { // Sends a greeting rpc SayHello (HelloRequest) returns (HelloReply) {} } // The request message containing the user's name. message HelloRequest { string name = 1; } // The response message containing the greetings message HelloReply { string message = 1; } axum-axum-v0.6.20/examples/rest-grpc-multiplex/src/000077500000000000000000000000001446276533600222635ustar00rootroot00000000000000axum-axum-v0.6.20/examples/rest-grpc-multiplex/src/main.rs000066400000000000000000000043511446276533600235600ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-rest-grpc-multiplex //! ``` use self::multiplex_service::MultiplexService; use axum::{routing::get, Router}; use proto::{ greeter_server::{Greeter, GreeterServer}, HelloReply, HelloRequest, }; use std::net::SocketAddr; use tonic::{Response as TonicResponse, Status}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod multiplex_service; mod proto { tonic::include_proto!("helloworld"); pub(crate) const FILE_DESCRIPTOR_SET: &[u8] = tonic::include_file_descriptor_set!("helloworld_descriptor"); } #[derive(Default)] struct GrpcServiceImpl {} #[tonic::async_trait] impl Greeter for GrpcServiceImpl { async fn say_hello( &self, request: tonic::Request, ) -> Result, Status> { tracing::info!("Got a request from {:?}", request.remote_addr()); let reply = HelloReply { message: format!("Hello {}!", request.into_inner().name), }; Ok(TonicResponse::new(reply)) } } async fn web_root() -> &'static str { "Hello, World!" } #[tokio::main] async fn main() { // initialize tracing tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_rest_grpc_multiplex=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build the rest service let rest = Router::new().route("/", get(web_root)); // build the grpc service let reflection_service = tonic_reflection::server::Builder::configure() .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) .build() .unwrap(); let grpc = tonic::transport::Server::builder() .add_service(reflection_service) .add_service(GreeterServer::new(GrpcServiceImpl::default())) .into_service(); // combine them into one service let service = MultiplexService::new(rest, grpc); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(tower::make::Shared::new(service)) .await .unwrap(); } axum-axum-v0.6.20/examples/rest-grpc-multiplex/src/multiplex_service.rs000066400000000000000000000065321446276533600264020ustar00rootroot00000000000000use axum::{body::BoxBody, http::header::CONTENT_TYPE, response::IntoResponse}; use futures::{future::BoxFuture, ready}; use hyper::{Body, Request, Response}; use std::{ convert::Infallible, task::{Context, Poll}, }; use tower::Service; pub struct MultiplexService { rest: A, rest_ready: bool, grpc: B, grpc_ready: bool, } impl MultiplexService { pub fn new(rest: A, grpc: B) -> Self { Self { rest, rest_ready: false, grpc, grpc_ready: false, } } } impl Clone for MultiplexService where A: Clone, B: Clone, { fn clone(&self) -> Self { Self { rest: self.rest.clone(), grpc: self.grpc.clone(), // the cloned services probably wont be ready rest_ready: false, grpc_ready: false, } } } impl Service> for MultiplexService where A: Service, Error = Infallible>, A::Response: IntoResponse, A::Future: Send + 'static, B: Service>, B::Response: IntoResponse, B::Future: Send + 'static, { type Response = Response; type Error = B::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // drive readiness for each inner service and record which is ready loop { match (self.rest_ready, self.grpc_ready) { (true, true) => { return Ok(()).into(); } (false, _) => { ready!(self.rest.poll_ready(cx)).map_err(|err| match err {})?; self.rest_ready = true; } (_, false) => { ready!(self.grpc.poll_ready(cx))?; self.grpc_ready = true; } } } } fn call(&mut self, req: Request) -> Self::Future { // require users to call `poll_ready` first, if they don't we're allowed to panic // as per the `tower::Service` contract assert!( self.grpc_ready, "grpc service not ready. Did you forget to call `poll_ready`?" ); assert!( self.rest_ready, "rest service not ready. Did you forget to call `poll_ready`?" ); // if we get a grpc request call the grpc service, otherwise call the rest service // when calling a service it becomes not-ready so we have drive readiness again if is_grpc_request(&req) { self.grpc_ready = false; let future = self.grpc.call(req); Box::pin(async move { let res = future.await?; Ok(res.into_response()) }) } else { self.rest_ready = false; let future = self.rest.call(req); Box::pin(async move { let res = future.await.map_err(|err| match err {})?; Ok(res.into_response()) }) } } } fn is_grpc_request(req: &Request) -> bool { req.headers() .get(CONTENT_TYPE) .map(|content_type| content_type.as_bytes()) .filter(|content_type| content_type.starts_with(b"application/grpc")) .is_some() } axum-axum-v0.6.20/examples/reverse-proxy/000077500000000000000000000000001446276533600203775ustar00rootroot00000000000000axum-axum-v0.6.20/examples/reverse-proxy/Cargo.toml000066400000000000000000000003341446276533600223270ustar00rootroot00000000000000[package] name = "example-reverse-proxy" version = "0.1.0" edition = "2021" [dependencies] axum = { path = "../../axum" } hyper = { version = "0.14", features = ["full"] } tokio = { version = "1", features = ["full"] } axum-axum-v0.6.20/examples/reverse-proxy/src/000077500000000000000000000000001446276533600211665ustar00rootroot00000000000000axum-axum-v0.6.20/examples/reverse-proxy/src/main.rs000066400000000000000000000027561446276533600224720ustar00rootroot00000000000000//! Reverse proxy listening in "localhost:4000" will proxy all requests to "localhost:3000" //! endpoint. //! //! Run with //! //! ```not_rust //! cargo run -p example-reverse-proxy //! ``` use axum::{ extract::State, http::{uri::Uri, Request, Response}, routing::get, Router, }; use hyper::{client::HttpConnector, Body}; use std::net::SocketAddr; type Client = hyper::client::Client; #[tokio::main] async fn main() { tokio::spawn(server()); let client = Client::new(); let app = Router::new().route("/", get(handler)).with_state(client); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler(State(client): State, mut req: Request) -> Response { let path = req.uri().path(); let path_query = req .uri() .path_and_query() .map(|v| v.as_str()) .unwrap_or(path); let uri = format!("http://127.0.0.1:3000{}", path_query); *req.uri_mut() = Uri::try_from(uri).unwrap(); client.request(req).await.unwrap() } async fn server() { let app = Router::new().route("/", get(|| async { "Hello, world!" })); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("server listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } axum-axum-v0.6.20/examples/routes-and-handlers-close-together/000077500000000000000000000000001446276533600243465ustar00rootroot00000000000000axum-axum-v0.6.20/examples/routes-and-handlers-close-together/Cargo.toml000066400000000000000000000003211446276533600262720ustar00rootroot00000000000000[package] name = "example-routes-and-handlers-close-together" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } axum-axum-v0.6.20/examples/routes-and-handlers-close-together/src/000077500000000000000000000000001446276533600251355ustar00rootroot00000000000000axum-axum-v0.6.20/examples/routes-and-handlers-close-together/src/main.rs000066400000000000000000000020361446276533600264300ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-routes-and-handlers-close-together //! ``` use axum::{ routing::{get, post, MethodRouter}, Router, }; use std::net::SocketAddr; #[tokio::main] async fn main() { let app = Router::new() .merge(root()) .merge(get_foo()) .merge(post_foo()); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } fn root() -> Router { async fn handler() -> &'static str { "Hello, World!" } route("/", get(handler)) } fn get_foo() -> Router { async fn handler() -> &'static str { "Hi from `GET /foo`" } route("/foo", get(handler)) } fn post_foo() -> Router { async fn handler() -> &'static str { "Hi from `POST /foo`" } route("/foo", post(handler)) } fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) } axum-axum-v0.6.20/examples/simple-router-wasm/000077500000000000000000000000001446276533600213215ustar00rootroot00000000000000axum-axum-v0.6.20/examples/simple-router-wasm/Cargo.toml000066400000000000000000000006351446276533600232550ustar00rootroot00000000000000[package] name = "example-simple-router-wasm" version = "0.1.0" edition = "2018" publish = false [dependencies] # `default-features = false` to not depend on tokio features which don't support wasm # you can still pull in tokio manually and only add features that tokio supports for wasm axum = { path = "../../axum", default-features = false } futures-executor = "0.3.21" http = "0.2.7" tower-service = "0.3.1" axum-axum-v0.6.20/examples/simple-router-wasm/src/000077500000000000000000000000001446276533600221105ustar00rootroot00000000000000axum-axum-v0.6.20/examples/simple-router-wasm/src/main.rs000066400000000000000000000032201446276533600233770ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-simple-router-wasm //! ``` //! //! This example shows what using axum in a wasm context might look like. This example should //! always compile with `--target wasm32-unknown-unknown`. //! //! [`mio`](https://docs.rs/mio/latest/mio/index.html), tokio's IO layer, does not support the //! `wasm32-unknown-unknown` target which is why this crate requires `default-features = false` //! for axum. //! //! Most serverless runtimes expect an exported function that takes in a single request and returns //! a single response, much like axum's `Handler` trait. In this example, the handler function is //! `app` with `main` acting as the serverless runtime which originally receives the request and //! calls the app function. //! //! We can use axum's routing, extractors, tower services, and everything else to implement //! our serverless function, even though we are running axum in a wasm context. use axum::{ response::{Html, Response}, routing::get, Router, }; use futures_executor::block_on; use http::Request; use tower_service::Service; fn main() { let request: Request = Request::builder() .uri("https://serverless.example/api/") .body("Some Body Data".into()) .unwrap(); let response: Response = block_on(app(request)); assert_eq!(200, response.status()); } #[allow(clippy::let_and_return)] async fn app(request: Request) -> Response { let mut router = Router::new().route("/api/", get(index)); let response = router.call(request).await.unwrap(); response } async fn index() -> Html<&'static str> { Html("

Hello, World!

") } axum-axum-v0.6.20/examples/sqlx-postgres/000077500000000000000000000000001446276533600204005ustar00rootroot00000000000000axum-axum-v0.6.20/examples/sqlx-postgres/Cargo.toml000066400000000000000000000005431446276533600223320ustar00rootroot00000000000000[package] name = "example-sqlx-postgres" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "any", "postgres"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/sqlx-postgres/src/000077500000000000000000000000001446276533600211675ustar00rootroot00000000000000axum-axum-v0.6.20/examples/sqlx-postgres/src/main.rs000066400000000000000000000061151446276533600224640ustar00rootroot00000000000000//! Example of application using //! //! Run with //! //! ```not_rust //! cargo run -p example-sqlx-postgres //! ``` //! //! Test with curl: //! //! ```not_rust //! curl 127.0.0.1:3000 //! curl -X POST 127.0.0.1:3000 //! ``` use axum::{ async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, Router, }; use sqlx::postgres::{PgPool, PgPoolOptions}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use std::{net::SocketAddr, time::Duration}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_connection_str = std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://postgres:password@localhost".to_string()); // setup connection pool let pool = PgPoolOptions::new() .max_connections(5) .acquire_timeout(Duration::from_secs(3)) .connect(&db_connection_str) .await .expect("can't connect to database"); // build our application with some routes let app = Router::new() .route( "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // we can extract the connection pool with `State` async fn using_connection_pool_extractor( State(pool): State, ) -> Result { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&pool) .await .map_err(internal_error) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] impl FromRequestParts for DatabaseConnection where PgPool: FromRef, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { let pool = PgPool::from_ref(state); let conn = pool.acquire().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn using_connection_extractor( DatabaseConnection(mut conn): DatabaseConnection, ) -> Result { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&mut *conn) .await .map_err(internal_error) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } axum-axum-v0.6.20/examples/sse/000077500000000000000000000000001446276533600163375ustar00rootroot00000000000000axum-axum-v0.6.20/examples/sse/Cargo.toml000066400000000000000000000006221446276533600202670ustar00rootroot00000000000000[package] name = "example-sse" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["headers"] } futures = "0.3" headers = "0.3" tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" tower-http = { version = "0.4.0", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/sse/assets/000077500000000000000000000000001446276533600176415ustar00rootroot00000000000000axum-axum-v0.6.20/examples/sse/assets/index.html000066400000000000000000000000421446276533600216320ustar00rootroot00000000000000 axum-axum-v0.6.20/examples/sse/assets/script.js000066400000000000000000000002141446276533600215000ustar00rootroot00000000000000var eventSource = new EventSource('sse'); eventSource.onmessage = function(event) { console.log('Message from server ', event.data); } axum-axum-v0.6.20/examples/sse/src/000077500000000000000000000000001446276533600171265ustar00rootroot00000000000000axum-axum-v0.6.20/examples/sse/src/main.rs000066400000000000000000000036011446276533600204200ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-sse //! ``` use axum::{ extract::TypedHeader, response::sse::{Event, Sse}, routing::get, Router, }; use futures::stream::{self, Stream}; use std::{convert::Infallible, net::SocketAddr, path::PathBuf, time::Duration}; use tokio_stream::StreamExt as _; use tower_http::{services::ServeDir, trace::TraceLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_sse=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true); // build our application with a route let app = Router::new() .fallback_service(static_files_service) .route("/sse", get(sse_handler)) .layer(TraceLayer::new_for_http()); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn sse_handler( TypedHeader(user_agent): TypedHeader, ) -> Sse>> { println!("`{}` connected", user_agent.as_str()); // A `Stream` that repeats an event every second let stream = stream::repeat_with(|| Event::default().data("hi!")) .map(Ok) .throttle(Duration::from_secs(1)); Sse::new(stream).keep_alive( axum::response::sse::KeepAlive::new() .interval(Duration::from_secs(1)) .text("keep-alive-text"), ) } axum-axum-v0.6.20/examples/static-file-server/000077500000000000000000000000001446276533600212555ustar00rootroot00000000000000axum-axum-v0.6.20/examples/static-file-server/Cargo.toml000066400000000000000000000006601446276533600232070ustar00rootroot00000000000000[package] name = "example-static-file-server" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } axum-extra = { path = "../../axum-extra" } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.4.0", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/static-file-server/assets/000077500000000000000000000000001446276533600225575ustar00rootroot00000000000000axum-axum-v0.6.20/examples/static-file-server/assets/index.html000066400000000000000000000000231446276533600245470ustar00rootroot00000000000000Hi from index.html axum-axum-v0.6.20/examples/static-file-server/assets/script.js000066400000000000000000000000361446276533600244200ustar00rootroot00000000000000console.log("Hello, World!"); axum-axum-v0.6.20/examples/static-file-server/src/000077500000000000000000000000001446276533600220445ustar00rootroot00000000000000axum-axum-v0.6.20/examples/static-file-server/src/main.rs000066400000000000000000000070651446276533600233460ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-static-file-server //! ``` use axum::{ body::Body, handler::HandlerWithoutStateExt, http::{Request, StatusCode}, routing::get, Router, }; use std::net::SocketAddr; use tower::ServiceExt; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_static_file_server=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); tokio::join!( serve(using_serve_dir(), 3001), serve(using_serve_dir_with_assets_fallback(), 3002), serve(using_serve_dir_only_from_root_via_fallback(), 3003), serve(using_serve_dir_with_handler_as_service(), 3004), serve(two_serve_dirs(), 3005), serve(calling_serve_dir_from_a_handler(), 3006), ); } fn using_serve_dir() -> Router { // serve the file in the "assets" directory under `/assets` Router::new().nest_service("/assets", ServeDir::new("assets")) } fn using_serve_dir_with_assets_fallback() -> Router { // `ServeDir` allows setting a fallback if an asset is not found // so with this `GET /assets/doesnt-exist.jpg` will return `index.html` // rather than a 404 let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html")); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .nest_service("/assets", serve_dir.clone()) .fallback_service(serve_dir) } fn using_serve_dir_only_from_root_via_fallback() -> Router { // you can also serve the assets directly from the root (not nested under `/assets`) // by only setting a `ServeDir` as the fallback let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html")); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .fallback_service(serve_dir) } fn using_serve_dir_with_handler_as_service() -> Router { async fn handle_404() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not found") } let serve_dir = ServeDir::new("assets").not_found_service(handle_404.into_service()); Router::new() .route("/foo", get(|| async { "Hi from /foo" })) .fallback_service(serve_dir) } fn two_serve_dirs() -> Router { // you can also have two `ServeDir`s nested at different paths let serve_dir_from_assets = ServeDir::new("assets"); let serve_dir_from_dist = ServeDir::new("dist"); Router::new() .nest_service("/assets", serve_dir_from_assets) .nest_service("/dist", serve_dir_from_dist) } #[allow(clippy::let_and_return)] fn calling_serve_dir_from_a_handler() -> Router { // via `tower::Service::call`, or more conveniently `tower::ServiceExt::oneshot` you can // call `ServeDir` yourself from a handler Router::new().nest_service( "/foo", get(|request: Request| async { let service = ServeDir::new("assets"); let result = service.oneshot(request).await; result }), ) } async fn serve(app: Router, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.layer(TraceLayer::new_for_http()).into_make_service()) .await .unwrap(); } axum-axum-v0.6.20/examples/stream-to-file/000077500000000000000000000000001446276533600203755ustar00rootroot00000000000000axum-axum-v0.6.20/examples/stream-to-file/Cargo.toml000066400000000000000000000005571446276533600223340ustar00rootroot00000000000000[package] name = "example-stream-to-file" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["multipart"] } futures = "0.3" tokio = { version = "1.0", features = ["full"] } tokio-util = { version = "0.7", features = ["io"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/stream-to-file/src/000077500000000000000000000000001446276533600211645ustar00rootroot00000000000000axum-axum-v0.6.20/examples/stream-to-file/src/main.rs000066400000000000000000000103661446276533600224640ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-stream-to-file //! ``` use axum::{ body::Bytes, extract::{BodyStream, Multipart, Path}, http::StatusCode, response::{Html, Redirect}, routing::{get, post}, BoxError, Router, }; use futures::{Stream, TryStreamExt}; use std::{io, net::SocketAddr}; use tokio::{fs::File, io::BufWriter}; use tokio_util::io::StreamReader; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const UPLOADS_DIRECTORY: &str = "uploads"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_stream_to_file=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // save files to a separate directory to not override files in the current directory tokio::fs::create_dir(UPLOADS_DIRECTORY) .await .expect("failed to create `uploads` directory"); let app = Router::new() .route("/", get(show_form).post(accept_form)) .route("/file/:file_name", post(save_request_body)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // Handler that streams the request body to a file. // // POST'ing to `/file/foo.txt` will create a file called `foo.txt`. async fn save_request_body( Path(file_name): Path, body: BodyStream, ) -> Result<(), (StatusCode, String)> { stream_to_file(&file_name, body).await } // Handler that returns HTML for a multipart form. async fn show_form() -> Html<&'static str> { Html( r#" Upload something!
"#, ) } // Handler that accepts a multipart form upload and streams each field to a file. async fn accept_form(mut multipart: Multipart) -> Result { while let Some(field) = multipart.next_field().await.unwrap() { let file_name = if let Some(file_name) = field.file_name() { file_name.to_owned() } else { continue; }; stream_to_file(&file_name, field).await?; } Ok(Redirect::to("/")) } // Save a `Stream` to a file async fn stream_to_file(path: &str, stream: S) -> Result<(), (StatusCode, String)> where S: Stream>, E: Into, { if !path_is_valid(path) { return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned())); } async { // Convert the stream into an `AsyncRead`. let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err)); let body_reader = StreamReader::new(body_with_io_error); futures::pin_mut!(body_reader); // Create the file. `File` implements `AsyncWrite`. let path = std::path::Path::new(UPLOADS_DIRECTORY).join(path); let mut file = BufWriter::new(File::create(path).await?); // Copy the body into the file. tokio::io::copy(&mut body_reader, &mut file).await?; Ok::<_, io::Error>(()) } .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())) } // to prevent directory traversal attacks we ensure the path consists of exactly one normal // component fn path_is_valid(path: &str) -> bool { let path = std::path::Path::new(path); let mut components = path.components().peekable(); if let Some(first) = components.peek() { if !matches!(first, std::path::Component::Normal(_)) { return false; } } components.count() == 1 } axum-axum-v0.6.20/examples/templates/000077500000000000000000000000001446276533600175435ustar00rootroot00000000000000axum-axum-v0.6.20/examples/templates/Cargo.toml000066400000000000000000000004341446276533600214740ustar00rootroot00000000000000[package] name = "example-templates" version = "0.1.0" edition = "2021" publish = false [dependencies] askama = "0.11" axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/templates/src/000077500000000000000000000000001446276533600203325ustar00rootroot00000000000000axum-axum-v0.6.20/examples/templates/src/main.rs000066400000000000000000000031471446276533600216310ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-templates //! ``` use askama::Template; use axum::{ extract, http::StatusCode, response::{Html, IntoResponse, Response}, routing::get, Router, }; use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_templates=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = Router::new().route("/greet/:name", get(greet)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn greet(extract::Path(name): extract::Path) -> impl IntoResponse { let template = HelloTemplate { name }; HtmlTemplate(template) } #[derive(Template)] #[template(path = "hello.html")] struct HelloTemplate { name: String, } struct HtmlTemplate(T); impl IntoResponse for HtmlTemplate where T: Template, { fn into_response(self) -> Response { match self.0.render() { Ok(html) => Html(html).into_response(), Err(err) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to render template. Error: {}", err), ) .into_response(), } } } axum-axum-v0.6.20/examples/templates/templates/000077500000000000000000000000001446276533600215415ustar00rootroot00000000000000axum-axum-v0.6.20/examples/templates/templates/hello.html000066400000000000000000000000341446276533600235270ustar00rootroot00000000000000

Hello, {{ name }}!

axum-axum-v0.6.20/examples/testing-websockets/000077500000000000000000000000001446276533600213715ustar00rootroot00000000000000axum-axum-v0.6.20/examples/testing-websockets/Cargo.toml000066400000000000000000000004611446276533600233220ustar00rootroot00000000000000[package] name = "example-testing-websockets" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["ws"] } futures = "0.3" hyper = { version = "0.14", features = ["full"] } tokio = { version = "1.0", features = ["full"] } tokio-tungstenite = "0.20" axum-axum-v0.6.20/examples/testing-websockets/src/000077500000000000000000000000001446276533600221605ustar00rootroot00000000000000axum-axum-v0.6.20/examples/testing-websockets/src/main.rs000066400000000000000000000111721446276533600234540ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo test -p example-testing-websockets //! ``` use axum::{ extract::{ ws::{Message, WebSocket}, WebSocketUpgrade, }, response::Response, routing::get, Router, }; use futures::{Sink, SinkExt, Stream, StreamExt}; use std::net::SocketAddr; #[tokio::main] async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!("listening on {addr}"); axum::Server::bind(&addr) .serve(app().into_make_service()) .await .unwrap(); } fn app() -> Router { // WebSocket routes can generally be tested in two ways: // // - Integration tests where you run the server and connect with a real WebSocket client. // - Unit tests where you mock the socket as some generic send/receive type // // Which version you pick is up to you. Generally we recommend the integration test version // unless your app has a lot of setup that makes it hard to run in a test. Router::new() .route("/integration-testable", get(integration_testable_handler)) .route("/unit-testable", get(unit_testable_handler)) } // A WebSocket handler that echos any message it receives. // // This one we'll be integration testing so it can be written in the regular way. async fn integration_testable_handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(integration_testable_handle_socket) } async fn integration_testable_handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.recv().await { if let Message::Text(msg) = msg { if socket .send(Message::Text(format!("You said: {msg}"))) .await .is_err() { break; } } } } // The unit testable version requires some changes. // // By splitting the socket into an `impl Sink` and `impl Stream` we can test without providing a // real socket and instead using channels, which also implement `Sink` and `Stream`. async fn unit_testable_handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|socket| { let (write, read) = socket.split(); unit_testable_handle_socket(write, read) }) } // The implementation is largely the same as `integration_testable_handle_socket` expect we call // methods from `SinkExt` and `StreamExt`. async fn unit_testable_handle_socket(mut write: W, mut read: R) where W: Sink + Unpin, R: Stream> + Unpin, { while let Some(Ok(msg)) = read.next().await { if let Message::Text(msg) = msg { if write .send(Message::Text(format!("You said: {msg}"))) .await .is_err() { break; } } } } #[cfg(test)] mod tests { use super::*; use std::net::Ipv4Addr; use tokio_tungstenite::tungstenite; // We can integration test one handler by running the server in a background task and // connecting to it like any other client would. #[tokio::test] async fn integration_test() { let server = axum::Server::bind(&SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))) .serve(app().into_make_service()); let addr = server.local_addr(); tokio::spawn(server); let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/integration-testable")) .await .unwrap(); socket .send(tungstenite::Message::text("foo")) .await .unwrap(); let msg = match socket.next().await.unwrap().unwrap() { tungstenite::Message::Text(msg) => msg, other => panic!("expected a text message but got {other:?}"), }; assert_eq!(msg, "You said: foo"); } // We can unit test the other handler by creating channels to read and write from. #[tokio::test] async fn unit_test() { // Need to use "futures" channels rather than "tokio" channels as they implement `Sink` and // `Stream` let (socket_write, mut test_rx) = futures::channel::mpsc::channel(1024); let (mut test_tx, socket_read) = futures::channel::mpsc::channel(1024); tokio::spawn(unit_testable_handle_socket(socket_write, socket_read)); test_tx .send(Ok(Message::Text("foo".to_owned()))) .await .unwrap(); let msg = match test_rx.next().await.unwrap() { Message::Text(msg) => msg, other => panic!("expected a text message but got {other:?}"), }; assert_eq!(msg, "You said: foo"); } } axum-axum-v0.6.20/examples/testing/000077500000000000000000000000001446276533600172225ustar00rootroot00000000000000axum-axum-v0.6.20/examples/testing/Cargo.toml000066400000000000000000000007321446276533600211540ustar00rootroot00000000000000[package] name = "example-testing" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } hyper = { version = "0.14", features = ["full"] } mime = "0.3" serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] tower = { version = "0.4", features = ["util"] } axum-axum-v0.6.20/examples/testing/src/000077500000000000000000000000001446276533600200115ustar00rootroot00000000000000axum-axum-v0.6.20/examples/testing/src/main.rs000066400000000000000000000136441446276533600213130ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo test -p example-testing //! ``` use std::net::SocketAddr; use axum::{ extract::ConnectInfo, routing::{get, post}, Json, Router, }; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_testing=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app().into_make_service()) .await .unwrap(); } /// Having a function that produces our app makes it easy to call it from tests /// without having to create an HTTP server. #[allow(dead_code)] fn app() -> Router { Router::new() .route("/", get(|| async { "Hello, World!" })) .route( "/json", post(|payload: Json| async move { Json(serde_json::json!({ "data": payload.0 })) }), ) .route( "/requires-connect-into", get(|ConnectInfo(addr): ConnectInfo| async move { format!("Hi {addr}") }), ) // We can still add middleware .layer(TraceLayer::new_for_http()) } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, extract::connect_info::MockConnectInfo, http::{self, Request, StatusCode}, }; use serde_json::{json, Value}; use std::net::{SocketAddr, TcpListener}; use tower::Service; // for `call` use tower::ServiceExt; // for `oneshot` and `ready` #[tokio::test] async fn hello_world() { let app = app(); // `Router` implements `tower::Service>` so we can // call it like any tower service, no need to run an HTTP server. let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert_eq!(&body[..], b"Hello, World!"); } #[tokio::test] async fn json() { let app = app(); let response = app .oneshot( Request::builder() .method(http::Method::POST) .uri("/json") .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Body::from( serde_json::to_vec(&json!([1, 2, 3, 4])).unwrap(), )) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let body: Value = serde_json::from_slice(&body).unwrap(); assert_eq!(body, json!({ "data": [1, 2, 3, 4] })); } #[tokio::test] async fn not_found() { let app = app(); let response = app .oneshot( Request::builder() .uri("/does-not-exist") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert!(body.is_empty()); } // You can also spawn a server and talk to it like any other HTTP server: #[tokio::test] async fn the_real_deal() { let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { axum::Server::from_tcp(listener) .unwrap() .serve(app().into_make_service()) .await .unwrap(); }); let client = hyper::Client::new(); let response = client .request( Request::builder() .uri(format!("http://{}", addr)) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert_eq!(&body[..], b"Hello, World!"); } // You can use `ready()` and `call()` to avoid using `clone()` // in multiple request #[tokio::test] async fn multiple_request() { let mut app = app(); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); } // Here we're calling `/requires-connect-into` which requires `ConnectInfo` // // That is normally set with `Router::into_make_service_with_connect_info` but we can't easily // use that during tests. The solution is instead to set the `MockConnectInfo` layer during // tests. #[tokio::test] async fn with_into_make_service_with_connect_info() { let mut app = app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 3000)))); let request = Request::builder() .uri("/requires-connect-into") .body(Body::empty()) .unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); } } axum-axum-v0.6.20/examples/tls-rustls/000077500000000000000000000000001446276533600177015ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tls-rustls/Cargo.toml000066400000000000000000000005101446276533600216250ustar00rootroot00000000000000[package] name = "example-tls-rustls" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } axum-server = { version = "0.3", features = ["tls-rustls"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/tls-rustls/self_signed_certs/000077500000000000000000000000001446276533600233635ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tls-rustls/self_signed_certs/cert.pem000066400000000000000000000024261446276533600250270ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg +MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr AopvZ09uEQ== -----END CERTIFICATE----- axum-axum-v0.6.20/examples/tls-rustls/self_signed_certs/key.pem000066400000000000000000000032501446276533600246560ustar00rootroot00000000000000-----BEGIN PRIVATE KEY----- MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB 12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI +2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY 5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV 5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg 3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW w37pCS7ykL+7gp7V0WShYsw= -----END PRIVATE KEY----- axum-axum-v0.6.20/examples/tls-rustls/src/000077500000000000000000000000001446276533600204705ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tls-rustls/src/main.rs000066400000000000000000000054431446276533600217700ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-tls-rustls //! ``` use axum::{ extract::Host, handler::HandlerWithoutStateExt, http::{StatusCode, Uri}, response::Redirect, routing::get, BoxError, Router, }; use axum_server::tls_rustls::RustlsConfig; use std::{net::SocketAddr, path::PathBuf}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Clone, Copy)] struct Ports { http: u16, https: u16, } #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_tls_rustls=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let ports = Ports { http: 7878, https: 3000, }; // optional: spawn a second server to redirect http requests to this server tokio::spawn(redirect_http_to_https(ports)); // configure certificate and private key used by https let config = RustlsConfig::from_pem_file( PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("cert.pem"), PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("self_signed_certs") .join("key.pem"), ) .await .unwrap(); let app = Router::new().route("/", get(handler)); // run https server let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); tracing::debug!("listening on {}", addr); axum_server::bind_rustls(addr, config) .serve(app.into_make_service()) .await .unwrap(); } async fn handler() -> &'static str { "Hello, World!" } async fn redirect_http_to_https(ports: Ports) { fn make_https(host: String, uri: Uri, ports: Ports) -> Result { let mut parts = uri.into_parts(); parts.scheme = Some(axum::http::uri::Scheme::HTTPS); if parts.path_and_query.is_none() { parts.path_and_query = Some("/".parse().unwrap()); } let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); parts.authority = Some(https_host.parse()?); Ok(Uri::from_parts(parts)?) } let redirect = move |Host(host): Host, uri: Uri| async move { match make_https(host, uri, ports) { Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), Err(error) => { tracing::warn!(%error, "failed to convert URI to HTTPS"); Err(StatusCode::BAD_REQUEST) } } }; let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); tracing::debug!("http redirect listening on {}", addr); axum::Server::bind(&addr) .serve(redirect.into_make_service()) .await .unwrap(); } axum-axum-v0.6.20/examples/todos/000077500000000000000000000000001446276533600166755ustar00rootroot00000000000000axum-axum-v0.6.20/examples/todos/Cargo.toml000066400000000000000000000007701446276533600206310ustar00rootroot00000000000000[package] name = "example-todos" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util", "timeout"] } tower-http = { version = "0.4.0", features = ["add-extension", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.0", features = ["serde", "v4"] } axum-axum-v0.6.20/examples/todos/src/000077500000000000000000000000001446276533600174645ustar00rootroot00000000000000axum-axum-v0.6.20/examples/todos/src/main.rs000066400000000000000000000102541446276533600207600ustar00rootroot00000000000000//! Provides a RESTful web server managing some Todos. //! //! API will be: //! //! - `GET /todos`: return a JSON list of Todos. //! - `POST /todos`: create a new Todo. //! - `PUT /todos/:id`: update a specific Todo. //! - `DELETE /todos/:id`: delete a specific Todo. //! //! Run with //! //! ```not_rust //! cargo run -p example-todos //! ``` use axum::{ error_handling::HandleErrorLayer, extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, patch}, Json, Router, }; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, net::SocketAddr, sync::{Arc, RwLock}, time::Duration, }; use tower::{BoxError, ServiceBuilder}; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_todos=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let db = Db::default(); // Compose the routes let app = Router::new() .route("/todos", get(todos_index).post(todos_create)) .route("/todos/:id", patch(todos_update).delete(todos_delete)) // Add middleware to all routes .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { if error.is::() { Ok(StatusCode::REQUEST_TIMEOUT) } else { Err(( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), )) } })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .into_inner(), ) .with_state(db); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // The query parameters for todos index #[derive(Debug, Deserialize, Default)] pub struct Pagination { pub offset: Option, pub limit: Option, } async fn todos_index( pagination: Option>, State(db): State, ) -> impl IntoResponse { let todos = db.read().unwrap(); let Query(pagination) = pagination.unwrap_or_default(); let todos = todos .values() .skip(pagination.offset.unwrap_or(0)) .take(pagination.limit.unwrap_or(usize::MAX)) .cloned() .collect::>(); Json(todos) } #[derive(Debug, Deserialize)] struct CreateTodo { text: String, } async fn todos_create(State(db): State, Json(input): Json) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, completed: false, }; db.write().unwrap().insert(todo.id, todo.clone()); (StatusCode::CREATED, Json(todo)) } #[derive(Debug, Deserialize)] struct UpdateTodo { text: Option, completed: Option, } async fn todos_update( Path(id): Path, State(db): State, Json(input): Json, ) -> Result { let mut todo = db .read() .unwrap() .get(&id) .cloned() .ok_or(StatusCode::NOT_FOUND)?; if let Some(text) = input.text { todo.text = text; } if let Some(completed) = input.completed { todo.completed = completed; } db.write().unwrap().insert(todo.id, todo.clone()); Ok(Json(todo)) } async fn todos_delete(Path(id): Path, State(db): State) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { StatusCode::NOT_FOUND } } type Db = Arc>>; #[derive(Debug, Serialize, Clone)] struct Todo { id: Uuid, text: String, completed: bool, } axum-axum-v0.6.20/examples/tokio-postgres/000077500000000000000000000000001446276533600205365ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tokio-postgres/Cargo.toml000066400000000000000000000005171446276533600224710ustar00rootroot00000000000000[package] name = "example-tokio-postgres" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } bb8 = "0.7.1" bb8-postgres = "0.7.0" tokio = { version = "1.0", features = ["full"] } tokio-postgres = "0.7.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/tokio-postgres/src/000077500000000000000000000000001446276533600213255ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tokio-postgres/src/main.rs000066400000000000000000000060221446276533600226170ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-tokio-postgres //! ``` use axum::{ async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, Router, }; use bb8::{Pool, PooledConnection}; use bb8_postgres::PostgresConnectionManager; use std::net::SocketAddr; use tokio_postgres::NoTls; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // set up connection pool let manager = PostgresConnectionManager::new_from_stringlike("host=localhost user=postgres", NoTls) .unwrap(); let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes let app = Router::new() .route( "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } type ConnectionPool = Pool>; async fn using_connection_pool_extractor( State(pool): State, ) -> Result { let conn = pool.get().await.map_err(internal_error)?; let row = conn .query_one("select 1 + 1", &[]) .await .map_err(internal_error)?; let two: i32 = row.try_get(0).map_err(internal_error)?; Ok(two.to_string()) } // we can also write a custom extractor that grabs a connection from the pool // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] impl FromRequestParts for DatabaseConnection where ConnectionPool: FromRef, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; Ok(Self(conn)) } } async fn using_connection_extractor( DatabaseConnection(conn): DatabaseConnection, ) -> Result { let row = conn .query_one("select 1 + 1", &[]) .await .map_err(internal_error)?; let two: i32 = row.try_get(0).map_err(internal_error)?; Ok(two.to_string()) } /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error(err: E) -> (StatusCode, String) where E: std::error::Error, { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } axum-axum-v0.6.20/examples/tracing-aka-logging/000077500000000000000000000000001446276533600213525ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tracing-aka-logging/Cargo.toml000066400000000000000000000005471446276533600233100ustar00rootroot00000000000000[package] name = "example-tracing-aka-logging" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["tracing"] } tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/tracing-aka-logging/src/000077500000000000000000000000001446276533600221415ustar00rootroot00000000000000axum-axum-v0.6.20/examples/tracing-aka-logging/src/main.rs000066400000000000000000000066321446276533600234420ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-tracing-aka-logging //! ``` use axum::{ body::Bytes, extract::MatchedPath, http::{HeaderMap, Request}, response::{Html, Response}, routing::get, Router, }; use std::{net::SocketAddr, time::Duration}; use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; use tracing::{info_span, Span}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events "example_tracing_aka_logging=debug,tower_http=debug,axum::rejection=trace".into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new() .route("/", get(handler)) // `TraceLayer` is provided by tower-http so you have to add that as a dependency. // It provides good defaults but is also very customizable. // // See https://docs.rs/tower-http/0.1.1/tower_http/trace/index.html for more details. // // If you want to customize the behavior using closures here is how. .layer( TraceLayer::new_for_http() .make_span_with(|request: &Request<_>| { // Log the matched route's path (with placeholders not filled in). // Use request.uri() or OriginalUri if you want the real path. let matched_path = request .extensions() .get::() .map(MatchedPath::as_str); info_span!( "http_request", method = ?request.method(), matched_path, some_other_field = tracing::field::Empty, ) }) .on_request(|_request: &Request<_>, _span: &Span| { // You can use `_span.record("some_other_field", value)` in one of these // closures to attach a value to the initially empty field in the info_span // created above. }) .on_response(|_response: &Response, _latency: Duration, _span: &Span| { // ... }) .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| { // ... }) .on_eos( |_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| { // ... }, ) .on_failure( |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { // ... }, ), ); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler() -> Html<&'static str> { Html("

Hello, World!

") } axum-axum-v0.6.20/examples/unix-domain-socket/000077500000000000000000000000001446276533600212635ustar00rootroot00000000000000axum-axum-v0.6.20/examples/unix-domain-socket/Cargo.toml000066400000000000000000000006101446276533600232100ustar00rootroot00000000000000[package] name = "example-unix-domain-socket" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } futures = "0.3" hyper = { version = "0.14", features = ["full"] } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/unix-domain-socket/src/000077500000000000000000000000001446276533600220525ustar00rootroot00000000000000axum-axum-v0.6.20/examples/unix-domain-socket/src/main.rs000066400000000000000000000115331446276533600233470ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-unix-domain-socket //! ``` #[cfg(unix)] #[tokio::main] async fn main() { unix::server().await; } #[cfg(not(unix))] fn main() { println!("This example requires unix") } #[cfg(unix)] mod unix { use axum::{ body::Body, extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode, Uri}, routing::get, Router, }; use futures::ready; use hyper::{ client::connect::{Connected, Connection}, server::accept::Accept, }; use std::{ io, path::PathBuf, pin::Pin, sync::Arc, task::{Context, Poll}, }; use tokio::{ io::{AsyncRead, AsyncWrite}, net::{unix::UCred, UnixListener, UnixStream}, }; use tower::BoxError; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let path = PathBuf::from("/tmp/axum/helloworld"); let _ = tokio::fs::remove_file(&path).await; tokio::fs::create_dir_all(path.parent().unwrap()) .await .unwrap(); let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async { let app = Router::new().route("/", get(handler)); axum::Server::builder(ServerAccept { uds }) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); }); let connector = tower::service_fn(move |_: Uri| { let path = path.clone(); Box::pin(async move { let stream = UnixStream::connect(path).await?; Ok::<_, io::Error>(ClientConnection { stream }) }) }); let client = hyper::Client::builder().build(connector); let request = Request::builder() .method(Method::GET) .uri("http://uri-doesnt-matter.com") .body(Body::empty()) .unwrap(); let response = client.request(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); let body = String::from_utf8(body.to_vec()).unwrap(); assert_eq!(body, "Hello, World!"); } async fn handler(ConnectInfo(info): ConnectInfo) -> &'static str { println!("new connection from `{:?}`", info); "Hello, World!" } struct ServerAccept { uds: UnixListener, } impl Accept for ServerAccept { type Conn = UnixStream; type Error = BoxError; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { let (stream, _addr) = ready!(self.uds.poll_accept(cx))?; Poll::Ready(Some(Ok(stream))) } } struct ClientConnection { stream: UnixStream, } impl AsyncWrite for ClientConnection { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.stream).poll_write(cx, buf) } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_flush(cx) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_shutdown(cx) } } impl AsyncRead for ClientConnection { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } } impl Connection for ClientConnection { fn connected(&self) -> Connected { Connected::new() } } #[derive(Clone, Debug)] #[allow(dead_code)] struct UdsConnectInfo { peer_addr: Arc, peer_cred: UCred, } impl connect_info::Connected<&UnixStream> for UdsConnectInfo { fn connect_info(target: &UnixStream) -> Self { let peer_addr = target.peer_addr().unwrap(); let peer_cred = target.peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } } axum-axum-v0.6.20/examples/validator/000077500000000000000000000000001446276533600175325ustar00rootroot00000000000000axum-axum-v0.6.20/examples/validator/Cargo.toml000066400000000000000000000006711446276533600214660ustar00rootroot00000000000000[package] edition = "2021" name = "example-validator" publish = false version = "0.1.0" [dependencies] async-trait = "0.1.67" axum = { path = "../../axum" } http-body = "0.4.3" serde = { version = "1.0", features = ["derive"] } thiserror = "1.0.29" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } validator = { version = "0.14.0", features = ["derive"] } axum-axum-v0.6.20/examples/validator/src/000077500000000000000000000000001446276533600203215ustar00rootroot00000000000000axum-axum-v0.6.20/examples/validator/src/main.rs000066400000000000000000000053051446276533600216160ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-validator //! //! curl '127.0.0.1:3000?name=' //! -> Input validation error: [name: Can not be empty] //! //! curl '127.0.0.1:3000?name=LT' //! ->

Hello, LT!

//! ``` use async_trait::async_trait; use axum::{ extract::{rejection::FormRejection, Form, FromRequest}, http::{Request, StatusCode}, response::{Html, IntoResponse, Response}, routing::get, Router, }; use serde::{de::DeserializeOwned, Deserialize}; use std::net::SocketAddr; use thiserror::Error; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use validator::Validate; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_validator=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route let app = Router::new().route("/", get(handler)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Debug, Deserialize, Validate)] pub struct NameInput { #[validate(length(min = 1, message = "Can not be empty"))] pub name: String, } async fn handler(ValidatedForm(input): ValidatedForm) -> Html { Html(format!("

Hello, {}!

", input.name)) } #[derive(Debug, Clone, Copy, Default)] pub struct ValidatedForm(pub T); #[async_trait] impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, S: Send + Sync, Form: FromRequest, B: Send + 'static, { type Rejection = ServerError; async fn from_request(req: Request, state: &S) -> Result { let Form(value) = Form::::from_request(req, state).await?; value.validate()?; Ok(ValidatedForm(value)) } } #[derive(Debug, Error)] pub enum ServerError { #[error(transparent)] ValidationError(#[from] validator::ValidationErrors), #[error(transparent)] AxumFormRejection(#[from] FormRejection), } impl IntoResponse for ServerError { fn into_response(self) -> Response { match self { ServerError::ValidationError(_) => { let message = format!("Input validation error: [{}]", self).replace('\n', ", "); (StatusCode::BAD_REQUEST, message) } ServerError::AxumFormRejection(_) => (StatusCode::BAD_REQUEST, self.to_string()), } .into_response() } } axum-axum-v0.6.20/examples/versioning/000077500000000000000000000000001446276533600177305ustar00rootroot00000000000000axum-axum-v0.6.20/examples/versioning/Cargo.toml000066400000000000000000000004151446276533600216600ustar00rootroot00000000000000[package] name = "example-versioning" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum-axum-v0.6.20/examples/versioning/src/000077500000000000000000000000001446276533600205175ustar00rootroot00000000000000axum-axum-v0.6.20/examples/versioning/src/main.rs000066400000000000000000000036121446276533600220130ustar00rootroot00000000000000//! Run with //! //! ```not_rust //! cargo run -p example-versioning //! ``` use axum::{ async_trait, extract::{FromRequestParts, Path}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::get, RequestPartsExt, Router, }; use std::{collections::HashMap, net::SocketAddr}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_versioning=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes let app = Router::new().route("/:version/foo", get(handler)); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn handler(version: Version) { println!("received request with version {:?}", version); } #[derive(Debug)] enum Version { V1, V2, V3, } #[async_trait] impl FromRequestParts for Version where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let params: Path> = parts.extract().await.map_err(IntoResponse::into_response)?; let version = params .get("version") .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; match version.as_str() { "v1" => Ok(Version::V1), "v2" => Ok(Version::V2), "v3" => Ok(Version::V3), _ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()), } } } axum-axum-v0.6.20/examples/websockets/000077500000000000000000000000001446276533600177165ustar00rootroot00000000000000axum-axum-v0.6.20/examples/websockets/Cargo.toml000066400000000000000000000012411446276533600216440ustar00rootroot00000000000000[package] name = "example-websockets" version = "0.1.0" edition = "2021" publish = false [dependencies] axum = { path = "../../axum", features = ["ws", "headers"] } futures = "0.3" futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } headers = "0.3" tokio = { version = "1.0", features = ["full"] } tokio-tungstenite = "0.20" tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.4.0", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [[bin]] name = "example-websockets" path = "src/main.rs" [[bin]] name = "example-client" path = "src/client.rs" axum-axum-v0.6.20/examples/websockets/assets/000077500000000000000000000000001446276533600212205ustar00rootroot00000000000000axum-axum-v0.6.20/examples/websockets/assets/index.html000066400000000000000000000001531446276533600232140ustar00rootroot00000000000000
Open the console to see stuff, then refresh to initiate exchange. axum-axum-v0.6.20/examples/websockets/assets/script.js000066400000000000000000000012201446276533600230550ustar00rootroot00000000000000const socket = new WebSocket('ws://localhost:3000/ws'); socket.addEventListener('open', function (event) { socket.send('Hello Server!'); }); socket.addEventListener('message', function (event) { console.log('Message from server ', event.data); }); setTimeout(() => { const obj = { hello: "world" }; const blob = new Blob([JSON.stringify(obj, null, 2)], { type: "application/json", }); console.log("Sending blob over websocket"); socket.send(blob); }, 1000); setTimeout(() => { socket.send('About done here...'); console.log("Sending close over websocket"); socket.close(3000, "Crash and Burn!"); }, 3000);axum-axum-v0.6.20/examples/websockets/src/000077500000000000000000000000001446276533600205055ustar00rootroot00000000000000axum-axum-v0.6.20/examples/websockets/src/client.rs000066400000000000000000000127121446276533600223340ustar00rootroot00000000000000//! Based on tokio-tungstenite example websocket client, but with multiple //! concurrent websocket clients in one package //! //! This will connect to a server specified in the SERVER with N_CLIENTS //! concurrent connections, and then flood some test messages over websocket. //! This will also print whatever it gets into stdout. //! //! Note that this is not currently optimized for performance, especially around //! stdout mutex management. Rather it's intended to show an example of working with axum's //! websocket server and how the client-side and server-side code can be quite similar. //! use futures_util::stream::FuturesUnordered; use futures_util::{SinkExt, StreamExt}; use std::borrow::Cow; use std::ops::ControlFlow; use std::time::Instant; // we will use tungstenite for websocket client impl (same library as what axum is using) use tokio_tungstenite::{ connect_async, tungstenite::protocol::{frame::coding::CloseCode, CloseFrame, Message}, }; const N_CLIENTS: usize = 2; //set to desired number const SERVER: &str = "ws://127.0.0.1:3000/ws"; #[tokio::main] async fn main() { let start_time = Instant::now(); //spawn several clients that will concurrently talk to the server let mut clients = (0..N_CLIENTS) .map(|cli| tokio::spawn(spawn_client(cli))) .collect::>(); //wait for all our clients to exit while clients.next().await.is_some() {} let end_time = Instant::now(); //total time should be the same no matter how many clients we spawn println!( "Total time taken {:#?} with {N_CLIENTS} concurrent clients, should be about 6.45 seconds.", end_time - start_time ); } //creates a client. quietly exits on failure. async fn spawn_client(who: usize) { let ws_stream = match connect_async(SERVER).await { Ok((stream, response)) => { println!("Handshake for client {} has been completed", who); // This will be the HTTP response, same as with server this is the last moment we // can still access HTTP stuff. println!("Server response was {:?}", response); stream } Err(e) => { println!("WebSocket handshake for client {who} failed with {e}!"); return; } }; let (mut sender, mut receiver) = ws_stream.split(); //we can ping the server for start sender .send(Message::Ping("Hello, Server!".into())) .await .expect("Can not send!"); //spawn an async sender to push some more messages into the server let mut send_task = tokio::spawn(async move { for i in 1..30 { // In any websocket error, break loop. if sender .send(Message::Text(format!("Message number {}...", i))) .await .is_err() { //just as with server, if send fails there is nothing we can do but exit. return; } tokio::time::sleep(std::time::Duration::from_millis(300)).await; } // When we are done we may want our client to close connection cleanly. println!("Sending close to {}...", who); if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: CloseCode::Normal, reason: Cow::from("Goodbye"), }))) .await { println!("Could not send Close due to {:?}, probably it is ok?", e); }; }); //receiver just prints whatever it gets let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { // print message and break if instructed to do so if process_message(msg, who).is_break() { break; } } }); //wait for either task to finish and kill the other task tokio::select! { _ = (&mut send_task) => { recv_task.abort(); }, _ = (&mut recv_task) => { send_task.abort(); } } } /// Function to handle messages we get (with a slight twist that Frame variant is visible /// since we are working with the underlying tungstenite library directly without axum here). fn process_message(msg: Message, who: usize) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { println!(">>> {} got str: {:?}", who, t); } Message::Binary(d) => { println!(">>> {} got {} bytes: {:?}", who, d.len(), d); } Message::Close(c) => { if let Some(cf) = c { println!( ">>> {} got close with code {} and reason `{}`", who, cf.code, cf.reason ); } else { println!(">>> {} somehow got close message without CloseFrame", who); } return ControlFlow::Break(()); } Message::Pong(v) => { println!(">>> {} got pong with {:?}", who, v); } // Just as with axum server, the underlying tungstenite websocket library // will handle Ping for you automagically by replying with Pong and copying the // v according to spec. But if you need the contents of the pings you can see them here. Message::Ping(v) => { println!(">>> {} got ping with {:?}", who, v); } Message::Frame(_) => { unreachable!("This is never supposed to happen") } } ControlFlow::Continue(()) } axum-axum-v0.6.20/examples/websockets/src/main.rs000066400000000000000000000206371446276533600220070ustar00rootroot00000000000000//! Example websocket server. //! //! Run the server with //! ```not_rust //! cargo run -p example-websockets --bin example-websockets //! ``` //! //! Run a browser client with //! ```not_rust //! firefox http://localhost:3000 //! ``` //! //! Alternatively you can run the rust client (showing two //! concurrent websocket connections being established) with //! ```not_rust //! cargo run -p example-websockets --bin example-client //! ``` use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, TypedHeader, }, response::IntoResponse, routing::get, Router, }; use std::borrow::Cow; use std::ops::ControlFlow; use std::{net::SocketAddr, path::PathBuf}; use tower_http::{ services::ServeDir, trace::{DefaultMakeSpan, TraceLayer}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; //allows to extract the IP of connecting user use axum::extract::connect_info::ConnectInfo; use axum::extract::ws::CloseFrame; //allows to split the websocket stream into separate TX and RX branches use futures::{sink::SinkExt, stream::StreamExt}; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "example_websockets=debug,tower_http=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); // build our application with some routes let app = Router::new() .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .route("/ws", get(ws_handler)) // logging so we can see whats going on .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::default().include_headers(true)), ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } /// The handler for the HTTP request (this gets called when the HTTP GET lands at the start /// of websocket negotiation). After this completes, the actual switching from HTTP to /// websocket protocol will occur. /// This is the last point where we can extract TCP/IP metadata such as IP address of the client /// as well as things from HTTP headers such as user-agent of the browser etc. async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option>, ConnectInfo(addr): ConnectInfo, ) -> impl IntoResponse { let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { user_agent.to_string() } else { String::from("Unknown browser") }; println!("`{user_agent}` at {addr} connected."); // finalize the upgrade process by returning upgrade callback. // we can customize the callback by sending additional info such as address. ws.on_upgrade(move |socket| handle_socket(socket, addr)) } /// Actual websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { //send a ping (unsupported by some browsers) just to kick things off and get a response if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() { println!("Pinged {}...", who); } else { println!("Could not send ping {}!", who); // no Error here since the only thing we can do is to close the connection. // If we can not send messages, there is no way to salvage the statemachine anyway. return; } // receive single message from a client (we can either receive or send with socket). // this will likely be the Pong for our Ping or a hello message from client. // waiting for message from a client will block this task, but will not block other client's // connections. if let Some(msg) = socket.recv().await { if let Ok(msg) = msg { if process_message(msg, who).is_break() { return; } } else { println!("client {who} abruptly disconnected"); return; } } // Since each client gets individual statemachine, we can pause handling // when necessary to wait for some external event (in this case illustrated by sleeping). // Waiting for this client to finish getting its greetings does not prevent other clients from // connecting to server and receiving their greetings. for i in 1..5 { if socket .send(Message::Text(format!("Hi {i} times!"))) .await .is_err() { println!("client {who} abruptly disconnected"); return; } tokio::time::sleep(std::time::Duration::from_millis(100)).await; } // By splitting socket we can send and receive at the same time. In this example we will send // unsolicited messages to client based on some sort of server's internal event (i.e .timer). let (mut sender, mut receiver) = socket.split(); // Spawn a task that will push several messages to the client (does not matter what client does) let mut send_task = tokio::spawn(async move { let n_msg = 20; for i in 0..n_msg { // In case of any websocket error, we exit. if sender .send(Message::Text(format!("Server message {i} ..."))) .await .is_err() { return i; } tokio::time::sleep(std::time::Duration::from_millis(300)).await; } println!("Sending close to {who}..."); if let Err(e) = sender .send(Message::Close(Some(CloseFrame { code: axum::extract::ws::close_code::NORMAL, reason: Cow::from("Goodbye"), }))) .await { println!("Could not send Close due to {}, probably it is ok?", e); } n_msg }); // This second task will receive messages from client and print them on server console let mut recv_task = tokio::spawn(async move { let mut cnt = 0; while let Some(Ok(msg)) = receiver.next().await { cnt += 1; // print message and break if instructed to do so if process_message(msg, who).is_break() { break; } } cnt }); // If any one of the tasks exit, abort the other. tokio::select! { rv_a = (&mut send_task) => { match rv_a { Ok(a) => println!("{} messages sent to {}", a, who), Err(a) => println!("Error sending messages {:?}", a) } recv_task.abort(); }, rv_b = (&mut recv_task) => { match rv_b { Ok(b) => println!("Received {} messages", b), Err(b) => println!("Error receiving messages {:?}", b) } send_task.abort(); } } // returning from the handler closes the websocket connection println!("Websocket context {} destroyed", who); } /// helper to print contents of messages to stdout. Has special treatment for Close. fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { println!(">>> {} sent str: {:?}", who, t); } Message::Binary(d) => { println!(">>> {} sent {} bytes: {:?}", who, d.len(), d); } Message::Close(c) => { if let Some(cf) = c { println!( ">>> {} sent close with code {} and reason `{}`", who, cf.code, cf.reason ); } else { println!(">>> {} somehow sent close message without CloseFrame", who); } return ControlFlow::Break(()); } Message::Pong(v) => { println!(">>> {} sent pong with {:?}", who, v); } // You should never need to manually handle Message::Ping, as axum's websocket library // will do so for you automagically by replying with Pong and copying the v according to // spec. But if you need the contents of the pings you can see them here. Message::Ping(v) => { println!(">>> {} sent ping with {:?}", who, v); } } ControlFlow::Continue(()) }