pax_global_header00006660000000000000000000000064147233133030014511gustar00rootroot0000000000000052 comment=6beddfcc0f4bdf030a08e8100e574b5468e97871 soketto-0.8.1/000077500000000000000000000000001472331330300132075ustar00rootroot00000000000000soketto-0.8.1/.editorconfig000066400000000000000000000001441472331330300156630ustar00rootroot00000000000000root = true [*] charset=utf-8 end_of_line=lf indent_size=4 indent_style=space max_line_length=100 soketto-0.8.1/.github/000077500000000000000000000000001472331330300145475ustar00rootroot00000000000000soketto-0.8.1/.github/dependabot.yml000066400000000000000000000003201472331330300173720ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: "cargo" directory: "/" schedule: interval: "weekly" - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" soketto-0.8.1/.github/workflows/000077500000000000000000000000001472331330300166045ustar00rootroot00000000000000soketto-0.8.1/.github/workflows/ci.yml000066400000000000000000000050321472331330300177220ustar00rootroot00000000000000name: Rust on: push: # Run jobs when commits are pushed to # master or release-like branches: branches: - master - release* pull_request: # Run jobs for any external PR that wants # to merge to master, too: branches: - master env: CARGO_TERM_COLOR: always jobs: build: name: Check Code runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v4.2.2 - name: Install Rust stable toolchain uses: actions-rs/toolchain@v1.0.7 with: profile: minimal toolchain: stable override: true - name: Rust Cache uses: Swatinem/rust-cache@v2.7.5 - name: Build uses: actions-rs/cargo@v1.0.3 with: command: check args: --all-targets --all-features fmt: name: Run rustfmt runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v4.2.2 - name: Install Rust stable toolchain uses: actions-rs/toolchain@v1.0.7 with: profile: minimal toolchain: stable override: true components: clippy, rustfmt - name: Rust Cache uses: Swatinem/rust-cache@v2.7.5 - name: Cargo fmt uses: actions-rs/cargo@v1.0.3 with: command: fmt args: --all -- --check docs: name: Check Documentation runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v4.2.2 - name: Install Rust stable toolchain uses: actions-rs/toolchain@v1.0.7 with: profile: minimal toolchain: stable override: true - name: Rust Cache uses: Swatinem/rust-cache@v2.7.5 - name: Check internal documentation links run: RUSTDOCFLAGS="--deny broken_intra_doc_links" cargo doc --verbose --workspace --no-deps --document-private-items tests: name: Run tests runs-on: ubuntu-latest steps: - name: Checkout sources uses: actions/checkout@v4.2.2 - name: Install Rust stable toolchain uses: actions-rs/toolchain@v1.0.7 with: profile: minimal toolchain: stable override: true - name: Rust Cache uses: Swatinem/rust-cache@v2.7.5 - name: Cargo build uses: actions-rs/cargo@v1.0.3 with: command: build args: --workspace - name: Cargo test uses: actions-rs/cargo@v1.0.3 with: command: test soketto-0.8.1/.gitignore000066400000000000000000000000301472331330300151700ustar00rootroot00000000000000target Cargo.lock *.dat soketto-0.8.1/CHANGELOG.md000066400000000000000000000123731472331330300150260ustar00rootroot00000000000000# Changelog The format is based on [Keep a Changelog]. [Keep a Changelog]: http://keepachangelog.com/en/1.0.0/ ## 0.8.1 - [fixed] ignore I/O error after successful close handshake [#115](https://github.com/paritytech/soketto/pull/115) ## 0.8.0 - [changed] move to rust 2021 [#56](https://github.com/paritytech/soketto/pull/56) - [changed] Replace sha-1 v0.9 with sha1 v0.10 [#62](https://github.com/paritytech/soketto/pull/62) - [changed] Update hyper requirement from v0.14 to v1.0 [#99](https://github.com/paritytech/soketto/pull/99) - [changed] Update base64 requirement from 0.13 to 0.22 [#97](https://github.com/paritytech/soketto/pull/97) - [changed] Bump MSRV to 1.71.1. - [fixed] doc typo on Client resource field [#79](https://github.com/paritytech/soketto/pull/97) ## 0.7.1 - [fixed] Advance reader when a too big message is received [#54](https://github.com/paritytech/soketto/pull/54) ## 0.7.0 - [added] Added the `handshake::http` module and example usage at `examples/hyper_server.rs` to make using Soketto in conjunction with libraries that use the `http` types (like Hyper) simpler [#45](https://github.com/paritytech/soketto/pull/45) [#48](https://github.com/paritytech/soketto/pull/48) - [added] Allow setting custom headers on the client to be sent to WebSocket servers when the opening handshake is performed [#47](https://github.com/paritytech/soketto/pull/47) ## 0.6.0 - [changed] Expose the `Origin` headers from the client handshake on `ClientRequest` [#35](https://github.com/paritytech/soketto/pull/35) - [changed] Update handshake error to expose a couple of new variants (`IncompleteHttpRequest` and `SecWebSocketKeyInvalidLength`) [#35](https://github.com/paritytech/soketto/pull/35) - [added] Add `send_text_owned` method to `Sender` as an optimisation when you can pass an owned `String` in [#36](https://github.com/paritytech/soketto/pull/36) - [updated] Run rustfmt over the repository, and minor tidy up [#41](https://github.com/paritytech/soketto/pull/41) ## 0.5.0 - Update examples to Tokio 1 [#27](https://github.com/paritytech/soketto/pull/27) - Update deps and remove unnecessary transients [#30](https://github.com/paritytech/soketto/pull/30) - Add CLOSE reason handling [#31](https://github.com/paritytech/soketto/pull/31) - Fix handshake with case-sensible servers [#32](https://github.com/paritytech/soketto/pull/32) ## 0.4.2 - Added connection ID to log output (#21). - Added `ClientRequest::path` to access the path requested by the client (See #23 by @mward for details). - Updated `sha-1` dependency to 0.9 (#24). ## 0.4.1 - Update some `dev-dependencies`. ## 0.4.0 - Remove all `unsafe` code blocks. - Remove internal use of `futures::io::BufWriter`. - `Extension::decode` now takes a `&mut Vec` instead of a `BytesMut`. - `Incoming::Pong` contains the PONG payload data slice inline. - `Data` not longer contains application data, but reports only the number of bytes. The actual data is written directly into the `&mut Vec` parameter of `Receiver::receive` or `Receiver::receive_data`. - `Receiver::into_stream` has been removed. ## 0.3.2 - Bugfix release. `Codec::encode_header` contained a hidden assumption that a `usize` would be 8 bytes long, which is obviously only true on 64-bit architectures. See #18 for details. ## 0.3.1 - A method `into_inner` to get back the socket has been added to `handshake::{Client, Server}`. ## 0.3.0 Update to use and work with async/await: - `Connection` has been split into a `Sender` and `Receiver` pair with async methods to send and receive data or control frames such as Pings or Pongs. - `connection::into_stream` has been added to get a `futures::stream::Stream` from a `Receiver`. - A `connection::Builder` has been added to setup connection parameters. `handshake::Client` and `handshake::Server` no longer have an `into_connection` method, but an `into_builder` one which returns the `Builder` and allows further configuration. - `base::Data` has been moved to `data`. In addition `data::Incoming` supports control frame data. - `base::Codec` no longer implements `Encoder`/`Decoder` traits but has inherent methods for encoding and decoding websocket frame headers. - `base::Frame` has been removed. The `base::Codec` only deals with headers. - The `handshake` module contains separate sub-modules for `client` and `server` handshakes. Some handshake related types have been refactored slightly. - `Extension`s `decode` methods work on `&mut BytesMut` parameters instead of `Data`. For `encode` a new type `Storage` has been added which unifies different types of data, i.e. shared, unique and owned data. ## 0.2.3 - Maintenance release. ## 0.2.2 - Improved handshake header matching which is now more robust and can cope with repeated header names and comma separated values. ## 0.2.1 - The DEFLATE extension now allows custom maximum window bits for client and server. - Fix handling of reserved bits in base codec. ## 0.2.0 - Change `Extension` trait and add an optional DEFLATE extension (RFC 7692). For now the possibility to use reserved opcodes in extensions is not enabled. The DEFLATE extension does not support setting of window bits other than 15 currently. - Limit the max. buffer size in `Connection` (see `Connection::set_max_buffer_size`). ## 0.1.0 Initial release. soketto-0.8.1/Cargo.toml000066400000000000000000000031401472331330300151350ustar00rootroot00000000000000[package] name = "soketto" version = "0.8.1" authors = ["Parity Technologies ", "Jason Ozias "] description = "A websocket protocol implementation." keywords = ["websocket", "codec", "async", "futures"] categories = ["network-programming", "asynchronous", "web-programming::websocket"] license = "Apache-2.0 OR MIT" readme = "README.md" repository = "https://github.com/paritytech/soketto" edition = "2021" rust-version = "1.71.1" [package.metadata.docs.rs] all-features = true [features] default = [] deflate = ["flate2"] [dependencies] base64 = { default-features = false, features = ["alloc"], version = "0.22" } bytes = { default-features = false, version = "1.0" } flate2 = { default-features = false, features = ["zlib"], optional = true, version = "1.0.13" } futures = { default-features = false, features = ["bilock", "std", "unstable"], version = "0.3.1" } httparse = { default-features = false, features = ["std"], version = "1.3.4" } log = { default-features = false, version = "0.4.8" } rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" } sha1 = { default-features = false, version = "0.10" } http = { version = "1", optional = true } [dev-dependencies] quickcheck = "1" tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7", features = ["compat"] } tokio-stream = { version = "0.1", features = ["net"] } http-body-util = "0.1" hyper = { version = "1.2", features = ["full"] } hyper-util = { version = "0.1", features = ["tokio"] } env_logger = "0.11.1" [[example]] name = "hyper_server" required-features = ["http"] soketto-0.8.1/LICENSE-APACHE000066400000000000000000000251371472331330300151430ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. soketto-0.8.1/LICENSE-MIT000066400000000000000000000021251472331330300146430ustar00rootroot00000000000000Copyright (c) 2019 Parity Technologies (UK) Ltd. Copyright (c) 2016 twist developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. soketto-0.8.1/README.md000066400000000000000000000003241472331330300144650ustar00rootroot00000000000000# Soketto An implementation of the [RFC 6455][1] websocket protocol. This crate is a heavily modified fork of the [twist][2] crate. [1]: https://tools.ietf.org/html/rfc6455 [2]: https://crates.io/crates/twist soketto-0.8.1/RELEASING.md000066400000000000000000000052371472331330300150510ustar00rootroot00000000000000# Release Checklist These steps assume that you've checked out the Soketto repository and are in the root directory of it. We also assume that ongoing work done is being merged directly to the `master` branch. 1. Ensure that everything you'd like to see released is on the `master` branch. 2. Create a release branch off `master`, for example `release-v0.6.0`. The branch name should start with `release` so that we can target commits with CI. Decide how far the version needs to be bumped based on the changes to date. If unsure what to bump the version to (e.g. is it a major, minor or patch release), check with the Parity Tools team. 3. Check that you're happy with the current documentation. ``` cargo doc --open --all-features ``` CI checks for broken internal links at the moment. Optionally you can also confirm that any external links are still valid like so: ``` cargo install cargo-deadlinks cargo deadlinks --check-http -- --all-features ``` If there are minor issues with the documentation, they can be fixed in the release branch. 4. Bump the crate version in `Cargo.toml` to whatever was decided in step 2. 5. Update `CHANGELOG.md` to reflect the difference between this release and last. If you're unsure of what to add, check with the Tools team. One way to gain some inspiration on what to write is to look at the [closed PRs](https://github.com/paritytech/soketto/pulls?q=is%3Apr+is%3Aclosed). You can also look through the commit history to find the code changes since the last release (eg `git log --pretty LAST_VERSION_TAG..HEAD`). 6. Commit any of the above changes to the release branch and open a PR in GitHub with a base of `master`. 7. Once the branch has been reviewed and passes CI, merge it. 8. Now, we're ready to publish the release to crates.io. Checkout `master`, ensuring we're looking at that latest merge (`git pull`). Next, do a dry run to make sure that things seem sane: ``` cargo publish --dry-run ``` If we're happy with everything, proceed with the release: ``` cargo publish ``` 9. If the release was successful, then tag the commit that we released in the `master` branch with the version that we just released, for example: ``` git tag v0.6.0 # use the version number you've just published to crates.io, not this one git push --tags ``` Once this is pushed, go along to [the releases page on GitHub](https://github.com/paritytech/soketto/releases) and draft a new release which points to the tag you just pushed to `master` above. Copy the changelog comments for the current release into the release description. soketto-0.8.1/examples/000077500000000000000000000000001472331330300150255ustar00rootroot00000000000000soketto-0.8.1/examples/autobahn_client.rs000066400000000000000000000073411472331330300205370ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. // Example to be used with the autobahn test suite, a fully automated test // suite to verify client and server implementations of websocket // implementation. // // Once started, the tests can be executed with: wstest -m fuzzingserver // // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; use soketto::{connection, handshake, BoxedError}; use std::str::FromStr; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION"); #[tokio::main] async fn main() -> Result<(), BoxedError> { let n = num_of_cases().await?; for i in 1..=n { if let Err(e) = run_case(i).await { log::error!("case {}: {:?}", i, e) } } update_report().await?; Ok(()) } async fn num_of_cases() -> Result { let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, "/getCaseCount"); assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); let (_, mut receiver) = client.into_builder().finish(); let mut data = Vec::new(); let kind = receiver.receive_data(&mut data).await?; assert!(kind.is_text()); let num = usize::from_str(std::str::from_utf8(&data)?)?; log::info!("{} cases to run", num); Ok(num) } async fn run_case(n: usize) -> Result<(), BoxedError> { log::info!("running case {}", n); let resource = format!("/runCase?case={}&agent=soketto-{}", n, SOKETTO_VERSION); let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, &resource); assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); let (mut sender, mut receiver) = client.into_builder().finish(); let mut message = Vec::new(); loop { message.clear(); match receiver.receive_data(&mut message).await { Ok(soketto::Data::Binary(n)) => { assert_eq!(n, message.len()); sender.send_binary_mut(&mut message).await?; sender.flush().await? } Ok(soketto::Data::Text(n)) => { assert_eq!(n, message.len()); sender.send_text(std::str::from_utf8(&message)?).await?; sender.flush().await? } Err(connection::Error::Closed) => return Ok(()), Err(e) => return Err(e.into()), } } } async fn update_report() -> Result<(), BoxedError> { log::info!("requesting report generation"); let resource = format!("/updateReports?agent=soketto-{}", SOKETTO_VERSION); let socket = TcpStream::connect("127.0.0.1:9001").await?; let mut client = new_client(socket, &resource); assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); client.into_builder().finish().0.close().await?; Ok(()) } #[cfg(not(feature = "deflate"))] fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "127.0.0.1:9001", path) } #[cfg(feature = "deflate")] fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(64 * 1024, socket.compat())); let mut client = handshake::Client::new(socket, "127.0.0.1:9001", path); let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Client); client.add_extension(Box::new(deflate)); client } soketto-0.8.1/examples/autobahn_server.rs000066400000000000000000000053051472331330300205650ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. // Example to be used with the autobahn test suite, a fully automated test // suite to verify client and server implementations of websocket // implementation. // // Once started, the tests can be executed with: wstest -m fuzzingclient // // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; use soketto::{connection, handshake, BoxedError}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; #[tokio::main] async fn main() -> Result<(), BoxedError> { let listener = TcpListener::bind("127.0.0.1:9001").await?; let mut incoming = TcpListenerStream::new(listener); while let Some(socket) = incoming.next().await { let mut server = new_server(socket?); let key = { let req = server.receive_request().await?; req.key() }; let accept = handshake::server::Response::Accept { key, protocol: None }; server.send_response(&accept).await?; let (mut sender, mut receiver) = server.into_builder().finish(); let mut message = Vec::new(); loop { message.clear(); match receiver.receive_data(&mut message).await { Ok(soketto::Data::Binary(n)) => { assert_eq!(n, message.len()); sender.send_binary_mut(&mut message).await?; sender.flush().await? } Ok(soketto::Data::Text(n)) => { assert_eq!(n, message.len()); if let Ok(txt) = std::str::from_utf8(&message) { sender.send_text(txt).await?; sender.flush().await? } else { break; } } Err(connection::Error::Closed) => break, Err(e) => { log::error!("connection error: {}", e); break; } } } } Ok(()) } #[cfg(not(feature = "deflate"))] fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { handshake::Server::new(BufReader::new(BufWriter::new(socket.compat()))) } #[cfg(feature = "deflate")] fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(16 * 1024, socket.compat())); let mut server = handshake::Server::new(socket); let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); server.add_extension(Box::new(deflate)); server } soketto-0.8.1/examples/hyper_server.rs000066400000000000000000000122261472331330300201130ustar00rootroot00000000000000// Copyright (c) 2021 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. // An example of how to use of Soketto alongside Hyper, so that we can handle // standard HTTP traffic with Hyper, and WebSocket connections with Soketto, on // the same port. // // To try this, start up the example (`cargo run --example hyper_server`) and then // navigate to localhost:3000 and, in the browser JS console, run: // // ``` // var socket = new WebSocket("ws://localhost:3000"); // socket.onmessage = function(msg) { console.log(msg) }; // socket.send("Hello!"); // ``` // // You'll see any messages you send echoed back. use std::net::SocketAddr; use futures::io::{BufReader, BufWriter}; use hyper::server::conn::http1; use hyper::{body::Bytes, service::service_fn, Request, Response}; use hyper_util::rt::TokioIo; use soketto::{ handshake::http::{is_upgrade_request, Server}, BoxedError, }; use tokio_util::compat::TokioAsyncReadCompatExt; type FullBody = http_body_util::Full; /// Start up a hyper server. #[tokio::main] async fn main() -> Result<(), BoxedError> { env_logger::init(); let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); let listener = tokio::net::TcpListener::bind(addr).await?; log::info!( "Listening on http://{:?} — connect and I'll echo back anything you send!", listener.local_addr().unwrap() ); loop { let stream = match listener.accept().await { Ok((stream, addr)) => { log::info!("Accepting new connection: {addr}"); stream } Err(e) => { log::error!("Accepting new connection failed: {e}"); continue; } }; tokio::spawn(async { let io = TokioIo::new(stream); let conn = http1::Builder::new().serve_connection(io, service_fn(handler)); // Enable upgrades on the connection for the websocket upgrades to work. let conn = conn.with_upgrades(); // Log any errors that might have occurred during the connection. if let Err(err) = conn.await { log::error!("HTTP connection failed {err}"); } }); } } /// Handle incoming HTTP Requests. async fn handler(req: Request) -> Result, BoxedError> { if is_upgrade_request(&req) { // Create a new handshake server. let mut server = Server::new(); // Add any extensions that we want to use. #[cfg(feature = "deflate")] { let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); server.add_extension(Box::new(deflate)); } // Attempt the handshake. match server.receive_request(&req) { // The handshake has been successful so far; return the response we're given back // and spawn a task to handle the long-running WebSocket server: Ok(response) => { tokio::spawn(async move { if let Err(e) = websocket_echo_messages(server, req).await { log::error!("Error upgrading to websocket connection: {}", e); } }); Ok(response.map(|()| FullBody::default())) } // We tried to upgrade and failed early on; tell the client about the failure however we like: Err(e) => { log::error!("Could not upgrade connection: {}", e); Ok(Response::new(FullBody::from("Something went wrong upgrading!"))) } } } else { // The request wasn't an upgrade request; let's treat it as a standard HTTP request: Ok(Response::new(FullBody::from("Hello HTTP!"))) } } /// Echo any messages we get from the client back to them async fn websocket_echo_messages(server: Server, req: Request) -> Result<(), BoxedError> { // The negotiation to upgrade to a WebSocket connection has been successful so far. Next, we get back the underlying // stream using `hyper::upgrade::on`, and hand this to a Soketto server to use to handle the WebSocket communication // on this socket. // // Note: awaiting this won't succeed until the handshake response has been returned to the client, so this must be // spawned on a separate task so as not to block that response being handed back. let stream = hyper::upgrade::on(req).await?; let io = TokioIo::new(stream); let stream = BufReader::new(BufWriter::new(io.compat())); // Get back a reader and writer that we can use to send and receive websocket messages. let (mut sender, mut receiver) = server.into_builder(stream).finish(); // Echo any received messages back to the client: let mut message = Vec::new(); loop { message.clear(); match receiver.receive_data(&mut message).await { Ok(soketto::Data::Binary(n)) => { assert_eq!(n, message.len()); sender.send_binary_mut(&mut message).await?; sender.flush().await? } Ok(soketto::Data::Text(n)) => { assert_eq!(n, message.len()); if let Ok(txt) = std::str::from_utf8(&message) { sender.send_text(txt).await?; sender.flush().await? } else { break; } } Err(soketto::connection::Error::Closed) => break, Err(e) => { eprintln!("Websocket connection error: {}", e); break; } } } Ok(()) } soketto-0.8.1/rustfmt.toml000066400000000000000000000001171472331330300156070ustar00rootroot00000000000000hard_tabs = true max_width = 120 use_small_heuristics = "Max" edition = "2018" soketto-0.8.1/src/000077500000000000000000000000001472331330300137765ustar00rootroot00000000000000soketto-0.8.1/src/base.rs000066400000000000000000000435301472331330300152630ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // Copyright (c) 2016 twist developers // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. // This file is largely based on the original twist implementation. // See [frame/base.rs] and [codec/base.rs]. // // [frame/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/frame/base.rs // [codec/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/codec/base.rs //! A websocket [base frame][base] codec. //! //! [base]: https://tools.ietf.org/html/rfc6455#section-5.2 use crate::{as_u64, Parsing}; use std::{fmt, io}; /// Max. size of a frame header. pub(crate) const MAX_HEADER_SIZE: usize = 14; /// Max. size of a control frame payload. pub(crate) const MAX_CTRL_BODY_SIZE: u64 = 125; // OpCode ///////////////////////////////////////////////////////////////////////////////////////// /// Operation codes defined in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.2). #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)] pub enum OpCode { /// A continuation frame of a fragmented message. Continue, /// A text data frame. Text, /// A binary data frame. Binary, /// A close control frame. Close, /// A ping control frame. Ping, /// A pong control frame. Pong, /// A reserved op code. Reserved3, /// A reserved op code. Reserved4, /// A reserved op code. Reserved5, /// A reserved op code. Reserved6, /// A reserved op code. Reserved7, /// A reserved op code. Reserved11, /// A reserved op code. Reserved12, /// A reserved op code. Reserved13, /// A reserved op code. Reserved14, /// A reserved op code. Reserved15, } impl OpCode { /// Is this a control opcode? pub fn is_control(self) -> bool { if let OpCode::Close | OpCode::Ping | OpCode::Pong = self { true } else { false } } /// Is this opcode reserved? pub fn is_reserved(self) -> bool { match self { OpCode::Reserved3 | OpCode::Reserved4 | OpCode::Reserved5 | OpCode::Reserved6 | OpCode::Reserved7 | OpCode::Reserved11 | OpCode::Reserved12 | OpCode::Reserved13 | OpCode::Reserved14 | OpCode::Reserved15 => true, _ => false, } } } impl fmt::Display for OpCode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { OpCode::Continue => f.write_str("Continue"), OpCode::Text => f.write_str("Text"), OpCode::Binary => f.write_str("Binary"), OpCode::Close => f.write_str("Close"), OpCode::Ping => f.write_str("Ping"), OpCode::Pong => f.write_str("Pong"), OpCode::Reserved3 => f.write_str("Reserved:3"), OpCode::Reserved4 => f.write_str("Reserved:4"), OpCode::Reserved5 => f.write_str("Reserved:5"), OpCode::Reserved6 => f.write_str("Reserved:6"), OpCode::Reserved7 => f.write_str("Reserved:7"), OpCode::Reserved11 => f.write_str("Reserved:11"), OpCode::Reserved12 => f.write_str("Reserved:12"), OpCode::Reserved13 => f.write_str("Reserved:13"), OpCode::Reserved14 => f.write_str("Reserved:14"), OpCode::Reserved15 => f.write_str("Reserved:15"), } } } /// Error returned by `OpCode::try_from` if an unknown opcode /// number is encountered. #[derive(Clone, Debug)] pub struct UnknownOpCode(()); impl fmt::Display for UnknownOpCode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("unknown opcode") } } impl std::error::Error for UnknownOpCode {} impl TryFrom for OpCode { type Error = UnknownOpCode; fn try_from(val: u8) -> Result { match val { 0 => Ok(OpCode::Continue), 1 => Ok(OpCode::Text), 2 => Ok(OpCode::Binary), 3 => Ok(OpCode::Reserved3), 4 => Ok(OpCode::Reserved4), 5 => Ok(OpCode::Reserved5), 6 => Ok(OpCode::Reserved6), 7 => Ok(OpCode::Reserved7), 8 => Ok(OpCode::Close), 9 => Ok(OpCode::Ping), 10 => Ok(OpCode::Pong), 11 => Ok(OpCode::Reserved11), 12 => Ok(OpCode::Reserved12), 13 => Ok(OpCode::Reserved13), 14 => Ok(OpCode::Reserved14), 15 => Ok(OpCode::Reserved15), _ => Err(UnknownOpCode(())), } } } impl From for u8 { fn from(opcode: OpCode) -> u8 { match opcode { OpCode::Continue => 0, OpCode::Text => 1, OpCode::Binary => 2, OpCode::Close => 8, OpCode::Ping => 9, OpCode::Pong => 10, OpCode::Reserved3 => 3, OpCode::Reserved4 => 4, OpCode::Reserved5 => 5, OpCode::Reserved6 => 6, OpCode::Reserved7 => 7, OpCode::Reserved11 => 11, OpCode::Reserved12 => 12, OpCode::Reserved13 => 13, OpCode::Reserved14 => 14, OpCode::Reserved15 => 15, } } } // Frame header /////////////////////////////////////////////////////////////////////////////////// /// A websocket base frame header, i.e. everything but the payload. #[derive(Debug, Clone)] pub struct Header { fin: bool, rsv1: bool, rsv2: bool, rsv3: bool, masked: bool, opcode: OpCode, mask: u32, payload_len: usize, } impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", self.opcode, self.fin as u8, self.rsv1 as u8, self.rsv2 as u8, self.rsv3 as u8, self.masked as u8, self.mask, self.payload_len ) } } impl Header { /// Create a new frame header with a given [`OpCode`]. pub fn new(oc: OpCode) -> Self { Header { fin: true, rsv1: false, rsv2: false, rsv3: false, masked: false, opcode: oc, mask: 0, payload_len: 0 } } /// Is the `fin` flag set? pub fn is_fin(&self) -> bool { self.fin } /// Set the `fin` flag. pub fn set_fin(&mut self, fin: bool) -> &mut Self { self.fin = fin; self } /// Is the `rsv1` flag set? pub fn is_rsv1(&self) -> bool { self.rsv1 } /// Set the `rsv1` flag. pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self { self.rsv1 = rsv1; self } /// Is the `rsv2` flag set? pub fn is_rsv2(&self) -> bool { self.rsv2 } /// Set the `rsv2` flag. pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self { self.rsv2 = rsv2; self } /// Is the `rsv3` flag set? pub fn is_rsv3(&self) -> bool { self.rsv3 } /// Set the `rsv3` flag. pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self { self.rsv3 = rsv3; self } /// Is the `masked` flag set? pub fn is_masked(&self) -> bool { self.masked } /// Set the `masked` flag. pub fn set_masked(&mut self, masked: bool) -> &mut Self { self.masked = masked; self } /// Get the `opcode`. pub fn opcode(&self) -> OpCode { self.opcode } /// Set the `opcode` pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self { self.opcode = opcode; self } /// Get the `mask`. pub fn mask(&self) -> u32 { self.mask } /// Set the `mask` pub fn set_mask(&mut self, mask: u32) -> &mut Self { self.mask = mask; self } /// Get the payload length. pub fn payload_len(&self) -> usize { self.payload_len } /// Set the payload length. pub fn set_payload_len(&mut self, len: usize) -> &mut Self { self.payload_len = len; self } } // Base codec ////////////////////////////////////////////////////////////////////////////////////. /// If the payload length byte is 126, the following two bytes represent the /// actual payload length. const TWO_EXT: u8 = 126; /// If the payload length byte is 127, the following eight bytes represent /// the actual payload length. const EIGHT_EXT: u8 = 127; /// Codec for encoding/decoding websocket [base] frames. /// /// [base]: https://tools.ietf.org/html/rfc6455#section-5.2 #[derive(Debug, Clone)] pub struct Codec { /// Maximum size of payload data per frame. max_data_size: usize, /// Bits reserved by an extension. reserved_bits: u8, /// Scratch buffer used during header encoding. header_buffer: [u8; MAX_HEADER_SIZE], } impl Default for Codec { fn default() -> Self { Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, header_buffer: [0; MAX_HEADER_SIZE] } } } impl Codec { /// Create a new base frame codec. /// /// The codec will support decoding payload lengths up to 256 MiB /// (use `set_max_data_size` to change this value). pub fn new() -> Self { Codec::default() } /// Get the configured maximum payload length. pub fn max_data_size(&self) -> usize { self.max_data_size } /// Limit the maximum size of payload data to `size` bytes. pub fn set_max_data_size(&mut self, size: usize) -> &mut Self { self.max_data_size = size; self } /// The reserved bits currently configured. pub fn reserved_bits(&self) -> (bool, bool, bool) { let r = self.reserved_bits; (r & 4 == 4, r & 2 == 2, r & 1 == 1) } /// Add to the reserved bits in use. pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self { let (r1, r2, r3) = bits; self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8; self } /// Reset the reserved bits. pub fn clear_reserved_bits(&mut self) { self.reserved_bits = 0 } /// Decode a websocket frame header. pub fn decode_header(&self, bytes: &[u8]) -> Result, Error> { if bytes.len() < 2 { return Ok(Parsing::NeedMore(2 - bytes.len())); } let first = bytes[0]; let second = bytes[1]; let mut offset = 2; let fin = first & 0x80 != 0; let opcode = OpCode::try_from(first & 0xF)?; if opcode.is_reserved() { return Err(Error::ReservedOpCode); } if opcode.is_control() && !fin { return Err(Error::FragmentedControl); } let mut header = Header::new(opcode); header.set_fin(fin); let rsv1 = first & 0x40 != 0; if rsv1 && (self.reserved_bits & 4 == 0) { return Err(Error::InvalidReservedBit(1)); } header.set_rsv1(rsv1); let rsv2 = first & 0x20 != 0; if rsv2 && (self.reserved_bits & 2 == 0) { return Err(Error::InvalidReservedBit(2)); } header.set_rsv2(rsv2); let rsv3 = first & 0x10 != 0; if rsv3 && (self.reserved_bits & 1 == 0) { return Err(Error::InvalidReservedBit(3)); } header.set_rsv3(rsv3); header.set_masked(second & 0x80 != 0); let len: u64 = match second & 0x7F { TWO_EXT => { if bytes.len() < offset + 2 { return Ok(Parsing::NeedMore(offset + 2 - bytes.len())); } let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]); offset += 2; u64::from(len) } EIGHT_EXT => { if bytes.len() < offset + 8 { return Ok(Parsing::NeedMore(offset + 8 - bytes.len())); } let mut b = [0; 8]; b.copy_from_slice(&bytes[offset..offset + 8]); offset += 8; u64::from_be_bytes(b) } n => u64::from(n), }; if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() { return Err(Error::InvalidControlFrameLen); } let len: usize = if len > as_u64(self.max_data_size) { return Err(Error::PayloadTooLarge { actual: len, maximum: as_u64(self.max_data_size) }); } else { len as usize }; header.set_payload_len(len); if header.is_masked() { if bytes.len() < offset + 4 { return Ok(Parsing::NeedMore(offset + 4 - bytes.len())); } let mut b = [0; 4]; b.copy_from_slice(&bytes[offset..offset + 4]); offset += 4; header.set_mask(u32::from_be_bytes(b)); } Ok(Parsing::Done { value: header, offset }) } /// Encode a websocket frame header. pub fn encode_header(&mut self, header: &Header) -> &[u8] { let mut offset = 0; let mut first_byte = 0_u8; if header.is_fin() { first_byte |= 0x80 } if header.is_rsv1() { first_byte |= 0x40 } if header.is_rsv2() { first_byte |= 0x20 } if header.is_rsv3() { first_byte |= 0x10 } let opcode: u8 = header.opcode().into(); first_byte |= opcode; self.header_buffer[offset] = first_byte; offset += 1; let mut second_byte = 0_u8; if header.is_masked() { second_byte |= 0x80 } let len = header.payload_len(); if len < usize::from(TWO_EXT) { second_byte |= len as u8; self.header_buffer[offset] = second_byte; offset += 1; } else if len <= usize::from(u16::max_value()) { second_byte |= TWO_EXT; self.header_buffer[offset] = second_byte; offset += 1; self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); offset += 2; } else { second_byte |= EIGHT_EXT; self.header_buffer[offset] = second_byte; offset += 1; self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); offset += 8; } if header.is_masked() { self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes()); offset += 4; } &self.header_buffer[..offset] } /// Use the given header's mask and apply it to the data. pub fn apply_mask(header: &Header, data: &mut [u8]) { if header.is_masked() { let mask = header.mask().to_be_bytes(); for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) { *byte ^= key; } } } } /// Error cases the base frame decoder may encounter. #[non_exhaustive] #[derive(Debug)] pub enum Error { /// An I/O error has been encountered. Io(io::Error), /// Some unknown opcode number has been decoded. UnknownOpCode, /// The opcode decoded is reserved. ReservedOpCode, /// A fragmented control frame (fin bit not set) has been decoded. FragmentedControl, /// A control frame with an invalid length code has been decoded. InvalidControlFrameLen, /// The reserved bit is invalid. InvalidReservedBit(u8), /// The payload length of a frame exceeded the configured maximum. PayloadTooLarge { actual: u64, maximum: u64 }, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Io(e) => write!(f, "i/o error: {}", e), Error::UnknownOpCode => f.write_str("unknown opcode"), Error::ReservedOpCode => f.write_str("reserved opcode"), Error::FragmentedControl => f.write_str("fragmented control frame"), Error::InvalidControlFrameLen => f.write_str("invalid control frame length"), Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n), Error::PayloadTooLarge { actual, maximum } => { write!(f, "payload too large: len = {}, maximum = {}", actual, maximum) } } } } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::Io(e) => Some(e), Error::UnknownOpCode | Error::ReservedOpCode | Error::FragmentedControl | Error::InvalidControlFrameLen | Error::InvalidReservedBit(_) | Error::PayloadTooLarge { .. } => None, } } } impl From for Error { fn from(e: io::Error) -> Self { Error::Io(e) } } impl From for Error { fn from(_: UnknownOpCode) -> Self { Error::UnknownOpCode } } // Tests ////////////////////////////////////////////////////////////////////////////////////////// #[cfg(test)] mod test { use super::{Codec, Error, OpCode}; use crate::Parsing; use quickcheck::QuickCheck; #[test] fn decode_partial_header() { let partial_header: &[u8] = &[0x89]; assert!(matches! { Codec::new().decode_header(partial_header), Ok(Parsing::NeedMore(1)) }) } #[test] fn decode_partial_len() { let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01]; assert!(matches! { Codec::new().decode_header(partial_length_1), Ok(Parsing::NeedMore(1)) }); let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04]; assert!(matches! { Codec::new().decode_header(partial_length_2), Ok(Parsing::NeedMore(4)) }) } #[test] fn decode_partial_mask() { let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00]; assert!(matches! { Codec::new().decode_header(partial_mask), Ok(Parsing::NeedMore(2)) }) } #[test] fn decode_partial_payload() { let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00]; if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) { assert_eq!(3, value.payload_len() - (partial_payload.len() - offset)) } else { assert!(false) } } #[test] fn decode_invalid_control_payload_len() { // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less. let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; assert!(matches! { Codec::new().decode_header(ctrl_payload_len), Err(Error::InvalidControlFrameLen) }) } /// Checking that rsv1, rsv2, and rsv3 bit set returns error. #[test] fn decode_reserved() { // rsv1, rsv2, and rsv3. let reserved = [0x90, 0xa0, 0xc0]; for res in &reserved { let mut buf = [0; 2]; buf[0] |= *res; assert!(matches! { Codec::new().decode_header(&buf), Err(Error::InvalidReservedBit(_)) }) } } /// Checking that a control frame, where fin bit is 0, returns an error. #[test] fn decode_fragmented_control() { let second_bytes = [8, 9, 10]; for sb in &second_bytes { let mut buf = [0; 2]; buf[0] |= *sb; assert!(matches! { Codec::new().decode_header(&buf), Err(Error::FragmentedControl) }) } } /// Checking that reserved opcodes return an error. #[test] fn decode_reserved_opcodes() { let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15]; for res in &reserved { let mut buf = [0; 2]; buf[0] |= 0x80 | *res; assert!(matches! { Codec::new().decode_header(&buf), Err(Error::ReservedOpCode) }) } } #[test] fn decode_ping_no_data() { let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01]; let c = Codec::new(); if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) { assert!(header.is_fin()); assert!(!header.is_rsv1()); assert!(!header.is_rsv2()); assert!(!header.is_rsv3()); assert!(header.opcode() == OpCode::Ping); assert!(header.payload_len() == 0) } else { assert!(false) } } #[test] fn reserved_bits() { fn property(bits: (bool, bool, bool)) -> bool { let mut c = Codec::new(); assert_eq!((false, false, false), c.reserved_bits()); c.add_reserved_bits(bits); bits == c.reserved_bits() } QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool) } } soketto-0.8.1/src/connection.rs000066400000000000000000000526151472331330300165140ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! A persistent websocket connection after the handshake phase, represented //! as a [`Sender`] and [`Receiver`] pair. use crate::data::{ByteSlice125, Data, Incoming}; use crate::{ base::{self, Header, OpCode, MAX_HEADER_SIZE}, extension::Extension, Parsing, Storage, }; use bytes::{Buf, BytesMut}; use futures::{ io::{ReadHalf, WriteHalf}, lock::BiLock, prelude::*, }; use std::{fmt, io, str}; /// Accumulated max. size of a complete message. const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; /// Max. size of a single message frame. const MAX_FRAME_SIZE: usize = MAX_MESSAGE_SIZE; /// Is the connection used by a client or server? #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Mode { /// Client-side of a connection (implies masking of payload data). Client, /// Server-side of a connection. Server, } impl Mode { pub fn is_client(self) -> bool { if let Mode::Client = self { true } else { false } } pub fn is_server(self) -> bool { !self.is_client() } } /// Connection ID. #[derive(Clone, Copy, Debug)] struct Id(u32); impl fmt::Display for Id { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:08x}", self.0) } } /// The sending half of a connection. #[derive(Debug)] pub struct Sender { id: Id, mode: Mode, codec: base::Codec, writer: BiLock>, mask_buffer: Vec, extensions: BiLock>>, has_extensions: bool, } /// The receiving half of a connection. #[derive(Debug)] pub struct Receiver { id: Id, mode: Mode, codec: base::Codec, reader: ReadHalf, writer: BiLock>, extensions: BiLock>>, has_extensions: bool, buffer: BytesMut, ctrl_buffer: BytesMut, max_message_size: usize, is_closed: bool, } /// A connection builder. /// /// Allows configuring certain parameters and extensions before /// creating the [`Sender`]/[`Receiver`] pair that represents the /// connection. #[derive(Debug)] pub struct Builder { id: Id, mode: Mode, socket: T, codec: base::Codec, extensions: Vec>, buffer: BytesMut, max_message_size: usize, } impl Builder { /// Create a new `Builder` from the given async I/O resource and mode. /// /// **Note**: Use this type only after a successful [handshake][0]. /// You can either use this crate's [handshake functionality][1] /// or perform the handshake by some other means. /// /// [0]: https://tools.ietf.org/html/rfc6455#section-4 /// [1]: crate::handshake pub fn new(socket: T, mode: Mode) -> Self { let mut codec = base::Codec::default(); codec.set_max_data_size(MAX_FRAME_SIZE); Builder { id: Id(rand::random()), mode, socket, codec, extensions: Vec::new(), buffer: BytesMut::new(), max_message_size: MAX_MESSAGE_SIZE, } } /// Set a custom buffer to use. pub fn set_buffer(&mut self, b: BytesMut) { self.buffer = b } /// Add extensions to use with this connection. /// /// Only enabled extensions will be considered. pub fn add_extensions(&mut self, extensions: I) where I: IntoIterator>, { for e in extensions.into_iter().filter(|e| e.is_enabled()) { log::debug!("{}: using extension: {}", self.id, e.name()); self.codec.add_reserved_bits(e.reserved_bits()); self.extensions.push(e) } } /// Set the maximum size of a complete message. /// /// Message fragments will be buffered and concatenated up to this value, /// i.e. the sum of all message frames payload lengths will not be greater /// than this maximum. However, extensions may increase the total message /// size further, e.g. by decompressing the payload data. pub fn set_max_message_size(&mut self, max: usize) { self.max_message_size = max } /// Set the maximum size of a single websocket frame payload. pub fn set_max_frame_size(&mut self, max: usize) { self.codec.set_max_data_size(max); } /// Create a configured [`Sender`]/[`Receiver`] pair. pub fn finish(self) -> (Sender, Receiver) { let (rhlf, whlf) = self.socket.split(); let (wrt1, wrt2) = BiLock::new(whlf); let has_extensions = !self.extensions.is_empty(); let (ext1, ext2) = BiLock::new(self.extensions); let recv = Receiver { id: self.id, mode: self.mode, reader: rhlf, writer: wrt1, codec: self.codec.clone(), extensions: ext1, has_extensions, buffer: self.buffer, ctrl_buffer: BytesMut::new(), max_message_size: self.max_message_size, is_closed: false, }; let send = Sender { id: self.id, mode: self.mode, writer: wrt2, mask_buffer: Vec::new(), codec: self.codec, extensions: ext2, has_extensions, }; (send, recv) } } impl Receiver { /// Receive the next websocket message. /// /// The received frames forming the complete message will be appended to /// the given `message` argument. The returned [`Incoming`] value describes /// the type of data that was received, e.g. binary or textual data. /// /// Interleaved PONG frames are returned immediately as `Data::Pong` /// values. If PONGs are not expected or uninteresting, /// [`Receiver::receive_data`] may be used instead which skips over PONGs /// and considers only application payload data. pub async fn receive(&mut self, message: &mut Vec) -> Result, Error> { let mut first_fragment_opcode = None; let mut length: usize = 0; let message_len = message.len(); loop { if self.is_closed { log::debug!("{}: cannot receive, connection is closed", self.id); return Err(Error::Closed); } self.ctrl_buffer.clear(); let mut header = self.receive_header().await?; log::trace!("{}: recv: {}", self.id, header); // Handle control frames: PING, PONG and CLOSE. if header.opcode().is_control() { self.read_buffer(&header).await?; self.ctrl_buffer = self.buffer.split_to(header.payload_len()); base::Codec::apply_mask(&header, &mut self.ctrl_buffer); if header.opcode() == OpCode::Pong { return Ok(Incoming::Pong(&self.ctrl_buffer[..])); } if let Some(close_reason) = self.on_control(&header).await? { log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason); return Ok(Incoming::Closed(close_reason)); } continue; } length = length.saturating_add(header.payload_len()); // Check if total message does not exceed maximum. if length > self.max_message_size { log::warn!("{}: accumulated message length exceeds maximum", self.id); // Discard bytes that were too large to fit in the buffer. discard_bytes(length as u64, &mut self.reader).await?; return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size }); } // Get the frame's payload data bytes from buffer or socket. { let old_msg_len = message.len(); let bytes_to_read = { let required = header.payload_len(); let buffered = self.buffer.len(); if buffered == 0 { required } else if required > buffered { message.extend_from_slice(&self.buffer); self.buffer.clear(); required - buffered } else { message.extend_from_slice(&self.buffer.split_to(required)); 0 } }; if bytes_to_read > 0 { let n = message.len(); message.resize(n + bytes_to_read, 0u8); self.reader.read_exact(&mut message[n..]).await? } debug_assert_eq!(header.payload_len(), message.len() - old_msg_len); base::Codec::apply_mask(&header, &mut message[old_msg_len..]); } match (header.is_fin(), header.opcode()) { (false, OpCode::Continue) => { // Intermediate message fragment. if first_fragment_opcode.is_none() { log::debug!("{}: continue frame while not processing message fragments", self.id); return Err(Error::UnexpectedOpCode(OpCode::Continue)); } continue; } (false, oc) => { // Initial message fragment. if first_fragment_opcode.is_some() { log::debug!("{}: initial fragment while processing a fragmented message", self.id); return Err(Error::UnexpectedOpCode(oc)); } first_fragment_opcode = Some(oc); self.decode_with_extensions(&mut header, message).await?; continue; } (true, OpCode::Continue) => { // Last message fragment. if let Some(oc) = first_fragment_opcode.take() { header.set_payload_len(message.len()); log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len()); self.decode_with_extensions(&mut header, message).await?; header.set_opcode(oc); } else { log::debug!("{}: last continue frame while not processing message fragments", self.id); return Err(Error::UnexpectedOpCode(OpCode::Continue)); } } (true, oc) => { // Regular non-fragmented message. if first_fragment_opcode.is_some() { log::debug!("{}: regular message while processing fragmented message", self.id); return Err(Error::UnexpectedOpCode(oc)); } self.decode_with_extensions(&mut header, message).await? } } let num_bytes = message.len() - message_len; if header.opcode() == OpCode::Text { return Ok(Incoming::Data(Data::Text(num_bytes))); } else { return Ok(Incoming::Data(Data::Binary(num_bytes))); } } } /// Receive the next websocket message, skipping over control frames. pub async fn receive_data(&mut self, message: &mut Vec) -> Result { loop { if let Incoming::Data(d) = self.receive(message).await? { return Ok(d); } } } /// Read the next frame header. async fn receive_header(&mut self) -> Result { loop { match self.codec.decode_header(&self.buffer)? { Parsing::Done { value: header, offset } => { debug_assert!(offset <= MAX_HEADER_SIZE); self.buffer.advance(offset); return Ok(header); } Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?, } } } /// Read the complete payload data into the read buffer. async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> { if header.payload_len() <= self.buffer.len() { return Ok(()); } let i = self.buffer.len(); let d = header.payload_len() - i; self.buffer.resize(i + d, 0u8); self.reader.read_exact(&mut self.buffer[i..]).await?; Ok(()) } /// Answer incoming control frames. /// `PING`: replied to immediately with a `PONG` /// `PONG`: no action /// `CLOSE`: replied to immediately with a `CLOSE`; returns the [`CloseReason`] /// All other [`OpCode`]s return [`Error::UnexpectedOpCode`] async fn on_control(&mut self, header: &Header) -> Result, Error> { match header.opcode() { OpCode::Ping => { let mut answer = Header::new(OpCode::Pong); let mut unused = Vec::new(); let mut data = Storage::Unique(&mut self.ctrl_buffer); write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused) .await?; self.flush().await?; Ok(None) } OpCode::Pong => Ok(None), OpCode::Close => { log::trace!("{}: Acknowledging CLOSE to sender", self.id); let (mut header, reason) = close_answer(&self.ctrl_buffer)?; // Write back a Close frame let mut unused = Vec::new(); if let Some(CloseReason { code, .. }) = reason { let mut data = code.to_be_bytes(); let mut data = Storage::Unique(&mut data); let _ = write( self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused, ) .await; } else { let mut data = Storage::Unique(&mut []); let _ = write( self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused, ) .await; } self.flush().await?; // Close down the connection but the I/O stream could already be closed and // we don't want propagate such error to the user if the I/O was already closed. _ = self.writer.lock().await.close().await; self.is_closed = true; Ok(reason) } OpCode::Binary | OpCode::Text | OpCode::Continue | OpCode::Reserved3 | OpCode::Reserved4 | OpCode::Reserved5 | OpCode::Reserved6 | OpCode::Reserved7 | OpCode::Reserved11 | OpCode::Reserved12 | OpCode::Reserved13 | OpCode::Reserved14 | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())), } } /// Apply all extensions to the given header and the internal message buffer. async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec) -> Result<(), Error> { if !self.has_extensions { return Ok(()); } for e in self.extensions.lock().await.iter_mut() { log::trace!("{}: decoding with extension: {}", self.id, e.name()); e.decode(header, message).map_err(Error::Extension)? } Ok(()) } /// Flush the socket buffer. async fn flush(&mut self) -> Result<(), Error> { log::trace!("{}: Receiver flushing connection", self.id); if self.is_closed { return Ok(()); } self.writer.lock().await.flush().await.or(Err(Error::Closed)) } } impl Sender { /// Send a text value over the websocket connection. pub async fn send_text(&mut self, data: impl AsRef) -> Result<(), Error> { let mut header = Header::new(OpCode::Text); self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await } /// Send a text value over the websocket connection. /// /// This method performs one copy fewer than [`Sender::send_text`]. pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> { let mut header = Header::new(OpCode::Text); self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await } /// Send some binary data over the websocket connection. pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> { let mut header = Header::new(OpCode::Binary); self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await } /// Send some binary data over the websocket connection. /// /// This method performs one copy fewer than [`Sender::send_binary`]. /// The `data` buffer may be modified by this method, e.g. if masking is necessary. pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> { let mut header = Header::new(OpCode::Binary); self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await } /// Ping the remote end. pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { let mut header = Header::new(OpCode::Ping); self.write(&mut header, &mut Storage::Shared(data.as_ref())).await } /// Send an unsolicited Pong to the remote. pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { let mut header = Header::new(OpCode::Pong); self.write(&mut header, &mut Storage::Shared(data.as_ref())).await } /// Flush the socket buffer. pub async fn flush(&mut self) -> Result<(), Error> { log::trace!("{}: Sender flushing connection", self.id); self.writer.lock().await.flush().await.or(Err(Error::Closed)) } /// Send a close message and close the connection. pub async fn close(&mut self) -> Result<(), Error> { log::trace!("{}: closing connection", self.id); let mut header = Header::new(OpCode::Close); let code = 1000_u16.to_be_bytes(); // 1000 = normal closure self.write(&mut header, &mut Storage::Shared(&code[..])).await?; self.flush().await?; self.writer.lock().await.close().await.or(Err(Error::Closed)) } /// Send arbitrary websocket frames. /// /// Before sending, extensions will be applied to header and payload data. async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { if !self.has_extensions { return self.write(header, data).await; } for e in self.extensions.lock().await.iter_mut() { log::trace!("{}: encoding with extension: {}", self.id, e.name()); e.encode(header, data).map_err(Error::Extension)? } self.write(header, data).await } /// Write final header and payload data to socket. /// /// The data will be masked if necessary. /// No extensions will be applied to header and payload data. async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await } } /// Write header and payload data to socket. async fn write( id: Id, mode: Mode, codec: &mut base::Codec, writer: &mut BiLock>, header: &mut Header, data: &mut Storage<'_>, mask_buffer: &mut Vec, ) -> Result<(), Error> { if mode.is_client() { header.set_masked(true); header.set_mask(rand::random()); } header.set_payload_len(data.as_ref().len()); log::trace!("{}: send: {}", id, header); let header_bytes = codec.encode_header(&header); let mut w = writer.lock().await; w.write_all(&header_bytes).await.or(Err(Error::Closed))?; if !header.is_masked() { return w.write_all(data.as_ref()).await.or(Err(Error::Closed)); } match data { Storage::Shared(slice) => { mask_buffer.clear(); mask_buffer.extend_from_slice(slice); base::Codec::apply_mask(header, mask_buffer); w.write_all(mask_buffer).await.or(Err(Error::Closed)) } Storage::Unique(slice) => { base::Codec::apply_mask(header, slice); w.write_all(slice).await.or(Err(Error::Closed)) } Storage::Owned(ref mut bytes) => { base::Codec::apply_mask(header, bytes); w.write_all(bytes).await.or(Err(Error::Closed)) } } } /// Create a close frame based on the given data. The close frame is echoed back /// to the sender. fn close_answer(data: &[u8]) -> Result<(Header, Option), Error> { let answer = Header::new(OpCode::Close); if data.len() < 2 { return Ok((answer, None)); } // Check that the reason string is properly encoded let descr = std::str::from_utf8(&data[2..])?.into(); let code = u16::from_be_bytes([data[0], data[1]]); let reason = CloseReason { code, descr: Some(descr) }; // Status codes are defined in // https://tools.ietf.org/html/rfc6455#section-7.4.1 and // https://mailarchive.ietf.org/arch/msg/hybi/P_1vbD9uyHl63nbIIbFxKMfSwcM/ match code { | 1000 ..= 1003 | 1007 ..= 1011 | 1012 // Service Restart | 1013 // Try Again Later | 1015 | 3000 ..= 4999 => Ok((answer, Some(reason))), // acceptable codes _ => { // invalid code => protocol error (1002) Ok((answer, Some(CloseReason { code: 1002, descr: None}))) } } } /// Errors which may occur when sending or receiving messages. #[non_exhaustive] #[derive(Debug)] pub enum Error { /// An I/O error was encountered. Io(io::Error), /// The base codec errored. Codec(base::Error), /// An extension produced an error while encoding or decoding. Extension(crate::BoxedError), /// An unexpected opcode was encountered. UnexpectedOpCode(OpCode), /// A close reason was not correctly UTF-8 encoded. Utf8(str::Utf8Error), /// The total message payload data size exceeds the configured maximum. MessageTooLarge { current: usize, maximum: usize }, /// The connection is closed. Closed, } /// Reason for closing the connection. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct CloseReason { pub code: u16, pub descr: Option, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Io(e) => write!(f, "i/o error: {}", e), Error::Codec(e) => write!(f, "codec error: {}", e), Error::Extension(e) => write!(f, "extension error: {}", e), Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c), Error::Utf8(e) => write!(f, "utf-8 error: {}", e), Error::MessageTooLarge { current, maximum } => { write!(f, "message too large: len >= {}, maximum = {}", current, maximum) } Error::Closed => f.write_str("connection closed"), } } } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::Io(e) => Some(e), Error::Codec(e) => Some(e), Error::Extension(e) => Some(&**e), Error::Utf8(e) => Some(e), Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None, } } } impl From for Error { fn from(e: io::Error) -> Self { if e.kind() == io::ErrorKind::UnexpectedEof { Error::Closed } else { Error::Io(e) } } } impl From for Error { fn from(e: str::Utf8Error) -> Self { Error::Utf8(e) } } impl From for Error { fn from(e: base::Error) -> Self { Error::Codec(e) } } /// Discard `n` bytes from the underlying reader. async fn discard_bytes(n: u64, reader: R) -> Result { futures::io::copy(&mut reader.take(n), &mut futures::io::sink()).await } #[cfg(test)] mod tests { use super::discard_bytes; use futures::{io::Cursor, AsyncReadExt}; #[tokio::test] async fn discard_bytes_works() { let bytes: Vec = (0..5).collect(); let mut cursor = Cursor::new(bytes); discard_bytes(1_u64, &mut cursor).await.unwrap(); let mut read = vec![0; 4]; cursor.read_exact(&mut read).await.unwrap(); assert_eq!(read, vec![1, 2, 3, 4]); } } soketto-0.8.1/src/data.rs000066400000000000000000000052211472331330300152550ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Types describing various forms of payload data. use std::fmt; use crate::connection::CloseReason; /// Data received from the remote end. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Incoming<'a> { /// Text or binary data. Data(Data), /// Data sent with a PONG control frame. Pong(&'a [u8]), /// The other end closed the connection. Closed(CloseReason), } impl Incoming<'_> { /// Is this text or binary data? pub fn is_data(&self) -> bool { if let Incoming::Data(_) = self { true } else { false } } /// Is this a PONG? pub fn is_pong(&self) -> bool { if let Incoming::Pong(_) = self { true } else { false } } /// Is this text data? pub fn is_text(&self) -> bool { if let Incoming::Data(d) = self { d.is_text() } else { false } } /// Is this binary data? pub fn is_binary(&self) -> bool { if let Incoming::Data(d) = self { d.is_binary() } else { false } } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Data { /// Textual data (number of bytes). Text(usize), /// Binary data (number of bytes). Binary(usize), } impl Data { /// Is this text data? pub fn is_text(&self) -> bool { if let Data::Text(_) = self { true } else { false } } /// Is this binary data? pub fn is_binary(&self) -> bool { if let Data::Binary(_) = self { true } else { false } } /// The length of data (number of bytes). pub fn len(&self) -> usize { match self { Data::Text(n) => *n, Data::Binary(n) => *n, } } } /// Wrapper type which restricts the length of its byte slice to 125 bytes. #[derive(Debug)] pub struct ByteSlice125<'a>(&'a [u8]); /// Error, if converting to [`ByteSlice125`] fails. #[derive(Clone, Debug)] pub struct SliceTooLarge(()); impl fmt::Display for SliceTooLarge { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("Slice larger than 125 bytes") } } impl std::error::Error for SliceTooLarge {} impl<'a> TryFrom<&'a [u8]> for ByteSlice125<'a> { type Error = SliceTooLarge; fn try_from(value: &'a [u8]) -> Result { if value.len() > 125 { Err(SliceTooLarge(())) } else { Ok(ByteSlice125(value)) } } } impl AsRef<[u8]> for ByteSlice125<'_> { fn as_ref(&self) -> &[u8] { self.0 } } soketto-0.8.1/src/extension.rs000066400000000000000000000113561472331330300163660ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // Copyright (c) 2016 twist developers // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Websocket extensions as per [RFC 6455][rfc6455]. //! //! [rfc6455]: https://tools.ietf.org/html/rfc6455#section-9 #[cfg(feature = "deflate")] pub mod deflate; use crate::{base::Header, BoxedError, Storage}; use std::{borrow::Cow, fmt}; /// A websocket extension as per RFC 6455, section 9. /// /// Extensions are invoked during handshake and subsequently during base /// frame encoding and decoding. The invocation during handshake differs /// on client and server side. /// /// # Server /// /// 1. All extensions should consider themselves as disabled but available. /// 2. When receiving a handshake request from a client, for each extension /// with a matching name, [`Extension::configure`] will be applied to the /// request parameters. The extension may internally enable itself. /// 3. When sending back the response, for each extension whose /// [`Extension::is_enabled`] returns true, the extension name and its /// parameters (as returned by [`Extension::params`]) will be included in the /// response. /// /// # Client /// /// 1. All extensions should consider themselves as disabled but available. /// 2. When creating the handshake request, all extensions and its parameters /// (as returned by [`Extension::params`]) will be included in the request. /// 3. When receiving the response from the server, for every extension with /// a matching name in the response, [`Extension::configure`] will be applied /// to the response parameters. The extension may internally enable itself. /// /// After this handshake phase, extensions have been configured and are /// potentially enabled. Enabled extensions can then be used for further base /// frame processing. pub trait Extension: std::fmt::Debug { /// Is this extension enabled? fn is_enabled(&self) -> bool; /// The name of this extension. fn name(&self) -> &str; /// The parameters this extension wants to send for negotiation. fn params(&self) -> &[Param]; /// Configure this extension with the parameters received from negotiation. fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError>; /// Encode a frame, given as frame header and payload data. fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError>; /// Decode a frame. /// /// The frame header is given, as well as the accumulated payload data, i.e. /// the concatenated payload data of all message fragments. fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError>; /// The reserved bits this extension uses. fn reserved_bits(&self) -> (bool, bool, bool) { (false, false, false) } } impl Extension for Box { fn is_enabled(&self) -> bool { (**self).is_enabled() } fn name(&self) -> &str { (**self).name() } fn params(&self) -> &[Param] { (**self).params() } fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { (**self).configure(params) } fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { (**self).encode(header, data) } fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { (**self).decode(header, data) } fn reserved_bits(&self) -> (bool, bool, bool) { (**self).reserved_bits() } } /// Extension parameter (used for negotiation). #[derive(Debug, Clone, PartialEq, Eq)] pub struct Param<'a> { name: Cow<'a, str>, value: Option>, } impl<'a> fmt::Display for Param<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(v) = &self.value { write!(f, "{} = {}", self.name, v) } else { write!(f, "{}", self.name) } } } impl<'a> Param<'a> { /// Create a new parameter with the given name. pub fn new(name: impl Into>) -> Self { Param { name: name.into(), value: None } } /// Access the parameter name. pub fn name(&self) -> &str { &self.name } /// Access the optional parameter value. pub fn value(&self) -> Option<&str> { self.value.as_ref().map(|v| v.as_ref()) } /// Set the parameter to the given value. pub fn set_value(&mut self, value: Option>>) -> &mut Self { self.value = value.map(Into::into); self } /// Turn this parameter into one that owns its values. pub fn acquire(self) -> Param<'static> { Param { name: Cow::Owned(self.name.into_owned()), value: self.value.map(|v| Cow::Owned(v.into_owned())) } } } soketto-0.8.1/src/extension/000077500000000000000000000000001472331330300160125ustar00rootroot00000000000000soketto-0.8.1/src/extension/deflate.rs000066400000000000000000000243171472331330300177730ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Deflate compression extension mostly conformant with [RFC 7692][rfc7692]. //! //! [rfc7692]: https://tools.ietf.org/html/rfc7692 use crate::{ as_u64, base::{Header, OpCode}, connection::Mode, extension::{Extension, Param}, BoxedError, Storage, }; use flate2::{write::DeflateDecoder, Compress, Compression, FlushCompress, Status}; use std::{ convert::TryInto, io::{self, Write}, mem, }; const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover"; const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits"; const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover"; const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits"; /// The deflate extension type. /// /// The extension does currently not support max. window bits other than the /// default, which is 15 and will ask for no context takeover during handshake. #[derive(Debug)] pub struct Deflate { mode: Mode, enabled: bool, buffer: Vec, params: Vec>, our_max_window_bits: u8, their_max_window_bits: u8, await_last_fragment: bool, } impl Deflate { /// Create a new deflate extension either on client or server side. pub fn new(mode: Mode) -> Self { let params = match mode { Mode::Server => Vec::new(), Mode::Client => { let mut params = Vec::new(); params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); params.push(Param::new(CLIENT_MAX_WINDOW_BITS)); params } }; Deflate { mode, enabled: false, buffer: Vec::new(), params, our_max_window_bits: 15, their_max_window_bits: 15, await_last_fragment: false, } } /// Set the server's max. window bits. /// /// The value must be within 9 ..= 15. /// The extension must be in client mode. /// /// By including this parameter, a client limits the LZ77 sliding window /// size that the server will use to compress messages. A server accepts /// by including the "server_max_window_bits" extension parameter in the /// response with the same or smaller value as the offer. pub fn set_max_server_window_bits(&mut self, max: u8) { assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode"); assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15"); self.their_max_window_bits = max; // upper bound of the server's window let mut p = Param::new(SERVER_MAX_WINDOW_BITS); p.set_value(Some(max.to_string())); self.params.push(p) } /// Set the client's max. window bits. /// /// The value must be within 9 ..= 15. /// The extension must be in client mode. /// /// The parameter informs the server that even if it doesn't include the /// "client_max_window_bits" extension parameter in the response with a /// value greater than the one in the negotiation offer or if it doesn't /// include the extension parameter at all, the client is not going to /// use an LZ77 sliding window size greater than one given here. /// The server may also respond with a smaller value which allows the client /// to reduce its sliding window even more. pub fn set_max_client_window_bits(&mut self, max: u8) { assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode"); assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15"); self.our_max_window_bits = max; // upper bound of the client's window if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) { p.set_value(Some(max.to_string())); } else { let mut p = Param::new(CLIENT_MAX_WINDOW_BITS); p.set_value(Some(max.to_string())); self.params.push(p) } } fn set_their_max_window_bits(&mut self, p: &Param, expected: Option) -> Result<(), ()> { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { if v < 8 || v > 15 { log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v); return Err(()); } if let Some(x) = expected { if v > x { log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x); return Err(()); } } self.their_max_window_bits = std::cmp::max(9, v); } Ok(()) } } impl Extension for Deflate { fn name(&self) -> &str { "permessage-deflate" } fn is_enabled(&self) -> bool { self.enabled } fn params(&self) -> &[Param] { &self.params } fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { match self.mode { Mode::Server => { self.params.clear(); for p in params { log::trace!("configure server with: {}", p); match p.name() { CLIENT_MAX_WINDOW_BITS => { if self.set_their_max_window_bits(&p, None).is_err() { // we just accept the client's offer as is => no need to reply return Ok(()); } } SERVER_MAX_WINDOW_BITS => { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { // The RFC allows 8 to 15 bits, but due to zlib limitations we // only support 9 to 15. if v < 9 || v > 15 { log::debug!("unacceptable server_max_window_bits: {}", v); return Ok(()); } let mut x = Param::new(SERVER_MAX_WINDOW_BITS); x.set_value(Some(v.to_string())); self.params.push(x); self.our_max_window_bits = v; } else { log::debug!("invalid server_max_window_bits: {:?}", p.value()); return Ok(()); } } CLIENT_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)), SERVER_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)), _ => { log::debug!("{}: unknown parameter: {}", self.name(), p.name()); return Ok(()); } } } } Mode::Client => { let mut server_no_context_takeover = false; for p in params { log::trace!("configure client with: {}", p); match p.name() { SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true, CLIENT_NO_CONTEXT_TAKEOVER => {} // must be supported SERVER_MAX_WINDOW_BITS => { let expected = Some(self.their_max_window_bits); if self.set_their_max_window_bits(&p, expected).is_err() { return Ok(()); } } CLIENT_MAX_WINDOW_BITS => { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { if v < 8 || v > 15 { log::debug!("unacceptable client_max_window_bits: {}", v); return Ok(()); } use std::cmp::{max, min}; // Due to zlib limitations we have to use 9 as a lower bound // here, even if the server allowed us to go down to 8 bits. self.our_max_window_bits = min(self.our_max_window_bits, max(9, v)); } } _ => { log::debug!("{}: unknown parameter: {}", self.name(), p.name()); return Ok(()); } } } if !server_no_context_takeover { log::debug!("{}: server did not confirm no context takeover", self.name()); return Ok(()); } } } self.enabled = true; Ok(()) } fn reserved_bits(&self) -> (bool, bool, bool) { (true, false, false) } fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { if data.is_empty() { return Ok(()); } match header.opcode() { OpCode::Binary | OpCode::Text if header.is_rsv1() => { if !header.is_fin() { self.await_last_fragment = true; log::trace!("deflate: not decoding {}; awaiting last fragment", header); return Ok(()); } log::trace!("deflate: decoding {}", header) } OpCode::Continue if header.is_fin() && self.await_last_fragment => { self.await_last_fragment = false; log::trace!("deflate: decoding {}", header) } _ => { log::trace!("deflate: not decoding {}", header); return Ok(()); } } // Restore LEN and NLEN: data.extend_from_slice(&[0, 0, 0xFF, 0xFF]); // cf. RFC 7692, 7.2.2 self.buffer.clear(); let mut decoder = DeflateDecoder::new(&mut self.buffer); decoder.write_all(&data)?; decoder.finish()?; mem::swap(data, &mut self.buffer); header.set_rsv1(false); header.set_payload_len(data.len()); Ok(()) } fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { if data.as_ref().is_empty() { return Ok(()); } if let OpCode::Binary | OpCode::Text = header.opcode() { log::trace!("deflate: encoding {}", header) } else { log::trace!("deflate: not encoding {}", header); return Ok(()); } self.buffer.clear(); self.buffer.reserve(data.as_ref().len()); let mut encoder = Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits); // Compress all input bytes. while encoder.total_in() < as_u64(data.as_ref().len()) { let i: usize = encoder.total_in().try_into()?; match encoder.compress_vec(&data.as_ref()[i..], &mut self.buffer, FlushCompress::None)? { Status::BufError => self.buffer.reserve(4096), Status::Ok => continue, Status::StreamEnd => break, } } // We need to append an empty deflate block if not there yet (RFC 7692, 7.2.1). while !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { self.buffer.reserve(5); // Make sure there is room for the trailing end bytes. match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { Status::Ok => continue, Status::BufError => continue, // more capacity is reserved above Status::StreamEnd => break, } } // If we still have not seen the empty deflate block appended, something is wrong. if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { log::error!("missing 00 00 FF FF"); return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()); } self.buffer.truncate(self.buffer.len() - 4); // Remove 00 00 FF FF; cf. RFC 7692, 7.2.1 if let Storage::Owned(d) = data { mem::swap(d, &mut self.buffer) } else { *data = Storage::Owned(mem::take(&mut self.buffer)) } header.set_rsv1(true); header.set_payload_len(data.as_ref().len()); Ok(()) } } soketto-0.8.1/src/handshake.rs000066400000000000000000000232621472331330300162770ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Websocket [handshake]s. //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 pub mod client; #[cfg(feature = "http")] pub mod http; pub mod server; use crate::extension::{Extension, Param}; use base64::Engine; use bytes::BytesMut; use sha1::{Digest, Sha1}; use std::{fmt, io, str}; pub use client::{Client, ServerResponse}; pub use server::{ClientRequest, Server}; // Defined in RFC 6455 and used to generate the `Sec-WebSocket-Accept` header // in the server handshake response. const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // How many HTTP headers do we support during parsing? const MAX_NUM_HEADERS: usize = 32; // Some HTTP headers we need to check during parsing. const SEC_WEBSOCKET_EXTENSIONS: &str = "Sec-WebSocket-Extensions"; const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol"; /// Check a set of headers contains a specific one. fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> { enum State { Init, // Start state Name, // Header name found Match, // Header value matches } headers .iter() .filter(|h| h.name.eq_ignore_ascii_case(name)) .fold(Ok(State::Init), |result, header| { if let Ok(State::Match) = result { return result; } if str::from_utf8(header.value)?.split(',').any(|v| v.trim().eq_ignore_ascii_case(ours)) { return Ok(State::Match); } Ok(State::Name) }) .and_then(|state| match state { State::Init => Err(Error::HeaderNotFound(name.into())), State::Name => Err(Error::UnexpectedHeader(name.into())), State::Match => Ok(()), }) } /// Pick the first header with the given name and apply the given closure to it. fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result where F: Fn(&'a [u8]) -> Result, { if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) { f(h.value) } else { Err(Error::HeaderNotFound(name.into())) } } // Configure all extensions with parsed parameters. fn configure_extensions(extensions: &mut [Box], line: &str) -> Result<(), Error> { for e in line.split(',') { let mut ext_parts = e.split(';'); if let Some(name) = ext_parts.next() { let name = name.trim(); if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) { let mut params = Vec::new(); for p in ext_parts { let mut key_value = p.split('='); if let Some(key) = key_value.next().map(str::trim) { let val = key_value.next().map(|v| v.trim().trim_matches('"')); let mut p = Param::new(key); p.set_value(val); params.push(p) } } ext.configure(¶ms).map_err(Error::Extension)? } } } Ok(()) } // Write all extensions to the given buffer. fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut) where I: IntoIterator>, { let mut iter = extensions.into_iter().peekable(); if iter.peek().is_some() { bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") } append_extension_header_value(iter, bytes) } // Write the extension header value to the given buffer. fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable, bytes: &mut BytesMut) where I: Iterator>, { while let Some(e) = extensions_iter.next() { bytes.extend_from_slice(e.name().as_bytes()); for p in e.params() { bytes.extend_from_slice(b"; "); bytes.extend_from_slice(p.name().as_bytes()); if let Some(v) = p.value() { bytes.extend_from_slice(b"="); bytes.extend_from_slice(v.as_bytes()) } } if extensions_iter.peek().is_some() { bytes.extend_from_slice(b", ") } } } // This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via // the `Sec-WebSocket-Key` header during a websocket handshake, and writes the response that's expected // to be handed back in the response header `Sec-WebSocket-Accept`. // // The response is a base64 encoding of a 160bit hash. base64 encoding uses 1 ascii character per 6 bits, // and 160 / 6 = 26.66 characters. The output is padded with '=' to the nearest 4 characters, so we need 28 // bytes in total for all of the characters. // // See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this. fn generate_accept_key<'k>(key_base64: &WebSocketKey) -> [u8; 28] { let mut digest = Sha1::new(); digest.update(key_base64); digest.update(KEY); let d = digest.finalize(); let mut output_buf = [0; 28]; let n = base64::engine::general_purpose::STANDARD .encode_slice(d, &mut output_buf) .expect("encoding to base64 is exactly 28 bytes; qed"); debug_assert_eq!(n, 28, "encoding to base64 should be exactly 28 bytes"); output_buf } /// Enumeration of possible handshake errors. #[non_exhaustive] #[derive(Debug)] pub enum Error { /// An I/O error has been encountered. Io(io::Error), /// An HTTP version =/= 1.1 was encountered. UnsupportedHttpVersion, /// An incomplete HTTP request. IncompleteHttpRequest, /// The value of the `Sec-WebSocket-Key` header is of unexpected length. SecWebSocketKeyInvalidLength(usize), /// The handshake request was not a GET request. InvalidRequestMethod, /// An HTTP header has not been present. HeaderNotFound(String), /// An HTTP header value was not expected. UnexpectedHeader(String), /// The Sec-WebSocket-Accept header value did not match. InvalidSecWebSocketAccept, /// The server returned an extension we did not ask for. UnsolicitedExtension, /// The server returned a protocol we did not ask for. UnsolicitedProtocol, /// An extension produced an error while encoding or decoding. Extension(crate::BoxedError), /// The HTTP entity could not be parsed successfully. Http(crate::BoxedError), /// UTF-8 decoding failed. Utf8(str::Utf8Error), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Io(e) => write!(f, "i/o error: {}", e), Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"), Error::IncompleteHttpRequest => f.write_str("http request was incomplete"), Error::SecWebSocketKeyInvalidLength(len) => { write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len) } Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"), Error::HeaderNotFound(name) => write!(f, "header {} not found", name), Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name), Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"), Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"), Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"), Error::Extension(e) => write!(f, "extension error: {}", e), Error::Http(e) => write!(f, "http parser error: {}", e), Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e), } } } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::Io(e) => Some(e), Error::Extension(e) => Some(&**e), Error::Http(e) => Some(&**e), Error::Utf8(e) => Some(e), Error::UnsupportedHttpVersion | Error::IncompleteHttpRequest | Error::SecWebSocketKeyInvalidLength(_) | Error::InvalidRequestMethod | Error::HeaderNotFound(_) | Error::UnexpectedHeader(_) | Error::InvalidSecWebSocketAccept | Error::UnsolicitedExtension | Error::UnsolicitedProtocol => None, } } } impl From for Error { fn from(e: io::Error) -> Self { Error::Io(e) } } impl From for Error { fn from(e: str::Utf8Error) -> Self { Error::Utf8(e) } } /// Owned value of the `Sec-WebSocket-Key` header. /// /// Per [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1): /// /// ```text /// (...) The value of this header field MUST be a /// nonce consisting of a randomly selected 16-byte value that has /// been base64-encoded (see Section 4 of [RFC4648]). (...) /// ``` /// /// Base64 encoding of the nonce produces 24 ASCII bytes, padding included. pub type WebSocketKey = [u8; 24]; #[cfg(test)] mod tests { use super::expect_ascii_header; #[test] fn header_match() { let headers = &[ httparse::Header { name: "foo", value: b"a,b,c,d" }, httparse::Header { name: "foo", value: b"x" }, httparse::Header { name: "foo", value: b"y, z, a" }, httparse::Header { name: "bar", value: b"xxx" }, httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" }, httparse::Header { name: "baz", value: b"123" }, ]; assert!(expect_ascii_header(headers, "foo", "a").is_ok()); assert!(expect_ascii_header(headers, "foo", "b").is_ok()); assert!(expect_ascii_header(headers, "foo", "c").is_ok()); assert!(expect_ascii_header(headers, "foo", "d").is_ok()); assert!(expect_ascii_header(headers, "foo", "x").is_ok()); assert!(expect_ascii_header(headers, "foo", "y").is_ok()); assert!(expect_ascii_header(headers, "foo", "z").is_ok()); assert!(expect_ascii_header(headers, "foo", "a").is_ok()); assert!(expect_ascii_header(headers, "bar", "xxx").is_ok()); assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok()); assert!(expect_ascii_header(headers, "baz", "123").is_ok()); assert!(expect_ascii_header(headers, "baz", "???").is_err()); assert!(expect_ascii_header(headers, "???", "x").is_err()); } } soketto-0.8.1/src/handshake/000077500000000000000000000000001472331330300157245ustar00rootroot00000000000000soketto-0.8.1/src/handshake/client.rs000066400000000000000000000177161472331330300175640ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Websocket client [handshake]. //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 use super::{ append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY, MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, }; use crate::connection::{self, Mode}; use crate::{extension::Extension, Parsing}; use base64::Engine; use bytes::{Buf, BytesMut}; use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; pub use httparse::Header; const BLOCK_SIZE: usize = 8 * 1024; /// Websocket client handshake. #[derive(Debug)] pub struct Client<'a, T> { /// The underlying async I/O resource. socket: T, /// The HTTP host to send the handshake to. host: &'a str, /// The HTTP host resource. resource: &'a str, /// The HTTP headers. headers: &'a [Header<'a>], /// A buffer holding the base-64 encoded request nonce. nonce: WebSocketKey, /// The protocols to include in the handshake. protocols: Vec<&'a str>, /// The extensions the client wishes to include in the request. extensions: Vec>, /// Encoding/decoding buffer. buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { /// Create a new client handshake for some host and resource. pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self { Client { socket, host, resource, headers: &[], nonce: [0; 24], protocols: Vec::new(), extensions: Vec::new(), buffer: BytesMut::new(), } } /// Override the buffer to use for request/response handling. pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { self.buffer = b; self } /// Extract the buffer. pub fn take_buffer(&mut self) -> BytesMut { mem::take(&mut self.buffer) } /// Set connection headers to a slice. These headers are not checked for validity, /// the caller of this method is responsible for verification as well as avoiding /// conflicts with internally set headers. pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self { self.headers = h; self } /// Add a protocol to be included in the handshake. pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { self.protocols.push(p); self } /// Add an extension to be included in the handshake. pub fn add_extension(&mut self, e: Box) -> &mut Self { self.extensions.push(e); self } /// Get back all extensions. pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { self.extensions.drain(..) } /// Initiate client handshake request to server and get back the response. pub async fn handshake(&mut self) -> Result { self.buffer.clear(); self.encode_request(); self.socket.write_all(&self.buffer).await?; self.socket.flush().await?; self.buffer.clear(); loop { crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; if let Parsing::Done { value, offset } = self.decode_response()? { self.buffer.advance(offset); return Ok(value); } } } /// Turn this handshake into a [`connection::Builder`]. pub fn into_builder(mut self) -> connection::Builder { let mut builder = connection::Builder::new(self.socket, Mode::Client); builder.set_buffer(self.buffer); builder.add_extensions(self.extensions.drain(..)); builder } /// Get out the inner socket of the client. pub fn into_inner(self) -> T { self.socket } /// Encode the client handshake as a request, ready to be sent to the server. fn encode_request(&mut self) { let nonce: [u8; 16] = rand::random(); base64::engine::general_purpose::STANDARD .encode_slice(nonce, &mut self.nonce) .expect("encoding to base64 is exactly 16 bytes; qed"); self.buffer.extend_from_slice(b"GET "); self.buffer.extend_from_slice(self.resource.as_bytes()); self.buffer.extend_from_slice(b" HTTP/1.1"); self.buffer.extend_from_slice(b"\r\nHost: "); self.buffer.extend_from_slice(self.host.as_bytes()); self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); self.buffer.extend_from_slice(&self.nonce); self.headers.iter().for_each(|h| { self.buffer.extend_from_slice(b"\r\n"); self.buffer.extend_from_slice(h.name.as_bytes()); self.buffer.extend_from_slice(b": "); self.buffer.extend_from_slice(h.value); }); if let Some((last, prefix)) = self.protocols.split_last() { self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); for p in prefix { self.buffer.extend_from_slice(p.as_bytes()); self.buffer.extend_from_slice(b",") } self.buffer.extend_from_slice(last.as_bytes()) } append_extensions(&self.extensions, &mut self.buffer); self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") } /// Decode the server response to this client request. fn decode_response(&mut self) -> Result, Error> { let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; let mut response = httparse::Response::new(&mut header_buf); let offset = match response.parse(self.buffer.as_ref()) { Ok(httparse::Status::Complete(off)) => off, Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())), Err(e) => return Err(Error::Http(Box::new(e))), }; if response.version != Some(1) { return Err(Error::UnsupportedHttpVersion); } match response.code { Some(101) => (), Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => { // redirect response let location = with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?; let response = ServerResponse::Redirect { status_code: code, location }; return Ok(Parsing::Done { value: response, offset }); } other => { let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) }; return Ok(Parsing::Done { value: response, offset }); } } expect_ascii_header(response.headers, "Upgrade", "websocket")?; expect_ascii_header(response.headers, "Connection", "upgrade")?; with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| { let mut digest = Sha1::new(); digest.update(&self.nonce); digest.update(KEY); let ours = base64::engine::general_purpose::STANDARD.encode(digest.finalize()); if ours.as_bytes() != theirs { return Err(Error::InvalidSecWebSocketAccept); } Ok(()) })?; // Parse `Sec-WebSocket-Extensions` headers. for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? } // Match `Sec-WebSocket-Protocol` header. let mut selected_proto = None; if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) { selected_proto = Some(String::from(p)) } else { return Err(Error::UnsolicitedProtocol); } } let response = ServerResponse::Accepted { protocol: selected_proto }; Ok(Parsing::Done { value: response, offset }) } } /// Handshake response received from the server. #[derive(Debug)] pub enum ServerResponse { /// The server has accepted our request. Accepted { /// The protocol (if any) the server has selected. protocol: Option, }, /// The server is redirecting us to some other location. Redirect { /// The HTTP response status code. status_code: u16, /// The location URL we should go to. location: String, }, /// The server rejected our request. Rejected { /// HTTP response status code. status_code: u16, }, } soketto-0.8.1/src/handshake/http.rs000066400000000000000000000132411472331330300172520ustar00rootroot00000000000000// Copyright (c) 2021 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. /*! This module somewhat mirrors [`crate::handshake::server`], except it's focus is on working with [`http::Request`] and [`http::Response`] types, making it easier to integrate with external web servers such as Hyper. See `examples/hyper_server.rs` from this crate's repository for example usage. */ use super::{WebSocketKey, SEC_WEBSOCKET_EXTENSIONS}; use crate::connection::{self, Mode}; use crate::extension::Extension; use crate::handshake; use bytes::BytesMut; use futures::prelude::*; use http::{header, HeaderMap, Response}; use std::mem; /// A re-export of [`handshake::Error`]. pub type Error = handshake::Error; /// Websocket handshake server. This is similar to [`handshake::Server`], but it is /// focused on performing the WebSocket handshake using a provided [`http::Request`], as opposed /// to decoding the request internally. pub struct Server { // Extensions the server supports. extensions: Vec>, // Encoding/decoding buffer. buffer: BytesMut, } impl Server { /// Create a new server handshake. pub fn new() -> Self { Server { extensions: Vec::new(), buffer: BytesMut::new() } } /// Override the buffer to use for request/response handling. pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { self.buffer = b; self } /// Extract the buffer. pub fn take_buffer(&mut self) -> BytesMut { mem::take(&mut self.buffer) } /// Add an extension the server supports. pub fn add_extension(&mut self, e: Box) -> &mut Self { self.extensions.push(e); self } /// Get back all extensions. pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { self.extensions.drain(..) } /// Attempt to interpret the provided [`http::Request`] as a WebSocket Upgrade request. If successful, this /// returns an [`http::Response`] that should be returned to the client to complete the handshake. pub fn receive_request(&mut self, req: &http::Request) -> Result, Error> { if !is_upgrade_request(&req) { return Err(Error::InvalidSecWebSocketAccept); } let key = match req.headers().get("Sec-WebSocket-Key") { Some(key) => key, None => { return Err(Error::HeaderNotFound("Sec-WebSocket-Key".into()).into()); } }; if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") { return Err(Error::HeaderNotFound("Sec-WebSocket-Version".into()).into()); } // Pull out the Sec-WebSocket-Key and generate the appropriate response to it. let key: &WebSocketKey = match key.as_bytes().try_into() { Ok(key) => key, Err(_) => return Err(Error::InvalidSecWebSocketAccept), }; let accept_key = handshake::generate_accept_key(key); // Get extension information out of the request as we'll need this as well. let extension_config = req .headers() .iter() .filter(|&(name, _)| name.as_str().eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) .map(|(_, value)| Ok(std::str::from_utf8(value.as_bytes())?.to_string())) .collect::, Error>>()?; // Attempt to set the extension configuration params that the client requested. for config_str in &extension_config { handshake::configure_extensions(&mut self.extensions, &config_str)?; } // Build a response that should be sent back to the client to acknowledge the upgrade. let mut response = Response::builder() .status(http::StatusCode::SWITCHING_PROTOCOLS) .header(http::header::CONNECTION, "upgrade") .header(http::header::UPGRADE, "websocket") .header("Sec-WebSocket-Accept", &accept_key[..]); // Tell the client about the agreed-upon extension configuration. We reuse code to build up the // extension header value, but that does make this a little more clunky. if !self.extensions.is_empty() { let mut buf = bytes::BytesMut::new(); let enabled_extensions = self.extensions.iter().filter(|e| e.is_enabled()).peekable(); handshake::append_extension_header_value(enabled_extensions, &mut buf); response = response.header("Sec-WebSocket-Extensions", buf.as_ref()); } let response = response.body(()).expect("bug: failed to build response"); Ok(response) } /// Turn this handshake into a [`connection::Builder`]. pub fn into_builder(mut self, socket: T) -> connection::Builder { let mut builder = connection::Builder::new(socket, Mode::Server); builder.set_buffer(self.buffer); builder.add_extensions(self.extensions.drain(..)); builder } } /// Check if an [`http::Request`] looks like a valid websocket upgrade request. pub fn is_upgrade_request(request: &http::Request) -> bool { header_contains_value(request.headers(), header::CONNECTION, b"upgrade") && header_contains_value(request.headers(), header::UPGRADE, b"websocket") } // Check if there is a header of the given name containing the wanted value. fn header_contains_value(headers: &HeaderMap, header: header::HeaderName, value: &[u8]) -> bool { pub fn trim(x: &[u8]) -> &[u8] { let from = match x.iter().position(|x| !x.is_ascii_whitespace()) { Some(i) => i, None => return &[], }; let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap(); &x[from..=to] } for header in headers.get_all(header) { if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) { return true; } } false } soketto-0.8.1/src/handshake/server.rs000066400000000000000000000236651472331330300176140ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! Websocket server [handshake]. //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 use super::{ append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, }; use crate::connection::{self, Mode}; use crate::extension::Extension; use bytes::BytesMut; use futures::prelude::*; use std::{mem, str}; // Most HTTP servers default to 8KB limit on headers const MAX_HEADERS_SIZE: usize = 8 * 1024; const BLOCK_SIZE: usize = 8 * 1024; /// Websocket handshake server. #[derive(Debug)] pub struct Server<'a, T> { socket: T, /// Protocols the server supports. protocols: Vec<&'a str>, /// Extensions the server supports. extensions: Vec>, /// Encoding/decoding buffer. buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { /// Create a new server handshake. pub fn new(socket: T) -> Self { Server { socket, protocols: Vec::new(), extensions: Vec::new(), buffer: BytesMut::new() } } /// Override the buffer to use for request/response handling. pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { self.buffer = b; self } /// Extract the buffer. pub fn take_buffer(&mut self) -> BytesMut { mem::take(&mut self.buffer) } /// Add a protocol the server supports. pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { self.protocols.push(p); self } /// Add an extension the server supports. pub fn add_extension(&mut self, e: Box) -> &mut Self { self.extensions.push(e); self } /// Get back all extensions. pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { self.extensions.drain(..) } /// Await an incoming client handshake request. pub async fn receive_request(&mut self) -> Result, Error> { self.buffer.clear(); let mut skip = 0; loop { crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE); // We don't expect body, so can search for the CRLF headers tail from // the end of the buffer. if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") { break; } // Give up if we've reached the limit. We could emit a specific error here, // but httparse will produce meaningful error for us regardless. if limit == MAX_HEADERS_SIZE { break; } // Skip bytes that did not contain CRLF in the next iteration. // If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer // length, hence backing off the full 4 bytes. skip = self.buffer.len().saturating_sub(4); } self.decode_request() } /// Respond to the client. pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> { self.buffer.clear(); self.encode_response(r); self.socket.write_all(&self.buffer).await?; self.socket.flush().await?; self.buffer.clear(); Ok(()) } /// Turn this handshake into a [`connection::Builder`]. pub fn into_builder(mut self) -> connection::Builder { let mut builder = connection::Builder::new(self.socket, Mode::Server); builder.set_buffer(self.buffer); builder.add_extensions(self.extensions.drain(..)); builder } /// Get out the inner socket of the server. pub fn into_inner(self) -> T { self.socket } // Decode client handshake request. fn decode_request(&mut self) -> Result { let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; let mut request = httparse::Request::new(&mut header_buf); match request.parse(self.buffer.as_ref()) { Ok(httparse::Status::Complete(_)) => (), Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest), Err(e) => return Err(Error::Http(Box::new(e))), }; if request.method != Some("GET") { return Err(Error::InvalidRequestMethod); } if request.version != Some(1) { return Err(Error::UnsupportedHttpVersion); } let host = with_first_header(&request.headers, "Host", Ok)?; expect_ascii_header(request.headers, "Upgrade", "websocket")?; expect_ascii_header(request.headers, "Connection", "upgrade")?; expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?; let origin = request.headers.iter().find_map( |h| { if h.name.eq_ignore_ascii_case("Origin") { Some(h.value) } else { None } }, ); let headers = RequestHeaders { host, origin }; let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| { WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len())) })?; for h in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? } let mut protocols = Vec::new(); for p in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) { protocols.push(p) } } let path = request.path.unwrap_or("/"); Ok(ClientRequest { ws_key, protocols, path, headers }) } // Encode server handshake response. fn encode_response(&mut self, response: &Response<'_>) { match response { Response::Accept { key, protocol } => { let accept_value = super::generate_accept_key(&key); self.buffer.extend_from_slice( concat![ "HTTP/1.1 101 Switching Protocols", "\r\nServer: soketto-", env!("CARGO_PKG_VERSION"), "\r\nUpgrade: websocket", "\r\nConnection: upgrade", "\r\nSec-WebSocket-Accept: ", ] .as_bytes(), ); self.buffer.extend_from_slice(&accept_value); if let Some(p) = protocol { self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); self.buffer.extend_from_slice(p.as_bytes()) } append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer); self.buffer.extend_from_slice(b"\r\n\r\n") } Response::Reject { status_code } => { self.buffer.extend_from_slice(b"HTTP/1.1 "); let (_, reason) = if let Ok(i) = STATUSCODES.binary_search_by_key(status_code, |(n, _)| *n) { STATUSCODES[i] } else { (500, "500 Internal Server Error") }; self.buffer.extend_from_slice(reason.as_bytes()); self.buffer.extend_from_slice(b"\r\n\r\n") } } } } /// Handshake request received from the client. #[derive(Debug)] pub struct ClientRequest<'a> { ws_key: WebSocketKey, protocols: Vec<&'a str>, path: &'a str, headers: RequestHeaders<'a>, } /// Select HTTP headers sent by the client. #[derive(Debug, Copy, Clone)] pub struct RequestHeaders<'a> { /// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header. pub host: &'a [u8], /// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided. pub origin: Option<&'a [u8]>, } impl<'a> ClientRequest<'a> { /// The `Sec-WebSocket-Key` header nonce value. pub fn key(&self) -> WebSocketKey { self.ws_key } /// The protocols the client is proposing. pub fn protocols(&self) -> impl Iterator { self.protocols.iter().cloned() } /// The path the client is requesting. pub fn path(&self) -> &str { self.path } /// Select HTTP headers sent by the client. pub fn headers(&self) -> RequestHeaders { self.headers } } /// Handshake response the server sends back to the client. #[derive(Debug)] pub enum Response<'a> { /// The server accepts the handshake request. Accept { key: WebSocketKey, protocol: Option<&'a str> }, /// The server rejects the handshake request. Reject { status_code: u16 }, } /// Known status codes and their reason phrases. const STATUSCODES: &[(u16, &str)] = &[ (100, "100 Continue"), (101, "101 Switching Protocols"), (102, "102 Processing"), (200, "200 OK"), (201, "201 Created"), (202, "202 Accepted"), (203, "203 Non Authoritative Information"), (204, "204 No Content"), (205, "205 Reset Content"), (206, "206 Partial Content"), (207, "207 Multi-Status"), (208, "208 Already Reported"), (226, "226 IM Used"), (300, "300 Multiple Choices"), (301, "301 Moved Permanently"), (302, "302 Found"), (303, "303 See Other"), (304, "304 Not Modified"), (305, "305 Use Proxy"), (307, "307 Temporary Redirect"), (308, "308 Permanent Redirect"), (400, "400 Bad Request"), (401, "401 Unauthorized"), (402, "402 Payment Required"), (403, "403 Forbidden"), (404, "404 Not Found"), (405, "405 Method Not Allowed"), (406, "406 Not Acceptable"), (407, "407 Proxy Authentication Required"), (408, "408 Request Timeout"), (409, "409 Conflict"), (410, "410 Gone"), (411, "411 Length Required"), (412, "412 Precondition Failed"), (413, "413 Payload Too Large"), (414, "414 URI Too Long"), (415, "415 Unsupported Media Type"), (416, "416 Range Not Satisfiable"), (417, "417 Expectation Failed"), (418, "418 I'm a teapot"), (421, "421 Misdirected Request"), (422, "422 Unprocessable Entity"), (423, "423 Locked"), (424, "424 Failed Dependency"), (426, "426 Upgrade Required"), (428, "428 Precondition Required"), (429, "429 Too Many Requests"), (431, "431 Request Header Fields Too Large"), (451, "451 Unavailable For Legal Reasons"), (500, "500 Internal Server Error"), (501, "501 Not Implemented"), (502, "502 Bad Gateway"), (503, "503 Service Unavailable"), (504, "504 Gateway Timeout"), (505, "505 HTTP Version Not Supported"), (506, "506 Variant Also Negotiates"), (507, "507 Insufficient Storage"), (508, "508 Loop Detected"), (510, "510 Not Extended"), (511, "511 Network Authentication Required"), ]; soketto-0.8.1/src/lib.rs000066400000000000000000000136701472331330300151210ustar00rootroot00000000000000// Copyright (c) 2019 Parity Technologies (UK) Ltd. // Copyright (c) 2016 twist developers // // Licensed under the Apache License, Version 2.0 // or the MIT // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. //! An implementation of the [RFC 6455][rfc6455] websocket protocol. //! //! To begin a websocket connection one first needs to perform a [handshake], //! either as [client] or [server], in order to upgrade from HTTP. //! Once successful, the client or server can transition to a connection, //! i.e. a [Sender]/[Receiver] pair and send and receive textual or //! binary data. //! //! **Note**: While it is possible to only receive websocket messages it is //! not possible to only send websocket messages. Receiving data is required //! in order to react to control frames such as PING or CLOSE. While those will be //! answered transparently they have to be received in the first place, so //! calling [`connection::Receiver::receive`] is imperative. //! //! **Note**: None of the `async` methods are safe to cancel so their `Future`s //! must not be dropped unless they return `Poll::Ready`. //! //! # Client example //! //! ```no_run //! # use tokio_util::compat::TokioAsyncReadCompatExt; //! # async fn doc() -> Result<(), soketto::BoxedError> { //! use soketto::handshake::{Client, ServerResponse}; //! //! // First, we need to establish a TCP connection. //! let socket = tokio::net::TcpStream::connect("...").await?; //! //! // Then we configure the client handshake. //! let mut client = Client::new(socket.compat(), "...", "/"); //! //! // And finally we perform the handshake and handle the result. //! let (mut sender, mut receiver) = match client.handshake().await? { //! ServerResponse::Accepted { .. } => client.into_builder().finish(), //! ServerResponse::Redirect { status_code, location } => unimplemented!("follow location URL"), //! ServerResponse::Rejected { status_code } => unimplemented!("handle failure") //! }; //! //! // Over the established websocket connection we can send //! sender.send_text("some text").await?; //! sender.send_text("some more text").await?; //! sender.flush().await?; //! //! // ... and receive data. //! let mut data = Vec::new(); //! receiver.receive_data(&mut data).await?; //! //! # Ok(()) //! # } //! //! ``` //! //! # Server example //! //! ```no_run //! # use tokio_util::compat::TokioAsyncReadCompatExt; //! # use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; //! # async fn doc() -> Result<(), soketto::BoxedError> { //! use soketto::{handshake::{Server, ClientRequest, server::Response}}; //! //! // First, we listen for incoming connections. //! let listener = tokio::net::TcpListener::bind("...").await?; //! let mut incoming = TcpListenerStream::new(listener); //! //! while let Some(socket) = incoming.next().await { //! // For each incoming connection we perform a handshake. //! let mut server = Server::new(socket?.compat()); //! //! let websocket_key = { //! let req = server.receive_request().await?; //! req.key() //! }; //! //! // Here we accept the client unconditionally. //! let accept = Response::Accept { key: websocket_key, protocol: None }; //! server.send_response(&accept).await?; //! //! // And we can finally transition to a websocket connection. //! let (mut sender, mut receiver) = server.into_builder().finish(); //! //! let mut data = Vec::new(); //! let data_type = receiver.receive_data(&mut data).await?; //! //! if data_type.is_text() { //! sender.send_text(std::str::from_utf8(&data)?).await? //! } else { //! sender.send_binary(&data).await? //! } //! //! sender.close().await? //! } //! //! # Ok(()) //! # } //! //! ``` //! //! See `examples/hyper_server.rs` from this crate's repository for an example of //! starting up a WebSocket server alongside an Hyper HTTP server. //! //! [client]: handshake::Client //! [server]: handshake::Server //! [Sender]: connection::Sender //! [Receiver]: connection::Receiver //! [rfc6455]: https://tools.ietf.org/html/rfc6455 //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 #![forbid(unsafe_code)] pub mod base; pub mod connection; pub mod data; pub mod extension; pub mod handshake; use bytes::BytesMut; use futures::io::{AsyncRead, AsyncReadExt}; use std::io; pub use connection::{Mode, Receiver, Sender}; pub use data::{Data, Incoming}; pub type BoxedError = Box; /// A parsing result. #[derive(Debug, Clone)] pub enum Parsing { /// Parsing completed. Done { /// The parsed value. value: T, /// The offset into the byte slice that has been consumed. offset: usize, }, /// Parsing is incomplete and needs more data. NeedMore(N), } /// A buffer type used for implementing `Extension`s. #[derive(Debug)] pub enum Storage<'a> { /// A read-only shared byte slice. Shared(&'a [u8]), /// A mutable byte slice. Unique(&'a mut [u8]), /// An owned byte buffer. Owned(Vec), } impl AsRef<[u8]> for Storage<'_> { fn as_ref(&self) -> &[u8] { match self { Storage::Shared(d) => d, Storage::Unique(d) => d, Storage::Owned(b) => b.as_ref(), } } } /// Helper function to allow casts from `usize` to `u64` only on platforms /// where the sizes are guaranteed to fit. #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] const fn as_u64(a: usize) -> u64 { a as u64 } /// Fill the buffer from the given `AsyncRead` impl with up to `max` bytes. async fn read(reader: &mut R, dest: &mut BytesMut, max: usize) -> io::Result<()> where R: AsyncRead + Unpin, { let i = dest.len(); dest.resize(i + max, 0u8); let n = reader.read(&mut dest[i..]).await?; dest.truncate(i + n); if n == 0 { return Err(io::ErrorKind::UnexpectedEof.into()); } log::trace!("read {} bytes", n); Ok(()) }