condure-1.10.0/.cargo_vcs_info.json0000644000000001360000000000100125520ustar { "git": { "sha1": "4bdc9c9cc94325fb62f75c1867c6739b4ad46f7a" }, "path_in_vcs": "" }condure-1.10.0/.gitignore000064400000000000000000000000241046102023000133260ustar 00000000000000 /target **/*.rs.bk condure-1.10.0/CHANGELOG.md000064400000000000000000000025671046102023000131650ustar 00000000000000Condure Changelog ================= v. 1.10.0 (2023-06-29) * Add support for outgoing connections. * Ability to set mode/user/group when listening on Unix socket. v. 1.9.2 (2023-02-06) * Fix WebSocket compression with fragmented messages. v. 1.9.1 (2023-01-19) * Fix crash in stream connection handler. v. 1.9.0 (2022-12-05) * Support permessage-deflate WebSocket compression. v. 1.8.0 (2022-11-03) * Compatibility with httparse 1.8. * Add more benchmarks. v. 1.7.0 (2022-08-18) * Fix worker thread hang when backend buffer is full. v. 1.6.0 (2022-04-15) * Significantly reduce connection memory usage. * Allow up to 64 headers in requests and responses. * Allow WebSocket requests that include a Content-Length of 0. v. 1.5.0 (2022-03-11) * Ability to listen on a Unix socket for client connections. v. 1.4.1 (2021-10-24) * Fix crash when sending too fast to clients. v. 1.4.0 (2021-10-22) * Port connection handler to use async functions. v. 1.3.1 (2021-08-11) * Fixes for high load. v. 1.3.0 (2021-07-29) * Port to async/await. v. 1.2.0 (2021-05-04) * Send PING/PONG frame data to clients. * Port to mio 0.7. v. 1.1.0 (2020-11-02) * TLS support. * Don't preallocate connection buffers. * Start using async/await in some places. v. 1.0.1 (2020-07-24) * Remove some unsafe usage. v. 1.0.0 (2020-07-21) * Stable version. condure-1.10.0/COPYING000064400000000000000000000261361046102023000124050ustar 00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. condure-1.10.0/Cargo.lock0000644000000752170000000000100105410ustar # This file is automatically @generated by Cargo. # It is not intended for manual editing. version = 3 [[package]] name = "adler" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anes" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", "is-terminal", "utf8parse", ] [[package]] name = "anstyle" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" [[package]] name = "anstyle-parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" dependencies = [ "windows-sys", ] [[package]] name = "anstyle-wincon" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" dependencies = [ "anstyle", "windows-sys", ] [[package]] name = "arrayvec" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "block-buffer" version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] [[package]] name = "bumpalo" version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "cast" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "ciborium" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" dependencies = [ "ciborium-io", "ciborium-ll", "serde", ] [[package]] name = "ciborium-io" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" [[package]] name = "ciborium-ll" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ "ciborium-io", "half", ] [[package]] name = "clap" version = "4.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bba77a07e4489fb41bd90e8d4201c3eb246b3c2c9ea2ba0bddd6c1d1df87db7d" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" version = "4.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c9b4a88bb4bc35d3d6f65a21b0f0bafe9c894fa00978de242c555ec28bea1c0" dependencies = [ "anstream", "anstyle", "bitflags", "clap_lex", "once_cell", "strsim", "terminal_size", ] [[package]] name = "clap_lex" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" [[package]] name = "colorchoice" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "condure" version = "1.10.0" dependencies = [ "arrayvec", "base64", "clap", "criterion", "env_logger", "httparse", "ipnet", "libc", "log", "miniz_oxide", "mio", "openssl", "paste", "sha1", "signal-hook", "slab", "socket2", "test-log", "thiserror", "time", "url", "zmq", ] [[package]] name = "cpufeatures" version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c" dependencies = [ "libc", ] [[package]] name = "criterion" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", "is-terminal", "itertools", "num-traits", "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", "serde_derive", "serde_json", "tinytemplate", "walkdir", ] [[package]] name = "criterion-plot" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", "itertools", ] [[package]] name = "crossbeam-channel" version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" dependencies = [ "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-deque" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" version = "0.9.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", "memoffset", "scopeguard", ] [[package]] name = "crossbeam-utils" version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] [[package]] name = "crypto-common" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", ] [[package]] name = "digest" version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", ] [[package]] name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "env_logger" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" dependencies = [ "log", ] [[package]] name = "errno" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" dependencies = [ "errno-dragonfly", "libc", "windows-sys", ] [[package]] name = "errno-dragonfly" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" dependencies = [ "cc", "libc", ] [[package]] name = "error-chain" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9435d864e017c3c6afeac1654189b06cdb491cf2ff73dbf0d73b0f292f42ff8" [[package]] name = "foreign-types" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ "foreign-types-shared", ] [[package]] name = "foreign-types-shared" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" dependencies = [ "percent-encoding", ] [[package]] name = "generic-array" version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", ] [[package]] name = "half" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" [[package]] name = "hermit-abi" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" [[package]] name = "httparse" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "idna" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" dependencies = [ "unicode-bidi", "unicode-normalization", ] [[package]] name = "io-lifetimes" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", "windows-sys", ] [[package]] name = "ipnet" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi", "io-lifetimes", "rustix", "windows-sys", ] [[package]] name = "itertools" version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ "either", ] [[package]] name = "itoa" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "js-sys" version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" dependencies = [ "wasm-bindgen", ] [[package]] name = "libc" version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "linux-raw-sys" version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "log" version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "memoffset" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] [[package]] name = "metadeps" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73b122901b3a675fac8cecf68dcb2f0d3036193bc861d1ac0e1c337f7d5254c2" dependencies = [ "error-chain", "pkg-config", "toml", ] [[package]] name = "miniz_oxide" version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] [[package]] name = "mio" version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", "log", "wasi", "windows-sys", ] [[package]] name = "num-traits" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", ] [[package]] name = "num_cpus" version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ "hermit-abi", "libc", ] [[package]] name = "num_threads" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" dependencies = [ "libc", ] [[package]] name = "once_cell" version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" [[package]] name = "openssl" version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ "bitflags", "cfg-if", "foreign-types", "libc", "once_cell", "openssl-macros", "openssl-sys", ] [[package]] name = "openssl-macros" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", "syn 2.0.22", ] [[package]] name = "openssl-sys" version = "0.9.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", "pkg-config", "vcpkg", ] [[package]] name = "paste" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" [[package]] name = "percent-encoding" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pkg-config" version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "plotters" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" dependencies = [ "num-traits", "plotters-backend", "plotters-svg", "wasm-bindgen", "web-sys", ] [[package]] name = "plotters-backend" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" [[package]] name = "plotters-svg" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" dependencies = [ "plotters-backend", ] [[package]] name = "proc-macro2" version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" dependencies = [ "unicode-ident", ] [[package]] name = "quote" version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" dependencies = [ "proc-macro2", ] [[package]] name = "rayon" version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", ] [[package]] name = "rayon-core" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", "num_cpus", ] [[package]] name = "regex" version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" dependencies = [ "regex-syntax", ] [[package]] name = "regex-syntax" version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" [[package]] name = "rustix" version = "0.37.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", "windows-sys", ] [[package]] name = "ryu" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ "winapi-util", ] [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", "syn 2.0.22", ] [[package]] name = "serde_json" version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" dependencies = [ "itoa", "ryu", "serde", ] [[package]] name = "sha1" version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" dependencies = [ "cfg-if", "cpufeatures", "digest", ] [[package]] name = "signal-hook" version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" dependencies = [ "libc", "signal-hook-registry", ] [[package]] name = "signal-hook-registry" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] [[package]] name = "slab" version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] [[package]] name = "socket2" version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", ] [[package]] name = "strsim" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] [[package]] name = "syn" version = "2.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] [[package]] name = "terminal_size" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e6bf6f19e9f8ed8d4048dc22981458ebcf406d67e94cd422e5ecd73d63b3237" dependencies = [ "rustix", "windows-sys", ] [[package]] name = "test-log" version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9601d162c1d77e62c1ea0bc8116cd1caf143ce3af947536c3c9052a1677fe0c" dependencies = [ "proc-macro2", "quote", "syn 1.0.109", ] [[package]] name = "thiserror" version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", "syn 2.0.22", ] [[package]] name = "time" version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" dependencies = [ "itoa", "libc", "num_threads", "serde", "time-core", "time-macros", ] [[package]] name = "time-core" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" dependencies = [ "time-core", ] [[package]] name = "tinytemplate" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ "serde", "serde_json", ] [[package]] name = "tinyvec" version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" dependencies = [ "tinyvec_macros", ] [[package]] name = "tinyvec_macros" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toml" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "736b60249cb25337bc196faa43ee12c705e426f3d55c214d73a4e7be06f92cb4" [[package]] name = "typenum" version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicode-bidi" version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "unicode-normalization" version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" dependencies = [ "tinyvec", ] [[package]] name = "url" version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", "winapi-util", ] [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" dependencies = [ "cfg-if", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", "syn 2.0.22", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" dependencies = [ "quote", "wasm-bindgen-macro-support", ] [[package]] name = "wasm-bindgen-macro-support" version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", "syn 2.0.22", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "web-sys" version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" dependencies = [ "js-sys", "wasm-bindgen", ] [[package]] name = "winapi" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ "winapi-i686-pc-windows-gnu", "winapi-x86_64-pc-windows-gnu", ] [[package]] name = "winapi-i686-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" dependencies = [ "winapi", ] [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ "windows-targets", ] [[package]] name = "windows-targets" version = "0.48.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", "windows_i686_gnu", "windows_i686_msvc", "windows_x86_64_gnu", "windows_x86_64_gnullvm", "windows_x86_64_msvc", ] [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" [[package]] name = "windows_aarch64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" [[package]] name = "windows_i686_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" [[package]] name = "windows_i686_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" [[package]] name = "windows_x86_64_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" [[package]] name = "windows_x86_64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" [[package]] name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] name = "zmq" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aad98a7a617d608cd9e1127147f630d24af07c7cd95ba1533246d96cbdd76c66" dependencies = [ "bitflags", "libc", "log", "zmq-sys", ] [[package]] name = "zmq-sys" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d33a2c51dde24d5b451a2ed4b488266df221a5eaee2ee519933dc46b9a9b3648" dependencies = [ "libc", "metadeps", ] condure-1.10.0/Cargo.toml0000644000000051310000000000100105500ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2018" name = "condure" version = "1.10.0" authors = ["Justin Karneges "] autobins = false description = "HTTP/WebSocket connection manager" readme = "README.md" license = "Apache-2.0" repository = "https://github.com/fanout/condure" [package.metadata.deb] build-depends = "pkg-config, libzmq3-dev" extended-description = """ Condure is a service that manages network connections on behalf of server applications, in order to allow controlling the connections from multiple processes. Applications communicate with Condure over ZeroMQ. Condure can only manage connections for protocols it knows about. Currently this is HTTP/1 and WebSockets. The project was inspired by Mongrel2. """ separate-debug-symbols = true [package.metadata.rpm] package = "condure" [package.metadata.rpm.cargo] buildflags = ["--release"] [package.metadata.rpm.targets.condure] path = "/usr/bin/condure" [profile.dev] panic = "abort" [profile.release] panic = "abort" [[bin]] name = "condure" [[bench]] name = "server" harness = false [[bench]] name = "client" harness = false [dependencies.arrayvec] version = "0.7" [dependencies.base64] version = "0.13" [dependencies.clap] version = "4.3" features = [ "cargo", "string", "wrap_help", ] [dependencies.httparse] version = "1.7" [dependencies.ipnet] version = "2" [dependencies.libc] version = "0.2" [dependencies.log] version = "0.4" [dependencies.miniz_oxide] version = "0.6" [dependencies.mio] version = "0.8" features = [ "os-poll", "os-ext", "net", ] [dependencies.openssl] version = "0.10" [dependencies.paste] version = "1.0" [dependencies.sha1] version = "0.10" [dependencies.signal-hook] version = "0.3" [dependencies.slab] version = "0.4" [dependencies.socket2] version = "0.4" [dependencies.thiserror] version = "1.0" [dependencies.time] version = "0.3" features = [ "formatting", "local-offset", "macros", ] [dependencies.url] version = "2.3" [dependencies.zmq] version = "0.9" [dev-dependencies.criterion] version = "0.5" [dev-dependencies.env_logger] version = "0.9" default-features = false [dev-dependencies.test-log] version = "0.2" condure-1.10.0/Cargo.toml.orig000064400000000000000000000032441046102023000142340ustar 00000000000000[package] name = "condure" version = "1.10.0" authors = ["Justin Karneges "] license = "Apache-2.0" description = "HTTP/WebSocket connection manager" repository = "https://github.com/fanout/condure" readme = "README.md" edition = "2018" autobins = false [[bin]] name = "condure" [profile.dev] panic = "abort" [profile.release] panic = "abort" [dependencies] arrayvec = "0.7" base64 = "0.13" clap = { version = "4.3", features = ["cargo", "string", "wrap_help"] } httparse = "1.7" ipnet = "2" libc = "0.2" log = "0.4" miniz_oxide = "0.6" mio = { version = "0.8", features = ["os-poll", "os-ext", "net"] } openssl = "0.10" paste = "1.0" sha1 = "0.10" signal-hook = "0.3" slab = "0.4" socket2 = "0.4" thiserror = "1.0" time = { version = "0.3", features = ["formatting", "local-offset", "macros"] } url = "2.3" zmq = "0.9" [dev-dependencies] criterion = "0.5" env_logger = { version = "0.9", default-features = false } test-log = "0.2" [[bench]] name = "server" harness = false [[bench]] name = "client" harness = false [package.metadata.deb] extended-description = """\ Condure is a service that manages network connections on behalf of server applications, in order to allow controlling the connections from multiple processes. Applications communicate with Condure over ZeroMQ. Condure can only manage connections for protocols it knows about. Currently this is HTTP/1 and WebSockets. The project was inspired by Mongrel2. """ separate-debug-symbols = true build-depends = "pkg-config, libzmq3-dev" [package.metadata.rpm] package = "condure" [package.metadata.rpm.cargo] buildflags = ["--release"] [package.metadata.rpm.targets] condure = { path = "/usr/bin/condure" } condure-1.10.0/README.md000064400000000000000000000133501046102023000126230ustar 00000000000000# Condure Condure is a service that manages network connections in order to allow controlling the connections from multiple processes. It can manage incoming connections as well as outgoing connections. Applications communicate with Condure over [ZeroMQ](https://zeromq.org/). Condure can only manage connections for protocols it knows about. Currently this is HTTP/1 and WebSockets. See [Supported protocols](#supported-protocols). The project was inspired by [Mongrel2](https://mongrel2.org/). ## Use cases * Pass connection ownership from one process to another. * Restart an application without its connections getting disconnected. * Balance connection ownership among multiple processes. ## Basic usage Start the server: ``` $ condure --listen 8000 --zclient-stream ipc://client ``` Connect a handler to it, such as this simple Python program: ```py # this handler responds to every request with "hello world" import os import time import tnetstring import zmq instance_id = 'basichandler.{}'.format(os.getpid()).encode() ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') # await subscription time.sleep(0.01) while True: m_raw = in_sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) ``` A client request: ``` $ curl -i http://localhost:8000 HTTP/1.1 200 OK Content-Type: text/plain Content-Length: 12 hello world ``` The process that receives the request doesn't need to be the same one that responds! For example, here's a program that outputs request IDs to stdout: ```py # this handler just outputs the request ID import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.PULL) sock.connect('ipc://client-out') while True: m = sock.recv_multipart() req = tnetstring.loads(m[0][1:]) print('{} {}'.format(req[b'from'].decode(), req[b'id'].decode())) ``` We can see request ID information when a client request is made: ``` $ python examples/printreq.py condure 0-0-0 ``` From another shell we can respond using a program like this: ```py # this program sends a response to a certain request ID import sys import time import tnetstring import zmq body = sys.argv[1] addr = sys.argv[2].encode() rid = sys.argv[3].encode() ctx = zmq.Context() sock = ctx.socket(zmq.PUB) sock.connect('ipc://client-in') # await subscription time.sleep(0.01) resp = {} resp[b'from'] = b'sendresp' resp[b'id'] = rid resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = '{}\n'.format(body).encode() m = [addr + b' T' + tnetstring.dumps(resp)] sock.send_multipart(m) ``` For example: ``` $ python examples/sendresp.py "responding from another process" condure 0-0-0 ``` The client sees: ``` $ curl -i http://localhost:8000 HTTP/1.1 200 OK Content-Type: text/plain Content-Length: 32 responding from another process ``` For easy testing, the programs can be piped together: ``` $ python -u examples/printreq.py | xargs -n 2 python examples/sendresp.py "responding from another process" ``` ## Suspending and resuming connections When passing control of a connection from one process to another, it is important to suspend the connection first. This is done by sending a `handoff-start` message and waiting for a `handoff-proceed` message. At that point, the connection information can be given to another process, and the connection can be resumed by sending any message (such as `keep-alive`). See the [ZHTTP spec](https://rfc.zeromq.org/spec/33/). ## REQ mode In addition to the stream mode which uses PUSH/ROUTER/SUB sockets, there is a "REQ" mode available which uses a DEALER socket. To enable it, set `req` as the mode on a listen port. This mode can be handy for implementing simple request/response servers using ZeroMQ. ## Supported protocols Condure supports HTTP/1 and WebSockets. Condure manages connections at layer 7 and only supports protocols it knows about. This is to simplify its usage. Handling arbitrary protocols would require applications to build protocol stacks capable of suspending/resuming sessions at arbitrary byte positions in TCP streams, making Condure usage prohibitive. Instead, Condure is protocol-aware, and provides parsed frames to applications, so that applications are only required to support suspending/resuming sessions at frame boundaries. ## Performance Condure was built for high performance. It uses numerous optimization techniques, including minimal heap allocations, ring buffers, vectored I/O, hierarchical timing wheels, and fast data structures (e.g. slabs). Over 1M concurrent connections have been tested on a single instance using just 2 workers (4 threads total). See https://blog.fanout.io/2020/08/11/rewriting-pushpins-connection-manager-in-rust/ ## Comparison to Mongrel2 * Condure supports acting as a server and as a client. * Condure supports multiple cores. * Condure supports listening on multiple ports without requiring multiple processes. * Condure does not support multiple routes and is not intended to be a shared server. Each application that wants to keep connections in a separate process should spawn its own Condure instance. * Condure has no config file. Configuration is supplied using command line arguments. * Condure uses a different ZeroMQ-based protocol, [ZHTTP](https://rfc.zeromq.org/spec/33/), which is easier to use than Mongrel2's protocol and more reliable. ## Future plans * HTTP/2 * HTTP/3 condure-1.10.0/benches/client.rs000064400000000000000000000125401046102023000145770ustar 00000000000000/* * Copyright (C) 2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use condure::channel; use condure::client::TestClient; use condure::executor::Executor; use condure::future::{AsyncReadExt, AsyncSender, AsyncTcpListener, AsyncTcpStream, AsyncWriteExt}; use condure::reactor::Reactor; use criterion::{criterion_group, criterion_main, Criterion}; use mio::net::TcpListener; use std::net::SocketAddr; use std::rc::Rc; use std::str; const REQS_PER_ITER: usize = 10; fn req(listener: TcpListener, start: F1, wait: F2) -> TcpListener where F1: Fn(SocketAddr) + 'static, F2: Fn() + 'static, { let executor = Executor::new(REQS_PER_ITER + 1); let addr = listener.local_addr().unwrap(); let (s, r) = channel::channel(1); for _ in 0..REQS_PER_ITER { start(addr); } let spawner = executor.spawner(); executor .spawn(async move { let s = AsyncSender::new(s); let listener = AsyncTcpListener::new(listener); for _ in 0..REQS_PER_ITER { let (stream, _) = listener.accept().await.unwrap(); let mut stream = AsyncTcpStream::new(stream); spawner .spawn(async move { let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).await.unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ).await .unwrap(); }) .unwrap(); } s.send(listener.into_inner()).await.unwrap(); }) .unwrap(); executor .run(|timeout| Reactor::current().unwrap().poll(timeout)) .unwrap(); for _ in 0..REQS_PER_ITER { wait(); } let listener = r.recv().unwrap(); listener } fn criterion_benchmark(c: &mut Criterion) { let mut req_listener = Some(TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap()); let mut stream_listener = Some(TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap()); let _reactor = Reactor::new(REQS_PER_ITER * 10); { let client = Rc::new(TestClient::new(1)); c.bench_function("req_client workers=1", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); req_listener = Some(req( req_listener.take().unwrap(), move |addr| c1.do_req(addr), move || c2.wait_req(), )) }) }); c.bench_function("stream_client workers=1", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); stream_listener = Some(req( stream_listener.take().unwrap(), move |addr| c1.do_stream_http(addr), move || c2.wait_stream(), )) }) }); } { let client = Rc::new(TestClient::new(2)); c.bench_function("req_client workers=2", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); req_listener = Some(req( req_listener.take().unwrap(), move |addr| c1.do_req(addr), move || c2.wait_req(), )) }) }); c.bench_function("stream_client workers=2", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); stream_listener = Some(req( stream_listener.take().unwrap(), move |addr| c1.do_stream_http(addr), move || c2.wait_stream(), )) }) }); } } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); condure-1.10.0/benches/server.rs000064400000000000000000000112661046102023000146330ustar 00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use condure::connection::testutil::{ BenchServerReqConnection, BenchServerReqHandler, BenchServerStreamConnection, BenchServerStreamHandler, }; use condure::executor::Executor; use condure::future::{AsyncReadExt, AsyncTcpStream, AsyncWriteExt}; use condure::reactor::Reactor; use condure::server::TestServer; use condure::websocket::testutil::{BenchRecvMessage, BenchSendMessage}; use criterion::{criterion_group, criterion_main, Criterion}; use std::io::{self, Write}; use std::net::SocketAddr; use std::str; const REQS_PER_ITER: usize = 10; fn req(addr: SocketAddr) { let reactor = Reactor::new(REQS_PER_ITER * 10); let executor = Executor::new(REQS_PER_ITER); for _ in 0..REQS_PER_ITER { executor .spawn(async move { let mut client = AsyncTcpStream::connect(&[addr]).await.unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .await .unwrap(); let mut resp = [0u8; 1024]; let mut resp = io::Cursor::new(&mut resp[..]); loop { let mut buf = [0; 1024]; let size = client.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.write(&buf[..size]).unwrap(); } let size = resp.position() as usize; let resp = str::from_utf8(&resp.get_ref()[..size]).unwrap(); assert_eq!(resp, "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n"); }) .unwrap(); } executor.run(|timeout| reactor.poll(timeout)).unwrap(); } fn criterion_benchmark(c: &mut Criterion) { { let t = BenchServerReqHandler::new(); c.bench_function("req_handler", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerStreamHandler::new(); c.bench_function("stream_handler", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerReqConnection::new(); c.bench_function("req_connection", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerStreamConnection::new(); c.bench_function("stream_connection", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchSendMessage::new(false); c.bench_function("ws_send", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchSendMessage::new(true); c.bench_function("ws_send_deflate", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchRecvMessage::new(false); c.bench_function("ws_recv", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchRecvMessage::new(true); c.bench_function("ws_recv_deflate", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let server = TestServer::new(1); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req_server workers=1", |b| b.iter(|| req(req_addr))); c.bench_function("stream_server workers=1", |b| b.iter(|| req(stream_addr))); } { let server = TestServer::new(2); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req_server workers=2", |b| b.iter(|| req(req_addr))); c.bench_function("stream_server workers=2", |b| b.iter(|| req(stream_addr))); } } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); condure-1.10.0/examples/basichandler.py000064400000000000000000000014371046102023000161560ustar 00000000000000# this handler responds to every request with "hello world" import os import time import tnetstring import zmq instance_id = 'basichandler.{}'.format(os.getpid()).encode() ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') # await subscription time.sleep(0.01) while True: m_raw = in_sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) condure-1.10.0/examples/get.py000064400000000000000000000014061046102023000143120ustar 00000000000000import sys import uuid import tnetstring import zmq if len(sys.argv) < 2: print('usage: {} [url]'.format(sys.argv[0])) sys.exit(1) ctx = zmq.Context() sock = ctx.socket(zmq.REQ) sock.connect('ipc://server') req = { b'method': b'GET', b'uri': sys.argv[1].encode('utf-8'), #b'follow-redirects': True, #b'ignore-tls-errors': True, } sock.send(b'T' + tnetstring.dumps(req)) resp = tnetstring.loads(sock.recv()[1:]) if b'type' in resp and resp[b'type'] == b'error': print('error: {}'.format(resp[b'condition'])) sys.exit(1) print('code={} reason=[{}]'.format(resp[b'code'], resp[b'reason'])) for h in resp[b'headers']: print('{}: {}'.format(h[0], h[1])) if b'body' in resp: print('\n{}'.format(resp[b'body'])) else: print('\n') condure-1.10.0/examples/getstream.py000064400000000000000000000027741046102023000155370ustar 00000000000000import sys import time import uuid import tnetstring import zmq client_id = b'getstream.py' ctx = zmq.Context() out_sock = ctx.socket(zmq.PUSH) out_sock.connect('ipc://server-in') out_stream_sock = ctx.socket(zmq.ROUTER) out_stream_sock.connect('ipc://server-in-stream') in_sock = ctx.socket(zmq.SUB) in_sock.setsockopt(zmq.SUBSCRIBE, client_id) in_sock.connect('ipc://server-out') time.sleep(0.5) rid = str(uuid.uuid4()).encode('utf-8') inseq = 0 outseq = 0 out_sock.send(b'T' + tnetstring.dumps({ b'from': client_id, b'id': rid, b'seq': outseq, b'method': b'GET', b'uri': sys.argv[1].encode('utf-8'), b'stream': True, b'credits': 8192, })) outseq += 1 while True: buf = in_sock.recv() at = buf.find(b' ') receiver = buf[:at] indata = tnetstring.loads(buf[at + 2:]) if indata[b'id'] != rid: continue print('IN: {}'.format(indata)) assert(indata[b'seq'] == inseq) inseq += 1 if (b'type' in indata and (indata[b'type'] == b'error' or indata[b'type'] == b'cancel')) or (b'type' not in indata and b'more' not in indata): break raddr = indata[b'from'] if b'body' in indata and len(indata[b'body']) > 0: outdata = { b'id': rid, b'from': client_id, b'seq': outseq, b'type': b'credit', b'credits': len(indata[b'body']), } print('OUT: {}'.format(outdata)) out_stream_sock.send_multipart([raddr, b'', b'T' + tnetstring.dumps(outdata)]) outseq += 1 condure-1.10.0/examples/holdhandler.py000064400000000000000000000077401046102023000160260ustar 00000000000000# this handler holds all connections open import os import time import datetime import calendar import tnetstring import zmq CONN_TTL = 60000 EXPIRE_INTERVAL = 60000 instance_id = 'holdhandler.{}'.format(os.getpid()).encode('utf-8') ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') in_stream_sock = ctx.socket(zmq.ROUTER) in_stream_sock.identity = instance_id in_stream_sock.connect('ipc://client-out-stream') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') poller = zmq.Poller() poller.register(in_sock, zmq.POLLIN) poller.register(in_stream_sock, zmq.POLLIN) class Connection(object): def __init__(self, rid): self.rid = rid self.seq = 0 self.exp_time = None def send_msg(self, msg): msg[b'from'] = instance_id msg[b'id'] = self.rid[1] msg[b'seq'] = self.seq self.seq += 1 print('OUT {} {}'.format(self.rid[0], msg)) out_sock.send(self.rid[0] + b' T' + tnetstring.dumps(msg)) def send_header(self): msg = {} msg[b'code'] = 200 msg[b'reason'] = b'OK' msg[b'headers'] = [[b'Content-Type', b'text/plain']] msg[b'more'] = True self.send_msg(msg) def send_body(self, data): msg = {} msg[b'body'] = data msg[b'more'] = True self.send_msg(msg) def send_body(to_addr, conns, data): ids = [] for c in conns: ids.append({b'id': c.rid[1], b'seq': c.seq}) c.seq += 1 msg = {} msg[b'from'] = instance_id msg[b'id'] = ids msg[b'body'] = data msg[b'more'] = True print('OUT {} {}'.format(to_addr, msg)) out_sock.send(to_addr + b' T' + tnetstring.dumps(msg)) conns = {} last_exp_time = int(time.time()) while True: socks = dict(poller.poll(1000)) if socks.get(in_sock) == zmq.POLLIN: m_raw = in_sock.recv() elif socks.get(in_stream_sock) == zmq.POLLIN: m_list = in_stream_sock.recv_multipart() m_raw = m_list[2] else: m_raw = None now = int(time.time() * 1000) if m_raw is not None: req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) m_from = req[b'from'] m_id = req[b'id'] m_type = req.get(b'type', b'') ids = [] if isinstance(m_id, list): for id_seq in m_id: ids.append(id_seq[b'id']) else: ids.append(m_id) new_ids = [] known_conns = [] for i in ids: rid = (m_from, i) c = conns.get(rid) if c: c.exp_time = now + CONN_TTL known_conns.append(c) else: new_ids.append(rid) # data if not m_type: for rid in new_ids: c = Connection(rid) conns[rid] = c c.exp_time = now + CONN_TTL c.send_header() elif c: if m_type == b'keep-alive': dt = datetime.datetime.utcnow() ts = calendar.timegm(dt.timetuple()) body = ( 'id: TCPKaliMsgTS-{:016x}.\n' 'event: message\n' 'data: {:04}-{:02}-{:02}T{:02}:{:02}:{:02}\n\n' ).format( (ts * 1000000) + dt.microsecond, dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second ).encode() send_body(m_from, known_conns, body) elif m_type == b'cancel': for c in known_conns: del conns[c.rid] if now >= last_exp_time + EXPIRE_INTERVAL: last_exp_time = now to_remove = [] for rid, c in conns.items(): if now >= c.exp_time: to_remove.append(rid) for rid in to_remove: print('expired {}'.format(rid)) del conns[rid] condure-1.10.0/examples/printreq.py000064400000000000000000000004611046102023000153770ustar 00000000000000# this handler just outputs the request ID import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.PULL) sock.connect('ipc://client-out') while True: m = sock.recv_multipart() req = tnetstring.loads(m[0][1:]) print('{} {}'.format(req[b'from'].decode(), req[b'id'].decode())) condure-1.10.0/examples/reqhandler.py000064400000000000000000000010371046102023000156600ustar 00000000000000# this handler responds to every request with "hello world" import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.REP) sock.connect('ipc://client') while True: m_raw = sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) sock.send(b'T' + tnetstring.dumps(resp)) condure-1.10.0/examples/sendresp.py000064400000000000000000000010741046102023000153570ustar 00000000000000# this program sends a response to a certain request ID import sys import time import tnetstring import zmq body = sys.argv[1] addr = sys.argv[2].encode() rid = sys.argv[3].encode() ctx = zmq.Context() sock = ctx.socket(zmq.PUB) sock.connect('ipc://client-in') # await subscription time.sleep(0.01) resp = {} resp[b'from'] = b'sendresp' resp[b'id'] = rid resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = '{}\n'.format(body).encode() m = [addr + b' T' + tnetstring.dumps(resp)] sock.send_multipart(m) condure-1.10.0/examples/streamhandler.py000064400000000000000000000051031046102023000163620ustar 00000000000000# this handler responds to every request with "hello world" import os import tnetstring import zmq instance_id = 'streamhandler.{}'.format(os.getpid()).encode('utf-8') ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') in_stream_sock = ctx.socket(zmq.ROUTER) in_stream_sock.identity = instance_id in_stream_sock.connect('ipc://client-out-stream') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') poller = zmq.Poller() poller.register(in_sock, zmq.POLLIN) poller.register(in_stream_sock, zmq.POLLIN) while True: socks = dict(poller.poll(None)) if socks.get(in_sock) == zmq.POLLIN: m_raw = in_sock.recv() elif socks.get(in_stream_sock) == zmq.POLLIN: m_list = in_stream_sock.recv_multipart() m_raw = m_list[2] else: continue req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) if req.get(b'type'): # skip all non-data messages continue if req.get(b'uri', b'').startswith(b'ws'): resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 0 resp[b'code'] = 101 resp[b'reason'] = b'Switching Protocols' resp[b'credits'] = 1024 print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 1 resp[b'body'] = b'hello world' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 2 resp[b'type'] = b'close' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) else: resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 0 resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'more'] = True resp[b'credits'] = 1024 print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 1 resp[b'body'] = b'hello world\n' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) condure-1.10.0/rustfmt.toml000064400000000000000000000000251046102023000137400ustar 00000000000000# using the defaults condure-1.10.0/src/app.rs000064400000000000000000000311051046102023000132570ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::client::Client; use crate::server::{Server, MSG_RETAINED_PER_CONNECTION_MAX, MSG_RETAINED_PER_WORKER_MAX}; use crate::websocket; use crate::zhttpsocket; use crate::zmq::SpecInfo; use ipnet::IpNet; use log::info; use signal_hook; use signal_hook::consts::TERM_SIGNALS; use signal_hook::iterator::Signals; use std::cmp; use std::path::PathBuf; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; const INIT_HWM: usize = 128; fn make_specs(base: &str, is_server: bool) -> Result<(String, String, String), String> { if base.starts_with("ipc:") { if is_server { Ok(( format!("{}-{}", base, "in"), format!("{}-{}", base, "in-stream"), format!("{}-{}", base, "out"), )) } else { Ok(( format!("{}-{}", base, "out"), format!("{}-{}", base, "out-stream"), format!("{}-{}", base, "in"), )) } } else if base.starts_with("tcp:") { match base.rfind(':') { Some(pos) => match base[(pos + 1)..base.len()].parse::() { Ok(port) => Ok(( format!("{}:{}", &base[..pos], port), format!("{}:{}", &base[..pos], port + 1), format!("{}:{}", &base[..pos], port + 2), )), Err(e) => Err(format!("error parsing tcp port in base spec: {}", e)), }, None => Err("tcp base spec must specify port".into()), } } else { Err("base spec must be ipc or tcp".into()) } } pub enum ListenSpec { Tcp { addr: std::net::SocketAddr, tls: bool, default_cert: Option, }, Local { path: PathBuf, mode: Option, user: Option, group: Option, }, } pub struct ListenConfig { pub spec: ListenSpec, pub stream: bool, } pub struct Config { pub instance_id: String, pub workers: usize, pub req_maxconn: usize, pub stream_maxconn: usize, pub buffer_size: usize, pub body_buffer_size: usize, pub messages_max: usize, pub req_timeout: Duration, pub stream_timeout: Duration, pub listen: Vec, pub zclient_req: Vec, pub zclient_stream: Vec, pub zclient_connect: bool, pub zserver_req: Vec, pub zserver_stream: Vec, pub zserver_connect: bool, pub ipc_file_mode: u32, pub certs_dir: PathBuf, pub allow_compression: bool, pub deny: Vec, } pub struct App { _server: Option, _client: Option, } impl App { pub fn new(config: &Config) -> Result { if config.req_maxconn < config.workers { return Err("req maxconn must be >= workers".into()); } if config.stream_maxconn < config.workers { return Err("stream maxconn must be >= workers".into()); } let zmq_context = Arc::new(zmq::Context::new()); // set hwm to 5% of maxconn let other_hwm = cmp::max((config.req_maxconn + config.stream_maxconn) / 20, 1); let handle_bound = cmp::max(other_hwm / config.workers, 1); let maxconn = config.req_maxconn + config.stream_maxconn; let server = if !config.listen.is_empty() { let mut any_req = false; let mut any_stream = false; for lc in config.listen.iter() { if lc.stream { any_stream = true; } else { any_req = true; } } let mut zsockman = zhttpsocket::ClientSocketManager::new( Arc::clone(&zmq_context), &config.instance_id, (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * config.workers), INIT_HWM, other_hwm, handle_bound, ); if any_req { let mut specs = Vec::new(); for spec in config.zclient_req.iter() { if config.zclient_connect { info!("zhttp client connect {}", spec); } else { info!("zhttp client bind {}", spec); } specs.push(SpecInfo { spec: spec.clone(), bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_req_specs(&specs) { return Err(format!("failed to set zhttp client req specs: {}", e)); } } if any_stream { let mut out_specs = Vec::new(); let mut out_stream_specs = Vec::new(); let mut in_specs = Vec::new(); for spec in config.zclient_stream.iter() { let (out_spec, out_stream_spec, in_spec) = make_specs(spec, false)?; if config.zclient_connect { info!( "zhttp client connect {} {} {}", out_spec, out_stream_spec, in_spec ); } else { info!( "zhttp client bind {} {} {}", out_spec, out_stream_spec, in_spec ); } out_specs.push(SpecInfo { spec: out_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); out_stream_specs.push(SpecInfo { spec: out_stream_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); in_specs.push(SpecInfo { spec: in_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_stream_specs(&out_specs, &out_stream_specs, &in_specs) { return Err(format!("failed to set zhttp client stream specs: {}", e)); } } Some(Server::new( &config.instance_id, config.workers, config.req_maxconn, config.stream_maxconn, config.buffer_size, config.body_buffer_size, config.messages_max, config.req_timeout, config.stream_timeout, &config.listen, config.certs_dir.as_path(), config.allow_compression, zsockman, handle_bound, )?) } else { None }; let client = if !config.zserver_req.is_empty() || !config.zserver_stream.is_empty() { let mut zsockman = zhttpsocket::ServerSocketManager::new( Arc::clone(&zmq_context), &config.instance_id, (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * config.workers), INIT_HWM, other_hwm, handle_bound, config.stream_maxconn, ); if !config.zserver_req.is_empty() { let mut specs = Vec::new(); for spec in config.zserver_req.iter() { if config.zserver_connect { info!("zhttp server connect {}", spec); } else { info!("zhttp server bind {}", spec); } specs.push(SpecInfo { spec: spec.clone(), bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_server_req_specs(&specs) { return Err(format!("failed to set zhttp server req specs: {}", e)); } } let zsockman = Arc::new(zsockman); let client = Client::new( &config.instance_id, config.workers, config.req_maxconn, config.stream_maxconn, config.buffer_size, config.body_buffer_size, config.messages_max, config.req_timeout, config.stream_timeout, config.allow_compression, &config.deny, zsockman.clone(), handle_bound, )?; // stream specs must only be applied after client is initialized if !config.zserver_stream.is_empty() { let mut in_specs = Vec::new(); let mut in_stream_specs = Vec::new(); let mut out_specs = Vec::new(); for spec in config.zserver_stream.iter() { let (in_spec, in_stream_spec, out_spec) = make_specs(spec, true)?; if config.zserver_connect { info!( "zhttp server connect {} {} {}", in_spec, in_stream_spec, out_spec ); } else { info!( "zhttp server bind {} {} {}", in_spec, in_stream_spec, out_spec ); } in_specs.push(SpecInfo { spec: in_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); in_stream_specs.push(SpecInfo { spec: in_stream_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); out_specs.push(SpecInfo { spec: out_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_server_stream_specs(&in_specs, &in_stream_specs, &out_specs) { return Err(format!("failed to set zhttp server stream specs: {}", e)); } } Some(client) } else { None }; Ok(Self { _server: server, _client: client, }) } pub fn wait_for_term(&self) { let mut signals = Signals::new(TERM_SIGNALS).unwrap(); let term_now = Arc::new(AtomicBool::new(false)); // ensure two term signals in a row causes the app to immediately exit for signal_type in TERM_SIGNALS { signal_hook::flag::register_conditional_shutdown( *signal_type, 1, // exit code Arc::clone(&term_now), ) .unwrap(); signal_hook::flag::register(*signal_type, Arc::clone(&term_now)).unwrap(); } // wait for termination for signal in &mut signals { match signal { signal_type if TERM_SIGNALS.contains(&signal_type) => break, _ => unreachable!(), } } } pub fn sizes() -> Vec<(String, usize)> { let mut out = Vec::new(); out.extend(Server::task_sizes()); out.extend(Client::task_sizes()); out.push(( "deflate_codec_state".to_string(), websocket::deflate_codec_state_size(), )); out } } condure-1.10.0/src/arena.rs000064400000000000000000000411761046102023000135760ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use slab::Slab; use std::cell::{RefCell, RefMut}; use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::{Mutex, MutexGuard}; pub struct EntryGuard<'a, T> { entries: RefMut<'a, Slab>, entry: &'a mut T, key: usize, } impl EntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for EntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for EntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a sharable slab for use within a single thread. // operations are protected by a RefCell. when an element is retrieved for // reading or modification, it is wrapped in a EntryGuard which keeps the // entire slab borrowed until the caller is done working with the element pub struct Memory { entries: RefCell>, } impl Memory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: RefCell::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.borrow(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.borrow_mut(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.borrow_mut(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in EntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the EntryGuard // and we can extend the lifetime of the reference beyond the // RefMut let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(EntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.borrow(); entries.get(0).unwrap() as *const T } } pub struct SyncEntryGuard<'a, T> { entries: MutexGuard<'a, Slab>, entry: &'a mut T, key: usize, } impl SyncEntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for SyncEntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for SyncEntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a thread-safe slab. operations are protected by a // mutex. when an element is retrieved for reading or modification, it is // wrapped in a EntryGuard which keeps the entire slab locked until the // caller is done working with the element pub struct SyncMemory { entries: Mutex>, } impl SyncMemory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.lock().unwrap(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.lock().unwrap(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in SyncEntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the SyncEntryGuard // and we can extend the lifetime of the reference beyond the // MutexGuard let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(SyncEntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.lock().unwrap(); entries.get(0).unwrap() as *const T } } pub struct ReusableValue { reusable: std::sync::Arc>, value: *mut T, key: usize, } impl ReusableValue { // vec element addresses are guaranteed to be stable once created, // and elements are only removed when the Reusable is dropped, and // the Arc'd Reusable is guaranteed to live as long as // ReusableValue, therefore it is safe to assume the element will // live at least as long as the ReusableValue fn get(&self) -> &T { unsafe { &*self.value } } fn get_mut(&mut self) -> &mut T { unsafe { &mut *self.value } } } impl Drop for ReusableValue { fn drop(&mut self) { let mut entries = self.reusable.entries.lock().unwrap(); entries.0.remove(self.key); } } impl Deref for ReusableValue { type Target = T; fn deref(&self) -> &Self::Target { self.get() } } impl DerefMut for ReusableValue { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } // like Memory, but for preinitializing each value and reusing pub struct Reusable { entries: Mutex<(Slab<()>, Vec)>, } impl Reusable { pub fn new(capacity: usize, init_fn: F) -> Self where F: Fn() -> T, { let mut values = Vec::with_capacity(capacity); for _ in 0..capacity { values.push(init_fn()); } // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new((s, values)), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.0.len() } #[allow(clippy::result_unit_err)] pub fn reserve(self: &std::sync::Arc) -> Result, ()> { let mut entries = self.entries.lock().unwrap(); // out of capacity. the number of buffers is fixed if entries.0.len() == entries.0.capacity() { return Err(()); } let key = entries.0.insert(()); let value = &mut entries.1[key] as *mut T; Ok(ReusableValue { reusable: self.clone(), value, key, }) } } pub struct RcEntry { value: T, refs: usize, } pub type RcMemory = Memory>; pub struct Rc { memory: std::rc::Rc>, key: usize, } impl Rc { #[allow(clippy::result_unit_err)] pub fn new(v: T, memory: &std::rc::Rc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: std::rc::Rc::clone(memory), key, }) } #[allow(clippy::should_implement_trait)] pub fn clone(rc: &Rc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Rc won't be dropped until this Rc drops, // therefore it is safe to assume the entry managed by this Rc will // live at least as long as this Rc, and we can extend the lifetime // of the reference beyond the EntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Rc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } pub type ArcMemory = SyncMemory>; pub struct Arc { memory: std::sync::Arc>, key: usize, } impl Arc { #[allow(clippy::result_unit_err)] pub fn new(v: T, memory: &std::sync::Arc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: memory.clone(), key, }) } #[allow(clippy::should_implement_trait)] pub fn clone(rc: &Arc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Arc won't be dropped until this Arc drops, // therefore it is safe to assume the entry managed by this Arc will // live at least as long as this Arc, and we can extend the lifetime // of the reference beyond the SyncEntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Arc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } // adapted from https://github.com/rust-lang/rfcs/pull/2802 pub fn recycle_vec(mut v: Vec) -> Vec { assert_eq!(core::mem::size_of::(), core::mem::size_of::()); assert_eq!(core::mem::align_of::(), core::mem::align_of::()); v.clear(); let ptr = v.as_mut_ptr(); let capacity = v.capacity(); mem::forget(v); let ptr = ptr as *mut U; unsafe { Vec::from_raw_parts(ptr, 0, capacity) } } // ReusableVec inspired by recycle_vec pub struct ReusableVecHandle<'a, T> { vec: &'a mut Vec, } impl ReusableVecHandle<'_, T> { pub fn get_ref(&self) -> &Vec { self.vec } pub fn get_mut(&mut self) -> &mut Vec { self.vec } } impl Drop for ReusableVecHandle<'_, T> { fn drop(&mut self) { self.vec.clear(); } } impl Deref for ReusableVecHandle<'_, T> { type Target = Vec; fn deref(&self) -> &Self::Target { self.get_ref() } } impl DerefMut for ReusableVecHandle<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } pub struct ReusableVec { vec: Vec<()>, size: usize, align: usize, } impl ReusableVec { pub fn new(capacity: usize) -> Self { let size = mem::size_of::(); let align = mem::align_of::(); let vec: Vec = Vec::with_capacity(capacity); // safety: we must cast to Vec before using, where U has the same // size and alignment as T let vec: Vec<()> = unsafe { mem::transmute(vec) }; Self { vec, size, align } } pub fn get_as_new(&mut self) -> ReusableVecHandle<'_, U> { let size = mem::size_of::(); let align = mem::align_of::(); // if these don't match, panic. it's up the user to ensure the type // is acceptable assert_eq!(self.size, size); assert_eq!(self.align, align); let vec: &mut Vec<()> = &mut self.vec; // safety: U has the expected size and alignment let vec: &mut Vec = unsafe { mem::transmute(vec) }; // the vec starts empty, and is always cleared when the handle drops. // get_as_new() borrows self mutably, so it's not possible to create // a handle when one already exists assert!(vec.is_empty()); ReusableVecHandle { vec } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_reusable() { let reusable = std::sync::Arc::new(Reusable::new(2, || vec![0; 128])); assert_eq!(reusable.len(), 0); let mut buf1 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 1); let mut buf2 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 2); // no room assert!(reusable.reserve().is_err()); buf1[..5].copy_from_slice(b"hello"); buf2[..5].copy_from_slice(b"world"); assert_eq!(&buf1[..5], b"hello"); assert_eq!(&buf2[..5], b"world"); mem::drop(buf1); assert_eq!(reusable.len(), 1); mem::drop(buf2); assert_eq!(reusable.len(), 0); } #[test] fn test_rc() { let memory = std::rc::Rc::new(RcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Rc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Rc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Rc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Rc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_arc() { let memory = std::sync::Arc::new(ArcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Arc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Arc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Arc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Arc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_reusable_vec() { let mut vec_mem = ReusableVec::new::(100); let mut vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); vec.push(1); assert_eq!(vec.len(), 1); mem::drop(vec); let vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); } } condure-1.10.0/src/buffer.rs000064400000000000000000000752561046102023000137670ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::future::{AsyncWrite, AsyncWriteExt}; use std::cell::RefCell; use std::cmp; use std::io; use std::io::{Read, Write}; use std::mem::{self, MaybeUninit}; use std::rc::Rc; use std::slice; pub const VECTORED_MAX: usize = 8; pub fn trim_for_display(s: &str, max: usize) -> String { // NOTE: O(n) let char_len = s.chars().count(); if char_len > max && max >= 7 { let dist = max / 2; let mut left_end = 0; let mut right_start = 0; // NOTE: O(n) for (i, (pos, _)) in s.char_indices().enumerate() { // dist guaranteed to be < char_len if i == dist { left_end = pos; } // (char_len - dist + 3) guaranteed to be < char_len if i == char_len - dist + 3 { right_start = pos; } } let left = &s[..left_end]; let right = &s[right_start..]; format!("{}...{}", left, right) } else { s.to_owned() } } #[allow(clippy::len_without_is_empty)] pub trait RefRead { fn len(&self) -> usize; fn get_ref(&self) -> &[u8]; fn get_mut(&mut self) -> &mut [u8]; fn consume(&mut self, amt: usize); fn get_ref_vectored<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { assert!(!bufs.is_empty()); bufs[0] = self.get_ref(); &mut bufs[..1] } fn get_mut_vectored<'data, 'bufs, const N: usize>( &'data mut self, bufs: &'bufs mut MaybeUninit<[&'data mut [u8]; N]>, ) -> &'bufs mut [&'data mut [u8]] { let bufs = unsafe { bufs.assume_init_mut() }; bufs[0] = self.get_mut(); &mut bufs[..1] } } impl RefRead for io::Cursor<&mut [u8]> { fn len(&self) -> usize { RefRead::get_ref(self).len() } fn get_ref(&self) -> &[u8] { let pos = self.position() as usize; &self.get_ref()[pos..] } fn get_mut(&mut self) -> &mut [u8] { let pos = self.position() as usize; &mut self.get_mut()[pos..] } fn consume(&mut self, amt: usize) { let pos = self.position(); self.set_position(pos + (amt as u64)); } } pub fn write_vectored_offset( writer: &mut W, bufs: &[&[u8]], offset: usize, ) -> Result { if bufs.is_empty() { return Ok(0); } let mut offset = offset; let mut start = 0; while offset >= bufs[start].len() { // on the last buf? if start + 1 >= bufs.len() { // exceeding the last buf is an error if offset > bufs[start].len() { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } return Ok(0); } offset -= bufs[start].len(); start += 1; } let mut arr = [io::IoSlice::new(&b""[..]); VECTORED_MAX]; let mut arr_len = 0; for (index, &buf) in bufs.iter().enumerate().skip(start) { let buf = if index == start { &buf[offset..] } else { buf }; arr[arr_len] = io::IoSlice::new(buf); arr_len += 1; } writer.write_vectored(&arr[..arr_len]) } pub async fn write_vectored_offset_async( writer: &mut W, bufs: &[&[u8]], offset: usize, ) -> Result { if bufs.is_empty() { return Ok(0); } let mut offset = offset; let mut start = 0; while offset >= bufs[start].len() { // on the last buf? if start + 1 >= bufs.len() { // exceeding the last buf is an error if offset > bufs[start].len() { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } return Ok(0); } offset -= bufs[start].len(); start += 1; } let mut arr = [io::IoSlice::new(&b""[..]); VECTORED_MAX]; let mut arr_len = 0; for (index, &buf) in bufs.iter().enumerate().skip(start) { let buf = if index == start { &buf[offset..] } else { buf }; arr[arr_len] = io::IoSlice::new(buf); arr_len += 1; } writer.write_vectored(&arr[..arr_len]).await } struct LimitBufsRestore { index: usize, ptr: T, len: usize, } pub struct LimitBufsGuard<'a, 'b> { bufs: &'b mut [&'a [u8]], start: usize, end: usize, restore: Option>, } impl<'a: 'b, 'b> LimitBufsGuard<'a, 'b> { pub fn as_slice(&self) -> &[&'a [u8]] { &self.bufs[self.start..self.end] } } impl<'a: 'b, 'b> Drop for LimitBufsGuard<'a, 'b> { fn drop(&mut self) { if let Some(restore) = self.restore.take() { // SAFETY: ptr and len were collected earlier from the original // memory referred to by the slice at this index and they are // still valid. the only issue with reconstructing the slice is // that we currently have a different slice using the same memory // at this index. however, this is safe because we also replace // the slice at this index and the two slices don't coexist unsafe { self.bufs[restore.index] = slice::from_raw_parts(restore.ptr, restore.len); } } } } pub struct LimitBufsMutGuard<'a, 'b> { bufs: &'b mut [&'a mut [u8]], start: usize, end: usize, restore: Option>, } impl<'a: 'b, 'b> LimitBufsMutGuard<'a, 'b> { pub fn as_slice(&mut self) -> &mut [&'a mut [u8]] { &mut self.bufs[self.start..self.end] } } impl<'a: 'b, 'b> Drop for LimitBufsMutGuard<'a, 'b> { fn drop(&mut self) { if let Some(restore) = self.restore.take() { // SAFETY: ptr and len were collected earlier from the original // memory referred to by the slice at this index and they are // still valid. the only issue with reconstructing the slice is // that we currently have a different slice using the same memory // at this index. however, this is safe because we also replace // the slice at this index and the two slices don't coexist unsafe { self.bufs[restore.index] = slice::from_raw_parts_mut(restore.ptr, restore.len); } } } } pub trait LimitBufs<'a, 'b> { fn limit(&'b mut self, size: usize) -> LimitBufsGuard<'a, 'b>; } impl<'a: 'b, 'b> LimitBufs<'a, 'b> for [&'a [u8]] { fn limit(&'b mut self, size: usize) -> LimitBufsGuard<'a, 'b> { let mut end = self.len(); let mut restore = None; let mut want = size; for (index, item) in self.iter_mut().enumerate() { let buf: &[u8] = item; let buf_len = buf.len(); if buf_len >= want { let len = buf.len(); let ptr = buf.as_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &slice::from_raw_parts(ptr, len)[..want]; } end = index + 1; break; } want -= buf_len; } LimitBufsGuard { bufs: self, start: 0, end, restore, } } } pub trait LimitBufsMut<'a: 'b, 'b> { fn skip(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b>; fn limit(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b>; } impl<'a: 'b, 'b> LimitBufsMut<'a, 'b> for [&'a mut [u8]] { fn skip(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b> { let mut start = 0; let end = self.len(); let mut restore = None; let mut skip = size; for (index, item) in self.iter_mut().enumerate() { let buf: &mut [u8] = item; let buf_len = buf.len(); if buf_len >= skip { let len = buf.len(); let ptr = buf.as_mut_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &mut slice::from_raw_parts_mut(ptr, len)[skip..]; } start = index; break; } skip -= buf_len; } LimitBufsMutGuard { bufs: self, start, end, restore, } } fn limit(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b> { let mut end = self.len(); let mut restore = None; let mut want = size; for (index, item) in self.iter_mut().enumerate() { let buf: &mut [u8] = item; let buf_len = buf.len(); if buf_len >= want { let len = buf.len(); let ptr = buf.as_mut_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &mut slice::from_raw_parts_mut(ptr, len)[..want]; } end = index + 1; break; } want -= buf_len; } LimitBufsMutGuard { bufs: self, start: 0, end, restore, } } } pub struct Buffer { buf: Vec, start: usize, end: usize, } impl Buffer { pub fn new(size: usize) -> Buffer { let buf = vec![0; size]; Buffer { buf, start: 0, end: 0, } } pub fn clear(&mut self) { self.start = 0; self.end = 0; } pub fn read_avail(&self) -> usize { self.end - self.start } pub fn read_buf(&self) -> &[u8] { &self.buf[self.start..self.end] } pub fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); self.start += amount; } pub fn write_avail(&self) -> usize { self.buf.len() - self.end } pub fn write_buf(&mut self) -> &mut [u8] { let len = self.buf.len(); &mut self.buf[self.end..len] } pub fn write_commit(&mut self, amount: usize) { assert!(self.end + amount <= self.buf.len()); self.end += amount; } } #[cfg(test)] impl Read for Buffer { fn read(&mut self, buf: &mut [u8]) -> Result { // fully qualified to work around future method warning // https://github.com/rust-lang/rust/issues/48919 let src = Buffer::read_buf(self); let size = cmp::min(src.len(), buf.len()); buf[..size].copy_from_slice(&src[..size]); self.read_commit(size); Ok(size) } } impl Write for Buffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len()); dest[..size].copy_from_slice(&buf[..size]); self.write_commit(size); Ok(size) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct TmpBuffer(RefCell>); #[allow(clippy::len_without_is_empty)] impl TmpBuffer { pub fn new(size: usize) -> Self { Self(RefCell::new(vec![0; size])) } pub fn len(&self) -> usize { self.0.borrow().len() } } // holds a Vec but only exposes the portion of it considered to be // readable ("filled"). any remaining bytes may be zeroed or uninitialized // and are not considered to be readable pub struct FilledBuf { data: Vec, filled: usize, } impl FilledBuf { // panics if filled is larger than data.len() pub fn new(data: Vec, filled: usize) -> Self { assert!(filled <= data.len()); Self { data, filled } } pub fn filled(&self) -> &[u8] { &self.data[..self.filled] } pub fn filled_len(&self) -> usize { self.filled } pub fn into_inner(self) -> Vec { self.data } } pub struct BaseRingBuffer { buf: T, start: usize, end: usize, tmp: Rc, } impl + AsMut<[u8]>> BaseRingBuffer { pub fn capacity(&self) -> usize { self.buf.as_ref().len() } pub fn clear(&mut self) { self.start = 0; self.end = 0; } pub fn write_from(&mut self, r: &mut R) -> Result { let size = match r.read(self.write_buf()) { Ok(size) => size, Err(e) => return Err(e), }; self.write_commit(size); Ok(size) } pub fn read_avail(&self) -> usize { self.end - self.start } pub fn read_buf(&self) -> &[u8] { let buf = self.buf.as_ref(); let end = cmp::min(self.end, buf.len()); &buf[self.start..end] } pub fn read_buf_mut(&mut self) -> &mut [u8] { let buf = self.buf.as_mut(); let end = cmp::min(self.end, buf.len()); &mut buf[self.start..end] } pub fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); let buf = self.buf.as_ref(); self.start += amount; if self.start == self.end { self.start = 0; self.end = 0; } else if self.start >= buf.len() { self.start -= buf.len(); self.end -= buf.len(); } } pub fn write_avail(&self) -> usize { self.buf.as_ref().len() - (self.end - self.start) } pub fn write_buf(&mut self) -> &mut [u8] { let buf = self.buf.as_mut(); let (start, end) = if self.end < buf.len() { (self.end, buf.len()) } else { (self.end - buf.len(), self.start) }; &mut buf[start..end] } pub fn write_commit(&mut self, amount: usize) { assert!((self.end - self.start) + amount <= self.buf.as_ref().len()); self.end += amount; } // return true if the readable bytes have not wrapped pub fn is_readable_contiguous(&self) -> bool { self.end <= self.buf.as_ref().len() } pub fn align(&mut self) -> usize { if self.start == 0 { return 0; } let buf = self.buf.as_mut(); let size = self.end - self.start; if self.end <= buf.len() { // if the buffer hasn't wrapped, simply copy down buf.copy_within(self.start.., 0); } else if size <= self.start { // if the buffer has wrapped, but the wrapped part can be copied // without overlapping, then copy the wrapped part followed by // initial part let left_size = self.end - buf.len(); let right_size = buf.len() - self.start; buf.copy_within(..left_size, right_size); buf.copy_within(self.start..(self.start + right_size), 0); } else { // if the buffer has wrapped and the wrapped part can't be copied // without overlapping, then use a temporary buffer to // facilitate. smaller part is copied to the temp buffer, then // the larger and small parts (in that order) are copied into // their intended locations. in the worst case, up to 50% of // the buffer may be copied twice let left_size = self.end - buf.len(); let right_size = buf.len() - self.start; let (lsize, lsrc, ldest, hsize, hsrc, hdest); if left_size < right_size { lsize = left_size; hsize = right_size; lsrc = 0; ldest = hsize; hsrc = self.start; hdest = 0; } else { lsize = right_size; hsize = left_size; lsrc = self.start; ldest = 0; hsrc = 0; hdest = lsize; } let mut tmp = self.tmp.0.borrow_mut(); tmp[..lsize].copy_from_slice(&buf[lsrc..(lsrc + lsize)]); buf.copy_within(hsrc..(hsrc + hsize), hdest); buf[ldest..(ldest + lsize)].copy_from_slice(&tmp[..lsize]); } self.start = 0; self.end = size; size } pub fn get_tmp(&self) -> &Rc { &self.tmp } } #[cfg(test)] impl + AsMut<[u8]>> Read for BaseRingBuffer { fn read(&mut self, buf: &mut [u8]) -> Result { let mut pos = 0; while pos < buf.len() && self.read_avail() > 0 { // fully qualified to work around future method warning // https://github.com/rust-lang/rust/issues/48919 let src = Self::read_buf(self); let size = cmp::min(src.len(), buf.len() - pos); buf[pos..(pos + size)].copy_from_slice(&src[..size]); self.read_commit(size); pos += size; } Ok(pos) } } impl + AsMut<[u8]>> Write for BaseRingBuffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let mut pos = 0; while pos < buf.len() && self.write_avail() > 0 { let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len() - pos); dest[..size].copy_from_slice(&buf[pos..(pos + size)]); self.write_commit(size); pos += size; } Ok(pos) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } impl + AsMut<[u8]>> RefRead for BaseRingBuffer { fn len(&self) -> usize { self.read_avail() } fn get_ref(&self) -> &[u8] { self.read_buf() } fn get_mut(&mut self) -> &mut [u8] { self.read_buf_mut() } fn consume(&mut self, amt: usize) { self.read_commit(amt); } fn get_ref_vectored<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { assert!(!bufs.is_empty()); let buf = self.buf.as_ref(); let buf_len = buf.len(); if self.end > buf_len && bufs.len() >= 2 { let (part1, part2) = buf.split_at(self.start); bufs[0] = part2; bufs[1] = &part1[..(self.end - buf_len)]; &mut bufs[..2] } else { bufs[0] = &buf[self.start..self.end]; &mut bufs[..1] } } fn get_mut_vectored<'data, 'bufs, const N: usize>( &'data mut self, bufs: &'bufs mut MaybeUninit<[&'data mut [u8]; N]>, ) -> &'bufs mut [&'data mut [u8]] { let bufs = unsafe { bufs.assume_init_mut() }; let buf = self.buf.as_mut(); let buf_len = buf.len(); if self.end > buf_len && bufs.len() >= 2 { let (part1, part2) = buf.split_at_mut(self.start); bufs[0] = part2; bufs[1] = &mut part1[..(self.end - buf_len)]; &mut bufs[..2] } else { bufs[0] = &mut buf[self.start..self.end]; &mut bufs[..1] } } } impl BaseRingBuffer> { pub fn new(size: usize, tmp: &Rc) -> Self { assert!(size <= tmp.len()); let buf = vec![0; size]; BaseRingBuffer { buf, start: 0, end: 0, tmp: Rc::clone(tmp), } } // extract inner buffer, aligning it first if necessary, and replace it // with an empty buffer. this should be cheap if the inner buffer is // already aligned. afterwards, the ringbuffer will have a capacity of // zero and will be essentially unusable until set_inner is called with a // non-empty buffer pub fn take_inner(&mut self) -> FilledBuf { self.align(); let data = mem::take(&mut self.buf); let filled = self.end; self.end = 0; FilledBuf::new(data, filled) } // replace the inner buffer. this should be cheap if the original inner // buffer is empty, which is the case if take_inner was called earlier. // panics if the new buffer is larger than the tmp buffer pub fn set_inner(&mut self, buf: FilledBuf) { let filled = buf.filled_len(); let data = buf.into_inner(); assert!(data.len() <= self.tmp.len()); self.buf = data; self.start = 0; self.end = filled; } pub fn swap_inner(&mut self, other: &mut Self) { let buf = self.take_inner(); self.set_inner(other.take_inner()); other.set_inner(buf); } } impl<'a> BaseRingBuffer<&'a mut [u8]> { pub fn new(buf: &'a mut [u8], tmp: &Rc) -> Self { assert!(buf.len() <= tmp.len()); BaseRingBuffer { buf, start: 0, end: 0, tmp: Rc::clone(tmp), } } } pub type RingBuffer = BaseRingBuffer>; pub type SliceRingBuffer<'a> = BaseRingBuffer<&'a mut [u8]>; #[cfg(test)] mod tests { use super::*; use std::io::{Read, Write}; #[test] fn test_write_vectored_offset() { struct MyWriter { bufs: Vec, } impl MyWriter { fn new() -> Self { Self { bufs: Vec::new() } } } impl Write for MyWriter { fn write(&mut self, buf: &[u8]) -> Result { self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { total += buf.len(); self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } // empty let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[], 0); assert_eq!(r.unwrap(), 0); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 6); assert!(r.is_err()); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 12); assert!(r.is_err()); // nothing to write let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 5); assert_eq!(r.unwrap(), 0); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 0); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "apple"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 3); assert_eq!(r.unwrap(), 2); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "le"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 3); assert_eq!(r.unwrap(), 8); assert_eq!(w.bufs.len(), 2); assert_eq!(w.bufs[0], "le"); assert_eq!(w.bufs[1], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 5); assert_eq!(r.unwrap(), 6); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 6); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "anana"); } #[test] fn test_buffer() { let mut b = Buffer::new(8); assert_eq!(b.read_avail(), 0); assert_eq!(b.write_avail(), 8); let size = b.write(b"hello").unwrap(); assert_eq!(size, 5); assert_eq!(b.read_avail(), 5); assert_eq!(b.write_avail(), 3); let size = b.write(b"world").unwrap(); assert_eq!(size, 3); assert_eq!(b.read_avail(), 8); assert_eq!(b.write_avail(), 0); let mut tmp = [0; 16]; let size = b.read(&mut tmp).unwrap(); assert_eq!(&tmp[..size], b"hellowor"); b.clear(); assert_eq!(b.read_avail(), 0); assert_eq!(b.write_avail(), 8); } #[test] fn test_ringbuffer() { let mut buf = [0u8; 8]; let tmp = Rc::new(TmpBuffer::new(8)); let mut r = RingBuffer::new(8, &tmp); assert_eq!(r.read_avail(), 0); assert_eq!(r.write_avail(), 8); r.write(b"12345").unwrap(); assert_eq!(r.read_avail(), 5); assert_eq!(r.write_avail(), 3); r.write(b"678").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.write_avail(), 0); assert_eq!(r.read_buf(), b"12345678"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"12345678"); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.read_avail(), 3); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"9abcd").unwrap(); assert_eq!(r.read_avail(), 8); assert_eq!(r.write_avail(), 0); r.read(&mut buf[5..]).unwrap(); assert_eq!(r.read_avail(), 5); assert_eq!(r.write_avail(), 3); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.read_avail(), 0); assert_eq!(r.write_avail(), 8); assert_eq!(&buf, b"9abcd678"); r.write(b"12345").unwrap(); r.read(&mut buf[..2]).unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"345"); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 3); r.align(); assert_eq!(r.read_avail(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"6789a").unwrap(); r.read(&mut buf[..2]).unwrap(); r.write(b"bc").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"56789a"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"56789a"); assert_eq!(bufs[1], b"bc"); assert_eq!(r.write_avail(), 0); r.align(); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"56789abc"); assert_eq!(r.write_avail(), 0); r.read(&mut buf[..6]).unwrap(); r.write(b"def123").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"bc"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"bc"); assert_eq!(bufs[1], b"def123"); assert_eq!(r.write_avail(), 0); r.align(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"bcdef123"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"bcdef123"); assert_eq!(r.write_avail(), 0); r.clear(); r.write(b"12345678").unwrap(); r.read(&mut buf[..6]).unwrap(); r.write(b"9abc").unwrap(); assert_eq!(r.read_avail(), 6); assert_eq!(r.read_buf().len(), 2); r.align(); assert_eq!(r.read_avail(), 6); assert_eq!(r.read_buf().len(), 6); } #[test] fn test_slice_ringbuffer() { let mut buf = [0; 8]; let mut backing_buf = [0; 8]; let tmp = Rc::new(TmpBuffer::new(8)); let mut r = SliceRingBuffer::new(&mut backing_buf, &tmp); r.write(b"12345678").unwrap(); let size = r.read(&mut buf[..4]).unwrap(); assert_eq!(&buf[..size], b"1234"); r.write(b"90ab").unwrap(); let size = r.read(&mut buf).unwrap(); assert_eq!(&buf[..size], b"567890ab"); } #[test] fn test_limitbufs() { let mut buf1 = [b'1', b'2', b'3', b'4']; let mut buf2 = [b'5', b'6', b'7', b'8']; let mut buf3 = [b'9', b'0', b'a', b'b']; let mut bufs = [buf1.as_slice(), buf2.as_slice(), buf3.as_slice()]; { let limited = bufs.limit(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"1234"); assert_eq!(&limited[1], b"567"); } assert_eq!(bufs.len(), 3); assert_eq!(&bufs[0], b"1234"); assert_eq!(&bufs[1], b"5678"); assert_eq!(&bufs[2], b"90ab"); let mut bufs = [ buf1.as_mut_slice(), buf2.as_mut_slice(), buf3.as_mut_slice(), ]; { let mut limited = bufs.limit(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"1234"); assert_eq!(&limited[1], b"567"); } { let mut limited = bufs.skip(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"8"); assert_eq!(&limited[1], b"90ab"); } assert_eq!(bufs.len(), 3); assert_eq!(&bufs[0], b"1234"); assert_eq!(&bufs[1], b"5678"); assert_eq!(&bufs[2], b"90ab"); } } condure-1.10.0/src/channel.rs000064400000000000000000000577571046102023000141340ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::event; use crate::list; use slab::Slab; use std::cell::RefCell; use std::collections::VecDeque; use std::mem; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::sync::Arc; pub struct Sender { sender: Option>, read_set_readiness: event::SetReadiness, write_registration: event::Registration, cts: Option>, } impl Sender { // NOTE: only makes sense for rendezvous channels pub fn can_send(&self) -> bool { match &self.cts { Some(cts) => cts.load(Ordering::Relaxed), None => true, } } pub fn get_write_registration(&self) -> &event::Registration { &self.write_registration } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { if let Some(cts) = &self.cts { if cts .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) .is_err() { return Err(mpsc::TrySendError::Full(t)); } // cts will only be true if a read was performed while the queue // was empty, and this function is the only place where the queue // is written to. this means the try_send call below will only // fail if the receiver disconnected } match self.sender.as_ref().unwrap().try_send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } pub fn send(&self, t: T) -> Result<(), mpsc::SendError> { if self.cts.is_some() { panic!("blocking send with rendezvous channel not supported") } match self.sender.as_ref().unwrap().send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } } impl Drop for Sender { fn drop(&mut self) { mem::drop(self.sender.take().unwrap()); self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); } } pub struct Receiver { receiver: mpsc::Receiver, read_registration: event::Registration, write_set_readiness: event::SetReadiness, cts: Option>, } impl Receiver { pub fn get_read_registration(&self) -> &event::Registration { &self.read_registration } pub fn try_recv(&self) -> Result { match self.receiver.try_recv() { Ok(t) => { if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } Err(mpsc::TryRecvError::Empty) if self.cts.is_some() => { let cts = self.cts.as_ref().unwrap(); if cts .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Err(mpsc::TryRecvError::Empty) } Err(e) => Err(e), } } pub fn recv(&self) -> Result { let t = self.receiver.recv()?; if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } } pub fn channel(bound: usize) -> (Sender, Receiver) { let (read_reg, read_sr) = event::Registration::new(); let (write_reg, write_sr) = event::Registration::new(); // rendezvous channel if bound == 0 { let (s, r) = mpsc::sync_channel::(1); let cts = Arc::new(AtomicBool::new(false)); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: Some(Arc::clone(&cts)), }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: Some(Arc::clone(&cts)), }; (sender, receiver) } else { let (s, r) = mpsc::sync_channel::(bound); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: None, }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: None, }; // channel is immediately writable receiver .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); (sender, receiver) } } struct LocalSenderData { notified: bool, write_set_readiness: event::LocalSetReadiness, } struct LocalSenders { nodes: Slab>, waiting: list::List, } struct LocalChannel { queue: RefCell>, senders: RefCell, read_set_readiness: RefCell>, } impl LocalChannel { fn senders_is_empty(&self) -> bool { self.senders.borrow().nodes.is_empty() } fn add_sender(&self, write_sr: event::LocalSetReadiness) -> Result { let mut senders = self.senders.borrow_mut(); if senders.nodes.len() == senders.nodes.capacity() { return Err(()); } let key = senders.nodes.insert(list::Node::new(LocalSenderData { notified: false, write_set_readiness: write_sr, })); Ok(key) } fn remove_sender(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); senders.waiting.remove(&mut senders.nodes, key); senders.nodes.remove(key); if senders.nodes.is_empty() { if let Some(read_sr) = &*self.read_set_readiness.borrow() { // notify for disconnect read_sr.set_readiness(mio::Interest::READABLE).unwrap(); } } } fn set_sender_waiting(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); // add if not already present if senders.nodes[key].prev.is_none() && senders.waiting.head != Some(key) { senders.waiting.push_back(&mut senders.nodes, key); } } fn notify_one_sender(&self) { let senders = &mut *self.senders.borrow_mut(); // notify next waiting sender, if any if let Some(key) = senders.waiting.pop_front(&mut senders.nodes) { let sender = &mut senders.nodes[key].value; sender.notified = true; sender .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } } fn sender_is_notified(&self, key: usize) -> bool { self.senders.borrow().nodes[key].value.notified } fn clear_sender_notified(&self, key: usize) { self.senders.borrow_mut().nodes[key].value.notified = false; } } pub struct LocalSender { channel: Rc>, key: usize, write_registration: event::LocalRegistration, } impl LocalSender { pub fn get_write_registration(&self) -> &event::LocalRegistration { &self.write_registration } // if this returns true, then the next call to try_send() by any sender // is guaranteed to not return TrySendError::Full. // if this returns false, the sender is added to the wait list pub fn check_send(&self) -> bool { let queue = self.channel.queue.borrow(); let can_send = queue.len() < queue.capacity(); if !can_send { self.channel.set_sender_waiting(self.key); } can_send } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { // we are acting, so clear the notified flag self.channel.clear_sender_notified(self.key); let read_sr = &*self.channel.read_set_readiness.borrow(); let read_sr = match read_sr { Some(sr) => sr, None => { // receiver is disconnected return Err(mpsc::TrySendError::Disconnected(t)); } }; let mut queue = self.channel.queue.borrow_mut(); if queue.len() < queue.capacity() { queue.push_back(t); read_sr.set_readiness(mio::Interest::READABLE).unwrap(); Ok(()) } else { self.channel.set_sender_waiting(self.key); Err(mpsc::TrySendError::Full(t)) } } pub fn cancel(&self) { // if we were notified but never acted on it, notify the next waiting sender, if any if self.channel.sender_is_notified(self.key) { self.channel.clear_sender_notified(self.key); self.channel.notify_one_sender(); } } // NOTE: if the receiver is dropped while there are multiple senders, // only one of the senders will be notified of the disconnect #[allow(clippy::result_unit_err)] pub fn try_clone( &self, memory: &Rc>, ) -> Result { let (write_reg, write_sr) = event::LocalRegistration::new(memory); let key = self.channel.add_sender(write_sr)?; Ok(Self { channel: self.channel.clone(), key, write_registration: write_reg, }) } // returns error if a receiver already exists #[allow(clippy::result_unit_err)] pub fn make_receiver( &self, memory: &Rc>, ) -> Result, ()> { if self.channel.read_set_readiness.borrow().is_some() { return Err(()); } let (read_reg, read_sr) = event::LocalRegistration::new(memory); *self.channel.read_set_readiness.borrow_mut() = Some(read_sr); Ok(LocalReceiver { channel: self.channel.clone(), read_registration: read_reg, }) } } impl Drop for LocalSender { fn drop(&mut self) { self.cancel(); self.channel.remove_sender(self.key); } } pub struct LocalReceiver { channel: Rc>, read_registration: event::LocalRegistration, } impl LocalReceiver { pub fn get_read_registration(&self) -> &event::LocalRegistration { &self.read_registration } pub fn try_recv(&self) -> Result { let mut queue = self.channel.queue.borrow_mut(); if queue.is_empty() { if self.channel.senders_is_empty() { return Err(mpsc::TryRecvError::Disconnected); } return Err(mpsc::TryRecvError::Empty); } let value = queue.pop_front().unwrap(); self.channel.notify_one_sender(); Ok(value) } pub fn clear(&self) { // loop over try_recv() in order to notify senders while self.try_recv().is_ok() {} } } impl Drop for LocalReceiver { fn drop(&mut self) { *self.channel.read_set_readiness.borrow_mut() = None; self.channel.notify_one_sender(); } } pub fn local_channel( bound: usize, max_senders: usize, memory: &Rc>, ) -> (LocalSender, LocalReceiver) { let (read_reg, read_sr) = event::LocalRegistration::new(memory); let (write_reg, write_sr) = event::LocalRegistration::new(memory); // no support for rendezvous channels assert!(bound > 0); // need to support at least one sender assert!(max_senders > 0); let channel = Rc::new(LocalChannel { queue: RefCell::new(VecDeque::with_capacity(bound)), senders: RefCell::new(LocalSenders { nodes: Slab::with_capacity(max_senders), waiting: list::List::default(), }), read_set_readiness: RefCell::new(Some(read_sr)), }); let key = channel.add_sender(write_sr).unwrap(); let sender = LocalSender { channel: channel.clone(), key, write_registration: write_reg, }; let receiver = LocalReceiver { channel, read_registration: read_reg, }; (sender, receiver) } #[cfg(test)] mod tests { use super::*; use std::time; #[test] fn test_send_recv_bound0() { let (sender, receiver) = channel(0); assert_eq!(sender.can_send(), false); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); assert_eq!(sender.can_send(), true); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); assert_eq!(sender.can_send(), false); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn test_send_recv_bound1() { let (sender, receiver) = channel(1); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn test_notify_bound0() { let (sender, receiver) = channel(0); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); assert_eq!(sender.can_send(), false); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); assert_eq!(sender.can_send(), true); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn test_notify_bound1() { let (sender, receiver) = channel(1); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn test_local_send_recv() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(receiver.try_recv(), Ok(1)); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); assert_eq!(sender1.try_send(2), Ok(())); let channel = sender2.channel.clone(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender2.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); assert_eq!(sender2.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); mem::drop(sender1); mem::drop(sender2); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Disconnected)); } #[test] fn test_local_send_disc() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); mem::drop(receiver); assert_eq!(sender.try_send(1), Err(mpsc::TrySendError::Disconnected(1))); } #[test] fn test_local_cancel() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); let channel = sender2.channel.clone(); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(sender2.try_send(2), Err(mpsc::TrySendError::Full(2))); assert_eq!(sender1.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); sender2.cancel(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, true ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender1.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); } #[test] fn test_local_check_send() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); let channel = sender.channel.clone(); assert_eq!(sender.check_send(), true); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.try_send(1), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.check_send(), false); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, true ); assert_eq!(sender.try_send(2), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); } } condure-1.10.0/src/client.rs000064400000000000000000002742651046102023000137750ustar 00000000000000/* * Copyright (C) 2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::buffer::TmpBuffer; use crate::can_move_mio_sockets_between_threads; use crate::channel; use crate::connection::{ client_req_connection, client_stream_connection, ConnectionPool, StreamSharedData, }; use crate::event; use crate::executor::{Executor, Spawner}; use crate::future::{ event_wait, select_2, select_5, select_6, select_option, yield_to_local_events, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver, CancellationSender, CancellationToken, Select2, Select5, Select6, Timeout, }; use crate::list; use crate::pin; use crate::reactor::Reactor; use crate::resolver::Resolver; use crate::tnetstring; use crate::zhttppacket; use crate::zhttpsocket::{self, SessionKey, FROM_MAX, REQ_ID_MAX}; use crate::zmq::{MultipartHeader, SpecInfo}; use arrayvec::ArrayVec; use ipnet::IpNet; use log::{debug, error, info, warn}; use mio::unix::SourceFd; use slab::Slab; use std::cell::Cell; use std::cell::RefCell; use std::collections::{HashMap, VecDeque}; use std::convert::TryFrom; use std::io::{self, Write}; use std::mem; use std::rc::Rc; use std::str; use std::sync::{mpsc, Arc}; use std::thread; use std::time::Duration; const REQ_SENDER_BOUND: usize = 1; // we read and process each request message one at a time, wrapping it in an // rc, and sending it to connections via channels. on the other side of each // channel, the message is received and processed immediately, except for the // first message. this means the max number of messages retained per // connection is the channel bound per connection plus one pub const MSG_RETAINED_PER_CONNECTION_MAX: usize = REQ_SENDER_BOUND + 1; // the max number of messages retained outside of connections is one per // handle we read from (req and stream), in preparation for sending to any // connections pub const MSG_RETAINED_PER_WORKER_MAX: usize = 2; // run x1 // req_handle_task x1 // stream_handle_task x1 // keep_alives_task x1 const WORKER_NON_CONNECTION_TASKS_MAX: usize = 10; // this is meant to be an average max of registrations per task, in order // to determine the total number of registrations sufficient for all tasks, // however it is not enforced per task const REGISTRATIONS_PER_TASK_MAX: usize = 32; const REACTOR_BUDGET: u32 = 100; const KEEP_ALIVE_TIMEOUT_MS: usize = 45_000; const KEEP_ALIVE_BATCH_MS: usize = 100; const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(KEEP_ALIVE_BATCH_MS as u64); const KEEP_ALIVE_BATCHES: usize = KEEP_ALIVE_TIMEOUT_MS / KEEP_ALIVE_BATCH_MS; const BULK_PACKET_SIZE_MAX: usize = 65_000; const SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10_000); const RESOLVER_THREADS: usize = 10; fn local_channel( bound: usize, max_senders: usize, ) -> (channel::LocalSender, channel::LocalReceiver) { let (s, r) = channel::local_channel( bound, max_senders, &Reactor::current().unwrap().local_registration_memory(), ); (s, r) } fn async_local_channel( bound: usize, max_senders: usize, ) -> (AsyncLocalSender, AsyncLocalReceiver) { let (s, r) = local_channel(bound, max_senders); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); (s, r) } struct BatchKey { addr_index: usize, nkey: usize, } struct BatchGroup<'a, 'b> { addr: &'b [u8], ids: arena::ReusableVecHandle<'b, zhttppacket::Id<'a>>, } impl<'a> BatchGroup<'a, '_> { fn addr(&self) -> &[u8] { self.addr } fn ids(&self) -> &[zhttppacket::Id<'a>] { &self.ids } } struct Batch { nodes: Slab>, addrs: Vec<(ArrayVec, list::List)>, addr_index: usize, group_ids: arena::ReusableVec, last_group_ckeys: Vec, } impl Batch { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), addrs: Vec::with_capacity(capacity), addr_index: 0, group_ids: arena::ReusableVec::new::(capacity), last_group_ckeys: Vec::with_capacity(capacity), } } fn len(&self) -> usize { self.nodes.len() } fn capacity(&self) -> usize { self.nodes.capacity() } fn is_empty(&self) -> bool { self.nodes.is_empty() } fn clear(&mut self) { self.addrs.clear(); self.nodes.clear(); self.addr_index = 0; } fn add(&mut self, to_addr: &[u8], ckey: usize) -> Result { let mut pos = self.addrs.len(); for (i, a) in self.addrs.iter().enumerate() { if a.0.as_ref() == to_addr { pos = i; } } if pos == self.addrs.len() { // connection limits to_addr to FROM_MAX so this is guaranteed to succeed let a = ArrayVec::try_from(to_addr).unwrap(); self.addrs.push((a, list::List::default())); } if self.nodes.len() == self.nodes.capacity() { return Err(()); } let nkey = self.nodes.insert(list::Node::new(ckey)); self.addrs[pos].1.push_back(&mut self.nodes, nkey); Ok(BatchKey { addr_index: pos, nkey, }) } fn remove(&mut self, key: BatchKey) { self.addrs[key.addr_index] .1 .remove(&mut self.nodes, key.nkey); self.nodes.remove(key.nkey); } fn take_group<'a, 'b: 'a, F>(&'a mut self, get_ids: F) -> Option where F: Fn(usize) -> (&'b [u8], u32), { // find the next addr with items while self.addr_index < self.addrs.len() && self.addrs[self.addr_index].1.is_empty() { self.addr_index += 1; } // if all are empty, we're done if self.addr_index == self.addrs.len() { return None; } let (addr, keys) = &mut self.addrs[self.addr_index]; self.last_group_ckeys.clear(); let mut ids = self.group_ids.get_as_new(); // get ids/seqs while ids.len() < zhttppacket::IDS_MAX { let nkey = match keys.pop_front(&mut self.nodes) { Some(nkey) => nkey, None => break, }; let ckey = self.nodes[nkey].value; self.nodes.remove(nkey); let (id, seq) = get_ids(ckey); self.last_group_ckeys.push(ckey); ids.push(zhttppacket::Id { id, seq: Some(seq) }); } Some(BatchGroup { addr, ids }) } fn last_group_ckeys(&self) -> &[usize] { &self.last_group_ckeys } } enum BatchType { KeepAlive, Cancel, } struct ChannelPool { items: RefCell, channel::LocalReceiver)>>, } impl ChannelPool { fn new(capacity: usize) -> Self { Self { items: RefCell::new(VecDeque::with_capacity(capacity)), } } fn take(&self) -> Option<(channel::LocalSender, channel::LocalReceiver)> { let p = &mut *self.items.borrow_mut(); p.pop_back() } fn push(&self, pair: (channel::LocalSender, channel::LocalReceiver)) { let p = &mut *self.items.borrow_mut(); p.push_back(pair); } } struct ConnectionDone { ckey: usize, } struct ConnectionItem { id: Option, stop: Option, zreceiver_sender: Option, usize)>>, shared: Option>, batch_key: Option, } struct ConnectionItems { nodes: Slab>, nodes_by_id: HashMap, batch: Batch, } impl ConnectionItems { fn new(capacity: usize, batch: Batch) -> Self { Self { nodes: Slab::with_capacity(capacity), nodes_by_id: HashMap::with_capacity(capacity), batch, } } } struct ConnectionsInner { active: list::List, count: usize, max: usize, } struct Connections { items: Rc>, inner: RefCell, } impl Connections { fn new(items: Rc>, max: usize) -> Self { Self { items, inner: RefCell::new(ConnectionsInner { active: list::List::default(), count: 0, max, }), } } fn count(&self) -> usize { self.inner.borrow().count } fn max(&self) -> usize { self.inner.borrow().max } fn add( &self, stop: CancellationSender, zreceiver_sender: Option< channel::LocalSender<(arena::Rc, usize)>, >, shared: Option>, ) -> Result { let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); if items.nodes.len() == items.nodes.capacity() { return Err(()); } let nkey = items.nodes.insert(list::Node::new(ConnectionItem { id: None, stop: Some(stop), zreceiver_sender, shared, batch_key: None, })); c.active.push_back(&mut items.nodes, nkey); c.count += 1; Ok(nkey) } // return zreceiver_sender fn remove( &self, ckey: usize, ) -> Option, usize)>> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } c.active.remove(&mut items.nodes, nkey); c.count -= 1; let ci = items.nodes.remove(nkey).value; if let Some(id) = &ci.id { items.nodes_by_id.remove(id); } ci.zreceiver_sender } fn set_id(&self, ckey: usize, id: Option<&SessionKey>) { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; // unset current id, if any if let Some(cur_id) = &ci.id { items.nodes_by_id.remove(cur_id); ci.id = None; } if let Some(id) = id.cloned() { ci.id = Some(id.clone()); items.nodes_by_id.insert(id, nkey); } else { // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } } } fn find_key(&self, id: &SessionKey) -> Option { let items = &*self.items.borrow(); items.nodes_by_id.get(id).copied() } fn try_send( &self, ckey: usize, value: (arena::Rc, usize), ) -> Result<(), mpsc::TrySendError<(arena::Rc, usize)>> { let nkey = ckey; let items = &*self.items.borrow(); let ci = &items.nodes[nkey].value; let sender = match &ci.zreceiver_sender { Some(s) => s, None => return Err(mpsc::TrySendError::Disconnected(value)), }; sender.try_send(value) } fn stop_all(&self, about_to_stop: F) where F: Fn(usize), { let items = &mut *self.items.borrow_mut(); let cinner = &*self.inner.borrow_mut(); let mut next = cinner.active.head; while let Some(nkey) = next { let n = &mut items.nodes[nkey]; let ci = &mut n.value; about_to_stop(nkey); ci.stop = None; next = n.next; } } fn items_capacity(&self) -> usize { self.items.borrow().nodes.capacity() } fn can_stream(&self, ckey: usize) -> bool { let items = &*self.items.borrow(); match items.nodes.get(ckey) { Some(n) => { let ci = &n.value; // is stream mode with an id ci.shared.is_some() && ci.id.is_some() } None => false, } } fn batch_is_empty(&self) -> bool { let items = &*self.items.borrow(); items.batch.is_empty() } fn batch_len(&self) -> usize { let items = &*self.items.borrow(); items.batch.len() } fn batch_capacity(&self) -> usize { let items = &*self.items.borrow(); items.batch.capacity() } fn batch_clear(&self) { let items = &mut *self.items.borrow_mut(); items.batch.clear(); } fn batch_add(&self, ckey: usize) -> Result<(), ()> { let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // only batch connections with known handler addresses let addr_ref = cshared.to_addr(); let addr = match addr_ref.get() { Some(addr) => addr, None => return Err(()), }; let bkey = items.batch.add(addr, ckey)?; ci.batch_key = Some(bkey); Ok(()) } fn next_batch_message(&self, from: &str, btype: BatchType) -> Option<(usize, zmq::Message)> { let items = &mut *self.items.borrow_mut(); let nodes = &mut items.nodes; let batch = &mut items.batch; while !batch.is_empty() { let group = batch .take_group(|ckey| { let ci = &nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // item is guaranteed to have an id. only items with an // id are added to a batch, and if an item's id is // removed then the item is removed from the batch let id = ci.id.as_ref().unwrap(); (&id.1, cshared.out_seq()) }) .unwrap(); let count = group.ids().len(); assert!(count <= zhttppacket::IDS_MAX); let zreq = zhttppacket::Request { from: from.as_bytes(), ids: group.ids(), multi: true, ptype: match btype { BatchType::KeepAlive => zhttppacket::RequestPacket::KeepAlive, BatchType::Cancel => zhttppacket::RequestPacket::Cancel, }, ptype_str: "", }; let mut data = [0; BULK_PACKET_SIZE_MAX]; let size = match zreq.serialize(&mut data) { Ok(size) => size, Err(e) => { error!( "failed to serialize keep-alive packet with {} ids: {}", zreq.ids.len(), e ); continue; } }; let data = &data[..size]; let addr = group.addr(); let msg = { let mut v = vec![0; addr.len() + 1 + data.len()]; v[..addr.len()].copy_from_slice(addr); v[addr.len()] = b' '; let pos = addr.len() + 1; v[pos..(pos + data.len())].copy_from_slice(data); // this takes over the vec's memory without copying zmq::Message::from(v) }; drop(group); for &ckey in batch.last_group_ckeys() { let ci = &mut nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); cshared.inc_out_seq(); ci.batch_key = None; } return Some((count, msg)); } None } } #[derive(Clone)] struct ConnectionOpts { instance_id: Rc, buffer_size: usize, timeout: Duration, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } struct ConnectionReqOpts { body_buffer_size: usize, sender: channel::LocalSender<(MultipartHeader, zmq::Message)>, } struct ConnectionStreamOpts { messages_max: usize, allow_compression: bool, sender: channel::LocalSender, } struct Worker { thread: Option>, stop: Option>, } impl Worker { #[allow(clippy::too_many_arguments)] fn new( instance_id: &str, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: &[IpNet], resolver: &Arc, pool: &Arc, zsockman: &Arc, handle_bound: usize, ) -> Self { debug!("client worker {}: starting", id); let (stop, r_stop) = channel::channel(1); let (s_ready, ready) = channel::channel(1); let instance_id = String::from(instance_id); let deny = deny.to_vec(); let resolver = Arc::clone(resolver); let pool = Arc::clone(pool); let zsockman = Arc::clone(zsockman); let thread = thread::Builder::new() .name(format!("client-worker-{}", id)) .spawn(move || { let maxconn = req_maxconn + stream_maxconn; // 1 task per connection, plus a handful of supporting tasks let tasks_max = maxconn + WORKER_NON_CONNECTION_TASKS_MAX; let registrations_max = REGISTRATIONS_PER_TASK_MAX * tasks_max; let reactor = Reactor::new(registrations_max); let executor = Executor::new(tasks_max); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(REACTOR_BUDGET)); }); } executor .spawn(Self::run( r_stop, s_ready, instance_id, id, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, messages_max, req_timeout, stream_timeout, allow_compression, deny, resolver, pool, zsockman, handle_bound, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("client worker {}: stopped", id); }) .unwrap(); ready.recv().unwrap(); Self { thread: Some(thread), stop: Some(stop), } } fn stop(&mut self) { self.stop = None; } #[allow(clippy::too_many_arguments)] async fn run( stop: channel::Receiver<()>, ready: channel::Sender<()>, instance_id: String, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: Vec, resolver: Arc, pool: Arc, zsockman: Arc, handle_bound: usize, ) { let executor = Executor::current().unwrap(); let reactor = Reactor::current().unwrap(); let stop = AsyncReceiver::new(stop); debug!("client-worker {}: allocating buffers", id); let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); // large enough to fit anything let packet_buf = Rc::new(RefCell::new(vec![0; buffer_size + body_buffer_size + 4096])); // same size as working buffers let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let instance_id = Rc::new(instance_id); let ka_batch = (stream_maxconn + (KEEP_ALIVE_BATCHES - 1)) / KEEP_ALIVE_BATCHES; let batch = Batch::new(ka_batch); let maxconn = req_maxconn + stream_maxconn; let conn_items = Rc::new(RefCell::new(ConnectionItems::new(maxconn, batch))); let req_conns = Rc::new(Connections::new(conn_items.clone(), req_maxconn)); let stream_conns = Rc::new(Connections::new(conn_items.clone(), stream_maxconn)); let (req_handle_stop, r_req_handle_stop) = async_local_channel(1, 1); let (stream_handle_stop, r_stream_handle_stop) = async_local_channel(1, 1); let (keep_alives_stop, r_keep_alives_stop) = async_local_channel(1, 1); let (s_req_handle_done, req_handle_done) = async_local_channel(1, 1); let (s_stream_handle_done, stream_handle_done) = async_local_channel(1, 1); let (s_keep_alives_done, keep_alives_done) = async_local_channel(1, 1); // max_senders is 1 per connection + 1 for the handle task + 1 for the keep alive task let (zstream_out_sender, zstream_out_receiver) = local_channel(handle_bound, stream_maxconn + 2); let zstream_out_receiver = AsyncLocalReceiver::new(zstream_out_receiver); let req_handle = zhttpsocket::AsyncServerReqHandle::new(zsockman.server_req_handle()); let stream_handle = zhttpsocket::AsyncServerStreamHandle::new(zsockman.server_stream_handle()); let deny = Rc::new(deny); executor .spawn(Self::req_handle_task( id, r_req_handle_stop, s_req_handle_done, executor.spawner(), Arc::clone(&resolver), Arc::clone(&pool), req_handle, req_maxconn, req_conns, body_buffer_size, Rc::clone(&deny), handle_bound, ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: req_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, )) .unwrap(); { let zstream_out_sender = zstream_out_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); executor .spawn(Self::stream_handle_task( id, r_stream_handle_stop, s_stream_handle_done, zstream_out_receiver, zstream_out_sender, executor.spawner(), Arc::clone(&resolver), Arc::clone(&pool), stream_handle, stream_maxconn, stream_conns.clone(), messages_max, allow_compression, Rc::clone(&deny), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: stream_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, )) .unwrap(); } executor .spawn(Self::keep_alives_task( id, r_keep_alives_stop, s_keep_alives_done, instance_id.clone(), zstream_out_sender, stream_conns.clone(), )) .unwrap(); debug!("client-worker {}: started", id); ready.send(()).unwrap(); drop(ready); // wait for stop let _ = stop.recv().await; // stop keep alives drop(keep_alives_stop); let _ = keep_alives_done.recv().await; // stop remaining tasks drop(req_handle_stop); drop(stream_handle_stop); let _ = req_handle_done.recv().await; let stream_handle = stream_handle_done.recv().await.unwrap(); // send cancels stream_conns.batch_clear(); let now = reactor.now(); let shutdown_timeout = Timeout::new(now + SHUTDOWN_TIMEOUT); let mut next_cancel_index = 0; 'outer: while next_cancel_index < stream_conns.items_capacity() { while stream_conns.batch_len() < stream_conns.batch_capacity() && next_cancel_index < stream_conns.items_capacity() { let key = next_cancel_index; next_cancel_index += 1; if stream_conns.can_stream(key) { // ignore errors let _ = stream_conns.batch_add(key); } } while let Some((count, msg)) = stream_conns.next_batch_message(&instance_id, BatchType::Cancel) { debug!( "client-worker {}: sending cancels for {} sessions", id, count ); match select_2(pin!(stream_handle.send(msg)), shutdown_timeout.elapsed()).await { Select2::R1(r) => r.unwrap(), Select2::R2(_) => break 'outer, } } stream_conns.batch_clear(); } } #[allow(clippy::too_many_arguments)] async fn req_handle_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, spawner: Spawner, resolver: Arc, conn_pool: Arc, req_handle: zhttpsocket::AsyncServerReqHandle, req_maxconn: usize, conns: Rc, body_buffer_size: usize, deny: Rc>, handle_bound: usize, opts: ConnectionOpts, ) { let reactor = Reactor::current().unwrap(); let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * req_maxconn); let req_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let req_req_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); // max_senders is 1 per connection + 1 for this task let (zreq_sender, zreq_receiver) = local_channel(handle_bound, req_maxconn + 1); let zreq_receiver = AsyncLocalReceiver::new(zreq_receiver); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for this task let (s_cdone, r_cdone) = channel::local_channel::( conns.max(), conns.max() + 1, &reactor.local_registration_memory(), ); let r_cdone = AsyncLocalReceiver::new(r_cdone); debug!("client-worker {}: task started: req_handle", id); let mut handle_send = pin!(None); loop { let receiver_recv = if handle_send.is_none() { Some(zreq_receiver.recv()) } else { None }; let req_handle_recv = if conns.count() < conns.max() { Some(req_handle.recv()) } else { None }; match select_5( stop.recv(), select_option(receiver_recv), select_option(handle_send.as_mut().as_pin_mut()), r_cdone.recv(), select_option(pin!(req_handle_recv).as_pin_mut()), ) .await { // stop.recv Select5::R1(_) => break, // receiver_recv Select5::R2(result) => match result { Ok((header, msg)) => handle_send.set(Some(req_handle.send(header, msg))), Err(e) => panic!("zreq_receiver channel error: {}", e), }, // handle_send Select5::R3(result) => { handle_send.set(None); if let Err(e) = result { error!("req send error: {}", e); } } // r_cdone.recv Select5::R4(result) => match result { Ok(done) => { let ret = conns.remove(done.ckey); // req mode doesn't have a sender assert!(ret.is_none()); } Err(e) => panic!("r_cdone channel error: {}", e), }, // req_handle_recv Select5::R5(result) => match result { Ok((header, msg)) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq_ref = zreq.get(); let ids = zreq_ref.ids; if ids.len() > 1 { warn!( "client-worker {}: request contained more than one id, skipping", id ); continue; } let from: ArrayVec = match ArrayVec::try_from(zreq_ref.from) { Ok(v) => v, Err(_) => { warn!("client-worker {}: from address too long, skipping", id); continue; } }; let cid: Option> = if !ids.is_empty() { match ArrayVec::try_from(ids[0].id) { Ok(v) => Some(v), Err(_) => { warn!("client-worker {}: request id too long, skipping", id); continue; } } } else { None }; let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let (cstop, r_cstop) = CancellationToken::new(&reactor.local_registration_memory()); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let zreq_sender = zreq_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let ckey = conns.add(cstop, None, None).unwrap(); if let Some(cid) = &cid { let cid = (from, cid.clone()); conns.set_id(ckey, Some(&cid)); } debug!( "client-worker {}: req conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); if spawner .spawn(Self::req_connection_task( r_cstop, s_cdone, id, ckey, cid, (header, zreq), Arc::clone(&resolver), Arc::clone(&conn_pool), Rc::clone(&deny), opts.clone(), ConnectionReqOpts { body_buffer_size, sender: zreq_sender, }, )) .is_err() { // this should never happen. we only read a message // if we know we can spawn panic!("failed to spawn req_connection_task"); } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, } } drop(s_cdone); conns.stop_all(|ckey| debug!("client-worker {}: stopping {}", id, ckey)); while r_cdone.recv().await.is_ok() {} debug!("client-worker {}: task stopped: req_handle", id); } #[allow(clippy::too_many_arguments)] async fn stream_handle_task( id: usize, stop: AsyncLocalReceiver<()>, done: AsyncLocalSender, zstream_out_receiver: AsyncLocalReceiver, zstream_out_sender: channel::LocalSender, spawner: Spawner, resolver: Arc, conn_pool: Arc, stream_handle: zhttpsocket::AsyncServerStreamHandle, stream_maxconn: usize, conns: Rc, messages_max: usize, allow_compression: bool, deny: Rc>, opts: ConnectionOpts, ) { let reactor = Reactor::current().unwrap(); let stream_shared_mem = Rc::new(arena::RcMemory::new(stream_maxconn)); let zreceiver_pool = Rc::new(ChannelPool::new(stream_maxconn)); for _ in 0..stream_maxconn { zreceiver_pool.push(local_channel(REQ_SENDER_BOUND, 1)); } let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * stream_maxconn); let stream_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let stream_req_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for this task let (s_cdone, r_cdone) = channel::local_channel::( conns.max(), conns.max() + 1, &reactor.local_registration_memory(), ); let r_cdone = AsyncLocalReceiver::new(r_cdone); debug!("client-worker {}: task started: stream_handle", id); { let mut handle_send = pin!(None); loop { let receiver_recv = if handle_send.is_none() { Some(zstream_out_receiver.recv()) } else { None }; let stream_handle_recv_from_any = if conns.count() < conns.max() { Some(stream_handle.recv_from_any()) } else { None }; match select_6( stop.recv(), select_option(receiver_recv), select_option(handle_send.as_mut().as_pin_mut()), r_cdone.recv(), select_option(pin!(stream_handle_recv_from_any).as_pin_mut()), pin!(stream_handle.recv_directed()), ) .await { // stop.recv Select6::R1(_) => break, // receiver_recv Select6::R2(result) => match result { Ok(msg) => handle_send.set(Some(stream_handle.send(msg))), Err(e) => panic!("zstream_out_receiver channel error: {}", e), }, // handle_send Select6::R3(result) => { handle_send.set(None); if let Err(e) = result { error!("stream send error: {}", e); } } // r_cdone.recv Select6::R4(result) => match result { Ok(done) => { let zreceiver_sender = conns.remove(done.ckey).unwrap(); let zreceiver = zreceiver_sender .make_receiver(&reactor.local_registration_memory()) .unwrap(); zreceiver.clear(); zreceiver_pool.push((zreceiver_sender, zreceiver)); } Err(e) => panic!("r_cdone channel error: {}", e), }, // stream_handle_recv_from_any Select6::R5(result) => match result { Ok(ret) => { let (msg, session) = ret; let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq_ref = zreq.get(); let ids = zreq_ref.ids; if ids.len() != 1 { warn!("client-worker {}: packet did not contain exactly one id, skipping", id); continue; } if ids[0].seq != Some(0) { warn!("client-worker {}: received message with seq != 0 as first message, skipping", id); continue; } if !zreq_ref.ptype_str.is_empty() { warn!("client-worker {}: received non-data message as first message, skipping", id); continue; } if zreq_ref.from.len() > FROM_MAX { warn!("client-worker {}: from address too long, skipping", id); continue; } let cid: ArrayVec = match ArrayVec::try_from(ids[0].id) { Ok(v) => v, Err(_) => { warn!("client-worker {}: request id too long, skipping", id); continue; } }; let zreq = arena::Rc::new(zreq, &stream_req_mem).unwrap(); let (cstop, r_cstop) = CancellationToken::new(&reactor.local_registration_memory()); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let zstream_out_sender = zstream_out_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zstream_receiver_sender, zstream_receiver) = zreceiver_pool.take().unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), &stream_shared_mem) .unwrap(); let ckey = conns .add( cstop, Some(zstream_receiver_sender), Some(arena::Rc::clone(&shared)), ) .unwrap(); debug!( "client-worker {}: stream conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); if spawner .spawn(Self::stream_connection_task( r_cstop, s_cdone, id, ckey, cid, arena::Rc::clone(&zreq), Arc::clone(&resolver), Arc::clone(&conn_pool), zstream_receiver, Rc::clone(&deny), Rc::clone(&conns), opts.clone(), ConnectionStreamOpts { messages_max, allow_compression, sender: zstream_out_sender, }, shared, Some(session), )) .is_err() { // this should never happen. we only read a message // if we know we can spawn panic!("failed to spawn stream_connection_task"); } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, // stream_handle.recv_directed Select6::R6(result) => match result { Ok(msg) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq = arena::Rc::new(zreq, &stream_req_mem).unwrap(); let zreq_ref = zreq.get().get(); let ids = zreq_ref.ids; if ids.is_empty() { warn!("client-worker {}: packet contained no ids, skipping", id); continue; } let from: ArrayVec = match ArrayVec::try_from(zreq_ref.from) { Ok(v) => v, Err(_) => { warn!( "client-worker {}: from address too long, skipping", id ); continue; } }; let mut count = 0; for (i, rid) in ids.iter().enumerate() { let cid: ArrayVec = match ArrayVec::try_from(rid.id) { Ok(v) => v, Err(_) => { warn!( "client-worker {}: request id too long, skipping", id ); continue; } }; let cid = (from.clone(), cid); let key = match conns.find_key(&cid) { Some(key) => key, None => continue, }; // this should always succeed, since afterwards we yield // to let the connection receive the message match conns.try_send(key, (arena::Rc::clone(&zreq), i)) { Ok(()) => count += 1, Err(mpsc::TrySendError::Full(_)) => error!( "client-worker {}: connection-{} cannot receive message", id, key ), Err(mpsc::TrySendError::Disconnected(_)) => {} // conn task ended } } debug!( "client-worker {}: queued zmq message for {} conns", id, count ); if count > 0 { yield_to_local_events().await; } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, } } } drop(s_cdone); conns.stop_all(|ckey| debug!("client-worker {}: stopping {}", id, ckey)); while r_cdone.recv().await.is_ok() {} // give the handle back done.send(stream_handle).await.unwrap(); debug!("client-worker {}: task stopped: stream_handle", id); } #[allow(clippy::too_many_arguments)] async fn req_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: Option>, zreq: (MultipartHeader, arena::Rc), resolver: Arc, pool: Arc, deny: Rc>, opts: ConnectionOpts, req_opts: ConnectionReqOpts, ) { let done = AsyncLocalSender::new(done); debug!( "client-worker {}: task started: connection-{}", worker_id, ckey ); let log_id = if let Some(cid) = &cid { // zhttp ids are pretty much always valid strings, but we'll // do a lossy conversion just in case let cid_str = String::from_utf8_lossy(cid); format!("{}-{}-{}", worker_id, ckey, cid_str) } else { format!("{}-{}", worker_id, ckey) }; client_req_connection( token, &log_id, cid.as_deref(), zreq, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.packet_buf, opts.timeout, &deny, &resolver, &pool, AsyncLocalSender::new(req_opts.sender), ) .await; done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "client-worker {}: task stopped: connection-{}", worker_id, ckey ); } #[allow(clippy::too_many_arguments)] async fn stream_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: ArrayVec, zreq: arena::Rc, resolver: Arc, pool: Arc, zreceiver: channel::LocalReceiver<(arena::Rc, usize)>, deny: Rc>, conns: Rc, opts: ConnectionOpts, stream_opts: ConnectionStreamOpts, shared: arena::Rc, session: Option, ) { let done = AsyncLocalSender::new(done); let zreceiver = AsyncLocalReceiver::new(zreceiver); debug!( "client-worker {}: task started: connection-{}", worker_id, ckey ); let log_id = { // zhttp ids are pretty much always valid strings, but we'll // do a lossy conversion just in case let cid_str = String::from_utf8_lossy(&cid); format!("{}-{}-{}", worker_id, ckey, cid_str) }; client_stream_connection( token, &log_id, &cid, arena::Rc::clone(&zreq), opts.buffer_size, stream_opts.messages_max, &opts.rb_tmp, opts.packet_buf, opts.tmp_buf, opts.timeout, stream_opts.allow_compression, &deny, &opts.instance_id, &resolver, &pool, zreceiver, AsyncLocalSender::new(stream_opts.sender), shared, &|| { // handle task limits addr to FROM_MAX so this is guaranteed to succeed let from: ArrayVec = ArrayVec::try_from(zreq.get().get().from).unwrap(); let cid = (from, cid.clone()); conns.set_id(ckey, Some(&cid)) }, ) .await; drop(session); done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "client-worker {}: task stopped: connection-{}", worker_id, ckey ); } async fn keep_alives_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, instance_id: Rc, sender: channel::LocalSender, conns: Rc, ) { debug!("client-worker {}: task started: keep_alives", id); let reactor = Reactor::current().unwrap(); let mut keep_alive_count = 0; let mut next_keep_alive_time = reactor.now() + KEEP_ALIVE_INTERVAL; let next_keep_alive_timeout = Timeout::new(next_keep_alive_time); let mut next_keep_alive_index = 0; let sender_registration = reactor .register_custom_local(sender.get_write_registration(), mio::Interest::WRITABLE) .unwrap(); sender_registration.set_readiness(Some(mio::Interest::WRITABLE)); 'main: loop { while conns.batch_is_empty() { // wait for next keep alive time match select_2(stop.recv(), next_keep_alive_timeout.elapsed()).await { Select2::R1(_) => break 'main, Select2::R2(_) => {} } for _ in 0..conns.batch_capacity() { if next_keep_alive_index >= conns.items_capacity() { break; } let key = next_keep_alive_index; next_keep_alive_index += 1; if conns.can_stream(key) { // ignore errors let _ = conns.batch_add(key); } } keep_alive_count += 1; if keep_alive_count >= KEEP_ALIVE_BATCHES { keep_alive_count = 0; next_keep_alive_index = 0; } // keep steady pace next_keep_alive_time += KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); } match select_2( stop.recv(), pin!(event_wait(&sender_registration, mio::Interest::WRITABLE)), ) .await { Select2::R1(_) => break, Select2::R2(_) => {} } if !sender.check_send() { // if check_send returns false, we'll be on the waitlist for a notification sender_registration.clear_readiness(mio::Interest::WRITABLE); continue; } // if check_send returns true, we are guaranteed to be able to send match conns.next_batch_message(&instance_id, BatchType::KeepAlive) { Some((count, msg)) => { debug!( "client-worker {}: sending keep alives for {} sessions", id, count ); if let Err(e) = sender.try_send(msg) { error!("zhttp write error: {}", e); } } None => { // this could happen if message construction failed sender.cancel(); } } if conns.batch_is_empty() { conns.batch_clear(); let now = reactor.now(); if now >= next_keep_alive_time + KEEP_ALIVE_INTERVAL { // got really behind somehow. just skip ahead next_keep_alive_time = now + KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); } } } debug!("client-worker {}: task stopped: keep_alives", id); } } impl Drop for Worker { fn drop(&mut self) { self.stop(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } pub struct Client { workers: Vec, } impl Client { #[allow(clippy::too_many_arguments)] pub fn new( instance_id: &str, worker_count: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: &[IpNet], zsockman: Arc, handle_bound: usize, ) -> Result { // 1 active query per connection let queries_max = req_maxconn + stream_maxconn; let resolver = Arc::new(Resolver::new(RESOLVER_THREADS, queries_max)); let pool_max = if can_move_mio_sockets_between_threads() { (req_maxconn + stream_maxconn) / 10 } else { // disable persistent connections 0 }; let pool = Arc::new(ConnectionPool::new(pool_max)); if !deny.is_empty() { info!("default policy: block outgoing connections to {:?}", deny); } let mut workers = Vec::new(); for i in 0..worker_count { let w = Worker::new( instance_id, i, req_maxconn / worker_count, stream_maxconn / worker_count, buffer_size, body_buffer_size, messages_max, req_timeout, stream_timeout, allow_compression, deny, &resolver, &pool, &zsockman, handle_bound, ); workers.push(w); } Ok(Self { workers }) } pub fn task_sizes() -> Vec<(String, usize)> { let req_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let req_scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_req_mem = Rc::new(arena::RcMemory::new(1)); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let msg = concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ); let msg = arena::Arc::new(zmq::Message::from(msg.as_bytes()), &msg_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let resolver = Arc::new(Resolver::new(1, 1)); let pool = Arc::new(ConnectionPool::new(0)); let fut = Worker::req_connection_task( stop, done, 0, 0, None, (MultipartHeader::new(), zreq), resolver, pool, Rc::new(Vec::new()), ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionReqOpts { body_buffer_size: 0, sender, }, ); mem::size_of_val(&fut) }; let stream_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (_, zreceiver) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let batch = Batch::new(1); let conn_items = Rc::new(RefCell::new(ConnectionItems::new(1, batch))); let conns = Rc::new(Connections::new(conn_items, 1)); let req_scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_req_mem = Rc::new(arena::RcMemory::new(1)); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let msg = concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ); let msg = arena::Arc::new(zmq::Message::from(msg.as_bytes()), &msg_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let resolver = Arc::new(Resolver::new(1, 1)); let pool = Arc::new(ConnectionPool::new(0)); let stream_shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &stream_shared_mem).unwrap(); let fut = Worker::stream_connection_task( stop, done, 0, 0, ArrayVec::new(), zreq, resolver, pool, zreceiver, Rc::new(Vec::new()), conns, ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionStreamOpts { messages_max: 0, allow_compression: false, sender, }, shared, None, ); mem::size_of_val(&fut) }; vec![ ("client_req_connection_task".to_string(), req_task_size), ( "client_stream_connection_task".to_string(), stream_task_size, ), ] } } impl Drop for Client { fn drop(&mut self) { for w in self.workers.iter_mut() { w.stop(); } } } #[derive(Debug, Eq, PartialEq)] enum StatusMessage { Started, ReqFinished, StreamFinished, } enum ControlMessage { Stop, Req(zmq::Message), Stream(zmq::Message), } pub struct TestClient { _client: Client, thread: Option>, status: channel::Receiver, control: channel::Sender, next_id: Cell, } impl TestClient { pub fn new(workers: usize) -> Self { let zmq_context = Arc::new(zmq::Context::new()); let req_maxconn = 100; let stream_maxconn = 100; let maxconn = req_maxconn + stream_maxconn; let mut zsockman = zhttpsocket::ServerSocketManager::new( Arc::clone(&zmq_context), "test", (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * workers), 100, 100, 100, stream_maxconn, ); zsockman .set_server_req_specs(&[SpecInfo { spec: String::from("inproc://client-test"), bind: true, ipc_file_mode: 0, }]) .unwrap(); let zsockman = Arc::new(zsockman); let client = Client::new( "test", workers, req_maxconn, stream_maxconn, 1024, 1024, 10, Duration::from_secs(5), Duration::from_secs(5), false, &[], zsockman.clone(), 100, ) .unwrap(); zsockman .set_server_stream_specs( &[SpecInfo { spec: String::from("inproc://client-test-out"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://client-test-out-stream"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://client-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let (status_s, status_r) = channel::channel(1000); let (control_s, control_r) = channel::channel(1000); let thread = thread::spawn(move || { Self::run(status_s, control_r, zmq_context); }); // wait for handler thread to start assert_eq!(status_r.recv().unwrap(), StatusMessage::Started); Self { _client: client, thread: Some(thread), status: status_r, control: control_s, next_id: Cell::new(0), } } pub fn do_req(&self, addr: std::net::SocketAddr) { let msg = self.make_req_message(addr).unwrap(); self.control.send(ControlMessage::Req(msg)).unwrap(); } pub fn do_stream_http(&self, addr: std::net::SocketAddr) { let msg = self.make_stream_message(addr, false).unwrap(); self.control.send(ControlMessage::Stream(msg)).unwrap(); } pub fn do_stream_ws(&self, addr: std::net::SocketAddr) { let msg = self.make_stream_message(addr, true).unwrap(); self.control.send(ControlMessage::Stream(msg)).unwrap(); } pub fn wait_req(&self) { assert_eq!(self.status.recv().unwrap(), StatusMessage::ReqFinished); } pub fn wait_stream(&self) { assert_eq!(self.status.recv().unwrap(), StatusMessage::StreamFinished); } fn make_req_message(&self, addr: std::net::SocketAddr) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; let mut tmp = [0u8; 1024]; let id = { let id = self.next_id.get(); self.next_id.set(id + 1); let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "{}", id)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"method")?; w.write_string(b"GET")?; let mut tmp = [0u8; 1024]; let uri = { let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "http://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"uri")?; w.write_string(uri)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn make_stream_message( &self, addr: std::net::SocketAddr, ws: bool, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; let mut tmp = [0u8; 1024]; let id = { let id = self.next_id.get(); self.next_id.set(id + 1); let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "{}", id)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; let mut tmp = [0u8; 1024]; let uri = if ws { let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "ws://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] } else { w.write_string(b"method")?; w.write_string(b"GET")?; let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "http://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"uri")?; w.write_string(uri)?; w.write_string(b"credits")?; w.write_int(1024)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_msg( id: &[u8], seq: u32, ptype: &str, content_type: &str, body: &[u8], code: Option, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(seq as isize)?; if ptype.is_empty() { w.write_string(b"content-type")?; w.write_string(content_type.as_bytes())?; } else { w.write_string(b"type")?; w.write_string(ptype.as_bytes())?; } if let Some(x) = code { w.write_string(b"code")?; w.write_int(x as isize)?; } w.write_string(b"body")?; w.write_string(body)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn run( status: channel::Sender, control: channel::Receiver, zmq_context: Arc, ) { let req_sock = zmq_context.socket(zmq::DEALER).unwrap(); req_sock.connect("inproc://client-test").unwrap(); let out_sock = zmq_context.socket(zmq::PUSH).unwrap(); out_sock.connect("inproc://client-test-out").unwrap(); let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); out_stream_sock .connect("inproc://client-test-out-stream") .unwrap(); let in_sock = zmq_context.socket(zmq::SUB).unwrap(); in_sock.set_subscribe(b"handler ").unwrap(); in_sock.connect("inproc://client-test-in").unwrap(); // ensure zsockman is subscribed thread::sleep(Duration::from_millis(100)); status.send(StatusMessage::Started).unwrap(); let mut poller = event::Poller::new(1).unwrap(); poller .register_custom( control.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&req_sock.get_fd().unwrap()), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_sock.get_fd().unwrap()), mio::Token(3), mio::Interest::READABLE, ) .unwrap(); let mut req_events = req_sock.get_events().unwrap(); let mut in_events = in_sock.get_events().unwrap(); 'main: loop { while req_events.contains(zmq::POLLIN) { let parts = match req_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { req_events = req_sock.get_events().unwrap(); break; } Err(e) => panic!("recv error: {:?}", e), }; req_events = req_sock.get_events().unwrap(); assert_eq!(parts.len(), 2); let msg = &parts[1]; assert_eq!(msg[0], b'T'); let mut ptype = ""; let mut code: u16 = 0; let mut reason = ""; let mut body = b"".as_slice(); for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "type" => { let s = tnetstring::parse_string(f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "code" => { let x = tnetstring::parse_int(f.data).unwrap(); code = x as u16; } "reason" => { let s = tnetstring::parse_string(f.data).unwrap(); reason = str::from_utf8(s).unwrap(); } "body" => { let s = tnetstring::parse_string(f.data).unwrap(); body = s; } _ => {} } } assert_eq!(ptype, ""); assert_eq!(code, 200); assert_eq!(reason, "OK"); assert_eq!(str::from_utf8(body).unwrap(), "hello\n"); status.send(StatusMessage::ReqFinished).unwrap(); } while in_events.contains(zmq::POLLIN) { let parts = match in_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { in_events = in_sock.get_events().unwrap(); break; } Err(e) => panic!("recv error: {:?}", e), }; in_events = in_sock.get_events().unwrap(); assert_eq!(parts.len(), 1); let buf = &parts[0]; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = pos.unwrap(); let msg = &buf[(pos + 1)..]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut seq = None; let mut ptype = ""; let mut code = None; let mut reason = ""; let mut content_type = ""; let mut body = &b""[..]; let mut more = false; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "seq" => { let x = tnetstring::parse_int(f.data).unwrap(); seq = Some(x as u32); } "type" => { let s = tnetstring::parse_string(f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "code" => { let x = tnetstring::parse_int(f.data).unwrap(); code = Some(x as u16); } "reason" => { let s = tnetstring::parse_string(f.data).unwrap(); reason = str::from_utf8(s).unwrap(); } "content-type" => { let s = tnetstring::parse_string(f.data).unwrap(); content_type = str::from_utf8(s).unwrap(); } "body" => { let s = tnetstring::parse_string(f.data).unwrap(); body = s; } "more" => { let b = tnetstring::parse_bool(f.data).unwrap(); more = b; } _ => {} } } let seq = seq.unwrap() + 1; // as a hack to make the test server stateless, respond to every message // using the received sequence number. for messages we don't care about, // respond with keep-alive in order to keep the sequencing going if ptype.is_empty() || ptype == "ping" || ptype == "pong" || ptype == "close" { if ptype.is_empty() && content_type.is_empty() { // assume http/ws accept, or http body if !reason.is_empty() { // http/ws accept let code = code.unwrap(); assert!(code == 200 || code == 101); if code == 200 { assert_eq!(reason, "OK"); assert_eq!(body.len(), 0); assert!(more); } else { // 101 assert_eq!(reason, "Switching Protocols"); assert_eq!(body.len(), 0); assert!(!more); } let msg = Self::respond_msg(id.as_bytes(), seq, "keep-alive", "", b"", None) .unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); } else { // http body assert_eq!(str::from_utf8(body).unwrap(), "hello\n"); assert!(!more); status.send(StatusMessage::StreamFinished).unwrap(); } } else { // assume ws message if ptype == "ping" { ptype = "pong"; } // echo let msg = Self::respond_msg(id.as_bytes(), seq, ptype, content_type, body, code) .unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); if ptype == "close" { status.send(StatusMessage::StreamFinished).unwrap(); } } } else { let msg = Self::respond_msg(id.as_bytes(), seq, "keep-alive", "", b"", None).unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); } } poller.poll(None).unwrap(); for event in poller.iter_events() { match event.token() { mio::Token(1) => { while let Ok(msg) = control.try_recv() { match msg { ControlMessage::Stop => break 'main, ControlMessage::Req(msg) => { req_sock .send_multipart([zmq::Message::new(), msg], 0) .unwrap(); req_events = req_sock.get_events().unwrap(); } ControlMessage::Stream(msg) => out_sock.send(msg, 0).unwrap(), } } } mio::Token(2) => req_events = req_sock.get_events().unwrap(), mio::Token(3) => in_events = in_sock.get_events().unwrap(), _ => unreachable!(), } } } } } impl Drop for TestClient { fn drop(&mut self) { self.control.try_send(ControlMessage::Stop).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] pub mod tests { use super::*; use crate::connection::calculate_ws_accept; use crate::websocket; use std::io::Read; use test_log::test; fn recv_frame( stream: &mut R, buf: &mut Vec, ) -> Result<(bool, u8, Vec), io::Error> { loop { let fi = match websocket::read_header(buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => return Err(e), }; while buf.len() < fi.payload_offset + fi.payload_size { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); } let mut content = Vec::from(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); if let Some(mask) = fi.mask { websocket::apply_mask(&mut content, mask, 0); } *buf = buf.split_off(fi.payload_offset + fi.payload_size); return Ok((fi.fin, fi.opcode, content)); } } #[test] fn test_batch() { let mut batch = Batch::new(3); assert_eq!(batch.capacity(), 3); assert_eq!(batch.len(), 0); assert_eq!(batch.last_group_ckeys(), &[]); assert!(batch.add(b"addr-a", 1).is_ok()); assert!(batch.add(b"addr-a", 2).is_ok()); assert!(batch.add(b"addr-b", 3).is_ok()); assert_eq!(batch.len(), 3); assert!(batch.add(b"addr-c", 4).is_err()); assert_eq!(batch.len(), 3); assert_eq!(batch.is_empty(), false); let ids = ["id-1", "id-2", "id-3"]; let group = batch .take_group(|ckey| (ids[ckey - 1].as_bytes(), 0)) .unwrap(); assert_eq!(group.ids().len(), 2); assert_eq!(group.ids()[0].id, b"id-1"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.ids()[1].id, b"id-2"); assert_eq!(group.ids()[1].seq, Some(0)); assert_eq!(group.addr(), b"addr-a"); drop(group); assert_eq!(batch.is_empty(), false); assert_eq!(batch.last_group_ckeys(), &[1, 2]); let group = batch .take_group(|ckey| (ids[ckey - 1].as_bytes(), 0)) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-3"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); drop(group); assert_eq!(batch.is_empty(), true); assert_eq!(batch.last_group_ckeys(), &[3]); assert!(batch .take_group(|ckey| { (ids[ckey - 1].as_bytes(), 0) }) .is_none()); assert_eq!(batch.last_group_ckeys(), &[3]); } #[test] fn test_client() { let client = TestClient::new(1); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); // req client.do_req(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ) .unwrap(); drop(stream); client.wait_req(); // stream (http) client.do_stream_http(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ) .unwrap(); drop(stream); client.wait_stream(); // stream (ws) client.do_stream_ws(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let req_buf = &buf[..req_end]; // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; 32]; let mut req = httparse::Request::new(&mut headers); match req.parse(req_buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: {}\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "\r\n" ), addr, ws_key, ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); buf = buf.split_off(req_end); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); stream.write(resp_data.as_bytes()).unwrap(); // send message let mut data = vec![0; 1024]; let body = &b"hello"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv message let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_TEXT); assert_eq!(str::from_utf8(&content).unwrap(), "hello"); } #[test] fn test_ws() { let client = TestClient::new(1); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); client.do_stream_ws(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let req_buf = &buf[..req_end]; // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; 32]; let mut req = httparse::Request::new(&mut headers); match req.parse(req_buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: {}\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "\r\n" ), addr, ws_key, ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); buf = buf.split_off(req_end); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); stream.write(resp_data.as_bytes()).unwrap(); // send binary let mut data = vec![0; 1024]; let body = &[1, 2, 3][..]; let size = websocket::write_header( true, false, websocket::OPCODE_BINARY, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv binary let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_BINARY); assert_eq!(content, &[1, 2, 3][..]); buf.clear(); // send ping let mut data = vec![0; 1024]; let body = &b""[..]; let size = websocket::write_header( true, false, websocket::OPCODE_PING, body.len(), None, &mut data, ) .unwrap(); stream.write(&data[..size]).unwrap(); // recv pong let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_PONG); assert_eq!(str::from_utf8(&content).unwrap(), ""); buf.clear(); // send close let mut data = vec![0; 1024]; let body = &b"\x03\xf0gone"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_CLOSE, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv close let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_CLOSE); assert_eq!(&content, &b"\x03\xf0gone"[..]); // expect tcp close let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); assert_eq!(size, 0); client.wait_stream(); } #[cfg(target_arch = "x86_64")] #[cfg(debug_assertions)] #[test] fn test_task_sizes() { // sizes in debug mode at commit c0e4d161997e5c2880ba3409efe13afa3ec26fd7 const REQ_TASK_SIZE_BASE: usize = 6888; const STREAM_TASK_SIZE_BASE: usize = 12152; // cause tests to fail if sizes grow too much const GROWTH_LIMIT: usize = 1000; const REQ_TASK_SIZE_MAX: usize = REQ_TASK_SIZE_BASE + GROWTH_LIMIT; const STREAM_TASK_SIZE_MAX: usize = STREAM_TASK_SIZE_BASE + GROWTH_LIMIT; let sizes = Client::task_sizes(); assert_eq!(sizes[0].0, "client_req_connection_task"); assert!(sizes[0].1 <= REQ_TASK_SIZE_MAX); assert_eq!(sizes[1].0, "client_stream_connection_task"); assert!(sizes[1].1 <= STREAM_TASK_SIZE_MAX); } } condure-1.10.0/src/connection.rs000064400000000000000000011645141046102023000146520ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Note: Always Be Receiving (ABR) // // Connection handlers are expected to read ZHTTP messages as fast as // possible. If they don't, the whole thread could stall. This is by design, // to limit the number of to-be-processed messages in memory. They either // need to do something immediately with the messages, or discard them. // // Every await point must ensure messages keep getting read/processed, by // doing one of: // // - Directly awaiting a message. // - Awaiting a select that is awaiting a message. // - Wrapping other activity with discard_while(). // - Calling handle_other(), which itself will read messages. // - Awaiting something known to not block. #![allow(clippy::collapsible_if)] #![allow(clippy::collapsible_else_if)] use crate::arena; use crate::buffer::{ BaseRingBuffer, Buffer, LimitBufsMut, RefRead, RingBuffer, SliceRingBuffer, TmpBuffer, VECTORED_MAX, }; use crate::future::{ io_split, poll_async, select_2, select_3, select_4, select_option, AsyncLocalReceiver, AsyncLocalSender, AsyncRead, AsyncReadExt, AsyncResolver, AsyncTcpStream, AsyncTlsStream, AsyncWrite, AsyncWriteExt, CancellationToken, ReadHalf, Select2, Select3, Select4, StdWriteWrapper, Timeout, TlsWaker, WriteHalf, }; use crate::http1; use crate::net::SocketAddr; use crate::pool::Pool; use crate::reactor::Reactor; use crate::resolver; use crate::shuffle::random; use crate::tls::{TlsStream, VerifyMode}; use crate::track::{track_future, Track, TrackFlag, TrackedAsyncLocalReceiver, ValueActiveError}; use crate::waker::RefWakerData; use crate::websocket; use crate::zhttppacket; use crate::zmq::MultipartHeader; use crate::{pin, Defer}; use arrayvec::{ArrayString, ArrayVec}; use ipnet::IpNet; use log::{debug, log, warn, Level}; use sha1::{Digest, Sha1}; use std::cell::{Ref, RefCell}; use std::cmp; use std::collections::VecDeque; use std::convert::TryFrom; use std::future::Future; use std::io::{self, Read, Write}; use std::mem; use std::net::IpAddr; use std::pin::Pin; use std::rc::Rc; use std::str; use std::str::FromStr; use std::sync::{mpsc, Arc, Mutex}; use std::task::Context; use std::task::Poll; use std::thread; use std::time::{Duration, Instant}; const URI_SIZE_MAX: usize = 4096; const HEADERS_MAX: usize = 64; const WS_HASH_INPUT_MAX: usize = 256; const WS_KEY_MAX: usize = 24; // base64_encode([16 bytes]) = 24 bytes const WS_ACCEPT_MAX: usize = 28; // base64_encode(sha1_hash) = 28 bytes const REDIRECTS_MAX: usize = 8; const ZHTTP_SESSION_TIMEOUT: Duration = Duration::from_secs(60); const CONNECTION_POOL_TTL: Duration = Duration::from_secs(55); pub trait CidProvider { fn get_new_assigned_cid(&mut self) -> ArrayString<32>; } pub trait Identify { fn set_id(&mut self, id: &str); } #[derive(PartialEq)] enum Mode { HttpReq, HttpStream, WebSocket, } fn get_host<'a>(headers: &'a [httparse::Header]) -> &'a str { for h in headers.iter() { if h.name.eq_ignore_ascii_case("Host") { match str::from_utf8(h.value) { Ok(s) => return s, Err(_) => break, } } } "localhost" } fn gen_ws_key() -> ArrayString { let mut nonce = [0; 16]; for b in nonce.iter_mut() { *b = (random() % 256) as u8; } let mut output = [0; WS_KEY_MAX]; let size = base64::encode_config_slice(nonce, base64::STANDARD, &mut output); let output = str::from_utf8(&output[..size]).unwrap(); ArrayString::from_str(output).unwrap() } #[allow(clippy::result_unit_err)] pub fn calculate_ws_accept(key: &[u8]) -> Result, ()> { let input_len = key.len() + websocket::WS_GUID.len(); if input_len > WS_HASH_INPUT_MAX { return Err(()); } let mut input = [0; WS_HASH_INPUT_MAX]; input[..key.len()].copy_from_slice(key); input[key.len()..input_len].copy_from_slice(websocket::WS_GUID.as_bytes()); let input = &input[..input_len]; let mut hasher = Sha1::new(); hasher.update(input); let digest = hasher.finalize(); let mut output = [0; WS_ACCEPT_MAX]; let size = base64::encode_config_slice(digest, base64::STANDARD, &mut output); let output = match str::from_utf8(&output[..size]) { Ok(s) => s, Err(_) => return Err(()), }; Ok(ArrayString::from_str(output).unwrap()) } fn validate_ws_request( req: &http1::Request, ws_version: Option<&[u8]>, ws_key: Option<&[u8]>, ) -> Result, ()> { // a websocket request must not have a body. // some clients send "Content-Length: 0", which we'll allow. // chunked encoding will be rejected. if req.method == "GET" && (req.body_size == http1::BodySize::NoBody || req.body_size == http1::BodySize::Known(0)) && ws_version == Some(b"13") { if let Some(ws_key) = ws_key { return calculate_ws_accept(ws_key); } } Err(()) } fn validate_ws_response(ws_key: &[u8], ws_accept: Option<&[u8]>) -> Result<(), ()> { if let Some(ws_accept) = ws_accept { if calculate_ws_accept(ws_key)?.as_bytes() == ws_accept { return Ok(()); } } Err(()) } fn gen_mask() -> [u8; 4] { let mut out = [0; 4]; for b in out.iter_mut() { *b = (random() % 256) as u8; } out } fn write_ws_ext_header_value( config: &websocket::PerMessageDeflateConfig, dest: &mut W, ) -> Result<(), io::Error> { write!(dest, "permessage-deflate")?; config.serialize(dest) } #[allow(clippy::too_many_arguments)] fn make_zhttp_request( instance: &str, ids: &[zhttppacket::Id], method: &str, path: &str, headers: &[httparse::Header], body: &[u8], more: bool, mode: Mode, credits: u32, peer_addr: Option<&SocketAddr>, secure: bool, packet_buf: &mut [u8], ) -> Result { let mut data = zhttppacket::RequestData::new(); data.method = method; let host = get_host(headers); let mut zheaders = [zhttppacket::EMPTY_HEADER; HEADERS_MAX]; let mut zheaders_len = 0; for h in headers.iter() { zheaders[zheaders_len] = zhttppacket::Header { name: h.name, value: h.value, }; zheaders_len += 1; } data.headers = &zheaders[..zheaders_len]; let scheme = match mode { Mode::HttpReq | Mode::HttpStream => { if secure { "https" } else { "http" } } Mode::WebSocket => { if secure { "wss" } else { "ws" } } }; let mut uri = [0; URI_SIZE_MAX]; let mut c = io::Cursor::new(&mut uri[..]); write!(&mut c, "{}://{}{}", scheme, host, path)?; let size = c.position() as usize; data.uri = match str::from_utf8(&uri[..size]) { Ok(s) => s, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidData)), }; data.body = body; data.more = more; if mode == Mode::HttpStream { data.stream = true; } data.credits = credits; let mut addr = [0; 128]; if let Some(SocketAddr::Ip(peer_addr)) = peer_addr { let mut c = io::Cursor::new(&mut addr[..]); write!(&mut c, "{}", peer_addr.ip()).unwrap(); let size = c.position() as usize; data.peer_address = str::from_utf8(&addr[..size]).unwrap(); data.peer_port = peer_addr.port(); } let mut zreq = zhttppacket::Request::new_data(instance.as_bytes(), ids, data); zreq.multi = true; let size = zreq.serialize(packet_buf)?; Ok(zmq::Message::from(&packet_buf[..size])) } #[derive(Debug)] enum Error { Io(io::Error), Utf8(str::Utf8Error), Http(http1::Error), WebSocket(websocket::Error), InvalidWebSocketRequest, InvalidWebSocketResponse, CompressionError, BadMessage, HandlerError, HandlerCancel, BufferExceeded, Unusable, BadFrame, BadRequest, TlsError, PolicyViolation, TooManyRedirects, ValueActive, StreamTimeout, SessionTimeout, Stopped, } impl Error { fn to_condition(&self) -> &'static str { match self { Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => { "remote-connection-failed" } Error::Io(e) if e.kind() == io::ErrorKind::TimedOut => "connection-timeout", Error::BadRequest => "bad-request", Error::StreamTimeout => "connection-timeout", Error::TlsError => "tls-error", Error::PolicyViolation => "policy-violation", Error::TooManyRedirects => "too-many-redirects", _ => "undefined-condition", } } } impl From for Error { fn from(e: io::Error) -> Self { Self::Io(e) } } impl From for Error { fn from(e: str::Utf8Error) -> Self { Self::Utf8(e) } } impl From> for Error { fn from(_e: mpsc::SendError) -> Self { Self::Io(io::Error::from(io::ErrorKind::BrokenPipe)) } } impl From> for Error { fn from(e: mpsc::TrySendError) -> Self { let kind = match e { mpsc::TrySendError::Full(_) => io::ErrorKind::WriteZero, mpsc::TrySendError::Disconnected(_) => io::ErrorKind::BrokenPipe, }; Self::Io(io::Error::from(kind)) } } impl From for Error { fn from(_e: mpsc::RecvError) -> Self { Self::Io(io::Error::from(io::ErrorKind::UnexpectedEof)) } } impl From for Error { fn from(e: http1::Error) -> Self { Self::Http(e) } } impl From for Error { fn from(e: websocket::Error) -> Self { Self::WebSocket(e) } } impl From for Error { fn from(_e: ValueActiveError) -> Self { Self::ValueActive } } #[derive(Clone, Copy)] struct MessageItem { mtype: u8, avail: usize, } struct MessageTracker { items: VecDeque, last_partial: bool, } impl MessageTracker { fn new(max_messages: usize) -> Self { Self { items: VecDeque::with_capacity(max_messages), last_partial: false, } } fn in_progress(&self) -> bool { self.last_partial } fn start(&mut self, mtype: u8) -> Result<(), ()> { if self.last_partial || self.items.len() == self.items.capacity() { return Err(()); } self.items.push_back(MessageItem { mtype, avail: 0 }); self.last_partial = true; Ok(()) } fn extend(&mut self, amt: usize) { assert!(self.last_partial); self.items.back_mut().unwrap().avail += amt; } fn done(&mut self) { self.last_partial = false; } // type, avail, done fn current(&self) -> Option<(u8, usize, bool)> { #[allow(clippy::comparison_chain)] if self.items.len() > 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, true)) } else if self.items.len() == 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, !self.last_partial)) } else { None } } fn consumed(&mut self, amt: usize, done: bool) { assert!(amt <= self.items[0].avail); self.items[0].avail -= amt; if done { assert_eq!(self.items[0].avail, 0); self.items.pop_front().unwrap(); } } } pub struct AddrRef<'a> { s: Ref<'a, Option>>, } impl<'a> AddrRef<'a> { pub fn get(&self) -> Option<&[u8]> { match &*self.s { Some(s) => Some(s.as_slice()), None => None, } } } struct StreamSharedDataInner { to_addr: Option>, out_seq: u32, } pub struct StreamSharedData { inner: RefCell, } #[allow(clippy::new_without_default)] impl StreamSharedData { pub fn new() -> Self { Self { inner: RefCell::new(StreamSharedDataInner { to_addr: None, out_seq: 0, }), } } fn reset(&self) { let s = &mut *self.inner.borrow_mut(); s.to_addr = None; s.out_seq = 0; } fn set_to_addr(&self, addr: Option>) { let s = &mut *self.inner.borrow_mut(); s.to_addr = addr; } pub fn to_addr(&self) -> AddrRef { AddrRef { s: Ref::map(self.inner.borrow(), |s| &s.to_addr), } } pub fn out_seq(&self) -> u32 { self.inner.borrow().out_seq } pub fn inc_out_seq(&self) { let s = &mut *self.inner.borrow_mut(); s.out_seq += 1; } } fn make_zhttp_req_response( id: Option<&[u8]>, ptype: zhttppacket::ResponsePacket, scratch: &mut [u8], ) -> Result { let mut ids_mem = [zhttppacket::Id { id: b"", seq: None }]; let ids = if let Some(id) = id { ids_mem[0].id = id; ids_mem.as_slice() } else { &[] }; let zresp = zhttppacket::Response { from: b"", ids, multi: false, ptype, ptype_str: "", }; let size = zresp.serialize(scratch)?; let payload = &scratch[..size]; Ok(zmq::Message::from(payload)) } fn make_zhttp_response( addr: &[u8], zresp: zhttppacket::Response, scratch: &mut [u8], ) -> Result { let size = zresp.serialize(scratch)?; let payload = &scratch[..size]; let mut v = vec![0; addr.len() + 1 + payload.len()]; v[..addr.len()].copy_from_slice(addr); v[addr.len()] = b' '; let pos = addr.len() + 1; v[pos..(pos + payload.len())].copy_from_slice(payload); // this takes over the vec's memory without copying Ok(zmq::Message::from(v)) } async fn recv_nonzero(r: &mut R, buf: &mut RingBuffer) -> Result<(), io::Error> { if buf.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let size = match r.read(buf.write_buf()).await { Ok(size) => size, Err(e) => return Err(e), }; buf.write_commit(size); if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } Ok(()) } struct LimitedRingBuffer<'a> { inner: &'a mut RingBuffer, limit: usize, } impl AsRef<[u8]> for LimitedRingBuffer<'_> { fn as_ref(&self) -> &[u8] { let buf = BaseRingBuffer::read_buf(self.inner); let limit = cmp::min(buf.len(), self.limit); &buf[..limit] } } struct HttpRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, } struct HttpWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, } struct RequestHandler<'a, R: AsyncRead, W: AsyncWrite> { r: HttpRead<'a, R>, w: HttpWrite<'a, W>, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestHandler<'a, R, W> { fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, ) -> Self { buf1.align(); buf2.clear(); Self { r: HttpRead { stream: stream.0, buf1, buf2, }, w: HttpWrite { stream: stream.1 }, } } // read from stream into buf, and parse buf as a request header async fn recv_request<'b: 'c, 'c, const N: usize>( mut self, mut scratch: &'b mut http1::ParseScratch, req_mem: &'c mut Option>, ) -> Result, Error> { let mut protocol = http1::ServerProtocol::new(); assert_eq!(protocol.state(), http1::ServerState::ReceivingRequest); loop { { let hbuf = self.r.buf1.take_inner(); match protocol.recv_request_owned(hbuf, scratch) { http1::ParseStatus::Complete(req) => { assert!([ http1::ServerState::ReceivingBody, http1::ServerState::AwaitingResponse ] .contains(&protocol.state())); *req_mem = Some(req); break Ok(RequestHeader { r: self.r, w: self.w, protocol, req_mem, }); } http1::ParseStatus::Incomplete((), hbuf, ret_scratch) => { // NOTE: after polonius it may not be necessary for // scratch to be returned scratch = ret_scratch; self.r.buf1.set_inner(hbuf); } http1::ParseStatus::Error(e, hbuf, _) => { self.r.buf1.set_inner(hbuf); return Err(e.into()); } } } if let Err(e) = recv_nonzero(&mut self.r.stream, self.r.buf1).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } } } } struct RequestHeader<'a, 'b, 'c, R: AsyncRead, W: AsyncWrite, const N: usize> { r: HttpRead<'a, R>, w: HttpWrite<'a, W>, protocol: http1::ServerProtocol, req_mem: &'c mut Option>, } impl<'a, 'b, 'c, R: AsyncRead, W: AsyncWrite, const N: usize> RequestHeader<'a, 'b, 'c, R, W, N> { fn request(&self) -> http1::Request { self.req_mem.as_ref().unwrap().get() } async fn start_recv_body(mut self) -> Result, Error> { self.handle_expect().await?; // restore the read ringbuffer self.discard_request(); Ok(self.into_recv_body().0) } async fn start_recv_body_and_keep_header( mut self, ) -> Result, Error> { self.handle_expect().await?; // we're keeping the request, so put any remaining bytes into buf2 // and swap the inner buffers. those bytes will then become readable // from buf1. we'll plan to give the request's inner buffer to buf2 // after the request is no longer needed let req = self.req_mem.as_ref().unwrap(); self.r.buf2.write_all(req.remaining_bytes())?; self.r.buf1.swap_inner(self.r.buf2); let (recv_body, req_mem) = self.into_recv_body(); Ok(RequestRecvBodyKeepHeader { inner: recv_body, req_mem, }) } fn recv_done(mut self) -> Result, Error> { // restore the read ringbuffer self.discard_request(); Ok(RequestStartResponse::new(self.r, self.w, self.protocol)) } // this method requires the request to exist async fn handle_expect(&mut self) -> Result<(), Error> { if !self.request().expect_100 { return Ok(()); } let mut cont = [0; 32]; let cont = { let mut c = io::Cursor::new(&mut cont[..]); if let Err(e) = self.protocol.send_100_continue(&mut c) { return Err(e.into()); } let size = c.position() as usize; &cont[..size] }; let mut left = cont.len(); while left > 0 { let pos = cont.len() - left; let size = match self.w.stream.write(&cont[pos..]).await { Ok(size) => size, Err(e) => return Err(e.into()), }; left -= size; } Ok(()) } // consumes request and gives the inner buffer back to buf1 fn discard_request(&mut self) { let req = self.req_mem.take().unwrap(); let remaining_len = req.remaining_bytes().len(); let inner_buf = req.into_buf(); let hsize = inner_buf.filled_len() - remaining_len; self.r.buf1.set_inner(inner_buf); self.r.buf1.read_commit(hsize); } fn into_recv_body( self, ) -> ( RequestRecvBody<'a, R, W>, &'c mut Option>, ) { ( RequestRecvBody { r: RefCell::new(RecvBodyRead { stream: self.r.stream, buf: self.r.buf1, }), wstream: self.w.stream, buf2: self.r.buf2, protocol: RefCell::new(self.protocol), }, self.req_mem, ) } } struct RecvBodyRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut RingBuffer, } struct RequestRecvBody<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, wstream: WriteHalf<'a, W>, buf2: &'a mut RingBuffer, protocol: RefCell, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestRecvBody<'a, R, W> { fn more(&self) -> bool { self.protocol.borrow().state() == http1::ServerState::ReceivingBody } #[allow(clippy::await_holding_refcell_ref)] async fn add_to_recv_buffer(&self) -> Result<(), Error> { let r = &mut *self.r.borrow_mut(); if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } Ok(()) } fn try_recv_body(&self, dest: &mut [u8]) -> Option> { let r = &mut *self.r.borrow_mut(); let protocol = &mut *self.protocol.borrow_mut(); if protocol.state() == http1::ServerState::ReceivingBody { loop { let (size, read_size) = { let mut buf = io::Cursor::new(BaseRingBuffer::read_buf(r.buf)); let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let (size, _) = match protocol.recv_body(&mut buf, dest, &mut headers) { Ok(ret) => ret, Err(e) => return Some(Err(e.into())), }; let read_size = buf.position() as usize; (size, read_size) }; if protocol.state() == http1::ServerState::ReceivingBody && read_size == 0 { if !r.buf.is_readable_contiguous() { r.buf.align(); continue; } return None; } r.buf.read_commit(read_size); return Some(Ok(size)); } } assert_eq!(protocol.state(), http1::ServerState::AwaitingResponse); Some(Ok(0)) } async fn recv_body(&self, dest: &mut [u8]) -> Result { loop { if let Some(ret) = self.try_recv_body(dest) { return ret; } self.add_to_recv_buffer().await?; } } fn recv_done(self) -> RequestStartResponse<'a, R, W> { let r = self.r.into_inner(); RequestStartResponse::new( HttpRead { stream: r.stream, buf1: r.buf, buf2: self.buf2, }, HttpWrite { stream: self.wstream, }, self.protocol.into_inner(), ) } } struct RequestRecvBodyKeepHeader<'a, 'b, 'c, R: AsyncRead, W: AsyncWrite, const N: usize> { inner: RequestRecvBody<'a, R, W>, req_mem: &'c mut Option>, } impl<'a, 'b, 'c, R: AsyncRead, W: AsyncWrite, const N: usize> RequestRecvBodyKeepHeader<'a, 'b, 'c, R, W, N> { fn request(&self) -> http1::Request { self.req_mem.as_ref().unwrap().get() } async fn recv_body(&self, dest: &mut [u8]) -> Result { self.inner.recv_body(dest).await } fn recv_done(self) -> RequestStartResponse<'a, R, W> { // the request is no longer needed, so give its inner buffer to buf2 // and clear it let buf = self.req_mem.take().unwrap().into_buf(); self.inner.buf2.set_inner(buf); self.inner.buf2.clear(); self.inner.recv_done() } } struct RequestStartResponse<'a, R: AsyncRead, W: AsyncWrite> { r: HttpRead<'a, R>, w: HttpWrite<'a, W>, protocol: http1::ServerProtocol, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestStartResponse<'a, R, W> { fn new(r: HttpRead<'a, R>, w: HttpWrite<'a, W>, protocol: http1::ServerProtocol) -> Self { Self { r, w, protocol } } async fn fill_recv_buffer(&mut self) -> Error { loop { if let Err(e) = recv_nonzero(&mut self.r.stream, self.r.buf1).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever let () = std::future::pending().await; } return e.into(); } } } fn prepare_response( mut self, code: u16, reason: &str, headers: &[http1::Header<'_>], body_size: http1::BodySize, ) -> Result, Error> { self.r.buf2.clear(); let mut hbuf = io::Cursor::new(self.r.buf2.write_buf()); if let Err(e) = self .protocol .send_response(&mut hbuf, code, reason, headers, body_size) { return Err(e.into()); } let size = hbuf.position() as usize; self.r.buf2.write_commit(size); let (stream, buf1, buf2) = ((self.r.stream, self.w.stream), self.r.buf1, self.r.buf2); Ok(RequestSendHeader::new( stream, buf1, buf2, self.protocol, size, )) } } struct SendHeaderRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut RingBuffer, } struct EarlyBody { overflow: Option, done: bool, } struct RequestSendHeader<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, wstream: RefCell>, wbuf: RefCell>, protocol: http1::ServerProtocol, early_body: RefCell, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestSendHeader<'a, R, W> { fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, protocol: http1::ServerProtocol, header_size: usize, ) -> Self { Self { r: RefCell::new(SendHeaderRead { stream: stream.0, buf: buf1, }), wstream: RefCell::new(stream.1), wbuf: RefCell::new(LimitedRingBuffer { inner: buf2, limit: header_size, }), protocol, early_body: RefCell::new(EarlyBody { overflow: None, done: false, }), } } #[allow(clippy::await_holding_refcell_ref)] async fn send_header(&self) -> Result<(), Error> { let mut stream = self.wstream.borrow_mut(); // limit = header bytes left while self.wbuf.borrow().limit > 0 { let size = stream.write_shared(&self.wbuf).await?; let mut wbuf = self.wbuf.borrow_mut(); wbuf.inner.read_commit(size); wbuf.limit -= size; } let mut wbuf = self.wbuf.borrow_mut(); let mut early_body = self.early_body.borrow_mut(); if let Some(overflow) = &mut early_body.overflow { wbuf.inner.write_all(Buffer::read_buf(overflow))?; early_body.overflow = None; } Ok(()) } fn append_body(&self, body: &[u8], more: bool, id: &str) -> Result<(), Error> { let mut wbuf = self.wbuf.borrow_mut(); let mut early_body = self.early_body.borrow_mut(); // limit = header bytes left if wbuf.limit > 0 { // if there are still header bytes in the buffer, then we may // need to overflow into a separate buffer if there's not enough // room let accepted = if early_body.overflow.is_none() { wbuf.inner.write(body)? } else { 0 }; if accepted < body.len() { debug!( "server-conn {}: overflowing {} bytes", id, body.len() - accepted ); if early_body.overflow.is_none() { // only allow overflowing as much as there are header // bytes left early_body.overflow = Some(Buffer::new(wbuf.limit)); } let overflow = early_body.overflow.as_mut().unwrap(); overflow.write_all(&body[accepted..])?; } } else { // if the header has been fully cleared from the buffer, then // always write directly to the buffer wbuf.inner.write_all(body)?; } early_body.done = !more; Ok(()) } fn send_header_done(self) -> RequestSendBody<'a, R, W> { let r = self.r.into_inner(); let wstream = self.wstream.into_inner(); let wbuf = self.wbuf.into_inner(); let early_body = self.early_body.borrow(); assert_eq!(wbuf.limit, 0); assert!(early_body.overflow.is_none()); let (stream, buf1, buf2) = { ((r.stream, wstream), r.buf, wbuf.inner) }; RequestSendBody { r: RefCell::new(HttpSendBodyRead { stream: stream.0, buf: buf1, }), w: RefCell::new(HttpSendBodyWrite { stream: stream.1, buf: buf2, body_done: early_body.done, }), protocol: RefCell::new(self.protocol), } } } struct HttpSendBodyRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut RingBuffer, } struct HttpSendBodyWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, buf: &'a mut RingBuffer, body_done: bool, } struct SendBodyFuture<'a, 'b, W: AsyncWrite> { w: &'a RefCell>, protocol: &'a RefCell, } impl<'a, 'b, W: AsyncWrite> Future for SendBodyFuture<'a, 'b, W> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; let w = &mut *f.w.borrow_mut(); let stream = &mut w.stream; if !stream.is_writable() { return Poll::Pending; } let protocol = &mut *f.protocol.borrow_mut(); let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = w.buf.get_ref_vectored(&mut buf_arr); match protocol.send_body( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs, w.body_done, None, ) { Ok(size) => Poll::Ready(Ok(size)), Err(http1::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e.into())), } } } impl Drop for SendBodyFuture<'_, '_, W> { fn drop(&mut self) { self.w.borrow_mut().stream.cancel(); } } struct RequestSendBody<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, protocol: RefCell, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestSendBody<'a, R, W> { fn append_body(&self, body: &[u8], more: bool) -> Result<(), Error> { let w = &mut *self.w.borrow_mut(); w.buf.write_all(body)?; w.body_done = !more; Ok(()) } fn can_flush(&self) -> bool { let w = &*self.w.borrow(); w.buf.read_avail() > 0 || w.body_done } async fn flush_body(&self) -> Result<(usize, bool), Error> { { let protocol = &*self.protocol.borrow(); assert_eq!(protocol.state(), http1::ServerState::SendingBody); let w = &*self.w.borrow(); if w.buf.read_avail() == 0 && !w.body_done { return Ok((0, false)); } } let size = SendBodyFuture { w: &self.w, protocol: &self.protocol, } .await?; let w = &mut *self.w.borrow_mut(); let protocol = &*self.protocol.borrow(); w.buf.read_commit(size); if w.buf.read_avail() > 0 || !w.body_done || protocol.state() == http1::ServerState::SendingBody { return Ok((size, false)); } assert_eq!(protocol.state(), http1::ServerState::Finished); Ok((size, true)) } #[allow(clippy::await_holding_refcell_ref)] async fn send_body(&self, body: &[u8], more: bool) -> Result { let w = &mut *self.w.borrow_mut(); let protocol = &mut *self.protocol.borrow_mut(); assert_eq!(protocol.state(), http1::ServerState::SendingBody); Ok(protocol .send_body_async(&mut w.stream, &[body], !more, None) .await?) } #[allow(clippy::await_holding_refcell_ref)] async fn fill_recv_buffer(&self) -> Error { let r = &mut *self.r.borrow_mut(); loop { if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever let () = std::future::pending().await; } return e.into(); } } } fn finish(self) -> bool { self.protocol.borrow().is_persistent() } } struct WebSocketRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut RingBuffer, } struct WebSocketWrite<'a, W: AsyncWrite, M> { stream: WriteHalf<'a, W>, buf: &'a mut BaseRingBuffer, } struct SendMessageContentFuture<'a, 'b, W: AsyncWrite, M> { w: &'a RefCell>, protocol: &'a websocket::Protocol, avail: usize, done: bool, } impl<'a, 'b, W: AsyncWrite, M: AsRef<[u8]> + AsMut<[u8]>> Future for SendMessageContentFuture<'a, 'b, W, M> { type Output = Result<(usize, bool), Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; let w = &mut *f.w.borrow_mut(); let stream = &mut w.stream; if !stream.is_writable() { return Poll::Pending; } // protocol.send_message_content may add 1 element to vector let mut buf_arr = mem::MaybeUninit::<[&mut [u8]; VECTORED_MAX - 1]>::uninit(); let mut bufs = w.buf.get_mut_vectored(&mut buf_arr).limit(f.avail); match f.protocol.send_message_content( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs.as_slice(), f.done, ) { Ok(ret) => Poll::Ready(Ok(ret)), Err(websocket::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e.into())), } } } impl Drop for SendMessageContentFuture<'_, '_, W, M> { fn drop(&mut self) { self.w.borrow_mut().stream.cancel(); } } struct WebSocketHandler<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, protocol: websocket::Protocol<&'a mut [u8]>, } impl<'a, R: AsyncRead, W: AsyncWrite> WebSocketHandler<'a, R, W> { fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut RingBuffer, buf2: &'a mut SliceRingBuffer<'a>, deflate_config: Option<(bool, SliceRingBuffer<'a>)>, ) -> Self { buf2.clear(); Self { r: RefCell::new(WebSocketRead { stream: stream.0, buf: buf1, }), w: RefCell::new(WebSocketWrite { stream: stream.1, buf: buf2, }), protocol: websocket::Protocol::new(deflate_config), } } fn state(&self) -> websocket::State { self.protocol.state() } #[allow(clippy::await_holding_refcell_ref)] async fn add_to_recv_buffer(&self) -> Result<(), Error> { let r = &mut *self.r.borrow_mut(); if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } Ok(()) } fn try_recv_message_content( &self, dest: &mut [u8], ) -> Option> { let r = &mut *self.r.borrow_mut(); loop { match self.protocol.recv_message_content(r.buf, dest) { Some(Ok(ret)) => return Some(Ok(ret)), Some(Err(e)) => return Some(Err(e.into())), None => { if !r.buf.is_readable_contiguous() { r.buf.align(); continue; } return None; } } } } fn accept_avail(&self) -> usize { self.w.borrow().buf.write_avail() } fn accept_body(&self, body: &[u8]) -> Result<(), Error> { let w = &mut *self.w.borrow_mut(); w.buf.write_all(body)?; Ok(()) } fn is_sending_message(&self) -> bool { self.protocol.is_sending_message() } fn send_message_start(&self, opcode: u8, mask: Option<[u8; 4]>) { self.protocol.send_message_start(opcode, mask); } async fn send_message_content( &self, avail: usize, done: bool, bytes_sent: &F, ) -> Result<(usize, bool), Error> where F: Fn(), { loop { let (size, done) = SendMessageContentFuture { w: &self.w, protocol: &self.protocol, avail, done, } .await?; let w = &mut *self.w.borrow_mut(); if size == 0 && !done { continue; } w.buf.read_commit(size); bytes_sent(); return Ok((size, done)); } } } struct ZhttpStreamSessionOut<'a> { instance_id: &'a str, id: &'a str, packet_buf: &'a RefCell>, sender_stream: &'a AsyncLocalSender<(ArrayVec, zmq::Message)>, shared: &'a StreamSharedData, } impl<'a> ZhttpStreamSessionOut<'a> { fn new( instance_id: &'a str, id: &'a str, packet_buf: &'a RefCell>, sender_stream: &'a AsyncLocalSender<(ArrayVec, zmq::Message)>, shared: &'a StreamSharedData, ) -> Self { Self { instance_id, id, packet_buf, sender_stream, shared, } } async fn check_send(&self) { self.sender_stream.check_send().await } fn cancel_send(&self) { self.sender_stream.cancel(); } // this method is non-blocking, in order to increment the sequence number // and send the message in one shot, without concurrent activity // interfering with the sequencing. to send asynchronously, first await // on check_send and then call this method fn try_send_msg(&self, zreq: zhttppacket::Request) -> Result<(), Error> { let msg = { let mut zreq = zreq; let ids = [zhttppacket::Id { id: self.id.as_bytes(), seq: Some(self.shared.out_seq()), }]; zreq.from = self.instance_id.as_bytes(); zreq.ids = &ids; zreq.multi = true; let packet_buf = &mut *self.packet_buf.borrow_mut(); let size = zreq.serialize(packet_buf)?; zmq::Message::from(&packet_buf[..size]) }; let mut addr = ArrayVec::new(); if addr .try_extend_from_slice(self.shared.to_addr().get().unwrap()) .is_err() { return Err(io::Error::from(io::ErrorKind::InvalidInput).into()); } self.sender_stream.try_send((addr, msg))?; self.shared.inc_out_seq(); Ok(()) } } struct ZhttpServerStreamSessionOut<'a> { instance_id: &'a str, id: &'a [u8], packet_buf: &'a RefCell>, sender: &'a AsyncLocalSender, shared: &'a StreamSharedData, } impl<'a> ZhttpServerStreamSessionOut<'a> { fn new( instance_id: &'a str, id: &'a [u8], packet_buf: &'a RefCell>, sender: &'a AsyncLocalSender, shared: &'a StreamSharedData, ) -> Self { Self { instance_id, id, packet_buf, sender, shared, } } async fn check_send(&self) { self.sender.check_send().await } fn cancel_send(&self) { self.sender.cancel(); } // this method is non-blocking, in order to increment the sequence number // and send the message in one shot, without concurrent activity // interfering with the sequencing. to send asynchronously, first await // on check_send and then call this method fn try_send_msg(&self, zresp: zhttppacket::Response) -> Result<(), Error> { let msg = { let mut zresp = zresp; let ids = [zhttppacket::Id { id: self.id, seq: Some(self.shared.out_seq()), }]; zresp.from = self.instance_id.as_bytes(); zresp.ids = &ids; zresp.multi = true; let addr = self.shared.to_addr(); let addr = addr.get().unwrap(); let packet_buf = &mut *self.packet_buf.borrow_mut(); make_zhttp_response(addr, zresp, packet_buf)? }; self.sender.try_send(msg)?; self.shared.inc_out_seq(); Ok(()) } } struct ZhttpStreamSessionIn<'a, 'b, R> { id: &'a str, send_buf_size: usize, websocket: bool, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, next: Option<(Track<'b, arena::Rc>, usize)>, seq: u32, credits: u32, first_data: bool, } impl<'a, 'b: 'a, R> ZhttpStreamSessionIn<'a, 'b, R> where R: Fn(), { fn new( id: &'a str, send_buf_size: usize, websocket: bool, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, ) -> Self { Self { id, send_buf_size, websocket, receiver, shared, msg_read, next: None, seq: 0, credits: 0, first_data: true, } } fn credits(&self) -> u32 { self.credits } fn subtract_credits(&mut self, amount: u32) { self.credits -= amount; } async fn peek_msg(&mut self) -> Result<&arena::Rc, Error> { if self.next.is_none() { let (r, id_index) = loop { let (r, id_index) = Track::map_first(self.receiver.recv().await?); let zresp = r.get().get(); if zresp.ids[id_index].id != self.id.as_bytes() { // skip messages addressed to old ids continue; } break (r, id_index); }; let zresp = r.get().get(); if !zresp.ptype_str.is_empty() { debug!( "server-conn {}: handle packet: {}", self.id, zresp.ptype_str ); } else { debug!("server-conn {}: handle packet: (data)", self.id); } if zresp.ids.is_empty() { return Err(Error::BadMessage); } if let Some(seq) = zresp.ids[id_index].seq { if seq != self.seq { debug!( "server-conn {}: bad seq (expected {}, got {}), skipping", self.id, self.seq, seq ); return Err(Error::BadMessage); } self.seq += 1; } let mut addr = ArrayVec::new(); if addr.try_extend_from_slice(zresp.from).is_err() { return Err(Error::BadMessage); } self.shared.set_to_addr(Some(addr)); (self.msg_read)(); match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => { let mut credits = rdata.credits; if self.first_data { self.first_data = false; if self.websocket && credits == 0 { // workaround for pushpin-proxy, which doesn't // send credits on websocket accept credits = self.send_buf_size as u32; debug!( "server-conn {}: no credits in websocket accept, assuming {}", self.id, credits ); } } self.credits += credits; } zhttppacket::ResponsePacket::Error(edata) => { debug!( "server-conn {}: zhttp error condition={}", self.id, edata.condition ); } zhttppacket::ResponsePacket::Credit(cdata) => { self.credits += cdata.credits; } zhttppacket::ResponsePacket::Ping(pdata) => { self.credits += pdata.credits; } zhttppacket::ResponsePacket::Pong(pdata) => { self.credits += pdata.credits; } _ => {} } self.next = Some((r, id_index)); } Ok(&self.next.as_ref().unwrap().0) } async fn recv_msg( &mut self, ) -> Result>, Error> { self.peek_msg().await?; Ok(self.next.take().unwrap().0) } } struct ZhttpServerStreamSessionIn<'a, 'b, R> { log_id: &'a str, id: &'a [u8], receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, next: Option<(Track<'b, arena::Rc>, usize)>, seq: u32, credits: u32, } impl<'a, 'b: 'a, R> ZhttpServerStreamSessionIn<'a, 'b, R> where R: Fn(), { fn new( log_id: &'a str, id: &'a [u8], credits: u32, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, ) -> Self { Self { log_id, id, receiver, shared, msg_read, next: None, seq: 1, credits, } } fn credits(&self) -> u32 { self.credits } fn subtract_credits(&mut self, amount: u32) { self.credits -= amount; } async fn peek_msg(&mut self) -> Result<&arena::Rc, Error> { if self.next.is_none() { let (r, id_index) = loop { let (r, id_index) = Track::map_first(self.receiver.recv().await?); let zreq = r.get().get(); if zreq.ids[id_index].id != self.id { // skip messages addressed to old ids continue; } break (r, id_index); }; let zreq = r.get().get(); if !zreq.ptype_str.is_empty() { debug!( "client-conn {}: handle packet: {}", self.log_id, zreq.ptype_str ); } else { debug!("client-conn {}: handle packet: (data)", self.log_id); } if zreq.ids.is_empty() { return Err(Error::BadMessage); } if let Some(seq) = zreq.ids[id_index].seq { if seq != self.seq { debug!( "client-conn {}: bad seq (expected {}, got {}), skipping", self.log_id, self.seq, seq ); return Err(Error::BadMessage); } self.seq += 1; } let mut addr = ArrayVec::new(); if addr.try_extend_from_slice(zreq.from).is_err() { return Err(Error::BadMessage); } self.shared.set_to_addr(Some(addr)); (self.msg_read)(); match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => { self.credits += rdata.credits; } zhttppacket::RequestPacket::Error(edata) => { debug!( "client-conn {}: zhttp error condition={}", self.log_id, edata.condition ); } zhttppacket::RequestPacket::Credit(cdata) => { self.credits += cdata.credits; } zhttppacket::RequestPacket::Ping(pdata) => { self.credits += pdata.credits; } zhttppacket::RequestPacket::Pong(pdata) => { self.credits += pdata.credits; } _ => {} } self.next = Some((r, id_index)); } Ok(&self.next.as_ref().unwrap().0) } async fn recv_msg(&mut self) -> Result>, Error> { self.peek_msg().await?; Ok(self.next.take().unwrap().0) } } async fn send_msg(sender: &AsyncLocalSender, msg: zmq::Message) -> Result<(), Error> { Ok(sender.send(msg).await?) } async fn discard_while( receiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, fut: F, ) -> F::Output where F: Future> + Unpin, { match select_2(fut, pin!(receiver.recv())).await { Select2::R1(v) => v, Select2::R2(_) => Err(Error::BadMessage), // unexpected message in current state } } async fn server_discard_while( receiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, fut: F, ) -> F::Output where F: Future> + Unpin, { match select_2(fut, pin!(receiver.recv())).await { Select2::R1(v) => v, Select2::R2(_) => Err(Error::BadMessage), // unexpected message in current state } } // return true if persistent #[allow(clippy::too_many_arguments)] async fn server_req_handler( id: &str, stream: &mut S, peer_addr: Option<&SocketAddr>, secure: bool, buf1: &mut RingBuffer, buf2: &mut RingBuffer, body_buf: &mut Buffer, packet_buf: &RefCell>, zsender: &AsyncLocalSender, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result { let stream = RefCell::new(stream); let handler = RequestHandler::new(io_split(&stream), buf1, buf2); let mut scratch = http1::ParseScratch::::new(); let mut req_mem = None; // receive request header // ABR: discard_while let handler = match discard_while( zreceiver, pin!(handler.recv_request(&mut scratch, &mut req_mem)), ) .await { Ok(handler) => handler, Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(false), Err(e) => return Err(e), }; // log request { let req = handler.request(); let host = get_host(req.headers); let scheme = if secure { "https" } else { "http" }; debug!( "server-conn {}: request: {} {}://{}{}", id, req.method, scheme, host, req.uri ); } // receive request body // ABR: discard_while let handler = discard_while(zreceiver, pin!(handler.start_recv_body_and_keep_header())).await?; loop { // ABR: discard_while let size = discard_while(zreceiver, pin!(handler.recv_body(body_buf.write_buf()))).await?; if size == 0 { break; } body_buf.write_commit(size); } // determine how to respond let msg = { let req = handler.request(); let mut websocket = false; for h in req.headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") && h.value == b"websocket" { websocket = true; break; } } if websocket { // websocket requests are not supported in req mode // toss the request body body_buf.clear(); None } else { // regular http requests we can handle // prepare zmq message let ids = [zhttppacket::Id { id: id.as_bytes(), seq: None, }]; let msg = make_zhttp_request( "", &ids, req.method, req.uri, req.headers, Buffer::read_buf(body_buf), false, Mode::HttpReq, 0, peer_addr, secure, &mut packet_buf.borrow_mut(), )?; // body consumed body_buf.clear(); Some(msg) } }; let (handler, websocket) = if let Some(msg) = msg { // handle as http let handler = handler.recv_done(); // send message // ABR: discard_while discard_while(zreceiver, pin!(send_msg(zsender, msg))).await?; // receive message let zresp = loop { // ABR: direct read let (zresp, id_index) = Track::map_first(zreceiver.recv().await?); let zresp_ref = zresp.get().get(); if zresp_ref.ids[id_index].id != id.as_bytes() { // skip messages addressed to old ids continue; } if !zresp_ref.ptype_str.is_empty() { debug!("server-conn {}: handle packet: {}", id, zresp_ref.ptype_str); } else { debug!("server-conn {}: handle packet: (data)", id); } // skip non-data messages match &zresp_ref.ptype { zhttppacket::ResponsePacket::Data(_) => break zresp, _ => debug!( "server-conn {}: unexpected packet in req mode: {}", id, zresp_ref.ptype_str ), } }; let handler = { let zresp = zresp.get().get(); let rdata = match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => rdata, _ => unreachable!(), // we confirmed the type above }; // send response header let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } let headers = &headers[..headers_len]; let handler = handler.prepare_response( rdata.code, rdata.reason, headers, http1::BodySize::Known(rdata.body.len()), )?; body_buf.write_all(rdata.body)?; handler }; drop(zresp); // ABR: discard_while discard_while(zreceiver, pin!(handler.send_header())).await?; (handler.send_header_done(), false) } else { // handle as websocket // send response header let headers = &[http1::Header { name: "Content-Type", value: b"text/plain", }]; let body = "WebSockets not supported on req mode interface.\n"; let handler = handler.recv_done(); let handler = handler.prepare_response( 400, "Bad Request", headers, http1::BodySize::Known(body.len()), )?; // ABR: discard_while discard_while(zreceiver, pin!(handler.send_header())).await?; let handler = handler.send_header_done(); body_buf.write_all(body.as_bytes())?; (handler, true) }; // send response body while body_buf.read_avail() > 0 { // ABR: discard_while let size = discard_while( zreceiver, pin!(handler.send_body(Buffer::read_buf(body_buf), false)), ) .await?; body_buf.read_commit(size); } let persistent = handler.finish(); if websocket { return Ok(false); } Ok(persistent) } #[allow(clippy::too_many_arguments)] async fn server_req_connection_inner( token: CancellationToken, cid: &mut ArrayString<32>, cid_provider: &mut P, mut stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, zsender: AsyncLocalSender, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let mut buf1 = RingBuffer::new(buffer_size, rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, rb_tmp); let mut body_buf = Buffer::new(body_buffer_size); loop { stream.set_id(cid); // this was originally logged when starting the non-async state // machine, so we'll keep doing that debug!("server-conn {}: assigning id", cid); let reuse = { let handler = server_req_handler( cid.as_ref(), &mut stream, peer_addr, secure, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, &zsender, zreceiver, ); let timeout = Timeout::new(reactor.now() + timeout); match select_3(pin!(handler), timeout.elapsed(), token.cancelled()).await { Select3::R1(ret) => ret?, Select3::R2(_) => return Err(Error::StreamTimeout), Select3::R3(_) => return Err(Error::Stopped), } }; if !reuse { break; } // note: buf1 is not cleared as there may be data to read buf2.clear(); body_buf.clear(); *cid = cid_provider.get_new_assigned_cid(); } // ABR: discard_while discard_while(zreceiver, pin!(async { Ok(stream.close().await?) })).await?; Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn server_req_connection( token: CancellationToken, mut cid: ArrayString<32>, cid_provider: &mut P, stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, zsender: AsyncLocalSender, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, ) { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( server_req_connection_inner( token, &mut cid, cid_provider, stream, peer_addr, secure, buffer_size, body_buffer_size, rb_tmp, packet_buf, timeout, zsender, &zreceiver, ), &value_active, ) .await { Ok(()) => debug!("server-conn {}: finished", cid), Err(e) => { let level = match e { Error::ValueActive => Level::Error, _ => Level::Debug, }; log!(level, "server-conn {}: process error: {:?}", cid, e); } } } async fn accept_handoff( zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { // discarding here is fine. the sender should cease sending // messages until we've replied with proceed discard_while( zsess_in.receiver, pin!(async { zsess_out.check_send().await; Ok(()) }), ) .await?; let zreq = zhttppacket::Request::new_handoff_proceed(b"", &[]); // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; // pause until we get a msg zsess_in.peek_msg().await?; Ok(()) } async fn server_accept_handoff( zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { // discarding here is fine. the sender should cease sending // messages until we've replied with proceed server_discard_while( zsess_in.receiver, pin!(async { zsess_out.check_send().await; Ok(()) }), ) .await?; let zresp = zhttppacket::Response::new_handoff_proceed(b"", &[]); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; // pause until we get a msg zsess_in.peek_msg().await?; Ok(()) } // this function will either return immediately or await messages async fn handle_other( zresp: Track<'_, arena::Rc>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { match &zresp.get().get().ptype { zhttppacket::ResponsePacket::KeepAlive => Ok(()), zhttppacket::ResponsePacket::Credit(_) => Ok(()), zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); accept_handoff(zsess_in, zsess_out).await?; Ok(()) } zhttppacket::ResponsePacket::Error(_) => Err(Error::HandlerError), zhttppacket::ResponsePacket::Cancel => Err(Error::HandlerCancel), _ => Err(Error::BadMessage), // unexpected type } } // this function will either return immediately or await messages async fn server_handle_other( zreq: Track<'_, arena::Rc>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { match &zreq.get().get().ptype { zhttppacket::RequestPacket::KeepAlive => Ok(()), zhttppacket::RequestPacket::Credit(_) => Ok(()), zhttppacket::RequestPacket::HandoffStart => { drop(zreq); server_accept_handoff(zsess_in, zsess_out).await?; Ok(()) } zhttppacket::RequestPacket::Error(_) => Err(Error::HandlerError), zhttppacket::RequestPacket::Cancel => Err(Error::HandlerCancel), _ => Err(Error::BadMessage), // unexpected type } } async fn stream_recv_body<'a, 'b, 'c, R1, R2, R, W, const N: usize>( tmp_buf: &RefCell>, bytes_read: &R1, handler: RequestHeader<'a, 'b, 'c, R, W, N>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result, Error> where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { let handler = { let mut start_recv_body = pin!(handler.start_recv_body()); // ABR: poll_async doesn't block match poll_async(start_recv_body.as_mut()).await { Poll::Ready(ret) => ret?, Poll::Pending => { // if we get here, then the send buffer with the client is full // keep trying to process while reading messages loop { // ABR: select contains read let ret = select_2(start_recv_body.as_mut(), pin!(zsess_in.recv_msg())).await; match ret { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zresp = ret?; // note: if we get a data message, handle_other will // error out. technically a data message should be // allowed here, but we're not in a position to do // anything with it, so we error. // // fortunately, the conditions to hit this are unusual: // * we need to receive a subsequent request over // a persistent connection // * that request needs to be one for which a body // would be expected, and the request needs to // include an expect header // * the send buffer to that connection needs to be // full // * the handler needs to provide an early response // before receiving the request body // // in other words, a client needs to send a large // pipelined POST over a reused connection, before it // has read the previous response, and the handler // needs to reject the request // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } } } }; { let mut check_send = pin!(None); let mut add_to_recv_buffer = pin!(None); loop { if zsess_in.credits() > 0 && add_to_recv_buffer.is_none() && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read let ret = select_3( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_recv_buffer.as_mut().as_pin_mut()), pin!(zsess_in.peek_msg()), ) .await; match ret { Select3::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); assert!(zsess_in.credits() > 0); assert!(add_to_recv_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let size = match handler.try_recv_body(&mut tmp_buf[..max_read]) { Some(ret) => ret?, None => { add_to_recv_buffer.set(Some(handler.add_to_recv_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; zsess_in.subtract_credits(size as u32); let mut rdata = zhttppacket::RequestData::new(); rdata.body = body; rdata.more = handler.more(); let zreq = zhttppacket::Request::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; if !handler.more() { break; } } Select3::R2(ret) => { ret?; add_to_recv_buffer.set(None); } Select3::R3(ret) => { let r = ret?; let zresp_ref = r.get().get(); match &zresp_ref.ptype { zhttppacket::ResponsePacket::Data(_) => break, _ => { // ABR: direct read let zresp = zsess_in.recv_msg().await?; // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } } } } Ok(handler.recv_done()) } async fn server_stream_recv_body<'a, R1, R2, R>( tmp_buf: &RefCell>, bytes_read: &R1, resp_body: ClientResponseBody<'a, R>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result where R1: Fn(), R2: Fn(), R: AsyncRead, { let mut check_send = pin!(None); let mut add_to_buffer = pin!(None); loop { if zsess_in.credits() > 0 && add_to_buffer.is_none() && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read let ret = select_3( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_buffer.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select3::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); assert!(zsess_in.credits() > 0); assert!(add_to_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (size, mut finished) = match resp_body.try_recv(&mut tmp_buf[..max_read])? { RecvStatus::Complete(finished, written) => (written, Some(finished)), RecvStatus::Read((), written) => { if written == 0 { add_to_buffer.set(Some(resp_body.add_to_buffer())); continue; } (written, None) } }; bytes_read(); let body = &tmp_buf[..size]; zsess_in.subtract_credits(size as u32); let mut rdata = zhttppacket::ResponseData::new(); rdata.body = body; rdata.more = finished.is_none(); let zresp = zhttppacket::Response::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; if let Some(finished) = finished.take() { return Ok(finished); } } Select3::R2(ret) => { ret?; add_to_buffer.set(None); } Select3::R3(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } async fn stream_send_body<'a, R1, R2, R, W>( bytes_read: &R1, handler: &RequestSendBody<'a, R, W>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { let mut out_credits = 0; let mut flush_body = pin!(None); let mut check_send = pin!(None); 'main: loop { let ret = { if flush_body.is_none() && handler.can_flush() { flush_body.set(Some(handler.flush_body())); } if out_credits > 0 && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read select_4( select_option(flush_body.as_mut().as_pin_mut()), select_option(check_send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), pin!(handler.fill_recv_buffer()), ) .await }; match ret { Select4::R1(ret) => { flush_body.set(None); let (size, done) = ret?; if done { break; } out_credits += size as u32; if size > 0 { bytes_read(); } } Select4::R2(()) => { check_send.set(None); let zreq = zhttppacket::Request::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; } Select4::R3(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => { handler.append_body(rdata.body, rdata.more)?; } zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait if flush_body.is_none() && handler.can_flush() { flush_body.set(Some(handler.flush_body())); } while let Some(fut) = flush_body.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; flush_body.set(None); let (size, done) = ret?; if done { break 'main; } out_credits += size as u32; if size > 0 { bytes_read(); } if handler.can_flush() { flush_body.set(Some(handler.flush_body())); } } // ABR: function contains read accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } Select4::R4(e) => return Err(e), } } Ok(()) } struct Overflow { buf: Buffer, end: bool, } async fn server_stream_send_body<'a, R1, R2, R, W>( bytes_read: &R1, req_body: ClientRequestBody<'a, R, W>, mut overflow: Option, recv_buf_size: usize, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result, Error> where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { // send initial body, including overflow, before offering credits let mut send = pin!(None); while send.is_some() || req_body.can_send() { if send.is_none() { send.set(Some(req_body.send())); } // ABR: select contains read let result = select_2( select_option(send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match result { Select2::R1(ret) => { send.set(None); match ret { SendStatus::Complete(resp) => return Ok(resp), SendStatus::EarlyResponse(resp) => return Ok(resp), SendStatus::Partial((), _) => { if !req_body.can_send() { if let Some(overflow) = &mut overflow { let size = req_body.prepare(overflow.buf.read_buf(), overflow.end)?; overflow.buf.read_commit(size); } } } SendStatus::Error((), e) => return Err(e), } } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } assert!(!req_body.can_send()); let mut out_credits = recv_buf_size as u32; let mut send = pin!(None); let mut check_send = pin!(None); let mut prepare_done = false; let resp = 'main: loop { let ret = { if send.is_none() && req_body.can_send() { send.set(Some(req_body.send())); } if !prepare_done && out_credits > 0 && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read select_3( select_option(send.as_mut().as_pin_mut()), select_option(check_send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await }; match ret { Select3::R1(ret) => { send.set(None); match ret { SendStatus::Complete(resp) => break resp, SendStatus::EarlyResponse(resp) => break resp, SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error(_, e) => return Err(e), } } Select3::R2(()) => { check_send.set(None); let zresp = zhttppacket::Response::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; } Select3::R3(ret) => { let zreq = ret?; match &zreq.get().get().ptype { zhttppacket::RequestPacket::Data(rdata) => { let size = req_body.prepare(rdata.body, !rdata.more)?; if size < rdata.body.len() { return Err(Error::BufferExceeded); } if !rdata.more { prepare_done = true; } } zhttppacket::RequestPacket::HandoffStart => { drop(zreq); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait if send.is_none() && req_body.can_send() { send.set(Some(req_body.send())); } while let Some(fut) = send.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send.set(None); match ret { SendStatus::Complete(resp) => break 'main resp, SendStatus::EarlyResponse(resp) => break 'main resp, SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error((), e) => return Err(e), } if req_body.can_send() { send.set(Some(req_body.send())); } } // ABR: function contains read server_accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } }; Ok(resp) } #[allow(clippy::too_many_arguments)] async fn stream_websocket( log_id: &str, stream: RefCell<&mut S>, buf1: &mut RingBuffer, buf2: &mut RingBuffer, messages_max: usize, tmp_buf: &RefCell>, bytes_read: &R1, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { // buf2 must be empty since we will repurpose the memory assert_eq!(buf2.read_avail(), 0); let rb_tmp = buf2.get_tmp().clone(); let mut wbuf = buf2.take_inner().into_inner(); let (mut wbuf, deflate_config) = match deflate_config { Some((config, write_buf_size)) => { let (wbuf, ebuf) = wbuf.split_at_mut(write_buf_size); let wbuf = SliceRingBuffer::new(wbuf, &rb_tmp); let ebuf = SliceRingBuffer::new(ebuf, &rb_tmp); (wbuf, Some((!config.server_no_context_takeover, ebuf))) } None => (SliceRingBuffer::new(&mut wbuf, &rb_tmp), None), }; let handler = WebSocketHandler::new(io_split(&stream), buf1, &mut wbuf, deflate_config); let mut ws_in_tracker = MessageTracker::new(messages_max); let mut out_credits = 0; let mut check_send = pin!(None); let mut add_to_recv_buffer = pin!(None); let mut send_content = pin!(None); loop { let (do_send, do_recv) = match handler.state() { websocket::State::Connected => (true, true), websocket::State::PeerClosed => (true, false), websocket::State::Closing => (false, true), websocket::State::Finished => break, }; if out_credits > 0 || (do_recv && zsess_in.credits() > 0 && add_to_recv_buffer.is_none()) && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } if do_send && send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, None); } if avail > 0 || done { send_content.set(Some(handler.send_message_content(avail, done, bytes_read))); } } } // ABR: select contains read let ret = select_4( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_recv_buffer.as_mut().as_pin_mut()), select_option(send_content.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select4::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); if out_credits > 0 { let zreq = zhttppacket::Request::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; continue; } assert!(zsess_in.credits() > 0); assert!(add_to_recv_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (opcode, size, end) = match handler.try_recv_message_content(&mut tmp_buf[..max_read]) { Some(ret) => ret?, None => { add_to_recv_buffer.set(Some(handler.add_to_recv_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; let zreq = match opcode { websocket::OPCODE_TEXT | websocket::OPCODE_BINARY => { if body.is_empty() && !end { // don't bother sending empty message continue; } let mut data = zhttppacket::RequestData::new(); data.body = body; data.content_type = if opcode == websocket::OPCODE_TEXT { Some(zhttppacket::ContentType::Text) } else { Some(zhttppacket::ContentType::Binary) }; data.more = !end; zhttppacket::Request::new_data(b"", &[], data) } websocket::OPCODE_CLOSE => { let status = if body.len() >= 2 { let mut arr = [0; 2]; arr[..].copy_from_slice(&body[..2]); let code = u16::from_be_bytes(arr); let reason = match str::from_utf8(&body[2..]) { Ok(reason) => reason, Err(e) => return Err(e.into()), }; Some((code, reason)) } else { None }; zhttppacket::Request::new_close(b"", &[], status) } websocket::OPCODE_PING => zhttppacket::Request::new_ping(b"", &[], body), websocket::OPCODE_PONG => zhttppacket::Request::new_pong(b"", &[], body), opcode => { debug!( "server-conn {}: unsupported websocket opcode: {}", log_id, opcode ); return Err(Error::BadFrame); } }; zsess_in.subtract_credits(size as u32); // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; } Select4::R2(ret) => { ret?; add_to_recv_buffer.set(None); } Select4::R3(ret) => { send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } Select4::R4(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(rdata.body) { warn!( "received too much data from handler (size={}, credits={})", rdata.body.len(), avail, ); return Err(e); } let opcode = match &rdata.content_type { Some(zhttppacket::ContentType::Binary) => websocket::OPCODE_BINARY, _ => websocket::OPCODE_TEXT, }; if !ws_in_tracker.in_progress() { if ws_in_tracker.start(opcode).is_err() { return Err(Error::BufferExceeded); } } ws_in_tracker.extend(rdata.body.len()); if !rdata.more { ws_in_tracker.done(); } } _ => {} }, zhttppacket::ResponsePacket::Close(cdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let (code, reason) = cdata.status.unwrap_or((1000, "")); let arr: [u8; 2] = code.to_be_bytes(); // close content isn't limited by credits. if we // don't have space for it, just error out handler.accept_body(&arr)?; handler.accept_body(reason.as_bytes())?; if ws_in_tracker.start(websocket::OPCODE_CLOSE).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(arr.len() + reason.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::Ping(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PING).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::Pong(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PONG).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait loop { if send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, None); } if avail > 0 || done { send_content.set(Some( handler.send_message_content(avail, done, bytes_read), )); } } } if let Some(fut) = send_content.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } else { break; } } // ABR: function contains read accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } } } Ok(()) } #[allow(clippy::too_many_arguments)] async fn server_stream_websocket( log_id: &str, stream: RefCell<&mut S>, buf1: &mut RingBuffer, buf2: &mut RingBuffer, messages_max: usize, tmp_buf: &RefCell>, bytes_read: &R1, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { // buf2 must be empty since we will repurpose the memory assert_eq!(buf2.read_avail(), 0); let rb_tmp = buf2.get_tmp().clone(); let mut wbuf = buf2.take_inner().into_inner(); let (mut wbuf, deflate_config) = match deflate_config { Some((config, write_buf_size)) => { let (wbuf, ebuf) = wbuf.split_at_mut(write_buf_size); let wbuf = SliceRingBuffer::new(wbuf, &rb_tmp); let ebuf = SliceRingBuffer::new(ebuf, &rb_tmp); (wbuf, Some((!config.client_no_context_takeover, ebuf))) } None => (SliceRingBuffer::new(&mut wbuf, &rb_tmp), None), }; let handler = WebSocketHandler::new(io_split(&stream), buf1, &mut wbuf, deflate_config); let mut ws_in_tracker = MessageTracker::new(messages_max); let mut out_credits = 0; let mut check_send = pin!(None); let mut add_to_recv_buffer = pin!(None); let mut send_content = pin!(None); loop { let (do_send, do_recv) = match handler.state() { websocket::State::Connected => (true, true), websocket::State::PeerClosed => (true, false), websocket::State::Closing => (false, true), websocket::State::Finished => break, }; if out_credits > 0 || (do_recv && zsess_in.credits() > 0 && add_to_recv_buffer.is_none()) && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } if do_send && send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, Some(gen_mask())); } if avail > 0 || done { send_content.set(Some(handler.send_message_content(avail, done, bytes_read))); } } } // ABR: select contains read let ret = select_4( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_recv_buffer.as_mut().as_pin_mut()), select_option(send_content.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select4::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); if out_credits > 0 { let zresp = zhttppacket::Response::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; continue; } assert!(zsess_in.credits() > 0); assert!(add_to_recv_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (opcode, size, end) = match handler.try_recv_message_content(&mut tmp_buf[..max_read]) { Some(ret) => ret?, None => { add_to_recv_buffer.set(Some(handler.add_to_recv_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; let zresp = match opcode { websocket::OPCODE_TEXT | websocket::OPCODE_BINARY => { if body.is_empty() && !end { // don't bother sending empty message continue; } let mut data = zhttppacket::ResponseData::new(); data.body = body; data.content_type = if opcode == websocket::OPCODE_TEXT { Some(zhttppacket::ContentType::Text) } else { Some(zhttppacket::ContentType::Binary) }; data.more = !end; zhttppacket::Response::new_data(b"", &[], data) } websocket::OPCODE_CLOSE => { let status = if body.len() >= 2 { let mut arr = [0; 2]; arr[..].copy_from_slice(&body[..2]); let code = u16::from_be_bytes(arr); let reason = match str::from_utf8(&body[2..]) { Ok(reason) => reason, Err(e) => return Err(e.into()), }; Some((code, reason)) } else { None }; zhttppacket::Response::new_close(b"", &[], status) } websocket::OPCODE_PING => zhttppacket::Response::new_ping(b"", &[], body), websocket::OPCODE_PONG => zhttppacket::Response::new_pong(b"", &[], body), opcode => { debug!( "client-conn {}: unsupported websocket opcode: {}", log_id, opcode ); return Err(Error::BadFrame); } }; zsess_in.subtract_credits(size as u32); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; } Select4::R2(ret) => { ret?; add_to_recv_buffer.set(None); } Select4::R3(ret) => { send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } Select4::R4(ret) => { let zreq = ret?; match &zreq.get().get().ptype { zhttppacket::RequestPacket::Data(rdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(rdata.body) { warn!( "received too much data from handler (size={}, credits={})", rdata.body.len(), avail, ); return Err(e); } let opcode = match &rdata.content_type { Some(zhttppacket::ContentType::Binary) => websocket::OPCODE_BINARY, _ => websocket::OPCODE_TEXT, }; if !ws_in_tracker.in_progress() { if ws_in_tracker.start(opcode).is_err() { return Err(Error::BufferExceeded); } } ws_in_tracker.extend(rdata.body.len()); if !rdata.more { ws_in_tracker.done(); } } _ => {} }, zhttppacket::RequestPacket::Close(cdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let (code, reason) = cdata.status.unwrap_or((1000, "")); let arr: [u8; 2] = code.to_be_bytes(); // close content isn't limited by credits. if we // don't have space for it, just error out handler.accept_body(&arr)?; handler.accept_body(reason.as_bytes())?; if ws_in_tracker.start(websocket::OPCODE_CLOSE).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(arr.len() + reason.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::Ping(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PING).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::Pong(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PONG).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::HandoffStart => { drop(zreq); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait loop { if send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, Some(gen_mask())); } if avail > 0 || done { send_content.set(Some( handler.send_message_content(avail, done, bytes_read), )); } } } if let Some(fut) = send_content.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } else { break; } } // ABR: function contains read server_accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } } Ok(()) } // return true if persistent #[allow(clippy::too_many_arguments)] async fn server_stream_handler( id: &str, stream: &mut S, peer_addr: Option<&SocketAddr>, secure: bool, buf1: &mut RingBuffer, buf2: &mut RingBuffer, messages_max: usize, allow_compression: bool, packet_buf: &RefCell>, tmp_buf: &RefCell>, instance_id: &str, zsender: &AsyncLocalSender, zsender_stream: &AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, shared: &StreamSharedData, refresh_stream_timeout: &R1, refresh_session_timeout: &R2, ) -> Result where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let stream = RefCell::new(stream); let send_buf_size = buf1.capacity(); // for sending to handler let recv_buf_size = buf2.capacity(); // for receiving from handler let handler = RequestHandler::new(io_split(&stream), buf1, buf2); let mut scratch = http1::ParseScratch::::new(); let mut req_mem = None; let zsess_out = ZhttpStreamSessionOut::new(instance_id, id, packet_buf, zsender_stream, shared); // receive request header // ABR: discard_while let handler = match discard_while( zreceiver, pin!(handler.recv_request(&mut scratch, &mut req_mem)), ) .await { Ok(handler) => handler, Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(false), Err(e) => return Err(e), }; refresh_stream_timeout(); let (body_size, ws_config, msg) = { let req = handler.request(); let mut websocket = false; let mut ws_version = None; let mut ws_key = None; let mut ws_deflate_config = None; for h in req.headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") && h.value == b"websocket" { websocket = true; } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Version") { ws_version = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { for value in http1::parse_header_value(h.value) { let (name, params) = match value { Ok(v) => v, Err(_) => return Err(Error::InvalidWebSocketRequest), }; match name { "permessage-deflate" => { // the client can present multiple offers. take // the first that works. if none work, it's not // an error. we'll just not use compression if allow_compression && ws_deflate_config.is_none() { if let Ok(config) = websocket::PerMessageDeflateConfig::from_params(params) { if let Ok(resp_config) = config.create_response() { // split the original recv buffer memory: // 75% for a new recv buffer, 25% for an encoded buffer let recv_buf_size = recv_buf_size * 3 / 4; ws_deflate_config = Some((resp_config, recv_buf_size)); } } } } name => { debug!("ignoring unsupported websocket extension: {}", name); continue; } } } } } // log request let host = get_host(req.headers); let scheme = if websocket { if secure { "wss" } else { "ws" } } else { if secure { "https" } else { "http" } }; debug!( "server-conn {}: request: {} {}://{}{}", id, req.method, scheme, host, req.uri ); let ws_config: Option<( ArrayString, Option<(websocket::PerMessageDeflateConfig, usize)>, )> = if websocket { let accept = match validate_ws_request(&req, ws_version, ws_key) { Ok(s) => s, Err(_) => return Err(Error::InvalidWebSocketRequest), }; Some((accept, ws_deflate_config)) } else { None }; let ids = [zhttppacket::Id { id: id.as_bytes(), seq: Some(shared.out_seq()), }]; let (mode, more) = if websocket { (Mode::WebSocket, false) } else { let more = match req.body_size { http1::BodySize::NoBody => false, http1::BodySize::Known(x) => x > 0, http1::BodySize::Unknown => true, }; (Mode::HttpStream, more) }; let credits = if let Some((_, Some((_, recv_buf_size)))) = &ws_config { *recv_buf_size } else { recv_buf_size }; let msg = make_zhttp_request( instance_id, &ids, req.method, req.uri, req.headers, b"", more, mode, credits as u32, peer_addr, secure, &mut packet_buf.borrow_mut(), )?; shared.inc_out_seq(); (req.body_size, ws_config, msg) }; // send request message // ABR: discard_while discard_while(zreceiver, pin!(send_msg(zsender, msg))).await?; let mut zsess_in = ZhttpStreamSessionIn::new( id, send_buf_size, ws_config.is_some(), zreceiver, shared, refresh_session_timeout, ); // receive any message, in order to get a handler address // ABR: direct read zsess_in.peek_msg().await?; let mut handler = if body_size != http1::BodySize::NoBody { // receive request body and send to handler // ABR: function contains read stream_recv_body( tmp_buf, refresh_stream_timeout, handler, &mut zsess_in, &zsess_out, ) .await? } else { handler.recv_done()? }; // receive response message let zresp = loop { // ABR: select contains read let ret = select_2(pin!(zsess_in.recv_msg()), pin!(handler.fill_recv_buffer())).await; match ret { Select2::R1(ret) => { let zresp = ret?; match zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(_) | zhttppacket::ResponsePacket::Error(_) => break zresp, _ => { // ABR: handle_other handle_other(zresp, &mut zsess_in, &zsess_out).await?; } } } Select2::R2(e) => return Err(e), } }; // determine how to respond let (handler, ws_config) = { let rdata = match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => rdata, zhttppacket::ResponsePacket::Error(edata) => { if ws_config.is_some() && edata.condition == "rejected" { // send websocket rejection let rdata = edata.rejected_info.as_ref().unwrap(); let handler = { let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { continue; } if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } let headers = &headers[..headers_len]; handler.prepare_response( rdata.code, rdata.reason, headers, http1::BodySize::Known(rdata.body.len()), )? }; // ABR: discard_while discard_while(zreceiver, pin!(handler.send_header())).await?; let handler = handler.send_header_done(); handler.append_body(rdata.body, false)?; drop(zresp); loop { // ABR: discard_while let (_, done) = discard_while(zreceiver, pin!(handler.flush_body())).await?; if done { break; } } return Ok(false); } else { // ABR: handle_other return Err(handle_other(zresp, &mut zsess_in, &zsess_out) .await .unwrap_err()); } } _ => unreachable!(), // we confirmed the type above }; // send response header let handler = { let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; let mut body_size = http1::BodySize::Unknown; for h in rdata.headers.iter() { if ws_config.is_some() { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { continue; } } else { if h.name.eq_ignore_ascii_case("Content-Length") { let s = str::from_utf8(h.value)?; let clen: usize = match s.parse() { Ok(clen) => clen, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidInput).into()) } }; body_size = http1::BodySize::Known(clen); } } if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } if body_size == http1::BodySize::Unknown && !rdata.more { body_size = http1::BodySize::Known(rdata.body.len()); } let mut ws_ext = ArrayVec::::new(); if let Some(ws_config) = &ws_config { let accept_data = &ws_config.0; if headers_len + 4 > headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: "Upgrade", value: b"websocket", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Connection", value: b"Upgrade", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Sec-WebSocket-Accept", value: accept_data.as_bytes(), }; headers_len += 1; if let Some((config, _)) = &ws_config.1 { if write_ws_ext_header_value(config, &mut ws_ext).is_err() { return Err(Error::CompressionError); } headers[headers_len] = http1::Header { name: "Sec-WebSocket-Extensions", value: ws_ext.as_ref(), }; headers_len += 1; } } let headers = &headers[..headers_len]; handler.prepare_response(rdata.code, rdata.reason, headers, body_size)? }; handler.append_body(rdata.body, rdata.more, id)?; drop(zresp); { let mut send_header = pin!(handler.send_header()); loop { // ABR: select contains read let ret = select_2(send_header.as_mut(), pin!(zsess_in.recv_msg())).await; match ret { Select2::R1(ret) => { ret?; break; } Select2::R2(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => { handler.append_body(rdata.body, rdata.more, id)?; } _ => { // ABR: handle_other handle_other(zresp, &mut zsess_in, &zsess_out).await?; } } } } } } let handler = handler.send_header_done(); refresh_stream_timeout(); let ws_config = if let Some((_, ws_deflate_config)) = ws_config { Some(ws_deflate_config) } else { None }; (handler, ws_config) }; if let Some(deflate_config) = ws_config { // reduce size of future #[allow(clippy::drop_non_drop)] drop(handler); // handle as websocket connection // ABR: function contains read stream_websocket( id, stream, buf1, buf2, messages_max, tmp_buf, refresh_stream_timeout, deflate_config, &mut zsess_in, &zsess_out, ) .await?; Ok(false) } else { // send response body // ABR: function contains read stream_send_body(refresh_stream_timeout, &handler, &mut zsess_in, &zsess_out).await?; let persistent = handler.finish(); Ok(persistent) } } #[allow(clippy::too_many_arguments)] async fn server_stream_connection_inner( token: CancellationToken, cid: &mut ArrayString<32>, cid_provider: &mut P, mut stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, stream_timeout_duration: Duration, allow_compression: bool, instance_id: &str, zsender: AsyncLocalSender, zsender_stream: AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, shared: arena::Rc, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let mut buf1 = RingBuffer::new(buffer_size, rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, rb_tmp); loop { stream.set_id(cid); // this was originally logged when starting the non-async state // machine, so we'll keep doing that debug!("server-conn {}: assigning id", cid); let reuse = { let stream_timeout = Timeout::new(reactor.now() + stream_timeout_duration); let session_timeout = Timeout::new(reactor.now() + ZHTTP_SESSION_TIMEOUT); let refresh_stream_timeout = || { stream_timeout.set_deadline(reactor.now() + stream_timeout_duration); }; let refresh_session_timeout = || { session_timeout.set_deadline(reactor.now() + ZHTTP_SESSION_TIMEOUT); }; let handler = pin!(server_stream_handler( cid.as_ref(), &mut stream, peer_addr, secure, &mut buf1, &mut buf2, messages_max, allow_compression, &packet_buf, &tmp_buf, instance_id, &zsender, &zsender_stream, zreceiver, shared.get(), &refresh_stream_timeout, &refresh_session_timeout, )); let ret = match select_4( handler, stream_timeout.elapsed(), session_timeout.elapsed(), token.cancelled(), ) .await { Select4::R1(ret) => ret, Select4::R2(_) => Err(Error::StreamTimeout), Select4::R3(_) => return Err(Error::SessionTimeout), Select4::R4(_) => return Err(Error::Stopped), }; match ret { Ok(reuse) => reuse, Err(e) => { let handler_caused = matches!( &e, Error::BadMessage | Error::HandlerError | Error::HandlerCancel ); if !handler_caused { let shared = shared.get(); let msg = if let Some(addr) = shared.to_addr().get() { let id = cid.as_ref(); let mut zreq = zhttppacket::Request::new_cancel(b"", &[]); let ids = [zhttppacket::Id { id: id.as_bytes(), seq: Some(shared.out_seq()), }]; zreq.from = instance_id.as_bytes(); zreq.ids = &ids; zreq.multi = true; let packet_buf = &mut *packet_buf.borrow_mut(); let size = zreq.serialize(packet_buf)?; let msg = zmq::Message::from(&packet_buf[..size]); let addr = match ArrayVec::try_from(addr) { Ok(v) => v, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidInput).into()) } }; Some((addr, msg)) } else { None }; if let Some((addr, msg)) = msg { // best effort let _ = zsender_stream.try_send((addr, msg)); shared.inc_out_seq(); } } return Err(e); } } }; if !reuse { break; } // note: buf1 is not cleared as there may be data to read buf2.clear(); shared.get().reset(); *cid = cid_provider.get_new_assigned_cid(); } // ABR: discard_while discard_while(zreceiver, pin!(async { Ok(stream.close().await?) })).await?; Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn server_stream_connection( token: CancellationToken, mut cid: ArrayString<32>, cid_provider: &mut P, stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, timeout: Duration, allow_compression: bool, instance_id: &str, zsender: AsyncLocalSender, zsender_stream: AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, shared: arena::Rc, ) { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( server_stream_connection_inner( token, &mut cid, cid_provider, stream, peer_addr, secure, buffer_size, messages_max, rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, instance_id, zsender, zsender_stream, &zreceiver, shared, ), &value_active, ) .await { Ok(()) => debug!("server-conn {}: finished", cid), Err(e) => { let level = match e { Error::ValueActive => Level::Error, _ => Level::Debug, }; log!(level, "server-conn {}: process error: {:?}", cid, e); } } } struct AsyncOperation where C: FnMut(), { op_fn: O, cancel_fn: C, } impl AsyncOperation where O: FnMut(&mut Context) -> Option, C: FnMut(), { fn new(op_fn: O, cancel_fn: C) -> Self { Self { op_fn, cancel_fn } } } impl Future for AsyncOperation where O: FnMut(&mut Context) -> Option + Unpin, C: FnMut() + Unpin, { type Output = R; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let s = Pin::into_inner(self); match (s.op_fn)(cx) { Some(ret) => Poll::Ready(ret), None => Poll::Pending, } } } impl Drop for AsyncOperation where C: FnMut(), { fn drop(&mut self) { (self.cancel_fn)(); } } pub enum SendStatus { Complete(T), EarlyResponse(T), Partial(P, usize), Error(P, E), } pub enum RecvStatus { Read(T, usize), Complete(C, usize), } struct ClientRequest<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: WriteHalf<'a, W>, buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, } impl<'a, R: AsyncRead, W: AsyncWrite> ClientRequest<'a, R, W> { fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, ) -> Self { Self { r: stream.0, w: stream.1, buf1, buf2, } } #[allow(clippy::too_many_arguments)] fn prepare_header( self, method: &str, uri: &str, headers: &[http1::Header<'_>], body_size: http1::BodySize, websocket: bool, initial_body: &[u8], end: bool, ) -> Result, Error> { let req = http1::ClientRequest::new(); let req_body = match req.send_header(self.buf1, method, uri, headers, body_size, websocket) { Ok(ret) => ret, Err(_) => return Err(Error::BufferExceeded), }; if self.buf2.write_all(initial_body).is_err() { return Err(Error::BufferExceeded); } Ok(ClientRequestHeader { r: self.r, w: self.w, buf1: self.buf1, buf2: self.buf2, req_body, end, }) } } struct ClientRequestHeader<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: WriteHalf<'a, W>, buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, req_body: http1::ClientRequestBody, end: bool, } impl<'a, R: AsyncRead, W: AsyncWrite> ClientRequestHeader<'a, R, W> { async fn send(mut self) -> Result, Error> { while self.buf1.read_avail() > 0 { let size = self.w.write(BaseRingBuffer::read_buf(self.buf1)).await?; self.buf1.read_commit(size); } Ok(ClientRequestBody { inner: RefCell::new(Some(ClientRequestBodyInner { r: RefCell::new(ClientRequestBodyRead { stream: self.r, buf: self.buf1, }), w: RefCell::new(ClientRequestBodyWrite { stream: self.w, buf: self.buf2, req_body: Some(self.req_body), end: self.end, }), })), }) } } struct ClientRequestBodyRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut RingBuffer, } struct ClientRequestBodyWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, buf: &'a mut RingBuffer, req_body: Option, end: bool, } struct ClientRequestBodyInner<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, } struct ClientRequestBody<'a, R: AsyncRead, W: AsyncWrite> { inner: RefCell>>, } impl<'a, R: AsyncRead, W: AsyncWrite> ClientRequestBody<'a, R, W> { fn prepare(&self, src: &[u8], end: bool) -> Result { if let Some(inner) = &*self.inner.borrow() { let w = &mut *inner.w.borrow_mut(); // call not allowed if the end has already been indicated if w.end { return Err(Error::Io(io::Error::from(io::ErrorKind::InvalidInput))); } let size = w.buf.write(src)?; assert!(size <= src.len()); if size == src.len() && end { w.end = true; } Ok(size) } else { Err(Error::Unusable) } } fn can_send(&self) -> bool { if let Some(inner) = &*self.inner.borrow() { let w = &*inner.w.borrow(); w.buf.read_avail() > 0 || w.end } else { false } } async fn send(&self) -> SendStatus, (), Error> { if self.inner.borrow().is_none() { return SendStatus::Error((), Error::Unusable); } let status = loop { if let Some(inner) = self.take_inner_if_early_response() { let r = inner.r.into_inner(); let w = inner.w.into_inner(); let resp = w.req_body.unwrap().into_early_response(); w.buf.clear(); return SendStatus::EarlyResponse(ClientResponse { r: r.stream, buf1: r.buf, buf2: w.buf, inner: resp, }); } match self.process().await { Some(Ok(status)) => break status, Some(Err(e)) => return SendStatus::Error((), e), None => {} // received data. loop and check for early response } }; let mut inner = self.inner.borrow_mut(); assert!(inner.is_some()); match status { http1::SendStatus::Complete(resp, size) => { let inner = inner.take().unwrap(); let r = inner.r.into_inner(); let w = inner.w.into_inner(); w.buf.read_commit(size); assert_eq!(w.buf.read_avail(), 0); SendStatus::Complete(ClientResponse { r: r.stream, buf1: r.buf, buf2: w.buf, inner: resp, }) } http1::SendStatus::Partial(req_body, size) => { let inner = inner.as_ref().unwrap(); let mut w = inner.w.borrow_mut(); w.req_body = Some(req_body); w.buf.read_commit(size); SendStatus::Partial((), size) } http1::SendStatus::Error(req_body, e) => { let inner = inner.as_ref().unwrap(); inner.w.borrow_mut().req_body = Some(req_body); SendStatus::Error((), e.into()) } } } // assumes self.inner is Some #[allow(clippy::await_holding_refcell_ref)] async fn process( &self, ) -> Option< Result< http1::SendStatus, Error, >, > { let inner = self.inner.borrow(); let inner = inner.as_ref().unwrap(); let mut r = inner.r.borrow_mut(); let result = select_2( AsyncOperation::new( |cx| { let w = &mut *inner.w.borrow_mut(); if !w.stream.is_writable() { return None; } let req_body = w.req_body.take().unwrap(); let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = w.buf.get_ref_vectored(&mut buf_arr); match req_body.send( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs, w.end, None, ) { http1::SendStatus::Error(req_body, http1::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { w.req_body = Some(req_body); None } ret => Some(ret), } }, || inner.w.borrow_mut().stream.cancel(), ), pin!(async { let r = &mut *r; if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever let () = std::future::pending().await; } return Err(Error::from(e)); } Ok(()) }), ) .await; match result { Select2::R1(ret) => match ret { http1::SendStatus::Error(req_body, http1::Error::Io(e)) if e.kind() == io::ErrorKind::BrokenPipe => { // if we get an error when trying to send, it could be // due to the server closing the connection after sending // an early response. here we'll check if the server left // us any data to read let w = &mut *inner.w.borrow_mut(); w.req_body = Some(req_body); if r.buf.read_avail() == 0 { let r = &mut *r; match recv_nonzero(&mut r.stream, r.buf).await { Ok(()) => None, // received data Err(e) => Some(Err(e.into())), // error while receiving data } } else { None // we already received data } } ret => Some(Ok(ret)), }, Select2::R2(ret) => match ret { Ok(()) => None, // received data Err(e) => Some(Err(e)), // error while receiving data }, } } // assumes self.inner is Some fn take_inner_if_early_response(&self) -> Option> { let mut inner = self.inner.borrow_mut(); let inner_mut = inner.as_mut().unwrap(); if inner_mut.r.borrow().buf.read_avail() > 0 { Some(inner.take().unwrap()) } else { None } } } struct ClientResponse<'a, R: AsyncRead> { r: ReadHalf<'a, R>, buf1: &'a mut RingBuffer, buf2: &'a mut RingBuffer, inner: http1::ClientResponse, } impl<'a, R: AsyncRead> ClientResponse<'a, R> { async fn recv_header<'b, const N: usize>( mut self, mut scratch: &'b mut http1::ParseScratch, ) -> Result< ( http1::OwnedResponse<'b, N>, ClientResponseBodyKeepHeader<'a, R>, ), Error, > { let mut resp = self.inner; let (resp, resp_body) = loop { { let hbuf = self.buf1.take_inner(); resp = match resp.recv_header(hbuf, scratch) { http1::ParseStatus::Complete(ret) => break ret, http1::ParseStatus::Incomplete(resp, hbuf, ret_scratch) => { // NOTE: after polonius it may not be necessary for // scratch to be returned scratch = ret_scratch; self.buf1.set_inner(hbuf); resp } http1::ParseStatus::Error(e, hbuf, _) => { self.buf1.set_inner(hbuf); return Err(e.into()); } } } if !self.buf1.is_readable_contiguous() { self.buf1.align(); continue; } if let Err(e) = recv_nonzero(&mut self.r, self.buf1).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } }; // at this point, resp has taken buf1's inner buffer, such that // buf1 has no inner buffer // put remaining readable bytes in buf2 self.buf2.write_all(resp.remaining_bytes())?; // swap inner buffers, such that buf1 now contains the remaining // readable bytes, and buf2 is now the one with no inner buffer self.buf1.swap_inner(self.buf2); Ok(( resp, ClientResponseBodyKeepHeader { inner: ClientResponseBody { inner: RefCell::new(Some(ClientResponseBodyInner { r: self.r, buf1: self.buf1, resp_body, })), }, buf2: RefCell::new(Some(self.buf2)), }, )) } } struct ClientResponseBodyInner<'a, R: AsyncRead> { r: ReadHalf<'a, R>, buf1: &'a mut RingBuffer, resp_body: http1::ClientResponseBody, } struct ClientResponseBody<'a, R: AsyncRead> { inner: RefCell>>, } impl<'a, R: AsyncRead> ClientResponseBody<'a, R> { #[allow(clippy::await_holding_refcell_ref)] async fn add_to_buffer(&self) -> Result<(), Error> { if let Some(inner) = &mut *self.inner.borrow_mut() { if let Err(e) = recv_nonzero(&mut inner.r, inner.buf1).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } Ok(()) } else { Err(Error::Unusable) } } fn try_recv(&self, dest: &mut [u8]) -> Result, Error> { loop { let mut b_inner = self.inner.borrow_mut(); if let Some(inner) = b_inner.take() { let mut scratch = mem::MaybeUninit::<[httparse::Header; HEADERS_MAX]>::uninit(); match inner.resp_body.recv( BaseRingBuffer::read_buf(inner.buf1), dest, &mut scratch, )? { http1::RecvStatus::Complete(finished, read, written) => { inner.buf1.read_commit(read); *b_inner = None; break Ok(RecvStatus::Complete( ClientFinished { inner: finished }, written, )); } http1::RecvStatus::Read(resp_body, read, written) => { *b_inner = Some(ClientResponseBodyInner { r: inner.r, buf1: inner.buf1, resp_body, }); let inner = b_inner.as_mut().unwrap(); if read == 0 && written == 0 { if !inner.buf1.is_readable_contiguous() { inner.buf1.align(); continue; } } inner.buf1.read_commit(read); return Ok(RecvStatus::Read((), written)); } } } else { return Err(Error::Unusable); } } } } struct ClientResponseBodyKeepHeader<'a, R: AsyncRead> { inner: ClientResponseBody<'a, R>, buf2: RefCell>, } impl<'a, R: AsyncRead> ClientResponseBodyKeepHeader<'a, R> { fn discard_header( self, resp: http1::OwnedResponse, ) -> Result, Error> { if let Some(buf2) = self.buf2.borrow_mut().take() { buf2.set_inner(resp.into_buf()); buf2.clear(); Ok(self.inner) } else { Err(Error::Unusable) } } async fn add_to_buffer(&self) -> Result<(), Error> { self.inner.add_to_buffer().await } fn try_recv( &self, dest: &mut [u8], ) -> Result>, Error> { if !self.buf2.borrow().is_some() { return Err(Error::Unusable); } match self.inner.try_recv(dest)? { RecvStatus::Complete(finished, written) => Ok(RecvStatus::Complete( ClientFinishedKeepHeader { inner: finished, buf2: self.buf2.borrow_mut().take().unwrap(), }, written, )), RecvStatus::Read((), written) => Ok(RecvStatus::Read((), written)), } } } struct ClientFinished { inner: http1::ClientFinished, } struct ClientFinishedKeepHeader<'a> { inner: ClientFinished, buf2: &'a mut RingBuffer, } impl<'a> ClientFinishedKeepHeader<'a> { fn discard_header(self, resp: http1::OwnedResponse) -> ClientFinished { self.buf2.set_inner(resp.into_buf()); self.buf2.clear(); self.inner } } enum Stream { Plain(std::net::TcpStream), Tls(TlsStream), } impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::Plain(stream) => stream.read(buf), Self::Tls(stream) => stream.read(buf), } } } enum AsyncStream<'a> { Plain(AsyncTcpStream), Tls(AsyncTlsStream<'a>), } impl<'a> AsyncStream<'a> { fn into_inner(self) -> Stream { match self { Self::Plain(stream) => Stream::Plain(stream.into_std()), Self::Tls(stream) => Stream::Tls(stream.into_std()), } } } #[derive(Clone, Eq, Hash, PartialEq)] struct ConnectionPoolKey { addr: std::net::SocketAddr, tls: bool, host: String, } impl ConnectionPoolKey { fn new(addr: std::net::SocketAddr, tls: bool, host: String) -> Self { Self { addr, tls, host } } } pub struct ConnectionPool { inner: Arc>>, thread: Option>, done: Option>, } impl ConnectionPool { pub fn new(capacity: usize) -> Self { let inner = Arc::new(Mutex::new(Pool::::new(capacity))); let (s, r) = mpsc::sync_channel(1); let thread = { let inner = Arc::clone(&inner); thread::Builder::new() .name("connection-pool".into()) .spawn(move || { while let Err(mpsc::RecvTimeoutError::Timeout) = r.recv_timeout(Duration::from_secs(1)) { let now = Instant::now(); while let Some((key, _)) = inner.lock().unwrap().expire(now) { debug!("closing idle connection to {:?} for {}", key.addr, key.host); } } }) .unwrap() }; Self { inner, thread: Some(thread), done: Some(s), } } #[allow(clippy::result_large_err)] fn push( &self, addr: std::net::SocketAddr, tls: bool, host: String, stream: Stream, ttl: Duration, ) -> Result<(), Stream> { self.inner.lock().unwrap().add( ConnectionPoolKey::new(addr, tls, host), stream, Instant::now() + ttl, ) } fn take(&self, addr: std::net::SocketAddr, tls: bool, host: &str) -> Option { let key = ConnectionPoolKey::new(addr, tls, host.to_string()); // take the first connection that returns WouldBlock when attempting a read. // anything else is considered an error and the connection is discarded while let Some(mut stream) = self.inner.lock().unwrap().take(&key) { match stream.read(&mut [0]) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Some(stream), _ => {} } debug!( "discarding broken connection to {:?} for {}", key.addr, key.host ); } None } } impl Drop for ConnectionPool { fn drop(&mut self) { self.done = None; let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } fn is_allowed(addr: &IpAddr, deny: &[IpNet]) -> bool { for net in deny { if net.contains(addr) { return false; } } true } async fn client_connect<'a>( log_id: &str, rdata: &zhttppacket::RequestData<'_, '_>, uri: &url::Url, resolver: &resolver::Resolver, deny: &[IpNet], pool: &ConnectionPool, tls_waker_data: &'a RefWakerData, ) -> Result<(std::net::SocketAddr, bool, AsyncStream<'a>), Error> { let use_tls = ["https", "wss"].contains(&uri.scheme()); let uri_host = match uri.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let default_port = if use_tls { 443 } else { 80 }; let (connect_host, connect_port) = if !rdata.connect_host.is_empty() { (rdata.connect_host, rdata.connect_port) } else { (uri_host, uri.port().unwrap_or(default_port)) }; let resolver = AsyncResolver::new(resolver); debug!("client-conn {}: resolving: [{}]", log_id, connect_host); let resolver_results = resolver.resolve(connect_host).await?; let mut addrs = ArrayVec::::new(); let mut denied = false; let mut reuse_stream = None; for addr in resolver_results { if !is_allowed(&addr, deny) { denied = true; continue; } let addr = std::net::SocketAddr::new(addr, connect_port); if let Some(stream) = pool.take(addr, use_tls, uri_host) { reuse_stream = Some((addr, stream)); break; } addrs.push(addr); } let (peer_addr, mut stream, is_new) = if let Some((peer_addr, stream)) = reuse_stream { debug!( "client-conn {}: reusing connection to {:?}", log_id, peer_addr, ); let stream = match stream { Stream::Plain(stream) => AsyncStream::Plain(AsyncTcpStream::from_std(stream)), Stream::Tls(stream) => { AsyncStream::Tls(AsyncTlsStream::from_std(stream, tls_waker_data)) } }; (peer_addr, stream, false) } else { if addrs.is_empty() && denied { return Err(Error::PolicyViolation); } debug!("client-conn {}: connecting to one of {:?}", log_id, addrs); let stream = AsyncTcpStream::connect(&addrs).await?; let peer_addr = stream.peer_addr()?; debug!("client-conn {}: connected to {}", log_id, peer_addr); let stream = if use_tls { let host = if rdata.trust_connect_host { connect_host } else { uri_host }; let verify_mode = if rdata.ignore_tls_errors { VerifyMode::None } else { VerifyMode::Full }; let stream = match AsyncTlsStream::connect(host, stream, verify_mode, tls_waker_data) { Ok(stream) => stream, Err(e) => { debug!("client-conn {}: tls connect error: {}", log_id, e); return Err(Error::TlsError); } }; AsyncStream::Tls(stream) } else { AsyncStream::Plain(stream) }; (peer_addr, stream, true) }; if let AsyncStream::Tls(stream) = &mut stream { if stream.inner().set_id(log_id).is_err() { warn!("client-conn {}: log id too long for TlsStream", log_id); return Err(Error::BadRequest); } if is_new { if let Err(e) = stream.ensure_handshake().await { debug!("client-conn {}: tls handshake error: {:?}", log_id, e); return Err(Error::TlsError); } } } Ok((peer_addr, use_tls, stream)) } // return Some if fully valid redirect response, else return None. fn check_redirect( method: &str, base_url: &url::Url, resp: &http1::Response, schemes: &[&str], ) -> Option<(url::Url, bool)> { if resp.code >= 300 && resp.code <= 399 { let mut location = None; for h in resp.headers.iter() { if h.name.eq_ignore_ascii_case("Location") { location = Some(h.value); break; } } // must have location header if let Some(s) = location { // must be UTF-8 if let Ok(s) = str::from_utf8(s) { // must be valid URL if let Ok(url) = base_url.join(s) { // must have an acceptable scheme if schemes.contains(&url.scheme()) { let use_get = resp.code >= 301 && resp.code <= 303 && method == "POST"; // all is well! return Some((url, use_get)); } } } } } None } enum ClientHandlerDone { Complete(T, bool), Redirect(bool, url::Url, bool), // rare alloc } impl ClientHandlerDone { fn is_persistent(&self) -> bool { match self { ClientHandlerDone::Complete(_, persistent) => *persistent, ClientHandlerDone::Redirect(persistent, _, _) => *persistent, } } } // return (_, true) if persistent #[allow(clippy::too_many_arguments)] async fn client_req_handler( log_id: &str, id: Option<&[u8]>, stream: &mut S, zreq: &zhttppacket::Request<'_, '_, '_>, method: &str, url: &url::Url, include_body: bool, follow_redirects: bool, buf1: &mut RingBuffer, buf2: &mut RingBuffer, body_buf: &mut Buffer, packet_buf: &RefCell>, ) -> Result, Error> where S: AsyncRead + AsyncWrite, { let stream = RefCell::new(stream); let req = ClientRequest::new(io_split(&stream), buf1, buf2); let req_header = { let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let host_port = &url[url::Position::BeforeHost..url::Position::AfterPort]; let mut headers = ArrayVec::::new(); headers.push(http1::Header { name: "Host", value: host_port.as_bytes(), }); for h in rdata.headers.iter() { if headers.remaining_capacity() == 0 { return Err(Error::BadRequest); } // host comes from the uri if h.name.eq_ignore_ascii_case("Host") { continue; } headers.push(http1::Header { name: h.name, value: h.value, }); } let path = &url[url::Position::BeforePath..]; let body_size = if include_body { body_buf.write_all(rdata.body)?; http1::BodySize::Known(rdata.body.len()) } else { http1::BodySize::NoBody }; req.prepare_header(method, path, &headers, body_size, false, &[], false)? }; let resp = { // send request header let req_body = req_header.send().await?; // send request body loop { // fill the buffer as much as possible let size = req_body.prepare(Buffer::read_buf(body_buf), true)?; body_buf.read_commit(size); // send the buffer match req_body.send().await { SendStatus::Complete(resp) => break resp, SendStatus::EarlyResponse(resp) => { body_buf.clear(); break resp; } SendStatus::Partial((), _) => {} SendStatus::Error((), e) => return Err(e), } } }; assert_eq!(body_buf.read_avail(), 0); // receive response header let mut scratch = http1::ParseScratch::::new(); let (resp, resp_body) = resp.recv_header(&mut scratch).await?; let (zresp, finished) = { let resp_ref = resp.get(); debug!( "client-conn {}: response: {} {}", log_id, resp_ref.code, resp_ref.reason ); // receive response body let finished = { loop { match resp_body.try_recv(body_buf.write_buf())? { RecvStatus::Complete(finished, written) => { body_buf.write_commit(written); break finished; } RecvStatus::Read((), written) => { body_buf.write_commit(written); if written == 0 { resp_body.add_to_buffer().await?; } } } } }; if follow_redirects { if let Some((url, use_get)) = check_redirect(method, url, &resp_ref, &["http", "https"]) { let finished = finished.discard_header(resp); debug!("client-conn {}: redirecting to {}", log_id, url); return Ok(ClientHandlerDone::Redirect( finished.inner.persistent, url, use_get, )); } } let mut zheaders = ArrayVec::::new(); for h in resp_ref.headers { zheaders.push(zhttppacket::Header { name: h.name, value: h.value, }); } let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, content_type: None, body: Buffer::read_buf(body_buf), }; let zresp = make_zhttp_req_response( id, zhttppacket::ResponsePacket::Data(rdata), &mut packet_buf.borrow_mut(), )?; (zresp, finished) }; let finished = finished.discard_header(resp); Ok(ClientHandlerDone::Complete( zresp, finished.inner.persistent, )) } #[allow(clippy::too_many_arguments)] async fn client_req_connect( log_id: &str, id: Option<&[u8]>, zreq: arena::Rc, buf1: &mut RingBuffer, buf2: &mut RingBuffer, body_buf: &mut Buffer, packet_buf: &RefCell>, deny: &[IpNet], resolver: &resolver::Resolver, pool: &ConnectionPool, ) -> Result { let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let initial_url = match url::Url::parse(rdata.uri) { Ok(url) => url, Err(_) => return Err(Error::BadRequest), }; // must be an http url if !["http", "https"].contains(&initial_url.scheme()) { return Err(Error::BadRequest); } // must have a method if rdata.method.is_empty() { return Err(Error::BadRequest); } debug!( "client-conn {}: request: {} {}", log_id, rdata.method, rdata.uri, ); let deny = if rdata.ignore_policies { &[] } else { deny }; let mut last_redirect: Option<(url::Url, bool)> = None; let mut redirect_count = 0; let zresp = loop { let (method, url, include_body) = match &last_redirect { Some((url, use_get)) => { let (method, include_body) = if *use_get { ("GET", false) } else { (rdata.method, true) }; (method, url, include_body) } None => (rdata.method, &initial_url, true), }; let url_host = match url.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let tls_waker_data = RefWakerData::new(TlsWaker::new()); let (peer_addr, using_tls, mut stream) = client_connect(log_id, rdata, url, resolver, deny, pool, &tls_waker_data).await?; let done = match &mut stream { AsyncStream::Plain(stream) => { client_req_handler( log_id, id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, body_buf, packet_buf, ) .await? } AsyncStream::Tls(stream) => { client_req_handler( log_id, id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, body_buf, packet_buf, ) .await? } }; if done.is_persistent() { if pool .push( peer_addr, using_tls, url_host.to_string(), stream.into_inner(), CONNECTION_POOL_TTL, ) .is_ok() { debug!("client-conn {}: leaving connection intact", log_id); } } match done { ClientHandlerDone::Complete(zresp, _) => break zresp, ClientHandlerDone::Redirect(_, url, mut use_get) => { if redirect_count >= REDIRECTS_MAX { return Err(Error::TooManyRedirects); } redirect_count += 1; if let Some((_, b)) = &last_redirect { use_get = use_get || *b; } last_redirect = Some((url, use_get)); } } }; Ok(zresp) } #[allow(clippy::too_many_arguments)] async fn client_req_connection_inner( token: CancellationToken, log_id: &str, id: Option<&[u8]>, zreq: (MultipartHeader, arena::Rc), buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, deny: &[IpNet], resolver: &resolver::Resolver, pool: &ConnectionPool, zsender: AsyncLocalSender<(MultipartHeader, zmq::Message)>, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let (zheader, zreq) = zreq; let mut buf1 = RingBuffer::new(buffer_size, rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, rb_tmp); let mut body_buf = Buffer::new(body_buffer_size); let handler = client_req_connect( log_id, id, zreq, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, deny, resolver, pool, ); let timeout = Timeout::new(reactor.now() + timeout); let ret = match select_3(pin!(handler), timeout.elapsed(), token.cancelled()).await { Select3::R1(ret) => ret, Select3::R2(_) => Err(Error::StreamTimeout), Select3::R3(_) => return Err(Error::Stopped), }; match ret { Ok(zresp) => zsender.send((zheader, zresp)).await?, Err(e) => { let zresp = make_zhttp_req_response( id, zhttppacket::ResponsePacket::Error(zhttppacket::ResponseErrorData { condition: e.to_condition(), rejected_info: None, }), &mut packet_buf.borrow_mut(), )?; zsender.send((zheader, zresp)).await?; return Err(e); } } Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn client_req_connection( token: CancellationToken, log_id: &str, id: Option<&[u8]>, zreq: (MultipartHeader, arena::Rc), buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, deny: &[IpNet], resolver: &resolver::Resolver, pool: &ConnectionPool, zsender: AsyncLocalSender<(MultipartHeader, zmq::Message)>, ) { match client_req_connection_inner( token, log_id, id, zreq, buffer_size, body_buffer_size, rb_tmp, packet_buf, timeout, deny, resolver, pool, zsender, ) .await { Ok(()) => debug!("client-conn {}: finished", log_id), Err(e) => { let level = match e { Error::ValueActive => Level::Error, _ => Level::Debug, }; log!(level, "client-conn {}: process error: {:?}", log_id, e); } } } // return true if persistent #[allow(clippy::too_many_arguments)] async fn client_stream_handler( log_id: &str, stream: &mut S, zreq: &zhttppacket::Request<'_, '_, '_>, method: &str, url: &url::Url, include_body: bool, mut follow_redirects: bool, buf1: &mut RingBuffer, buf2: &mut RingBuffer, messages_max: usize, allow_compression: bool, tmp_buf: &RefCell>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, response_received: &mut bool, refresh_stream_timeout: &R1, ) -> Result, Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let stream = RefCell::new(stream); let send_buf_size = buf1.capacity(); // for sending to handler let recv_buf_size = buf2.capacity(); // for receiving from handler let req = ClientRequest::new(io_split(&stream), buf1, buf2); let (req_header, ws_key, overflow) = { let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let websocket = ["wss", "ws"].contains(&url.scheme()); let host_port = &url[url::Position::BeforeHost..url::Position::AfterPort]; let ws_key = if websocket { Some(gen_ws_key()) } else { None }; if !websocket && rdata.more { follow_redirects = false; } let mut ws_ext = ArrayVec::::new(); let mut headers = ArrayVec::::new(); headers.push(http1::Header { name: "Host", value: host_port.as_bytes(), }); if let Some(ws_key) = &ws_key { headers.push(http1::Header { name: "Upgrade", value: b"websocket", }); headers.push(http1::Header { name: "Connection", value: b"Upgrade", }); headers.push(http1::Header { name: "Sec-WebSocket-Version", value: b"13", }); headers.push(http1::Header { name: "Sec-WebSocket-Key", value: ws_key.as_bytes(), }); if allow_compression { if write_ws_ext_header_value( &websocket::PerMessageDeflateConfig::default(), &mut ws_ext, ) .is_err() { return Err(Error::CompressionError); } headers.push(http1::Header { name: "Sec-WebSocket-Extensions", value: ws_ext.as_slice(), }); } } let mut body_size = if websocket || !include_body { http1::BodySize::NoBody } else { http1::BodySize::Unknown }; for h in rdata.headers.iter() { // host comes from the uri if h.name.eq_ignore_ascii_case("Host") { continue; } if websocket { // don't send these headers if h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Version") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { continue; } } else { if h.name.eq_ignore_ascii_case("Content-Length") { let s = str::from_utf8(h.value)?; let clen: usize = match s.parse() { Ok(clen) => clen, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidInput).into()), }; body_size = http1::BodySize::Known(clen); } } if headers.remaining_capacity() == 0 { return Err(Error::BadRequest); } headers.push(http1::Header { name: h.name, value: h.value, }); } let method = if websocket { "GET" } else { method }; let path = &url[url::Position::BeforePath..]; if body_size == http1::BodySize::Unknown && !rdata.more { body_size = http1::BodySize::Known(rdata.body.len()); } let mut overflow = None; let req_header = if websocket { req.prepare_header(method, path, &headers, body_size, true, &[], true)? } else { let (initial_body, end) = if include_body { if rdata.body.len() > recv_buf_size { let body = &rdata.body[..recv_buf_size]; let mut remainder = Buffer::new(rdata.body.len() - body.len()); remainder.write_all(&rdata.body[body.len()..])?; debug!( "initial={} overflow={} end={}", body.len(), remainder.read_avail(), !rdata.more ); overflow = Some(Overflow { buf: remainder, end: !rdata.more, }); (body, false) } else { (rdata.body, !rdata.more) } } else { (&[][..], true) }; req.prepare_header(method, path, &headers, body_size, false, initial_body, end)? }; (req_header, ws_key, overflow) }; // send request header let req_body = { let mut send_header = pin!(req_header.send()); loop { // ABR: select contains read let result = select_2(send_header.as_mut(), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } }; refresh_stream_timeout(); // send request body // ABR: function contains read let resp = server_stream_send_body( refresh_stream_timeout, req_body, overflow, recv_buf_size, zsess_in, zsess_out, ) .await?; // receive response header let (resp_body, ws_config) = { let mut scratch = http1::ParseScratch::::new(); let mut recv_header = pin!(resp.recv_header(&mut scratch)); let (resp, resp_body) = loop { // ABR: select contains read let result = select_2(recv_header.as_mut(), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } }; let ws_config = { let resp_ref = resp.get(); debug!( "client-conn {}: response: {} {}", log_id, resp_ref.code, resp_ref.reason ); loop { // ABR: select contains read let result = select_2(pin!(zsess_out.check_send()), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(()) => break, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } if follow_redirects { let schemes = if ws_key.is_some() { ["ws", "wss"] } else { ["http", "https"] }; if let Some((url, use_get)) = check_redirect(method, url, &resp_ref, &schemes) { // eat response body let finished = loop { let ret = { let mut buf = [0; 4_096]; resp_body.try_recv(&mut buf)? }; match ret { RecvStatus::Complete(finished, _) => break finished, RecvStatus::Read((), written) => { if written == 0 { let mut add_to_buffer = pin!(resp_body.add_to_buffer()); loop { // ABR: select contains read let result = select_2( add_to_buffer.as_mut(), pin!(zsess_in.recv_msg()), ) .await; match result { Select2::R1(ret) => { ret?; break; } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out) .await?; } } } } } } }; let finished = finished.discard_header(resp); debug!("client-conn {}: redirecting to {}", log_id, url); return Ok(ClientHandlerDone::Redirect( finished.inner.persistent, url, use_get, )); } } let mut zheaders = ArrayVec::::new(); let mut ws_accept = None; let mut ws_deflate_config = None; for h in resp_ref.headers { if ws_key.is_some() { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") { ws_accept = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { for value in http1::parse_header_value(h.value) { let (name, params) = match value { Ok(v) => v, Err(_) => return Err(Error::InvalidWebSocketResponse), }; match name { "permessage-deflate" => { // we must have offered, and server must // provide one response at most if !allow_compression || ws_deflate_config.is_some() { return Err(Error::InvalidWebSocketResponse); } if let Ok(config) = websocket::PerMessageDeflateConfig::from_params(params) { if config.check_response().is_ok() { // split the original recv buffer memory: // 75% for a new recv buffer, 25% for an encoded buffer let recv_buf_size = recv_buf_size * 3 / 4; ws_deflate_config = Some((config, recv_buf_size)); } } } name => { debug!("ignoring unsupported websocket extension: {}", name); continue; } } } } } zheaders.push(zhttppacket::Header { name: h.name, value: h.value, }); } if let Some(ws_key) = &ws_key { if resp_ref.code == 101 { if validate_ws_response(ws_key.as_bytes(), ws_accept).is_err() { return Err(Error::InvalidWebSocketResponse); } } else { // websocket request rejected // we need to allocate to collect the response body, // since buf1 holds bytes read from the socket, and // resp is using buf2's inner buffer let mut body_buf = Buffer::new(send_buf_size); // receive response body let finished = loop { match resp_body.try_recv(body_buf.write_buf())? { RecvStatus::Complete(finished, written) => { body_buf.write_commit(written); break finished; } RecvStatus::Read((), written) => { body_buf.write_commit(written); if written == 0 { let mut add_to_buffer = pin!(resp_body.add_to_buffer()); loop { // ABR: select contains read let result = select_2( add_to_buffer.as_mut(), pin!(zsess_in.recv_msg()), ) .await; match result { Select2::R1(ret) => { ret?; break; } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out) .await?; } } } } } } }; let edata = zhttppacket::ResponseErrorData { condition: "rejected", rejected_info: Some(zhttppacket::RejectedInfo { code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, body: body_buf.read_buf(), }), }; let zresp = zhttppacket::Response::new_error(b"", &[], edata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; drop(zheaders); let finished = finished.discard_header(resp); return Ok(ClientHandlerDone::Complete((), finished.inner.persistent)); } } let credits = if ws_key.is_some() { // for websockets, provide credits when sending response to handler if let Some((_, recv_buf_size)) = &ws_deflate_config { *recv_buf_size as u32 } else { recv_buf_size as u32 } } else { // for http, it is not necessary to provide credits when responding 0 }; let rdata = zhttppacket::ResponseData { credits, more: ws_key.is_none(), code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, content_type: None, body: b"", }; let zresp = zhttppacket::Response::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; if ws_key.is_some() { Some(ws_deflate_config) } else { None } }; let resp_body = resp_body.discard_header(resp)?; (resp_body, ws_config) }; *response_received = true; if let Some(deflate_config) = ws_config { // handle as websocket connection // ABR: function contains read server_stream_websocket( log_id, stream, buf1, buf2, messages_max, tmp_buf, refresh_stream_timeout, deflate_config, zsess_in, zsess_out, ) .await?; Ok(ClientHandlerDone::Complete((), false)) } else { // receive response body // ABR: function contains read let finished = server_stream_recv_body( tmp_buf, refresh_stream_timeout, resp_body, zsess_in, zsess_out, ) .await?; Ok(ClientHandlerDone::Complete((), finished.inner.persistent)) } } #[allow(clippy::too_many_arguments)] async fn client_stream_connect( log_id: &str, id: &[u8], zreq: arena::Rc, buf1: &mut RingBuffer, buf2: &mut RingBuffer, messages_max: usize, allow_compression: bool, packet_buf: &RefCell>, tmp_buf: &RefCell>, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, pool: &ConnectionPool, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, zsender: &AsyncLocalSender, shared: &StreamSharedData, enable_routing: &E, response_received: &mut bool, refresh_stream_timeout: &R1, refresh_session_timeout: &R2, ) -> Result<(), Error> where E: Fn(), R1: Fn(), R2: Fn(), { let zreq = zreq.get().get(); // assign address so we can send replies let addr: ArrayVec = match ArrayVec::try_from(zreq.from) { Ok(v) => v, Err(_) => return Err(Error::BadRequest), }; shared.set_to_addr(Some(addr)); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let initial_url = match url::Url::parse(rdata.uri) { Ok(url) => url, Err(_) => return Err(Error::BadRequest), }; // must be an http or websocket url if !["http", "https", "ws", "wss"].contains(&initial_url.scheme()) { return Err(Error::BadRequest); } // http requests must have a method if ["http", "https"].contains(&initial_url.scheme()) && rdata.method.is_empty() { return Err(Error::BadRequest); } let method = if !rdata.method.is_empty() { rdata.method } else { "_" }; debug!("client-conn {}: request: {} {}", log_id, method, rdata.uri); let zsess_out = ZhttpServerStreamSessionOut::new(instance_id, id, packet_buf, zsender, shared); // ack request // ABR: discard_while server_discard_while( zreceiver, pin!(async { zsess_out.check_send().await; Ok(()) }), ) .await?; zsess_out.try_send_msg(zhttppacket::Response::new_keep_alive(b"", &[]))?; let mut zsess_in = ZhttpServerStreamSessionIn::new( log_id, id, rdata.credits, zreceiver, shared, refresh_session_timeout, ); // allow receiving subsequent messages enable_routing(); let deny = if rdata.ignore_policies { &[] } else { deny }; let mut last_redirect: Option<(url::Url, bool)> = None; let mut redirect_count = 0; loop { let (method, url, include_body) = match &last_redirect { Some((url, use_get)) => { let (method, include_body) = if *use_get { ("GET", false) } else { (rdata.method, true) }; (method, url, include_body) } None => (rdata.method, &initial_url, true), }; let url_host = match url.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let tls_waker_data = RefWakerData::new(TlsWaker::new()); let (peer_addr, using_tls, mut stream) = { let mut client_connect = pin!(client_connect( log_id, rdata, url, resolver, deny, pool, &tls_waker_data )); loop { // ABR: select contains read let ret = select_2(client_connect.as_mut(), pin!(zsess_in.recv_msg())).await; match ret { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, &mut zsess_in, &zsess_out).await?; } } } }; let done = match &mut stream { AsyncStream::Plain(stream) => { client_stream_handler( log_id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, messages_max, allow_compression, tmp_buf, &mut zsess_in, &zsess_out, response_received, refresh_stream_timeout, ) .await? } AsyncStream::Tls(stream) => { client_stream_handler( log_id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, messages_max, allow_compression, tmp_buf, &mut zsess_in, &zsess_out, response_received, refresh_stream_timeout, ) .await? } }; if done.is_persistent() { if pool .push( peer_addr, using_tls, url_host.to_string(), stream.into_inner(), CONNECTION_POOL_TTL, ) .is_ok() { debug!("client-conn {}: leaving connection intact", log_id); } } match done { ClientHandlerDone::Complete((), _) => break, ClientHandlerDone::Redirect(_, url, mut use_get) => { if redirect_count >= REDIRECTS_MAX { return Err(Error::TooManyRedirects); } redirect_count += 1; if let Some((_, b)) = &last_redirect { use_get = use_get || *b; } last_redirect = Some((url, use_get)); } } } Ok(()) } #[allow(clippy::too_many_arguments)] async fn client_stream_connection_inner( token: CancellationToken, log_id: &str, id: &[u8], zreq: arena::Rc, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, stream_timeout_duration: Duration, allow_compression: bool, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, pool: &ConnectionPool, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, zsender: AsyncLocalSender, shared: arena::Rc, enable_routing: &E, ) -> Result<(), Error> where E: Fn(), { let reactor = Reactor::current().unwrap(); let mut buf1 = RingBuffer::new(buffer_size, rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, rb_tmp); let stream_timeout = Timeout::new(reactor.now() + stream_timeout_duration); let session_timeout = Timeout::new(reactor.now() + ZHTTP_SESSION_TIMEOUT); let refresh_stream_timeout = || { stream_timeout.set_deadline(reactor.now() + stream_timeout_duration); }; let refresh_session_timeout = || { session_timeout.set_deadline(reactor.now() + ZHTTP_SESSION_TIMEOUT); }; let mut response_received = false; let ret = { let handler = pin!(client_stream_connect( log_id, id, zreq, &mut buf1, &mut buf2, messages_max, allow_compression, &packet_buf, &tmp_buf, deny, instance_id, resolver, pool, zreceiver, &zsender, shared.get(), enable_routing, &mut response_received, &refresh_stream_timeout, &refresh_session_timeout, )); match select_4( handler, stream_timeout.elapsed(), session_timeout.elapsed(), token.cancelled(), ) .await { Select4::R1(ret) => ret, Select4::R2(_) => Err(Error::StreamTimeout), Select4::R3(_) => return Err(Error::SessionTimeout), Select4::R4(_) => return Err(Error::Stopped), } }; match ret { Ok(()) => {} Err(e) => { let handler_caused = matches!( &e, Error::BadMessage | Error::HandlerError | Error::HandlerCancel ); if !handler_caused { let shared = shared.get(); let msg = if let Some(addr) = shared.to_addr().get() { let mut zresp = if response_received { zhttppacket::Response::new_cancel(b"", &[]) } else { zhttppacket::Response::new_error( b"", &[], zhttppacket::ResponseErrorData { condition: e.to_condition(), rejected_info: None, }, ) }; let ids = [zhttppacket::Id { id, seq: Some(shared.out_seq()), }]; zresp.from = instance_id.as_bytes(); zresp.ids = &ids; zresp.multi = true; let packet_buf = &mut *packet_buf.borrow_mut(); let msg = make_zhttp_response(addr, zresp, packet_buf)?; Some(msg) } else { None }; if let Some(msg) = msg { // best effort let _ = zsender.try_send(msg); shared.inc_out_seq(); } } return Err(e); } } Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn client_stream_connection( token: CancellationToken, log_id: &str, id: &[u8], zreq: arena::Rc, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, timeout: Duration, allow_compression: bool, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, pool: &ConnectionPool, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, zsender: AsyncLocalSender, shared: arena::Rc, enable_routing: &E, ) where E: Fn(), { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( client_stream_connection_inner( token, log_id, id, zreq, buffer_size, messages_max, rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, deny, instance_id, resolver, pool, &zreceiver, zsender, shared, enable_routing, ), &value_active, ) .await { Ok(()) => debug!("client-conn {}: finished", log_id), Err(e) => { let level = match e { Error::ValueActive => Level::Error, _ => Level::Debug, }; log!(level, "client-conn {}: process error: {:?}", log_id, e); } } } pub mod testutil { use super::*; use crate::buffer::TmpBuffer; use crate::channel; use crate::waker; use std::fmt; use std::future::Future; use std::io::Read; use std::rc::Rc; use std::sync::Arc; use std::task::{Context, Poll, Waker}; use std::time::Instant; pub struct NoopWaker {} #[allow(clippy::new_without_default)] impl NoopWaker { pub fn new() -> Self { Self {} } pub fn into_std(self: Rc) -> Waker { waker::into_std(self) } } impl waker::RcWake for NoopWaker { fn wake(self: Rc) {} } pub struct StepExecutor<'a, F> { reactor: &'a Reactor, fut: Pin>, } impl<'a, F> StepExecutor<'a, F> where F: Future, { pub fn new(reactor: &'a Reactor, fut: F) -> Self { Self { reactor, fut: Box::pin(fut), } } pub fn step(&mut self) -> Poll { self.reactor.poll_nonblocking(self.reactor.now()).unwrap(); let waker = Rc::new(NoopWaker::new()).into_std(); let mut cx = Context::from_waker(&waker); self.fut.as_mut().poll(&mut cx) } pub fn advance_time(&mut self, now: Instant) { self.reactor.poll_nonblocking(now).unwrap(); } } #[track_caller] pub fn check_poll(p: Poll>) -> Option where E: fmt::Debug, { match p { Poll::Ready(v) => match v { Ok(t) => Some(t), Err(e) => panic!("check_poll error: {:?}", e), }, Poll::Pending => None, } } pub struct FakeSock { inbuf: Vec, outbuf: Vec, out_allow: usize, } #[allow(clippy::new_without_default)] impl FakeSock { pub fn new() -> Self { Self { inbuf: Vec::with_capacity(16384), outbuf: Vec::with_capacity(16384), out_allow: 0, } } pub fn add_readable(&mut self, buf: &[u8]) { self.inbuf.extend_from_slice(buf); } pub fn take_writable(&mut self) -> Vec { mem::take(&mut self.outbuf) } pub fn allow_write(&mut self, size: usize) { self.out_allow += size; } } impl Read for FakeSock { fn read(&mut self, buf: &mut [u8]) -> Result { if self.inbuf.is_empty() { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.inbuf.len()); buf[..size].copy_from_slice(&self.inbuf[..size]); let mut rest = self.inbuf.split_off(size); mem::swap(&mut self.inbuf, &mut rest); Ok(size) } } impl Write for FakeSock { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.out_allow == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf); self.out_allow -= size; Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { if self.out_allow == 0 { break; } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf.as_ref()); self.out_allow -= size; total += buf.len(); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct AsyncFakeSock { pub inner: Rc>, } impl AsyncFakeSock { pub fn new(sock: Rc>) -> Self { Self { inner: sock } } } impl AsyncRead for AsyncFakeSock { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context, buf: &mut [u8], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.read(buf) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn cancel(&mut self) {} } impl AsyncWrite for AsyncFakeSock { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.write(buf) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn poll_write_vectored( self: Pin<&mut Self>, _cx: &mut Context, bufs: &[io::IoSlice], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.write_vectored(bufs) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { true } fn cancel(&mut self) {} } impl Identify for AsyncFakeSock { fn set_id(&mut self, _id: &str) { // do nothing } } pub struct SimpleCidProvider { pub cid: ArrayString<32>, } impl CidProvider for SimpleCidProvider { fn get_new_assigned_cid(&mut self) -> ArrayString<32> { self.cid } } #[allow(clippy::too_many_arguments)] async fn server_req_handler_fut( sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, packet_buf: Rc>>, buf1: &mut RingBuffer, buf2: &mut RingBuffer, body_buf: &mut Buffer, ) -> Result { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); server_req_handler( "1", &mut sock, None, secure, buf1, buf2, body_buf, &packet_buf, &s_from_conn, &r_to_conn, ) .await } pub struct BenchServerReqHandlerArgs { sock: Rc>, buf1: RingBuffer, buf2: RingBuffer, body_buf: Buffer, } pub struct BenchServerReqHandler { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerReqHandler { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), } } pub fn init(&self) -> BenchServerReqHandlerArgs { let buffer_size = 1024; BenchServerReqHandlerArgs { sock: Rc::new(RefCell::new(FakeSock::new())), buf1: RingBuffer::new(buffer_size, &self.rb_tmp), buf2: RingBuffer::new(buffer_size, &self.rb_tmp), body_buf: Buffer::new(buffer_size), } } pub fn run(&self, args: &mut BenchServerReqHandlerArgs) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let packet_buf = &self.packet_buf; let sock = &args.sock; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = args.sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_handler_fut( sock, false, s_from_conn, r_to_conn, packet_buf.clone(), &mut args.buf1, &mut args.buf2, &mut args.body_buf, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(false)); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } async fn server_req_connection_inner_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, rb_tmp: Rc, packet_buf: Rc>>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let timeout = Duration::from_millis(5_000); server_req_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, buffer_size, &rb_tmp, packet_buf, timeout, s_from_conn, &r_to_conn, ) .await } pub struct BenchServerReqConnection { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerReqConnection { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), } } pub fn init(&self) -> Rc> { Rc::new(RefCell::new(FakeSock::new())) } pub fn run(&self, sock: &Rc>) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let rb_tmp = &self.rb_tmp; let packet_buf = &self.packet_buf; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_connection_inner_fut( token, sock, false, s_from_conn, r_to_conn, rb_tmp.clone(), packet_buf.clone(), ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } #[allow(clippy::too_many_arguments)] async fn server_stream_handler_fut( sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, packet_buf: Rc>>, tmp_buf: Rc>>, buf1: &mut RingBuffer, buf2: &mut RingBuffer, shared: arena::Rc, ) -> Result { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); server_stream_handler( "1", &mut sock, None, secure, buf1, buf2, 10, false, &packet_buf, &tmp_buf, "test", &s_from_conn, &s_stream_from_conn, &r_to_conn, shared.get(), &|| {}, &|| {}, ) .await } pub struct BenchServerStreamHandlerArgs { sock: Rc>, buf1: RingBuffer, buf2: RingBuffer, } pub struct BenchServerStreamHandler { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, shared_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerStreamHandler { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), shared_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), tmp_buf: Rc::new(RefCell::new(vec![0; 1024])), } } pub fn init(&self) -> BenchServerStreamHandlerArgs { let buffer_size = 1024; BenchServerStreamHandlerArgs { sock: Rc::new(RefCell::new(FakeSock::new())), buf1: RingBuffer::new(buffer_size, &self.rb_tmp), buf2: RingBuffer::new(buffer_size, &self.rb_tmp), } } pub fn run(&self, args: &mut BenchServerStreamHandlerArgs) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let shared_mem = &self.shared_mem; let packet_buf = &self.packet_buf; let tmp_buf = &self.tmp_buf; let sock = &args.sock; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = args.sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), shared_mem).unwrap(); server_stream_handler_fut( sock, false, s_from_conn, s_stream_from_conn, r_to_conn, packet_buf.clone(), tmp_buf.clone(), &mut args.buf1, &mut args.buf2, shared, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(true)); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } #[allow(clippy::too_many_arguments)] async fn server_stream_connection_inner_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, shared: arena::Rc, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); let buffer_size = 1024; let timeout = Duration::from_millis(5_000); server_stream_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, 10, &rb_tmp, packet_buf, tmp_buf, timeout, false, "test", s_from_conn, s_stream_from_conn, &r_to_conn, shared, ) .await } pub struct BenchServerStreamConnection { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, shared_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerStreamConnection { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), shared_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), tmp_buf: Rc::new(RefCell::new(vec![0; 1024])), } } pub fn init(&self) -> Rc> { Rc::new(RefCell::new(FakeSock::new())) } pub fn run(&self, sock: &Rc>) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let shared_mem = &self.shared_mem; let rb_tmp = &self.rb_tmp; let packet_buf = &self.packet_buf; let tmp_buf = &self.tmp_buf; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), shared_mem).unwrap(); server_stream_connection_inner_fut( token, sock, false, s_from_conn, s_stream_from_conn, r_to_conn, rb_tmp.clone(), packet_buf.clone(), tmp_buf.clone(), shared, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } } #[cfg(test)] mod tests { use super::testutil::*; use super::*; use crate::buffer::TmpBuffer; use crate::channel; use crate::websocket::Decoder; use std::rc::Rc; use std::sync::Arc; use std::task::Poll; use std::time::Instant; #[test] fn ws_ext_header() { let config = websocket::PerMessageDeflateConfig::default(); let mut dest = ArrayVec::::new(); write_ws_ext_header_value(&config, &mut dest).unwrap(); let expected = "permessage-deflate"; assert_eq!(str::from_utf8(&dest).unwrap(), expected); let mut config = websocket::PerMessageDeflateConfig::default(); config.client_no_context_takeover = true; let mut dest = ArrayVec::::new(); write_ws_ext_header_value(&config, &mut dest).unwrap(); let expected = "permessage-deflate; client_no_context_takeover"; assert_eq!(str::from_utf8(&dest).unwrap(), expected); } #[test] fn message_tracker() { let mut t = MessageTracker::new(2); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), None); t.start(websocket::OPCODE_TEXT).unwrap(); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, false))); t.extend(5); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 5, false))); t.consumed(2, false); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, false))); t.done(); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, true))); t.start(websocket::OPCODE_TEXT).unwrap(); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, true))); t.consumed(3, false); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, true))); t.consumed(0, true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, false))); t.done(); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, true))); t.consumed(0, true); assert_eq!(t.current(), None); for _ in 0..t.items.capacity() { t.start(websocket::OPCODE_TEXT).unwrap(); t.done(); } let r = t.start(websocket::OPCODE_TEXT); assert!(r.is_err()); } #[test] fn early_body() { let reactor = Reactor::new(100); let sock = Rc::new(RefCell::new(FakeSock::new())); sock.borrow_mut().allow_write(1024); let sock = RefCell::new(AsyncFakeSock::new(sock)); let rb_tmp = Rc::new(TmpBuffer::new(12)); let mut buf1 = RingBuffer::new(12, &rb_tmp); let mut buf2 = RingBuffer::new(12, &rb_tmp); buf2.write(b"foo").unwrap(); let handler = RequestSendHeader::new( io_split(&sock), &mut buf1, &mut buf2, http1::ServerProtocol::new(), 3, ); assert_eq!(handler.early_body.borrow().overflow.is_none(), true); handler.append_body(b"hello", false, "").unwrap(); assert_eq!(handler.early_body.borrow().overflow.is_none(), true); handler.append_body(b" world", false, "").unwrap(); assert_eq!(handler.early_body.borrow().overflow.is_some(), true); handler.append_body(b"!", false, "").unwrap(); handler.append_body(b"!", false, "").unwrap_err(); { let mut executor = StepExecutor::new(&reactor, handler.send_header()); assert_eq!(check_poll(executor.step()), Some(())); } assert_eq!(handler.early_body.borrow().overflow.is_none(), true); let handler = handler.send_header_done(); let header = sock.borrow_mut().inner.borrow_mut().take_writable(); assert_eq!(header, b"foo"); let w = handler.w.borrow(); let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = w.buf.get_ref_vectored(&mut buf_arr); assert_eq!(bufs[0], b"hello wor"); assert_eq!(bufs[1], b"ld!"); } async fn server_req_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let packet_buf = Rc::new(RefCell::new(vec![0; 2048])); let timeout = Duration::from_millis(5_000); server_req_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, buffer_size, &rb_tmp, packet_buf, timeout, s_from_conn, &r_to_conn, ) .await } #[test] fn server_req_without_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T148:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,23:http://example.com/path,7:headers,52:22:4:Host,11:exa", "mple.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_with_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "Connection: close\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T191:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,4:POST,3:u", "ri,23:http://example.com/path,7:headers,78:22:4:Host,11:ex", "ample.com,]22:14:Content-Length,1:6,]22:10:Connection,5:cl", "ose,]]4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_timeout() { let now = Instant::now(); let reactor = Reactor::new_with_time(100, now); let sock = Rc::new(RefCell::new(FakeSock::new())); let (_s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, _r_from_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); executor.advance_time(now + Duration::from_millis(5_000)); match executor.step() { Poll::Ready(Err(Error::StreamTimeout)) => {} _ => panic!("unexpected state"), } } #[test] fn server_req_pipeline() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path1 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", "GET /path2 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path1,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path2,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_secure() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, true, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T149:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:https://example.com/path,7:headers,52:22:4:Host,11:ex", "ample.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } async fn server_stream_fut( token: CancellationToken, sock: Rc>, secure: bool, allow_compression: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let packet_buf = Rc::new(RefCell::new(vec![0; 2048])); let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); server_stream_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, 10, &rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, "test", s_from_conn, s_stream_from_conn, &r_to_conn, shared, ) .await } #[test] fn server_stream_without_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T179:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_with_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T220:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!("T69:7:credits,4:1024#3:seq,1:0#2:id,1:1,4:from,7:handler,4:type,6:credit,}",); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T74:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:1#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_chunked() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T179:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T125:4:more,4:true!2:id,1:1,6:reason,2:OK,7:headers,34:30:", "12:Content-Type,10:text/plain,]]3:seq,1:0#4:from,7:handler", ",4:code,3:200#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let msg = concat!("T52:3:seq,1:1#2:id,1:1,4:from,7:handler,4:body,6:hello\n,}"); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\n", "hello\n", "\r\n", "0\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_early_response() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T220:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T150:2:id,1:1,6:reason,11:Bad Request,7:headers,34:30:12:C", "ontent-Type,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:c", "ode,3:400#4:body,18:stopping this now\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 400 Bad Request\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 18\r\n", "\r\n", "stopping this now\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_websocket() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T255:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,21:ws://example.com/path,7:heade", "rs,119:22:4:Host,11:example.com,]22:7:Upgrade,9:websocket,", "]30:21:Sec-WebSocket-Version,2:13,]29:17:Sec-WebSocket-Key", ",5:abcde,]]7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T98:2:id,1:1,6:reason,19:Switching Protocols,3:seq,1:0#4:f", "rom,7:handler,4:code,3:101#7:credits,4:1024#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = b"hello"; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T96:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = &data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; assert_eq!(str::from_utf8(content).unwrap(), "world"); } #[test] fn server_websocket_with_deflate() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_stream_fut( token, sock, false, true, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T308:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,21:ws://example.com/path,7:heade", "rs,173:22:4:Host,11:example.com,]22:7:Upgrade,9:websocket,", "]30:21:Sec-WebSocket-Version,2:13,]29:17:Sec-WebSocket-Key", ",5:abcde,]50:24:Sec-WebSocket-Extensions,18:permessage-def", "late,]]7:credits,3:768#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T98:2:id,1:1,6:reason,19:Switching Protocols,3:seq,1:0#4:f", "rom,7:handler,4:code,3:101#7:credits,4:1024#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = { let src = b"hello"; let mut enc = websocket::DeflateEncoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = enc.encode(src, true, &mut dest).unwrap(); assert_eq!(read, 5); assert_eq!(output_end, true); dest.truncate(written); dest }; let size = websocket::write_header( true, true, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(&body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T96:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = { let src = &data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; let mut dec = websocket::DeflateDecoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = dec.decode(src, true, &mut dest).unwrap(); assert_eq!(read, src.len()); assert_eq!(output_end, true); dest.truncate(written); dest }; assert_eq!(str::from_utf8(&content).unwrap(), "world"); } async fn client_req_fut( id: Option>, zreq: arena::Rc, sock: Rc>, s_from_conn: channel::LocalSender, ) -> Result<(), Error> { let mut sock = AsyncFakeSock::new(sock); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut buf1 = RingBuffer::new(buffer_size, &rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, &rb_tmp); let mut body_buf = Buffer::new(buffer_size); let packet_buf = RefCell::new(vec![0; 2048]); let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => rdata, _ => panic!("unexpected init packet"), }; let url = url::Url::parse(rdata.uri).unwrap(); let msg = match client_req_handler( "test", id.as_deref(), &mut sock, zreq, rdata.method, &url, true, false, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, ) .await? { ClientHandlerDone::Complete(r, _) => r, ClientHandlerDone::Redirect(_, _, _) => panic!("unexpected redirect"), }; s_from_conn.send(msg).await?; Ok(()) } #[test] fn client_req_without_id() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_mem = Rc::new(arena::RcMemory::new(1)); let data = concat!( "T74:7:headers,16:12:3:Foo,3:Bar,]]3:uri,19:https://example.co", "m,6:method,3:GET,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); client_req_fut(None, zreq, sock, s_from_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "GET / HTTP/1.1\r\n", "Host: example.com\r\n", "Foo: Bar\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); // now handler will be able to send a message and finish assert_eq!(check_poll(executor.step()), Some(())); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T117:4:code,3:200#6:reason,2:OK,7:headers,60:30:12:Content", "-Type,10:text/plain,]22:14:Content-Length,1:6,]]4:body,6:h", "ello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_req_with_id() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_mem = Rc::new(arena::RcMemory::new(1)); let data = concat!( "T83:7:headers,16:12:3:Foo,3:Bar,]]3:uri,19:https://example.co", "m,6:method,3:GET,2:id,1:1,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); client_req_fut(Some(b"1".to_vec()), zreq, sock, s_from_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "GET / HTTP/1.1\r\n", "Host: example.com\r\n", "Foo: Bar\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); // now handler will be able to send a message and finish assert_eq!(check_poll(executor.step()), Some(())); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T126:2:id,1:1,4:code,3:200#6:reason,2:OK,7:headers,60:30:1", "2:Content-Type,10:text/plain,]22:14:Content-Length,1:6,]]4", ":body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } async fn client_stream_fut( id: Vec, zreq: arena::Rc, sock: Rc>, allow_compression: bool, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, s_from_conn: channel::LocalSender, shared: arena::Rc, ) -> Result<(), Error> { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut buf1 = RingBuffer::new(buffer_size, &rb_tmp); let mut buf2 = RingBuffer::new(buffer_size, &rb_tmp); let packet_buf = RefCell::new(vec![0; 2048]); let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let mut response_received = false; let refresh_stream_timeout = || {}; let refresh_session_timeout = || {}; let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => rdata, _ => panic!("unexpected init packet"), }; let url = url::Url::parse(rdata.uri).unwrap(); let log_id = "test"; let instance_id = "test"; let zsess_out = ZhttpServerStreamSessionOut::new( instance_id, &id, &packet_buf, &s_from_conn, shared.get(), ); zsess_out.check_send().await; zsess_out.try_send_msg(zhttppacket::Response::new_keep_alive(b"", &[]))?; let mut zsess_in = ZhttpServerStreamSessionIn::new( log_id, &id, rdata.credits, &r_to_conn, shared.get(), &refresh_session_timeout, ); let _persistent = client_stream_handler( "test", &mut sock, zreq, rdata.method, &url, true, false, &mut buf1, &mut buf2, 10, allow_compression, &tmp_buf, &mut zsess_in, &zsess_out, &mut response_received, &refresh_stream_timeout, ) .await?; Ok(()) } #[test] fn client_stream() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T165:7:credits,4:1024#4:more,4:true!7:headers,34:30:12:Conten", "t-Type,10:text/plain,]]3:uri,24:https://example.com/path,6:me", "thod,4:POST,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:mu", "lti,4:true!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!("T52:3:seq,1:1#2:id,1:1,4:from,7:handler,4:body,6:hello\n,}"); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let expected = concat!("6\r\nhello\n\r\n0\r\n\r\n",); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), None); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T173:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:m", "ulti,4:true!}4:code,3:200#6:reason,2:OK,7:headers,60:30:12", ":Content-Type,10:text/plain,]22:14:Content-Length,1:6,]]4:", "more,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), Some(())); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T74:4:from,4:test,2:id,1:1,3:seq,1:3#3:ext,15:5:mu", "lti,4:true!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_websocket() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T115:7:credits,4:1024#7:headers,16:12:3:Foo,3:Bar,]]3:uri,22:", "wss://example.com/path,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let buf = sock.borrow_mut().take_writable(); // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let mut req = httparse::Request::new(&mut headers); match req.parse(&buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "Foo: Bar\r\n", "\r\n", ), ws_key ); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); sock.borrow_mut().add_readable(resp_data.as_bytes()); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = format!( concat!( "handler T249:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:m", "ulti,4:true!}}4:code,3:101#6:reason,19:Switching Protocols", ",7:headers,114:22:7:Upgrade,9:websocket,]24:10:Connection,", "7:Upgrade,]56:20:Sec-WebSocket-Accept,28:{},]]7:credits,4:", "1024#}}", ), ws_accept ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = b"hello"; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); let buf = &msg[..]; let expected = concat!( "handler T96:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:mu", "lti,4:true!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let mut data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = &mut data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; websocket::apply_mask(content, fi.mask.unwrap(), 0); assert_eq!(str::from_utf8(content).unwrap(), "world"); } #[test] fn client_websocket_with_deflate() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T115:7:credits,4:1024#7:headers,16:12:3:Foo,3:Bar,]]3:uri,22:", "wss://example.com/path,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, true, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let buf = sock.borrow_mut().take_writable(); // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let mut req = httparse::Request::new(&mut headers); match req.parse(&buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "Foo: Bar\r\n", "\r\n", ), ws_key ); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n", ), ws_accept ); sock.borrow_mut().add_readable(resp_data.as_bytes()); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = format!( concat!( "handler T302:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:m", "ulti,4:true!}}4:code,3:101#6:reason,19:Switching Protocols", ",7:headers,168:22:7:Upgrade,9:websocket,]24:10:Connection,", "7:Upgrade,]56:20:Sec-WebSocket-Accept,28:{},]50:24:Sec-Web", "Socket-Extensions,18:permessage-deflate,]]7:credits,3:768#", "}}", ), ws_accept ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = { let src = b"hello"; let mut enc = websocket::DeflateEncoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = enc.encode(src, true, &mut dest).unwrap(); assert_eq!(read, 5); assert_eq!(output_end, true); dest.truncate(written); dest }; let size = websocket::write_header( true, true, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(&body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); let buf = &msg[..]; let expected = concat!( "handler T96:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:mu", "lti,4:true!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let mut data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = { let src = &mut data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; websocket::apply_mask(src, fi.mask.unwrap(), 0); let mut dec = websocket::DeflateDecoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = dec.decode(src, true, &mut dest).unwrap(); assert_eq!(read, src.len()); assert_eq!(output_end, true); dest.truncate(written); dest }; assert_eq!(str::from_utf8(&content).unwrap(), "world"); } #[test] fn bench_server_req_handler() { let t = BenchServerReqHandler::new(); t.run(&mut t.init()); } #[test] fn bench_server_req_connection() { let t = BenchServerReqConnection::new(); t.run(&mut t.init()); } #[test] fn bench_server_stream_handler() { let t = BenchServerStreamHandler::new(); t.run(&mut t.init()); } #[test] fn bench_server_stream_connection() { let t = BenchServerStreamConnection::new(); t.run(&mut t.init()); } } condure-1.10.0/src/event.rs000064400000000000000000000524621046102023000136310ustar 00000000000000/* * Copyright (C) 2021-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::list; use mio::event::Source; use mio::{Events, Interest, Poll, Token, Waker}; use slab::Slab; use std::cell::{Cell, RefCell}; use std::io; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Duration; const EVENTS_MAX: usize = 1024; const LOCAL_BUDGET: u32 = 10; pub type Readiness = Option; pub trait ReadinessExt { fn contains_any(&self, readiness: Interest) -> bool; fn merge(&mut self, readiness: Interest); } impl ReadinessExt for Readiness { fn contains_any(&self, readiness: Interest) -> bool { match *self { Some(cur) => { (readiness.is_readable() && cur.is_readable()) || (readiness.is_writable() && cur.is_writable()) } None => false, } } fn merge(&mut self, readiness: Interest) { match *self { Some(cur) => *self = Some(cur.add(readiness)), None => *self = Some(readiness), } } } struct SourceItem { subtoken: Token, interests: Interest, readiness: Readiness, } struct RegisteredSources { nodes: Slab>, ready: list::List, } struct LocalSources { registered_sources: RefCell, } impl LocalSources { fn new(max_sources: usize) -> Self { Self { registered_sources: RefCell::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !((item.interests.is_readable() && readiness.is_readable()) || (item.interests.is_writable() && readiness.is_writable())) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.borrow(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.borrow_mut(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct SyncSources { registered_sources: Mutex, waker: Waker, } impl SyncSources { fn new(max_sources: usize, waker: Waker) -> Self { Self { registered_sources: Mutex::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), waker, } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !((item.interests.is_readable() && readiness.is_readable()) || (item.interests.is_writable() && readiness.is_writable())) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { let need_wake = sources.ready.is_empty(); sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); if need_wake { self.waker.wake()?; } } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.lock().unwrap(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.lock().unwrap(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct CustomSources { local: Rc, sync: Arc, next_local_only: Cell, } impl CustomSources { fn new(poll: &Poll, token: Token, max_sources: usize) -> Result { let waker = Waker::new(poll.registry(), token)?; Ok(Self { local: Rc::new(LocalSources::new(max_sources)), sync: Arc::new(SyncSources::new(max_sources, waker)), next_local_only: Cell::new(false), }) } fn set_next_local_only(&self, enabled: bool) { self.next_local_only.set(enabled); } fn register_local( &self, registration: &LocalRegistration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if reg.data.is_none() { let key = self.local.register(subtoken, interests)?; reg.data = Some((key, self.local.clone())); if let Some(readiness) = reg.readiness { self.local.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister_local(&self, registration: &LocalRegistration) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if let Some((key, _)) = reg.data { self.local.deregister(key)?; reg.data = None; } Ok(()) } fn register( &self, registration: &Registration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if reg.data.is_none() { let key = self.sync.register(subtoken, interests)?; reg.data = Some((key, self.sync.clone())); if let Some(readiness) = reg.readiness { self.sync.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister(&self, registration: &Registration) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if let Some((key, _)) = reg.data { self.sync.deregister(key)?; reg.data = None; } Ok(()) } fn has_local_events(&self) -> bool { self.local.has_events() } fn has_events(&self) -> bool { if self.local.has_events() { return true; } if self.next_local_only.get() { return false; } self.sync.has_events() } fn next_event(&self) -> Option<(Token, Interest)> { if let Some(e) = self.local.next_event() { return Some(e); } if self.next_local_only.get() { return None; } if let Some(e) = self.sync.next_event() { return Some(e); } None } } struct RegistrationInner { data: Option<(usize, Arc)>, readiness: Readiness, } pub struct Registration { inner: Arc>, } impl Registration { pub fn new() -> (Self, SetReadiness) { let reg = Arc::new(Mutex::new(RegistrationInner { data: None, readiness: None, })); let registration = Self { inner: reg.clone() }; let set_readiness = SetReadiness { inner: reg }; (registration, set_readiness) } } impl Drop for Registration { fn drop(&mut self) { let mut reg = self.inner.lock().unwrap(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct SetReadiness { inner: Arc>, } impl SetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.inner.lock().unwrap(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } struct LocalRegistrationData { data: Option<(usize, Rc)>, readiness: Readiness, } pub struct LocalRegistrationEntry { data: RefCell, } pub struct LocalRegistration { entry: arena::Rc, } impl LocalRegistration { pub fn new(memory: &Rc>) -> (Self, LocalSetReadiness) { let reg = arena::Rc::new( LocalRegistrationEntry { data: RefCell::new(LocalRegistrationData { data: None, readiness: None, }), }, memory, ) .unwrap(); let registration = Self { entry: arena::Rc::clone(®), }; let set_readiness = LocalSetReadiness { entry: reg }; (registration, set_readiness) } } impl Drop for LocalRegistration { fn drop(&mut self) { let mut reg = self.entry.get().data.borrow_mut(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct LocalSetReadiness { entry: arena::Rc, } impl LocalSetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.entry.get().data.borrow_mut(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } #[derive(Debug, PartialEq)] pub struct Event { token: Token, readiness: Interest, } impl Event { pub fn token(&self) -> Token { self.token } pub fn readiness(&self) -> Interest { self.readiness } pub fn is_readable(&self) -> bool { self.readiness.is_readable() } pub fn is_writable(&self) -> bool { self.readiness.is_writable() } } pub struct Poller { poll: Poll, events: Events, custom_sources: CustomSources, local_registration_memory: Rc>, local_budget: u32, } impl Poller { pub fn new(max_custom_sources: usize) -> Result { let poll = Poll::new()?; let events = Events::with_capacity(EVENTS_MAX); let custom_sources = CustomSources::new(&poll, Token(0), max_custom_sources)?; Ok(Self { poll, events, custom_sources, local_registration_memory: Rc::new(arena::RcMemory::new(max_custom_sources)), local_budget: LOCAL_BUDGET, }) } pub fn register( &self, source: &mut S, token: Token, interests: Interest, ) -> Result<(), io::Error> where S: Source + ?Sized, { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.poll.registry().register(source, token, interests) } pub fn deregister(&self, source: &mut S) -> Result<(), io::Error> where S: Source + ?Sized, { self.poll.registry().deregister(source) } pub fn register_custom( &self, registration: &Registration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources.register(registration, token, interests) } pub fn deregister_custom(&self, registration: &Registration) -> Result<(), io::Error> { self.custom_sources.deregister(registration) } pub fn local_registration_memory(&self) -> &Rc> { &self.local_registration_memory } pub fn register_custom_local( &self, registration: &LocalRegistration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources .register_local(registration, token, interests) } pub fn deregister_custom_local( &self, registration: &LocalRegistration, ) -> Result<(), io::Error> { self.custom_sources.deregister_local(registration) } pub fn poll(&mut self, timeout: Option) -> Result<(), io::Error> { if self.custom_sources.has_local_events() && self.local_budget > 0 { self.local_budget -= 1; self.custom_sources.set_next_local_only(true); self.events.clear(); // don't reread previous mio events return Ok(()); } self.local_budget = LOCAL_BUDGET; self.custom_sources.set_next_local_only(false); let timeout = if self.custom_sources.has_events() { Some(Duration::from_millis(0)) } else { timeout }; loop { match self.poll.poll(&mut self.events, timeout) { Err(e) if e.kind() == io::ErrorKind::Interrupted => {} ret => break ret, } } } pub fn iter_events(&self) -> EventsIterator<'_, '_> { EventsIterator { events: self.events.iter(), custom_sources: &self.custom_sources, custom_left: EVENTS_MAX, } } } pub struct EventsIterator<'a, 'b> { events: mio::event::Iter<'b>, custom_sources: &'a CustomSources, custom_left: usize, } impl Iterator for EventsIterator<'_, '_> { type Item = Event; fn next(&mut self) -> Option { for event in self.events.by_ref() { if event.token() == Token(0) { continue; } let mut readiness = None; if event.is_readable() { readiness.merge(Interest::READABLE); } if event.is_writable() { readiness.merge(Interest::WRITABLE); } if let Some(readiness) = readiness { return Some(Event { token: event.token(), readiness, }); } } if self.custom_left > 0 { self.custom_left -= 1; if let Some((token, readiness)) = self.custom_sources.next_event() { return Some(Event { token, readiness }); } } None } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; #[test] fn test_readiness() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); 'poll: loop { poll.poll(&mut events, None).unwrap(); for event in &events { if event.token() == token { break 'poll; } } } assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_early() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sr.set_readiness(Interest::READABLE).unwrap(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); let event = events.iter().next(); assert!(event.is_some()); let event = event.unwrap(); assert_eq!(event.token(), token); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_local() { let poller = Poller::new(1).unwrap(); let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = LocalRegistration::new(poller.local_registration_memory()); sources .register_local(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_poller() { let token = Token(123); let mut poller = Poller::new(1).unwrap(); assert_eq!(poller.iter_events().next(), None); let (reg, sr) = Registration::new(); poller .register_custom(®, token, Interest::READABLE) .unwrap(); poller.poll(Some(Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); sr.set_readiness(Interest::READABLE).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), token); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); } } condure-1.10.0/src/executor.rs000064400000000000000000000371601046102023000143440ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::list; use crate::waker; use log::debug; use slab::Slab; use std::cell::RefCell; use std::future::Future; use std::io; use std::mem; use std::pin::Pin; use std::rc::{Rc, Weak}; use std::task::{Context, Waker}; use std::time::Duration; thread_local! { static EXECUTOR: RefCell>> = RefCell::new(None); } type BoxFuture = Pin>>; struct TaskWaker { tasks: Weak, task_id: usize, } impl waker::RcWake for TaskWaker { fn wake(self: Rc) { if let Some(tasks) = self.tasks.upgrade() { tasks.wake(self.task_id); } } } fn poll_fut(fut: &mut BoxFuture, waker: Waker) -> bool { // convert from Pin to Pin<&mut> let fut: Pin<&mut dyn Future> = fut.as_mut(); let mut cx = Context::from_waker(&waker); fut.poll(&mut cx).is_ready() } struct Task { fut: Option>>>, wakeable: bool, } struct TasksData { nodes: Slab>, next: list::List, wakers: Vec>, } struct Tasks { data: RefCell, pre_poll: RefCell>>, } impl Tasks { fn new(max: usize) -> Rc { let data = TasksData { nodes: Slab::with_capacity(max), next: list::List::default(), wakers: Vec::with_capacity(max), }; let tasks = Rc::new(Self { data: RefCell::new(data), pre_poll: RefCell::new(None), }); { let data = &mut *tasks.data.borrow_mut(); for task_id in 0..data.nodes.capacity() { data.wakers.push(Rc::new(TaskWaker { tasks: Rc::downgrade(&tasks), task_id, })); } } tasks } fn is_empty(&self) -> bool { self.data.borrow().nodes.is_empty() } fn have_next(&self) -> bool { !self.data.borrow().next.is_empty() } fn add(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let data = &mut *self.data.borrow_mut(); if data.nodes.len() == data.nodes.capacity() { return Err(()); } let entry = data.nodes.vacant_entry(); let nkey = entry.key(); let task = Task { fut: Some(Box::pin(fut)), wakeable: false, }; entry.insert(list::Node::new(task)); data.next.push_back(&mut data.nodes, nkey); Ok(()) } fn remove(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; // drop the future. this should cause it to drop any owned wakers task.fut = None; // at this point, we should be the only remaining owner assert_eq!(Rc::strong_count(&data.wakers[nkey]), 1); data.next.remove(&mut data.nodes, nkey); data.nodes.remove(nkey); } fn take_next_list(&self) -> list::List { let data = &mut *self.data.borrow_mut(); let mut l = list::List::default(); l.concat(&mut data.nodes, &mut data.next); l } fn append_to_next_list(&self, mut l: list::List) { let data = &mut *self.data.borrow_mut(); data.next.concat(&mut data.nodes, &mut l); } fn take_task(&self, l: &mut list::List) -> Option<(usize, BoxFuture, Waker)> { let nkey = match l.head { Some(nkey) => nkey, None => return None, }; let data = &mut *self.data.borrow_mut(); l.remove(&mut data.nodes, nkey); let task = &mut data.nodes[nkey].value; // both of these are cheap let fut = task.fut.take().unwrap(); let waker = waker::into_std(data.wakers[nkey].clone()); task.wakeable = true; Some((nkey, fut, waker)) } fn process_next(&self) { let mut l = self.take_next_list(); while let Some((task_id, mut fut, waker)) = self.take_task(&mut l) { self.pre_poll(); let done = poll_fut(&mut fut, waker); // take_task() took the future out of the task, so we // could poll it without having to maintain a borrow of // the tasks set. we'll put it back now self.set_fut(task_id, fut); if done { self.remove(task_id); } } } fn set_fut(&self, task_id: usize, fut: BoxFuture) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; task.fut = Some(fut); } fn wake(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let node = &mut data.nodes[nkey]; if !node.value.wakeable { return; } node.value.wakeable = false; data.next.push_back(&mut data.nodes, nkey); } fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { *self.pre_poll.borrow_mut() = Some(Box::new(pre_poll_fn)); } fn pre_poll(&self) { let pre_poll = &mut *self.pre_poll.borrow_mut(); if let Some(f) = pre_poll { f(); } } } pub struct Executor { tasks: Rc, } impl Executor { pub fn new(tasks_max: usize) -> Self { let tasks = Tasks::new(tasks_max); EXECUTOR.with(|ex| { if ex.borrow().is_some() { panic!("thread already has an Executor"); } ex.replace(Some(Rc::downgrade(&tasks))); }); Self { tasks } } #[allow(clippy::result_unit_err)] pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { debug!("spawning future with size {}", mem::size_of::()); self.tasks.add(fut) } pub fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { self.tasks.set_pre_poll(pre_poll_fn); } pub fn have_tasks(&self) -> bool { !self.tasks.is_empty() } pub fn run_until_stalled(&self) { while self.tasks.have_next() { self.tasks.process_next() } } pub fn run(&self, mut park: F) -> Result<(), io::Error> where F: FnMut(Option) -> Result<(), io::Error>, { loop { self.tasks.process_next(); if !self.have_tasks() { break; } let (timeout, low_priority_tasks) = if self.tasks.have_next() { // some tasks trigger their own waker and return Pending in // order to achieve a yielding effect. in that case they will // already be queued up for processing again. move these // tasks aside so that they can be deprioritized, and use a // timeout of 0 when parking so we can quickly resume them let timeout = Duration::from_millis(0); let l = self.tasks.take_next_list(); (Some(timeout), Some(l)) } else { (None, None) }; park(timeout)?; // requeue any tasks that had yielded if let Some(l) = low_priority_tasks { self.tasks.append_to_next_list(l); } } Ok(()) } pub fn current() -> Option { EXECUTOR.with(|ex| { (*ex.borrow_mut()).as_mut().map(|tasks| Self { tasks: tasks.upgrade().unwrap(), }) }) } pub fn spawner(&self) -> Spawner { Spawner { tasks: Rc::downgrade(&self.tasks), } } } impl Drop for Executor { fn drop(&mut self) { EXECUTOR.with(|ex| { if Rc::strong_count(&self.tasks) == 1 { ex.replace(None); } }); } } pub struct Spawner { tasks: Weak, } impl Spawner { #[allow(clippy::result_unit_err)] pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let tasks = match self.tasks.upgrade() { Some(tasks) => tasks, None => return Err(()), }; let ex = Executor { tasks }; ex.spawn(fut) } } #[cfg(test)] mod tests { use super::*; use std::cell::Cell; use std::mem; use std::task::Poll; struct TestFutureData { ready: bool, waker: Option, } struct TestFuture { data: Rc>, } impl TestFuture { fn new() -> Self { let data = TestFutureData { ready: false, waker: None, }; Self { data: Rc::new(RefCell::new(data)), } } fn handle(&self) -> TestHandle { TestHandle { data: Rc::clone(&self.data), } } } impl Future for TestFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut data = self.data.borrow_mut(); match data.ready { true => Poll::Ready(()), false => { data.waker = Some(cx.waker().clone()); Poll::Pending } } } } struct TestHandle { data: Rc>, } impl TestHandle { fn set_ready(&self) { let data = &mut *self.data.borrow_mut(); data.ready = true; if let Some(waker) = data.waker.take() { waker.wake(); } } } struct EarlyWakeFuture { done: bool, } impl EarlyWakeFuture { fn new() -> Self { Self { done: false } } } impl Future for EarlyWakeFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if !self.done { self.done = true; cx.waker().wake_by_ref(); return Poll::Pending; } Poll::Ready(()) } } #[test] fn test_executor_step() { let executor = Executor::new(1); let fut1 = TestFuture::new(); let fut2 = TestFuture::new(); let handle1 = fut1.handle(); let handle2 = fut2.handle(); let started = Rc::new(Cell::new(false)); let fut1_done = Rc::new(Cell::new(false)); let finishing = Rc::new(Cell::new(false)); { let started = Rc::clone(&started); let fut1_done = Rc::clone(&fut1_done); let finishing = Rc::clone(&finishing); executor .spawn(async move { started.set(true); fut1.await; fut1_done.set(true); fut2.await; finishing.set(true); }) .unwrap(); } // not started yet, no progress assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), false); executor.run_until_stalled(); // started, but fut1 not ready assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), true); assert_eq!(fut1_done.get(), false); handle1.set_ready(); executor.run_until_stalled(); // fut1 finished assert_eq!(executor.have_tasks(), true); assert_eq!(fut1_done.get(), true); assert_eq!(finishing.get(), false); handle2.set_ready(); executor.run_until_stalled(); // fut2 finished, and thus the task finished assert_eq!(finishing.get(), true); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_run() { let executor = Executor::new(1); let fut = TestFuture::new(); let handle = fut.handle(); executor .spawn(async move { fut.await; }) .unwrap(); executor .run(|_| { handle.set_ready(); Ok(()) }) .unwrap(); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_spawn_error() { let executor = Executor::new(1); assert!(executor.spawn(async {}).is_ok()); assert!(executor.spawn(async {}).is_err()); } #[test] fn test_executor_current() { assert!(Executor::current().is_none()); let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor .spawn(async move { Executor::current() .unwrap() .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); let current = Executor::current().unwrap(); assert_eq!(executor.have_tasks(), false); assert!(current.spawn(async {}).is_ok()); assert_eq!(executor.have_tasks(), true); mem::drop(executor); assert!(Executor::current().is_some()); mem::drop(current); assert!(Executor::current().is_none()); } #[test] fn test_executor_spawner() { let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); let spawner = executor.spawner(); executor .spawn(async move { spawner .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } #[test] fn test_executor_early_wake() { let executor = Executor::new(1); let fut = EarlyWakeFuture::new(); executor .spawn(async move { fut.await; }) .unwrap(); let mut park_count = 0; executor .run(|_| { park_count += 1; Ok(()) }) .unwrap(); assert_eq!(park_count, 1); } #[test] fn test_executor_pre_poll() { let executor = Executor::new(1); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor.set_pre_poll(move || { flag.set(true); }); } executor.spawn(async {}).unwrap(); assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } } condure-1.10.0/src/future.rs000064400000000000000000003150531046102023000140200ustar 00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::channel; use crate::event::{self, ReadinessExt}; use crate::net::{NetListener, NetStream, SocketAddr}; use crate::reactor::{CustomEvented, FdEvented, IoEvented, Reactor, Registration, TimerEvented}; use crate::resolver; use crate::shuffle::shuffle; use crate::tls::{TlsStream, TlsStreamError, VerifyMode}; use crate::waker::{RefWake, RefWaker, RefWakerData}; use crate::zmq::{MultipartHeader, ZmqSocket}; use mio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use openssl::ssl; use paste::paste; use std::cell::{Cell, Ref, RefCell}; use std::future::Future; use std::io::{self, Read, Write}; use std::mem; use std::os::fd::{FromRawFd, IntoRawFd}; use std::path::Path; use std::pin::Pin; use std::rc::Rc; use std::sync::mpsc; use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; pub const REGISTRATIONS_PER_CHANNEL: usize = 1; // 1 for the zmq fd, and potentially 1 for the retry timer pub const REGISTRATIONS_PER_ZMQSOCKET: usize = 2; pub struct PollFuture { fut: F, } impl Future for PollFuture where F: Future + Unpin, { type Output = Poll; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let s = &mut *self; Poll::Ready(Pin::new(&mut s.fut).poll(cx)) } } pub fn poll_async(fut: F) -> PollFuture where F: Future + Unpin, { PollFuture { fut } } fn range_unordered(dest: &mut [usize]) -> &[usize] { for (index, v) in dest.iter_mut().enumerate() { *v = index; } shuffle(dest); dest } fn map_poll(cx: &mut Context, fut: &mut F, wrap_func: W) -> Poll where F: Future + Unpin, W: FnOnce(F::Output) -> V, { match Pin::new(fut).poll(cx) { Poll::Ready(v) => Poll::Ready(wrap_func(v)), Poll::Pending => Poll::Pending, } } macro_rules! declare_select { ($count: literal, ( $($num:literal),* )) => { paste! { pub enum []<$([], )*> { $( []: [], )* } impl<$([], )*> Future for []<$([]::Output, )*>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; $count]; for i in range_unordered(&mut indexes) { let s = &mut *self; let p = match i + 1 { $( $num => map_poll(cx, &mut s.[], |v| []<$([], )*> where $( []: Future + Unpin, )* { [