condure-1.3.1/.cargo_vcs_info.json0000644000000001120000000000000124660ustar { "git": { "sha1": "ad201732fa71ee160aefcfc49f9a058ca89f515b" } } condure-1.3.1/.gitignore000064400000000000000000000000240000000000000132260ustar 00000000000000 /target **/*.rs.bk condure-1.3.1/CHANGELOG.md000064400000000000000000000006760000000000000130640ustar 00000000000000Condure Changelog ================= v. 1.3.1 (2021-08-11) * Fixes for high load. v. 1.3.0 (2021-07-29) * Port to async/await. v. 1.2.0 (2021-05-04) * Send PING/PONG frame data to clients. * Port to mio 0.7. v. 1.1.0 (2020-11-02) * TLS support. * Don't preallocate connection buffers. * Start using async/await in some places. v. 1.0.1 (2020-07-24) * Remove some unsafe usage. v. 1.0.0 (2020-07-21) * Stable version. condure-1.3.1/COPYING000064400000000000000000000261360000000000000123050ustar 00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. condure-1.3.1/Cargo.lock0000644000000532760000000000000104640ustar # This file is automatically @generated by Cargo. # It is not intended for manual editing. version = 3 [[package]] name = "ansi_term" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" dependencies = [ "winapi", ] [[package]] name = "arrayvec" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" [[package]] name = "atty" version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ "hermit-abi", "libc", "winapi", ] [[package]] name = "autocfg" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" [[package]] name = "base64" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" [[package]] name = "bitflags" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" [[package]] name = "bstr" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" dependencies = [ "lazy_static", "memchr", "regex-automata", "serde", ] [[package]] name = "bumpalo" version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631" [[package]] name = "cast" version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a" dependencies = [ "rustc_version", ] [[package]] name = "cc" version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70cc2f62c6ce1868963827bd677764c62d07c3d9a3e1fb1177ee1a9ab199eb2" [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" version = "2.33.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002" dependencies = [ "ansi_term", "atty", "bitflags", "strsim", "term_size", "textwrap", "unicode-width", "vec_map", ] [[package]] name = "condure" version = "1.3.1" dependencies = [ "arrayvec", "base64", "clap", "criterion", "httparse", "iovec", "log", "mio", "openssl", "sha1", "signal-hook", "slab", "time", "zmq", ] [[package]] name = "criterion" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab327ed7354547cc2ef43cbe20ef68b988e70b4b593cbd66a2a61733123a3d23" dependencies = [ "atty", "cast", "clap", "criterion-plot", "csv", "itertools 0.10.1", "lazy_static", "num-traits", "oorandom", "plotters", "rayon", "regex", "serde", "serde_cbor", "serde_derive", "serde_json", "tinytemplate", "walkdir", ] [[package]] name = "criterion-plot" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e022feadec601fba1649cfa83586381a4ad31c6bf3a9ab7d408118b05dd9889d" dependencies = [ "cast", "itertools 0.9.0", ] [[package]] name = "crossbeam-channel" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" dependencies = [ "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-deque" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" dependencies = [ "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" dependencies = [ "cfg-if", "crossbeam-utils", "lazy_static", "memoffset", "scopeguard", ] [[package]] name = "crossbeam-utils" version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" dependencies = [ "cfg-if", "lazy_static", ] [[package]] name = "csv" version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" dependencies = [ "bstr", "csv-core", "itoa", "ryu", "serde", ] [[package]] name = "csv-core" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" dependencies = [ "memchr", ] [[package]] name = "either" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" [[package]] name = "error-chain" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9435d864e017c3c6afeac1654189b06cdb491cf2ff73dbf0d73b0f292f42ff8" [[package]] name = "foreign-types" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ "foreign-types-shared", ] [[package]] name = "foreign-types-shared" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "half" version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62aca2aba2d62b4a7f5b33f3712cb1b0692779a56fb510499d5c0aa594daeaf3" [[package]] name = "hermit-abi" version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" dependencies = [ "libc", ] [[package]] name = "httparse" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68" [[package]] name = "iovec" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e" dependencies = [ "libc", ] [[package]] name = "itertools" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" dependencies = [ "either", ] [[package]] name = "itertools" version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf" dependencies = [ "either", ] [[package]] name = "itoa" version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" [[package]] name = "js-sys" version = "0.3.51" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83bdfbace3a0e81a4253f73b49e960b053e396a11012cbd49b9b74d6a2b67062" dependencies = [ "wasm-bindgen", ] [[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" version = "0.2.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" [[package]] name = "log" version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if", ] [[package]] name = "memchr" version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" [[package]] name = "memoffset" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9" dependencies = [ "autocfg", ] [[package]] name = "metadeps" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73b122901b3a675fac8cecf68dcb2f0d3036193bc861d1ac0e1c337f7d5254c2" dependencies = [ "error-chain", "pkg-config", "toml", ] [[package]] name = "mio" version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" dependencies = [ "libc", "log", "miow", "ntapi", "winapi", ] [[package]] name = "miow" version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" dependencies = [ "winapi", ] [[package]] name = "ntapi" version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" dependencies = [ "winapi", ] [[package]] name = "num-traits" version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" dependencies = [ "autocfg", ] [[package]] name = "num_cpus" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" dependencies = [ "hermit-abi", "libc", ] [[package]] name = "once_cell" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" [[package]] name = "oorandom" version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" [[package]] name = "openssl" version = "0.10.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "549430950c79ae24e6d02e0b7404534ecf311d94cc9f861e9e4020187d13d885" dependencies = [ "bitflags", "cfg-if", "foreign-types", "libc", "once_cell", "openssl-sys", ] [[package]] name = "openssl-sys" version = "0.9.65" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a7907e3bfa08bb85105209cdfcb6c63d109f8f6c1ed6ca318fff5c1853fbc1d" dependencies = [ "autocfg", "cc", "libc", "pkg-config", "vcpkg", ] [[package]] name = "pkg-config" version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" [[package]] name = "plotters" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a3fd9ec30b9749ce28cd91f255d569591cdf937fe280c312143e3c4bad6f2a" dependencies = [ "num-traits", "plotters-backend", "plotters-svg", "wasm-bindgen", "web-sys", ] [[package]] name = "plotters-backend" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d88417318da0eaf0fdcdb51a0ee6c3bed624333bff8f946733049380be67ac1c" [[package]] name = "plotters-svg" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "521fa9638fa597e1dc53e9412a4f9cefb01187ee1f7413076f9e6749e2885ba9" dependencies = [ "plotters-backend", ] [[package]] name = "proc-macro2" version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" dependencies = [ "unicode-xid", ] [[package]] name = "quote" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" dependencies = [ "proc-macro2", ] [[package]] name = "rayon" version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90" dependencies = [ "autocfg", "crossbeam-deque", "either", "rayon-core", ] [[package]] name = "rayon-core" version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" dependencies = [ "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", "lazy_static", "num_cpus", ] [[package]] name = "regex" version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" dependencies = [ "regex-syntax", ] [[package]] name = "regex-automata" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" [[package]] name = "regex-syntax" version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" [[package]] name = "rustc_version" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ "semver", ] [[package]] name = "ryu" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" [[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ "winapi-util", ] [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "semver" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f3aac57ee7f3272d8395c6e4f502f434f0e289fcd62876f70daa008c20dcabe" [[package]] name = "serde" version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03" [[package]] name = "serde_cbor" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e18acfa2f90e8b735b2836ab8d538de304cbb6729a7360729ea5a895d15a622" dependencies = [ "half", "serde", ] [[package]] name = "serde_derive" version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "963a7dbc9895aeac7ac90e74f34a5d5261828f79df35cbed41e10189d3804d43" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "serde_json" version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" dependencies = [ "itoa", "ryu", "serde", ] [[package]] name = "sha1" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" [[package]] name = "signal-hook" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "470c5a6397076fae0094aaf06a08e6ba6f37acb77d3b1b91ea92b4d6c8650c39" dependencies = [ "libc", "signal-hook-registry", ] [[package]] name = "signal-hook-registry" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" dependencies = [ "libc", ] [[package]] name = "slab" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f173ac3d1a7e3b28003f40de0b5ce7fe2710f9b9dc3fc38664cebee46b3b6527" [[package]] name = "strsim" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "syn" version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7" dependencies = [ "proc-macro2", "quote", "unicode-xid", ] [[package]] name = "term_size" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e4129646ca0ed8f45d09b929036bafad5377103edd06e50bf574b353d2b08d9" dependencies = [ "libc", "winapi", ] [[package]] name = "textwrap" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" dependencies = [ "term_size", "unicode-width", ] [[package]] name = "time" version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" dependencies = [ "libc", "wasi", "winapi", ] [[package]] name = "tinytemplate" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ "serde", "serde_json", ] [[package]] name = "toml" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "736b60249cb25337bc196faa43ee12c705e426f3d55c214d73a4e7be06f92cb4" [[package]] name = "unicode-width" version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" [[package]] name = "unicode-xid" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vec_map" version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" [[package]] name = "walkdir" version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" dependencies = [ "same-file", "winapi", "winapi-util", ] [[package]] name = "wasi" version = "0.10.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" [[package]] name = "wasm-bindgen" version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d54ee1d4ed486f78874278e63e4069fc1ab9f6a18ca492076ffb90c5eb2997fd" dependencies = [ "cfg-if", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b33f6a0694ccfea53d94db8b2ed1c3a8a4c86dd936b13b9f0a15ec4a451b900" dependencies = [ "bumpalo", "lazy_static", "log", "proc-macro2", "quote", "syn", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "088169ca61430fe1e58b8096c24975251700e7b1f6fd91cc9d59b04fb9b18bd4" dependencies = [ "quote", "wasm-bindgen-macro-support", ] [[package]] name = "wasm-bindgen-macro-support" version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be2241542ff3d9f241f5e2cb6dd09b37efe786df8851c54957683a49f0987a97" dependencies = [ "proc-macro2", "quote", "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7cff876b8f18eed75a66cf49b65e7f967cb354a7aa16003fb55dbfd25b44b4f" [[package]] name = "web-sys" version = "0.3.51" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e828417b379f3df7111d3a2a9e5753706cae29c41f7c4029ee9fd77f3e09e582" dependencies = [ "js-sys", "wasm-bindgen", ] [[package]] name = "winapi" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ "winapi-i686-pc-windows-gnu", "winapi-x86_64-pc-windows-gnu", ] [[package]] name = "winapi-i686-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" dependencies = [ "winapi", ] [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "zmq" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aad98a7a617d608cd9e1127147f630d24af07c7cd95ba1533246d96cbdd76c66" dependencies = [ "bitflags", "libc", "log", "zmq-sys", ] [[package]] name = "zmq-sys" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d33a2c51dde24d5b451a2ed4b488266df221a5eaee2ee519933dc46b9a9b3648" dependencies = [ "libc", "metadeps", ] condure-1.3.1/Cargo.toml0000644000000042210000000000000104710ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies # # If you believe there's an error in this file please file an # issue against the rust-lang/cargo repository. If you're # editing this file be aware that the upstream Cargo.toml # will likely look very different (and much more reasonable) [package] edition = "2018" name = "condure" version = "1.3.1" authors = ["Justin Karneges "] autobins = false description = "HTTP/WebSocket connection manager" readme = "README.md" license = "Apache-2.0" repository = "https://github.com/fanout/condure" [package.metadata.deb] build-depends = "pkg-config, libzmq3-dev" extended-description = "Condure is a service that manages network connections on behalf of server\napplications, in order to allow controlling the connections from multiple\nprocesses. Applications communicate with Condure over ZeroMQ.\n\nCondure can only manage connections for protocols it knows about. Currently\nthis is HTTP/1 and WebSockets.\n\nThe project was inspired by Mongrel2.\n" separate-debug-symbols = true [package.metadata.rpm] package = "condure" [package.metadata.rpm.cargo] buildflags = ["--release"] [package.metadata.rpm.targets.condure] path = "/usr/bin/condure" [profile.dev] panic = "abort" [profile.release] panic = "abort" [[bin]] name = "condure" [[bench]] name = "server" harness = false [dependencies.arrayvec] version = "0.5" [dependencies.base64] version = "0.12" [dependencies.clap] version = "2.33" features = ["wrap_help"] [dependencies.httparse] version = "1.3" [dependencies.iovec] version = "0.1" [dependencies.log] version = "0.4" [dependencies.mio] version = "0.7" features = ["os-poll", "os-ext", "net"] [dependencies.openssl] version = "0.10" [dependencies.sha1] version = "0.6" [dependencies.signal-hook] version = "0.3" [dependencies.slab] version = "0.4" [dependencies.time] version = "0.1" [dependencies.zmq] version = "0.9" [dev-dependencies.criterion] version = "0.3" condure-1.3.1/Cargo.toml.orig000064400000000000000000000025730000000000000141400ustar 00000000000000[package] name = "condure" version = "1.3.1" authors = ["Justin Karneges "] license = "Apache-2.0" description = "HTTP/WebSocket connection manager" repository = "https://github.com/fanout/condure" readme = "README.md" edition = "2018" autobins = false [[bin]] name = "condure" [profile.dev] panic = "abort" [profile.release] panic = "abort" [dependencies] arrayvec = "0.5" base64 = "0.12" clap = { version = "2.33", features = ["wrap_help"] } httparse = "1.3" iovec = "0.1" log = "0.4" mio = { version = "0.7", features = ["os-poll", "os-ext", "net"] } openssl = "0.10" sha1 = "0.6" signal-hook = "0.3" slab = "0.4" time = "0.1" zmq = "0.9" [dev-dependencies] criterion = "0.3" [[bench]] name = "server" harness = false [package.metadata.deb] extended-description = """\ Condure is a service that manages network connections on behalf of server applications, in order to allow controlling the connections from multiple processes. Applications communicate with Condure over ZeroMQ. Condure can only manage connections for protocols it knows about. Currently this is HTTP/1 and WebSockets. The project was inspired by Mongrel2. """ separate-debug-symbols = true build-depends = "pkg-config, libzmq3-dev" [package.metadata.rpm] package = "condure" [package.metadata.rpm.cargo] buildflags = ["--release"] [package.metadata.rpm.targets] condure = { path = "/usr/bin/condure" } condure-1.3.1/README.md000064400000000000000000000131340000000000000125230ustar 00000000000000# Condure Condure is a service that manages network connections on behalf of server applications, in order to allow controlling the connections from multiple processes. Applications communicate with Condure over [ZeroMQ](https://zeromq.org/). Condure can only manage connections for protocols it knows about. Currently this is HTTP/1 and WebSockets. See [Supported protocols](#supported-protocols). The project was inspired by [Mongrel2](https://mongrel2.org/). ## Use cases * Pass connection ownership from one process to another. * Restart an application without its connections getting disconnected. * Balance connection ownership among multiple processes. ## Basic usage Start the server: ``` $ condure --listen 8000 --zclient-stream ipc://client ``` Connect a handler to it, such as this simple Python program: ```py # this handler responds to every request with "hello world" import os import time import tnetstring import zmq instance_id = 'basichandler.{}'.format(os.getpid()).encode() ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') # await subscription time.sleep(0.01) while True: m_raw = in_sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) ``` A client request: ``` $ curl -i http://localhost:8000 HTTP/1.1 200 OK Content-Type: text/plain Content-Length: 12 hello world ``` The process that receives the request doesn't need to be the same one that responds! For example, here's a program that outputs request IDs to stdout: ```py # this handler just outputs the request ID import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.PULL) sock.connect('ipc://client-out') while True: m = sock.recv_multipart() req = tnetstring.loads(m[0][1:]) print('{} {}'.format(req[b'from'].decode(), req[b'id'].decode())) ``` We can see request ID information when a client request is made: ``` $ python examples/printreq.py condure 0-0-0 ``` From another shell we can respond using a program like this: ```py # this program sends a response to a certain request ID import sys import time import tnetstring import zmq body = sys.argv[1] addr = sys.argv[2].encode() rid = sys.argv[3].encode() ctx = zmq.Context() sock = ctx.socket(zmq.PUB) sock.connect('ipc://client-in') # await subscription time.sleep(0.01) resp = {} resp[b'from'] = b'sendresp' resp[b'id'] = rid resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = '{}\n'.format(body).encode() m = [addr + b' T' + tnetstring.dumps(resp)] sock.send_multipart(m) ``` For example: ``` $ python examples/sendresp.py "responding from another process" condure 0-0-0 ``` The client sees: ``` $ curl -i http://localhost:8000 HTTP/1.1 200 OK Content-Type: text/plain Content-Length: 32 responding from another process ``` For easy testing, the programs can be piped together: ``` $ python -u examples/printreq.py | xargs -n 2 python examples/sendresp.py "responding from another process" ``` ## Suspending and resuming connections When passing control of a connection from one process to another, it is important to suspend the connection first. This is done by sending a `handoff-start` message and waiting for a `handoff-proceed` message. At that point, the connection information can be given to another process, and the connection can be resumed by sending any message (such as `keep-alive`). See the [ZHTTP spec](https://rfc.zeromq.org/spec/33/). ## REQ mode In addition to the stream mode which uses PUSH/ROUTER/SUB sockets, there is a "REQ" mode available which uses a DEALER socket. To enable it, set `req` as the mode on a listen port. This mode can be handy for implementing simple request/response servers using ZeroMQ. ## Supported protocols Condure supports HTTP/1 and WebSockets. Condure manages connections at layer 7 and only supports protocols it knows about. This is to simplify its usage. Handling arbitrary protocols would require applications to build protocol stacks capable of suspending/resuming sessions at arbitrary byte positions in TCP streams, making Condure usage prohibitive. Instead, Condure is protocol-aware, and provides parsed frames to applications, so that applications are only required to support suspending/resuming sessions at frame boundaries. ## Performance Condure was built for high performance. It uses numerous optimization techniques, including minimal heap allocations, ring buffers, vectored I/O, hierarchical timing wheels, and fast data structures (e.g. slabs). Over 1M concurrent connections have been tested on a single instance using just 2 workers (4 threads total). More detailed benchmarks to come. ## Comparison to Mongrel2 * Condure supports multiple cores. * Condure supports listening on multiple ports without requiring multiple processes. * Condure does not support multiple routes and is not intended to be a shared server. Each application that wants to keep connections in a separate process should spawn its own Condure instance. * Condure has no config file. Configuration is supplied using command line arguments. * Condure uses a different ZeroMQ-based protocol, [ZHTTP](https://rfc.zeromq.org/spec/33/), which is easier to use than Mongrel2's protocol and more reliable. ## Future plans * HTTP/2 * HTTP/3 condure-1.3.1/benches/server.rs000064400000000000000000000040570000000000000145330ustar 00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use condure::server::TestServer; use criterion::{criterion_group, criterion_main, Criterion}; use std::io::{Read, Write}; use std::net::SocketAddr; use std::str; const REQS_PER_ITER: usize = 10; fn req(addr: SocketAddr) { let mut clients = Vec::new(); for _ in 0..REQS_PER_ITER { let mut client = std::net::TcpStream::connect(&addr).unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); clients.push(client); } for client in clients.iter_mut() { let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n" ); } } fn criterion_benchmark(c: &mut Criterion) { { let server = TestServer::new(1); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req workers=1", |b| b.iter(|| req(req_addr))); c.bench_function("stream workers=1", |b| b.iter(|| req(stream_addr))); } { let server = TestServer::new(2); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req workers=2", |b| b.iter(|| req(req_addr))); c.bench_function("stream workers=2", |b| b.iter(|| req(stream_addr))); } } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); condure-1.3.1/examples/basichandler.py000064400000000000000000000014370000000000000160560ustar 00000000000000# this handler responds to every request with "hello world" import os import time import tnetstring import zmq instance_id = 'basichandler.{}'.format(os.getpid()).encode() ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') # await subscription time.sleep(0.01) while True: m_raw = in_sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) condure-1.3.1/examples/holdhandler.py000064400000000000000000000077400000000000000157260ustar 00000000000000# this handler holds all connections open import os import time import datetime import calendar import tnetstring import zmq CONN_TTL = 60000 EXPIRE_INTERVAL = 60000 instance_id = 'holdhandler.{}'.format(os.getpid()).encode('utf-8') ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') in_stream_sock = ctx.socket(zmq.ROUTER) in_stream_sock.identity = instance_id in_stream_sock.connect('ipc://client-out-stream') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') poller = zmq.Poller() poller.register(in_sock, zmq.POLLIN) poller.register(in_stream_sock, zmq.POLLIN) class Connection(object): def __init__(self, rid): self.rid = rid self.seq = 0 self.exp_time = None def send_msg(self, msg): msg[b'from'] = instance_id msg[b'id'] = self.rid[1] msg[b'seq'] = self.seq self.seq += 1 print('OUT {} {}'.format(self.rid[0], msg)) out_sock.send(self.rid[0] + b' T' + tnetstring.dumps(msg)) def send_header(self): msg = {} msg[b'code'] = 200 msg[b'reason'] = b'OK' msg[b'headers'] = [[b'Content-Type', b'text/plain']] msg[b'more'] = True self.send_msg(msg) def send_body(self, data): msg = {} msg[b'body'] = data msg[b'more'] = True self.send_msg(msg) def send_body(to_addr, conns, data): ids = [] for c in conns: ids.append({b'id': c.rid[1], b'seq': c.seq}) c.seq += 1 msg = {} msg[b'from'] = instance_id msg[b'id'] = ids msg[b'body'] = data msg[b'more'] = True print('OUT {} {}'.format(to_addr, msg)) out_sock.send(to_addr + b' T' + tnetstring.dumps(msg)) conns = {} last_exp_time = int(time.time()) while True: socks = dict(poller.poll(1000)) if socks.get(in_sock) == zmq.POLLIN: m_raw = in_sock.recv() elif socks.get(in_stream_sock) == zmq.POLLIN: m_list = in_stream_sock.recv_multipart() m_raw = m_list[2] else: m_raw = None now = int(time.time() * 1000) if m_raw is not None: req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) m_from = req[b'from'] m_id = req[b'id'] m_type = req.get(b'type', b'') ids = [] if isinstance(m_id, list): for id_seq in m_id: ids.append(id_seq[b'id']) else: ids.append(m_id) new_ids = [] known_conns = [] for i in ids: rid = (m_from, i) c = conns.get(rid) if c: c.exp_time = now + CONN_TTL known_conns.append(c) else: new_ids.append(rid) # data if not m_type: for rid in new_ids: c = Connection(rid) conns[rid] = c c.exp_time = now + CONN_TTL c.send_header() elif c: if m_type == b'keep-alive': dt = datetime.datetime.utcnow() ts = calendar.timegm(dt.timetuple()) body = ( 'id: TCPKaliMsgTS-{:016x}.\n' 'event: message\n' 'data: {:04}-{:02}-{:02}T{:02}:{:02}:{:02}\n\n' ).format( (ts * 1000000) + dt.microsecond, dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second ).encode() send_body(m_from, known_conns, body) elif m_type == b'cancel': for c in known_conns: del conns[c.rid] if now >= last_exp_time + EXPIRE_INTERVAL: last_exp_time = now to_remove = [] for rid, c in conns.items(): if now >= c.exp_time: to_remove.append(rid) for rid in to_remove: print('expired {}'.format(rid)) del conns[rid] condure-1.3.1/examples/printreq.py000064400000000000000000000004610000000000000152770ustar 00000000000000# this handler just outputs the request ID import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.PULL) sock.connect('ipc://client-out') while True: m = sock.recv_multipart() req = tnetstring.loads(m[0][1:]) print('{} {}'.format(req[b'from'].decode(), req[b'id'].decode())) condure-1.3.1/examples/reqhandler.py000064400000000000000000000010370000000000000155600ustar 00000000000000# this handler responds to every request with "hello world" import tnetstring import zmq ctx = zmq.Context() sock = ctx.socket(zmq.REP) sock.connect('ipc://client') while True: m_raw = sock.recv() req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) resp = {} resp[b'id'] = req[b'id'] resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = b'hello world\n' print('OUT {}'.format(resp)) sock.send(b'T' + tnetstring.dumps(resp)) condure-1.3.1/examples/sendresp.py000064400000000000000000000010740000000000000152570ustar 00000000000000# this program sends a response to a certain request ID import sys import time import tnetstring import zmq body = sys.argv[1] addr = sys.argv[2].encode() rid = sys.argv[3].encode() ctx = zmq.Context() sock = ctx.socket(zmq.PUB) sock.connect('ipc://client-in') # await subscription time.sleep(0.01) resp = {} resp[b'from'] = b'sendresp' resp[b'id'] = rid resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'body'] = '{}\n'.format(body).encode() m = [addr + b' T' + tnetstring.dumps(resp)] sock.send_multipart(m) condure-1.3.1/examples/streamhandler.py000064400000000000000000000051030000000000000162620ustar 00000000000000# this handler responds to every request with "hello world" import os import tnetstring import zmq instance_id = 'streamhandler.{}'.format(os.getpid()).encode('utf-8') ctx = zmq.Context() in_sock = ctx.socket(zmq.PULL) in_sock.connect('ipc://client-out') in_stream_sock = ctx.socket(zmq.ROUTER) in_stream_sock.identity = instance_id in_stream_sock.connect('ipc://client-out-stream') out_sock = ctx.socket(zmq.PUB) out_sock.connect('ipc://client-in') poller = zmq.Poller() poller.register(in_sock, zmq.POLLIN) poller.register(in_stream_sock, zmq.POLLIN) while True: socks = dict(poller.poll(None)) if socks.get(in_sock) == zmq.POLLIN: m_raw = in_sock.recv() elif socks.get(in_stream_sock) == zmq.POLLIN: m_list = in_stream_sock.recv_multipart() m_raw = m_list[2] else: continue req = tnetstring.loads(m_raw[1:]) print('IN {}'.format(req)) if req.get(b'type'): # skip all non-data messages continue if req.get(b'uri', b'').startswith(b'ws'): resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 0 resp[b'code'] = 101 resp[b'reason'] = b'Switching Protocols' resp[b'credits'] = 1024 print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 1 resp[b'body'] = b'hello world' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 2 resp[b'type'] = b'close' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) else: resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 0 resp[b'code'] = 200 resp[b'reason'] = b'OK' resp[b'headers'] = [[b'Content-Type', b'text/plain']] resp[b'more'] = True resp[b'credits'] = 1024 print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) resp = {} resp[b'from'] = instance_id resp[b'id'] = req[b'id'] resp[b'seq'] = 1 resp[b'body'] = b'hello world\n' print('OUT {} {}'.format(req[b'from'], resp)) out_sock.send(req[b'from'] + b' T' + tnetstring.dumps(resp)) condure-1.3.1/rustfmt.toml000064400000000000000000000000250000000000000136400ustar 00000000000000# using the defaults condure-1.3.1/src/app.rs000064400000000000000000000163430000000000000131660ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::server::{Server, MSG_RETAINED_PER_CONNECTION_MAX, MSG_RETAINED_PER_WORKER_MAX}; use crate::zhttpsocket; use crate::zmq::SpecInfo; use log::info; use signal_hook; use signal_hook::consts::TERM_SIGNALS; use signal_hook::iterator::Signals; use std::cmp; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; fn make_specs(base: &str) -> Result<(String, String, String), String> { if base.starts_with("ipc:") { Ok(( format!("{}-{}", base, "out"), format!("{}-{}", base, "out-stream"), format!("{}-{}", base, "in"), )) } else if base.starts_with("tcp:") { match base.rfind(':') { Some(pos) => match base[(pos + 1)..base.len()].parse::() { Ok(port) => Ok(( format!("{}:{}", &base[..pos], port), format!("{}:{}", &base[..pos], port + 1), format!("{}:{}", &base[..pos], port + 2), )), Err(e) => Err(format!("error parsing tcp port in base spec: {}", e)), }, None => Err("tcp base spec must specify port".into()), } } else { Err("base spec must be ipc or tcp".into()) } } pub struct ListenConfig { pub addr: SocketAddr, pub stream: bool, pub tls: bool, pub default_cert: Option, } pub struct Config { pub instance_id: String, pub workers: usize, pub req_maxconn: usize, pub stream_maxconn: usize, pub buffer_size: usize, pub body_buffer_size: usize, pub messages_max: usize, pub req_timeout: Duration, pub stream_timeout: Duration, pub listen: Vec, pub zclient_req: Vec, pub zclient_stream: Vec, pub zclient_connect: bool, pub ipc_file_mode: usize, pub certs_dir: PathBuf, } pub struct App { _server: Server, } impl App { pub fn new(config: &Config) -> Result { if config.req_maxconn < config.workers { return Err("req maxconn must be >= workers".into()); } if config.stream_maxconn < config.workers { return Err("stream maxconn must be >= workers".into()); } let zmq_context = Arc::new(zmq::Context::new()); // set hwm to 5% of maxconn let hwm = cmp::max((config.req_maxconn + config.stream_maxconn) / 20, 1); let handle_bound = cmp::max(hwm / config.workers, 1); let maxconn = config.req_maxconn + config.stream_maxconn; let mut zsockman = zhttpsocket::SocketManager::new( Arc::clone(&zmq_context), &config.instance_id, (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * config.workers), hwm, handle_bound, ); let mut any_req = false; let mut any_stream = false; for lc in config.listen.iter() { if lc.stream { any_stream = true; } else { any_req = true; } } if any_req { let mut specs = Vec::new(); for spec in config.zclient_req.iter() { if config.zclient_connect { info!("zhttp client connect {}", spec); } else { info!("zhttp client bind {}", spec); } specs.push(SpecInfo { spec: spec.clone(), bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_req_specs(&specs) { return Err(format!("failed to set zhttp client req specs: {}", e)); } } if any_stream { let mut out_specs = Vec::new(); let mut out_stream_specs = Vec::new(); let mut in_specs = Vec::new(); for spec in config.zclient_stream.iter() { let (out_spec, out_stream_spec, in_spec) = make_specs(spec)?; if config.zclient_connect { info!( "zhttp client connect {} {} {}", out_spec, out_stream_spec, in_spec ); } else { info!( "zhttp client bind {} {} {}", out_spec, out_stream_spec, in_spec ); } out_specs.push(SpecInfo { spec: out_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); out_stream_specs.push(SpecInfo { spec: out_stream_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); in_specs.push(SpecInfo { spec: in_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_stream_specs(&out_specs, &out_stream_specs, &in_specs) { return Err(format!("failed to set zhttp client stream specs: {}", e)); } } let server = Server::new( &config.instance_id, config.workers, config.req_maxconn, config.stream_maxconn, config.buffer_size, config.body_buffer_size, config.messages_max, config.req_timeout, config.stream_timeout, &config.listen, config.certs_dir.as_path(), zsockman, handle_bound, )?; Ok(Self { _server: server }) } pub fn wait_for_term(&self) { let mut signals = Signals::new(TERM_SIGNALS).unwrap(); let term_now = Arc::new(AtomicBool::new(false)); // ensure two term signals in a row causes the app to immediately exit for signal_type in TERM_SIGNALS { signal_hook::flag::register_conditional_shutdown( *signal_type, 1, // exit code Arc::clone(&term_now), ) .unwrap(); signal_hook::flag::register(*signal_type, Arc::clone(&term_now)).unwrap(); } // wait for termination for signal in &mut signals { match signal { signal_type if TERM_SIGNALS.contains(&signal_type) => break, _ => unreachable!(), } } } } condure-1.3.1/src/arena.rs000064400000000000000000000407350000000000000134760ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use slab::Slab; use std::cell::{RefCell, RefMut}; use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::{Mutex, MutexGuard}; pub struct EntryGuard<'a, T> { entries: RefMut<'a, Slab>, entry: &'a mut T, key: usize, } impl EntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for EntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for EntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a sharable slab for use within a single thread. // operations are protected by a RefCell. when an element is retrieved for // reading or modification, it is wrapped in a EntryGuard which keeps the // entire slab borrowed until the caller is done working with the element pub struct Memory { entries: RefCell>, } impl Memory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: RefCell::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.borrow(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.borrow_mut(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.borrow_mut(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in EntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the EntryGuard // and we can extend the lifetime of the reference beyond the // RefMut let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(EntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.borrow(); entries.get(0).unwrap() as *const T } } pub struct SyncEntryGuard<'a, T> { entries: MutexGuard<'a, Slab>, entry: &'a mut T, key: usize, } impl SyncEntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for SyncEntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for SyncEntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a thread-safe slab. operations are protected by a // mutex. when an element is retrieved for reading or modification, it is // wrapped in a EntryGuard which keeps the entire slab locked until the // caller is done working with the element pub struct SyncMemory { entries: Mutex>, } impl SyncMemory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.lock().unwrap(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.lock().unwrap(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in SyncEntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the SyncEntryGuard // and we can extend the lifetime of the reference beyond the // MutexGuard let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(SyncEntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.lock().unwrap(); entries.get(0).unwrap() as *const T } } pub struct ReusableValue { reusable: std::sync::Arc>, value: *mut T, key: usize, } impl ReusableValue { // vec element addresses are guaranteed to be stable once created, // and elements are only removed when the Reusable is dropped, and // the Arc'd Reusable is guaranteed to live as long as // ReusableValue, therefore it is safe to assume the element will // live at least as long as the ReusableValue fn get<'a>(&'a self) -> &'a T { unsafe { &*self.value } } fn get_mut<'a>(&'a mut self) -> &'a mut T { unsafe { &mut *self.value } } } impl Drop for ReusableValue { fn drop(&mut self) { let mut entries = self.reusable.entries.lock().unwrap(); entries.0.remove(self.key); } } impl Deref for ReusableValue { type Target = T; fn deref(&self) -> &Self::Target { self.get() } } impl DerefMut for ReusableValue { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } // like Memory, but for preinitializing each value and reusing pub struct Reusable { entries: Mutex<(Slab<()>, Vec)>, } impl Reusable { pub fn new(capacity: usize, init_fn: F) -> Self where F: Fn() -> T, { let mut values = Vec::with_capacity(capacity); for _ in 0..capacity { values.push(init_fn()); } // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new((s, values)), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.0.len() } pub fn reserve(self: &std::sync::Arc) -> Result, ()> { let mut entries = self.entries.lock().unwrap(); // out of capacity. the number of buffers is fixed if entries.0.len() == entries.0.capacity() { return Err(()); } let key = entries.0.insert(()); let value = &mut entries.1[key] as *mut T; Ok(ReusableValue { reusable: self.clone(), value, key, }) } } pub struct RcEntry { value: T, refs: usize, } pub type RcMemory = Memory>; pub struct Rc { memory: std::rc::Rc>, key: usize, } impl Rc { pub fn new(v: T, memory: &std::rc::Rc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: std::rc::Rc::clone(memory), key, }) } pub fn clone(rc: &Rc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Rc won't be dropped until this Rc drops, // therefore it is safe to assume the entry managed by this Rc will // live at least as long as this Rc, and we can extend the lifetime // of the reference beyond the EntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Rc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } pub type ArcMemory = SyncMemory>; pub struct Arc { memory: std::sync::Arc>, key: usize, } impl Arc { pub fn new(v: T, memory: &std::sync::Arc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: memory.clone(), key, }) } pub fn clone(rc: &Arc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Arc won't be dropped until this Arc drops, // therefore it is safe to assume the entry managed by this Arc will // live at least as long as this Arc, and we can extend the lifetime // of the reference beyond the SyncEntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Arc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } // adapted from https://github.com/rust-lang/rfcs/pull/2802 pub fn recycle_vec(mut v: Vec) -> Vec { assert_eq!(core::mem::size_of::(), core::mem::size_of::()); assert_eq!(core::mem::align_of::(), core::mem::align_of::()); v.clear(); let ptr = v.as_mut_ptr(); let capacity = v.capacity(); mem::forget(v); let ptr = ptr as *mut U; unsafe { Vec::from_raw_parts(ptr, 0, capacity) } } // ReusableVec inspired by recycle_vec pub struct ReusableVecHandle<'a, T> { vec: &'a mut Vec, } impl ReusableVecHandle<'_, T> { pub fn get_ref(&self) -> &Vec { &self.vec } pub fn get_mut(&mut self) -> &mut Vec { &mut self.vec } } impl Drop for ReusableVecHandle<'_, T> { fn drop(&mut self) { self.vec.clear(); } } impl Deref for ReusableVecHandle<'_, T> { type Target = Vec; fn deref(&self) -> &Self::Target { self.get_ref() } } impl DerefMut for ReusableVecHandle<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } pub struct ReusableVec { vec: Vec<()>, size: usize, align: usize, } impl ReusableVec { pub fn new(capacity: usize) -> Self { let size = mem::size_of::(); let align = mem::align_of::(); let vec: Vec = Vec::with_capacity(capacity); // safety: we must cast to Vec before using, where U has the same // size and alignment as T let vec: Vec<()> = unsafe { mem::transmute(vec) }; Self { vec, size, align } } pub fn get_as_new<'a, U>(&'a mut self) -> ReusableVecHandle<'a, U> { let size = mem::size_of::(); let align = mem::align_of::(); // if these don't match, panic. it's up the user to ensure the type // is acceptable assert_eq!(self.size, size); assert_eq!(self.align, align); let vec: &mut Vec<()> = &mut self.vec; // safety: U has the expected size and alignment let vec: &mut Vec = unsafe { mem::transmute(vec) }; // the vec starts empty, and is always cleared when the handle drops. // get_as_new() borrows self mutably, so it's not possible to create // a handle when one already exists assert!(vec.is_empty()); ReusableVecHandle { vec } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_reusable() { let reusable = std::sync::Arc::new(Reusable::new(2, || vec![0; 128])); assert_eq!(reusable.len(), 0); let mut buf1 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 1); let mut buf2 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 2); // no room assert!(reusable.reserve().is_err()); &mut buf1[..5].copy_from_slice(b"hello"); &mut buf2[..5].copy_from_slice(b"world"); assert_eq!(&buf1[..5], b"hello"); assert_eq!(&buf2[..5], b"world"); mem::drop(buf1); assert_eq!(reusable.len(), 1); mem::drop(buf2); assert_eq!(reusable.len(), 0); } #[test] fn test_rc() { let memory = std::rc::Rc::new(RcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Rc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Rc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Rc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Rc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_arc() { let memory = std::sync::Arc::new(ArcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Arc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Arc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Arc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Arc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_reusable_vec() { let mut vec_mem = ReusableVec::new::(100); let mut vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); vec.push(1); assert_eq!(vec.len(), 1); mem::drop(vec); let vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); } } condure-1.3.1/src/buffer.rs000064400000000000000000000451620000000000000136600ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::cell::RefCell; use std::cmp; use std::io; use std::io::{Read, Write}; use std::rc::Rc; pub const VECTORED_MAX: usize = 8; pub trait RefRead { fn get_ref(&self) -> &[u8]; fn get_mut(&mut self) -> &mut [u8]; fn consume(&mut self, amt: usize); fn get_ref_vectored<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { assert!(bufs.len() >= 1); bufs[0] = self.get_ref(); &mut bufs[..1] } fn get_mut_vectored<'data, 'bufs>( &'data mut self, bufs: &'bufs mut [&'data mut [u8]], ) -> &'bufs mut [&'data mut [u8]] { assert!(bufs.len() >= 1); bufs[0] = self.get_mut(); &mut bufs[..1] } } impl RefRead for io::Cursor<&mut [u8]> { fn get_ref(&self) -> &[u8] { let pos = self.position() as usize; &self.get_ref()[pos..] } fn get_mut(&mut self) -> &mut [u8] { let pos = self.position() as usize; &mut self.get_mut()[pos..] } fn consume(&mut self, amt: usize) { let pos = self.position(); self.set_position(pos + (amt as u64)); } } pub fn write_vectored_offset( writer: &mut dyn Write, bufs: &[&[u8]], offset: usize, ) -> Result { if bufs.is_empty() { return Ok(0); } let mut offset = offset; let mut start = 0; while offset >= bufs[start].len() { // on the last buf? if start + 1 >= bufs.len() { // exceeding the last buf is an error if offset > bufs[start].len() { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } return Ok(0); } offset -= bufs[start].len(); start += 1; } let mut arr = [io::IoSlice::new(&b""[..]); VECTORED_MAX]; let mut arr_len = 0; for i in start..bufs.len() { let buf = if i == start { &bufs[i][offset..] } else { bufs[i] }; arr[arr_len] = io::IoSlice::new(buf); arr_len += 1; } writer.write_vectored(&arr[..arr_len]) } pub trait LimitBufs<'a> { fn limit(&mut self, size: usize) -> &mut [&'a [u8]]; } impl<'a> LimitBufs<'a> for [&'a [u8]] { fn limit(&mut self, size: usize) -> &mut [&'a [u8]] { let mut want = size; for i in 0..self.len() { let buf_len = self[i].len(); if buf_len >= want { self[i] = &self[i][..want]; return &mut self[..(i + 1)]; } want -= buf_len; } self } } pub struct Buffer { buf: Vec, start: usize, end: usize, } impl Buffer { pub fn new(size: usize) -> Buffer { let buf = vec![0; size]; Buffer { buf, start: 0, end: 0, } } pub fn clear(&mut self) { self.start = 0; self.end = 0; } pub fn read_avail(&self) -> usize { self.end - self.start } pub fn read_buf(&self) -> &[u8] { &self.buf[self.start..self.end] } pub fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); self.start += amount; } pub fn write_avail(&self) -> usize { self.buf.len() - self.end } pub fn write_buf(&mut self) -> &mut [u8] { let len = self.buf.len(); &mut self.buf[self.end..len] } pub fn write_commit(&mut self, amount: usize) { assert!(self.end + amount <= self.buf.len()); self.end += amount; } } #[cfg(test)] impl Read for Buffer { fn read(&mut self, buf: &mut [u8]) -> Result { let src = self.read_buf(); let size = cmp::min(src.len(), buf.len()); &buf[..size].copy_from_slice(&src[..size]); self.read_commit(size); Ok(size) } } impl Write for Buffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len()); &dest[..size].copy_from_slice(&buf[..size]); self.write_commit(size); Ok(size) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct TmpBuffer(RefCell>); impl TmpBuffer { pub fn new(size: usize) -> Self { Self(RefCell::new(vec![0; size])) } pub fn len(&self) -> usize { self.0.borrow().len() } } pub struct RingBuffer { buf: Vec, start: usize, end: usize, tmp: Rc, } impl RingBuffer { pub fn new(size: usize, tmp: &Rc) -> RingBuffer { assert!(size <= tmp.len()); let buf = vec![0; size]; RingBuffer { buf, start: 0, end: 0, tmp: Rc::clone(tmp), } } pub fn capacity(&self) -> usize { self.buf.len() } pub fn clear(&mut self) { self.start = 0; self.end = 0; } pub fn write_from(&mut self, r: &mut dyn Read) -> Result { let size = match r.read(self.write_buf()) { Ok(size) => size, Err(e) => return Err(e), }; self.write_commit(size); Ok(size) } pub fn read_avail(&self) -> usize { self.end - self.start } pub fn read_buf(&self) -> &[u8] { let end = cmp::min(self.end, self.buf.len()); &self.buf[self.start..end] } pub fn read_buf_mut(&mut self) -> &mut [u8] { let end = cmp::min(self.end, self.buf.len()); &mut self.buf[self.start..end] } pub fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); self.start += amount; if self.start == self.end { self.start = 0; self.end = 0; } else if self.start >= self.buf.len() { self.start -= self.buf.len(); self.end -= self.buf.len(); } } pub fn write_avail(&self) -> usize { self.buf.len() - (self.end - self.start) } pub fn write_buf(&mut self) -> &mut [u8] { let (start, end) = if self.end < self.buf.len() { (self.end, self.buf.len()) } else { (self.end - self.buf.len(), self.start) }; &mut self.buf[start..end] } pub fn write_commit(&mut self, amount: usize) { assert!((self.end - self.start) + amount <= self.buf.len()); self.end += amount; } pub fn align(&mut self) -> usize { if self.start == 0 { return 0; } let size = self.end - self.start; if self.end <= self.buf.len() { // if the buffer hasn't wrapped, simply copy down self.buf.copy_within(self.start.., 0); } else if size <= self.start { // if the buffer has wrapped, but the wrapped part can be copied // without overlapping, then copy the wrapped part followed by // initial part let left_size = self.end - self.buf.len(); let right_size = self.buf.len() - self.start; self.buf.copy_within(..left_size, right_size); self.buf .copy_within(self.start..(self.start + right_size), 0); } else { // if the buffer has wrapped and the wrapped part can't be copied // without overlapping, then use a temporary buffer to // facilitate. smaller part is copied to the temp buffer, then // the larger and small parts (in that order) are copied into // their intended locations. in the worst case, up to 50% of // the buffer may be copied twice let left_size = self.end - self.buf.len(); let right_size = self.buf.len() - self.start; let (lsize, lsrc, ldest, hsize, hsrc, hdest); if left_size < right_size { lsize = left_size; hsize = right_size; lsrc = 0; ldest = hsize; hsrc = self.start; hdest = 0; } else { lsize = right_size; hsize = left_size; lsrc = self.start; ldest = 0; hsrc = 0; hdest = lsize; } let mut tmp = self.tmp.0.borrow_mut(); &mut tmp[..lsize].copy_from_slice(&self.buf[lsrc..(lsrc + lsize)]); self.buf.copy_within(hsrc..(hsrc + hsize), hdest); &mut self.buf[ldest..(ldest + lsize)].copy_from_slice(&mut tmp[..lsize]); } self.start = 0; self.end = size; size } } #[cfg(test)] impl Read for RingBuffer { fn read(&mut self, buf: &mut [u8]) -> Result { let mut pos = 0; while pos < buf.len() && self.read_avail() > 0 { let src = self.read_buf(); let size = cmp::min(src.len(), buf.len() - pos); &buf[pos..(pos + size)].copy_from_slice(&src[..size]); self.read_commit(size); pos += size; } Ok(pos) } } impl Write for RingBuffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let mut pos = 0; while pos < buf.len() && self.write_avail() > 0 { let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len() - pos); &dest[..size].copy_from_slice(&buf[pos..(pos + size)]); self.write_commit(size); pos += size; } Ok(pos) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } impl RefRead for RingBuffer { fn get_ref(&self) -> &[u8] { self.read_buf() } fn get_mut(&mut self) -> &mut [u8] { self.read_buf_mut() } fn consume(&mut self, amt: usize) { self.read_commit(amt); } fn get_ref_vectored<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { assert!(bufs.len() >= 1); let buf_len = self.buf.len(); if self.end > buf_len && bufs.len() >= 2 { let (part1, part2) = self.buf.split_at(self.start); bufs[0] = part2; bufs[1] = &part1[..(self.end - buf_len)]; &mut bufs[..2] } else { bufs[0] = &self.buf[self.start..self.end]; &mut bufs[..1] } } fn get_mut_vectored<'data, 'bufs>( &'data mut self, bufs: &'bufs mut [&'data mut [u8]], ) -> &'bufs mut [&'data mut [u8]] { assert!(bufs.len() >= 1); let buf_len = self.buf.len(); if self.end > buf_len && bufs.len() >= 2 { let (part1, part2) = self.buf.split_at_mut(self.start); bufs[0] = part2; bufs[1] = &mut part1[..(self.end - buf_len)]; &mut bufs[..2] } else { bufs[0] = &mut self.buf[self.start..self.end]; &mut bufs[..1] } } } #[cfg(test)] mod tests { use super::*; use std::io::{Read, Write}; #[test] fn test_write_vectored_offset() { struct MyWriter { bufs: Vec, } impl MyWriter { fn new() -> Self { Self { bufs: Vec::new() } } } impl Write for MyWriter { fn write(&mut self, buf: &[u8]) -> Result { self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { total += buf.len(); self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } // empty let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[], 0); assert_eq!(r.unwrap(), 0); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 6); assert!(r.is_err()); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 12); assert!(r.is_err()); // nothing to write let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 5); assert_eq!(r.unwrap(), 0); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 0); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "apple"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 3); assert_eq!(r.unwrap(), 2); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "le"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 3); assert_eq!(r.unwrap(), 8); assert_eq!(w.bufs.len(), 2); assert_eq!(w.bufs[0], "le"); assert_eq!(w.bufs[1], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 5); assert_eq!(r.unwrap(), 6); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 6); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "anana"); } #[test] fn test_buffer() { let mut b = Buffer::new(8); assert_eq!(b.read_avail(), 0); assert_eq!(b.write_avail(), 8); let size = b.write(b"hello").unwrap(); assert_eq!(size, 5); assert_eq!(b.read_avail(), 5); assert_eq!(b.write_avail(), 3); let size = b.write(b"world").unwrap(); assert_eq!(size, 3); assert_eq!(b.read_avail(), 8); assert_eq!(b.write_avail(), 0); let mut tmp = [0; 16]; let size = b.read(&mut tmp).unwrap(); assert_eq!(&tmp[..size], b"hellowor"); b.clear(); assert_eq!(b.read_avail(), 0); assert_eq!(b.write_avail(), 8); } #[test] fn test_ringbuffer() { let mut buf = [0u8; 8]; let tmp = Rc::new(TmpBuffer::new(8)); let mut r = RingBuffer::new(8, &tmp); assert_eq!(r.read_avail(), 0); assert_eq!(r.write_avail(), 8); r.write(b"12345").unwrap(); assert_eq!(r.read_avail(), 5); assert_eq!(r.write_avail(), 3); r.write(b"678").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.write_avail(), 0); assert_eq!(r.read_buf(), b"12345678"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"12345678"); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.read_avail(), 3); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"9abcd").unwrap(); assert_eq!(r.read_avail(), 8); assert_eq!(r.write_avail(), 0); r.read(&mut buf[5..]).unwrap(); assert_eq!(r.read_avail(), 5); assert_eq!(r.write_avail(), 3); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.read_avail(), 0); assert_eq!(r.write_avail(), 8); assert_eq!(&buf, b"9abcd678"); r.write(b"12345").unwrap(); r.read(&mut buf[..2]).unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"345"); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 3); r.align(); assert_eq!(r.read_avail(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(r.write_avail(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"6789a").unwrap(); r.read(&mut buf[..2]).unwrap(); r.write(b"bc").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"56789a"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"56789a"); assert_eq!(bufs[1], b"bc"); assert_eq!(r.write_avail(), 0); r.align(); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"56789abc"); assert_eq!(r.write_avail(), 0); r.read(&mut buf[..6]).unwrap(); r.write(b"def123").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"bc"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"bc"); assert_eq!(bufs[1], b"def123"); assert_eq!(r.write_avail(), 0); r.align(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.get_ref_vectored(&mut bufs_arr); assert_eq!(r.read_avail(), 8); assert_eq!(r.read_buf(), b"bcdef123"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"bcdef123"); assert_eq!(r.write_avail(), 0); r.clear(); r.write(b"12345678").unwrap(); r.read(&mut buf[..6]).unwrap(); r.write(b"9abc").unwrap(); assert_eq!(r.read_avail(), 6); assert_eq!(r.read_buf().len(), 2); r.align(); assert_eq!(r.read_avail(), 6); assert_eq!(r.read_buf().len(), 6); } } condure-1.3.1/src/channel.rs000064400000000000000000000560770000000000000140260ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::event; use crate::list; use mio; use slab::Slab; use std::cell::RefCell; use std::collections::VecDeque; use std::mem; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::sync::Arc; pub struct Sender { sender: Option>, read_set_readiness: event::SetReadiness, write_registration: event::Registration, cts: Option>, } impl Sender { // NOTE: only makes sense for rendezvous channels pub fn can_send(&self) -> bool { match &self.cts { Some(cts) => cts.load(Ordering::Relaxed), None => true, } } pub fn get_write_registration(&self) -> &event::Registration { &self.write_registration } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { if let Some(cts) = &self.cts { if cts .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) .is_err() { return Err(mpsc::TrySendError::Full(t)); } // cts will only be true if a read was performed while the queue // was empty, and this function is the only place where the queue // is written to. this means the try_send call below will only // fail if the receiver disconnected } match self.sender.as_ref().unwrap().try_send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } pub fn send(&self, t: T) -> Result<(), mpsc::SendError> { if self.cts.is_some() { panic!("blocking send with rendezvous channel not supported") } match self.sender.as_ref().unwrap().send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } } impl Drop for Sender { fn drop(&mut self) { mem::drop(self.sender.take().unwrap()); self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); } } pub struct Receiver { receiver: mpsc::Receiver, read_registration: event::Registration, write_set_readiness: event::SetReadiness, cts: Option>, } impl Receiver { pub fn get_read_registration(&self) -> &event::Registration { &self.read_registration } pub fn try_recv(&self) -> Result { match self.receiver.try_recv() { Ok(t) => { if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } Err(mpsc::TryRecvError::Empty) if self.cts.is_some() => { let cts = self.cts.as_ref().unwrap(); if cts .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Err(mpsc::TryRecvError::Empty) } Err(e) => Err(e), } } pub fn recv(&self) -> Result { let t = self.receiver.recv()?; if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } } pub fn channel(bound: usize) -> (Sender, Receiver) { let (read_reg, read_sr) = event::Registration::new(); let (write_reg, write_sr) = event::Registration::new(); // rendezvous channel if bound == 0 { let (s, r) = mpsc::sync_channel::(1); let cts = Arc::new(AtomicBool::new(false)); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: Some(Arc::clone(&cts)), }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: Some(Arc::clone(&cts)), }; (sender, receiver) } else { let (s, r) = mpsc::sync_channel::(bound); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: None, }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: None, }; // channel is immediately writable receiver .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); (sender, receiver) } } struct LocalSenderData { notified: bool, write_set_readiness: event::LocalSetReadiness, } struct LocalSenders { nodes: Slab>, waiting: list::List, } struct LocalChannel { queue: RefCell>, senders: RefCell, read_set_readiness: RefCell>, } impl LocalChannel { fn senders_is_empty(&self) -> bool { self.senders.borrow().nodes.is_empty() } fn add_sender(&self, write_sr: event::LocalSetReadiness) -> Result { let mut senders = self.senders.borrow_mut(); if senders.nodes.len() == senders.nodes.capacity() { return Err(()); } let key = senders.nodes.insert(list::Node::new(LocalSenderData { notified: false, write_set_readiness: write_sr, })); Ok(key) } fn remove_sender(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); senders.waiting.remove(&mut senders.nodes, key); senders.nodes.remove(key); if senders.nodes.is_empty() { if let Some(read_sr) = &*self.read_set_readiness.borrow() { // notify for disconnect read_sr.set_readiness(mio::Interest::READABLE).unwrap(); } } } fn set_sender_waiting(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); // add if not already present if senders.nodes[key].prev.is_none() && senders.waiting.head != Some(key) { senders.waiting.push_back(&mut senders.nodes, key); } } fn notify_one_sender(&self) { let senders = &mut *self.senders.borrow_mut(); // notify next waiting sender, if any if let Some(key) = senders.waiting.pop_front(&mut senders.nodes) { let sender = &mut senders.nodes[key].value; sender.notified = true; sender .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } } fn sender_is_notified(&self, key: usize) -> bool { self.senders.borrow().nodes[key].value.notified } fn clear_sender_notified(&self, key: usize) { self.senders.borrow_mut().nodes[key].value.notified = false; } } pub struct LocalSender { channel: Rc>, key: usize, write_registration: event::LocalRegistration, } impl LocalSender { pub fn get_write_registration(&self) -> &event::LocalRegistration { &self.write_registration } // if this returns true, then the next call to try_send() by any sender // is guaranteed to not return TrySendError::Full. // if this returns false, the sender is added to the wait list pub fn check_send(&self) -> bool { let queue = self.channel.queue.borrow(); let can_send = queue.len() < queue.capacity(); if !can_send { self.channel.set_sender_waiting(self.key); } can_send } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { // we are acting, so clear the notified flag self.channel.clear_sender_notified(self.key); let read_sr = &*self.channel.read_set_readiness.borrow(); let read_sr = match read_sr { Some(sr) => sr, None => { // receiver is disconnected return Err(mpsc::TrySendError::Disconnected(t)); } }; let mut queue = self.channel.queue.borrow_mut(); if queue.len() < queue.capacity() { queue.push_back(t); read_sr.set_readiness(mio::Interest::READABLE).unwrap(); Ok(()) } else { self.channel.set_sender_waiting(self.key); Err(mpsc::TrySendError::Full(t)) } } pub fn cancel(&self) { // if we were notified but never acted on it, notify the next waiting sender, if any if self.channel.sender_is_notified(self.key) { self.channel.clear_sender_notified(self.key); self.channel.notify_one_sender(); } } pub fn try_clone( &self, memory: &Rc>, ) -> Result { let (write_reg, write_sr) = event::LocalRegistration::new(memory); let key = self.channel.add_sender(write_sr)?; Ok(Self { channel: self.channel.clone(), key, write_registration: write_reg, }) } } impl Drop for LocalSender { fn drop(&mut self) { self.cancel(); self.channel.remove_sender(self.key); } } pub struct LocalReceiver { channel: Rc>, read_registration: event::LocalRegistration, } impl LocalReceiver { pub fn get_read_registration(&self) -> &event::LocalRegistration { &self.read_registration } pub fn try_recv(&self) -> Result { let mut queue = self.channel.queue.borrow_mut(); if queue.is_empty() { if self.channel.senders_is_empty() { return Err(mpsc::TryRecvError::Disconnected); } return Err(mpsc::TryRecvError::Empty); } let value = queue.pop_front().unwrap(); self.channel.notify_one_sender(); Ok(value) } } impl Drop for LocalReceiver { fn drop(&mut self) { *self.channel.read_set_readiness.borrow_mut() = None; } } pub fn local_channel( bound: usize, max_senders: usize, memory: &Rc>, ) -> (LocalSender, LocalReceiver) { let (read_reg, read_sr) = event::LocalRegistration::new(memory); let (write_reg, write_sr) = event::LocalRegistration::new(memory); // no support for rendezvous channels assert!(bound > 0); // need to support at least one sender assert!(max_senders > 0); let channel = Rc::new(LocalChannel { queue: RefCell::new(VecDeque::with_capacity(bound)), senders: RefCell::new(LocalSenders { nodes: Slab::with_capacity(max_senders), waiting: list::List::default(), }), read_set_readiness: RefCell::new(Some(read_sr)), }); let key = channel.add_sender(write_sr).unwrap(); let sender = LocalSender { channel: channel.clone(), key, write_registration: write_reg, }; let receiver = LocalReceiver { channel: channel.clone(), read_registration: read_reg, }; (sender, receiver) } #[cfg(test)] mod tests { use super::*; use std::time; #[test] fn test_send_recv_bound0() { let (sender, receiver) = channel(0); assert_eq!(sender.can_send(), false); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); assert_eq!(sender.can_send(), true); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); assert_eq!(sender.can_send(), false); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn test_send_recv_bound1() { let (sender, receiver) = channel(1); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn test_notify_bound0() { let (sender, receiver) = channel(0); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); assert_eq!(sender.can_send(), false); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); assert_eq!(sender.can_send(), true); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn test_notify_bound1() { let (sender, receiver) = channel(1); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn test_local_send_recv() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(receiver.try_recv(), Ok(1)); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); assert_eq!(sender1.try_send(2), Ok(())); let channel = sender2.channel.clone(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender2.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); assert_eq!(sender2.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); mem::drop(sender1); mem::drop(sender2); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Disconnected)); } #[test] fn test_local_send_disc() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); mem::drop(receiver); assert_eq!(sender.try_send(1), Err(mpsc::TrySendError::Disconnected(1))); } #[test] fn test_local_cancel() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); let channel = sender2.channel.clone(); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(sender2.try_send(2), Err(mpsc::TrySendError::Full(2))); assert_eq!(sender1.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); sender2.cancel(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, true ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender1.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); } #[test] fn test_local_check_send() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); let channel = sender.channel.clone(); assert_eq!(sender.check_send(), true); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.try_send(1), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.check_send(), false); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, true ); assert_eq!(sender.try_send(2), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); } } condure-1.3.1/src/connection.rs000064400000000000000000003623740000000000000145550ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::buffer::{Buffer, LimitBufs, RefRead, RingBuffer, TmpBuffer, VECTORED_MAX}; use crate::http1; use crate::websocket; use crate::zhttppacket; use crate::zhttpsocket; use arrayvec::{ArrayString, ArrayVec}; use log::debug; use std::cell::{Ref, RefCell}; use std::cmp; use std::collections::VecDeque; use std::io; use std::io::{Read, Write}; use std::net::SocketAddr; use std::rc::Rc; use std::str; use std::str::FromStr; use std::time::{Duration, Instant}; const URI_SIZE_MAX: usize = 4096; const WS_HASH_INPUT_MAX: usize = 256; const ZHTTP_SESSION_TIMEOUT: Duration = Duration::from_secs(60); pub trait Shutdown { fn shutdown(&mut self) -> Result<(), io::Error>; } pub trait ZhttpSender { fn can_send_to(&self) -> bool; fn send(&mut self, message: zmq::Message) -> Result<(), zhttpsocket::SendError>; fn send_to(&mut self, addr: &[u8], message: zmq::Message) -> Result<(), zhttpsocket::SendError>; } #[derive(PartialEq)] enum Mode { HttpReq, HttpStream, WebSocket, } fn get_host<'a>(headers: &'a [httparse::Header]) -> &'a str { for h in headers.iter() { if h.name.eq_ignore_ascii_case("Host") { match str::from_utf8(h.value) { Ok(s) => return s, Err(_) => break, } } } "localhost" } fn make_zhttp_request( instance: &str, ids: &[zhttppacket::Id], method: &str, path: &str, headers: &[httparse::Header], body: &[u8], more: bool, mode: Mode, credits: u32, peer_addr: Option, secure: bool, packet_buf: &mut [u8], ) -> Result { let mut data = zhttppacket::RequestData::new(); data.method = method; let host = get_host(headers); let mut zheaders = [zhttppacket::EMPTY_HEADER; http1::HEADERS_MAX]; let mut zheaders_len = 0; for h in headers.iter() { zheaders[zheaders_len] = zhttppacket::Header { name: h.name, value: h.value, }; zheaders_len += 1; } data.headers = &zheaders[..zheaders_len]; let scheme = match mode { Mode::HttpReq | Mode::HttpStream => { if secure { "https" } else { "http" } } Mode::WebSocket => { if secure { "wss" } else { "ws" } } }; let mut uri = [0; URI_SIZE_MAX]; let mut c = io::Cursor::new(&mut uri[..]); write!(&mut c, "{}://{}{}", scheme, host, path)?; let size = c.position() as usize; data.uri = match str::from_utf8(&uri[..size]) { Ok(s) => s, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidData)), }; data.body = body; data.more = more; if mode == Mode::HttpStream { data.stream = true; } data.credits = credits; let mut addr = [0; 128]; if let Some(peer_addr) = &peer_addr { let mut c = io::Cursor::new(&mut addr[..]); write!(&mut c, "{}", peer_addr.ip()).unwrap(); let size = c.position() as usize; data.peer_address = str::from_utf8(&addr[..size]).unwrap(); data.peer_port = peer_addr.port(); } let mut zreq = zhttppacket::Request::new_data(instance.as_bytes(), &ids, data); zreq.multi = true; let size = zreq.serialize(packet_buf)?; Ok(zmq::Message::from(&packet_buf[..size])) } pub struct Want { pub sock_read: bool, pub sock_write: bool, pub zhttp_read: bool, pub zhttp_write: bool, pub zhttp_write_to: bool, pub timeout: Option, } impl Want { pub fn nothing() -> Self { Self { sock_read: false, sock_write: false, zhttp_read: false, zhttp_write: false, zhttp_write_to: false, timeout: None, } } pub fn merge(&self, other: &Want) -> Want { let timeout = if self.timeout.is_some() && other.timeout.is_some() { let a = self.timeout.unwrap(); let b = other.timeout.unwrap(); Some(cmp::min(a, b)) } else if self.timeout.is_some() && other.timeout.is_none() { self.timeout } else if self.timeout.is_none() && other.timeout.is_some() { other.timeout } else { // both none None }; Want { sock_read: self.sock_read || other.sock_read, sock_write: self.sock_write || other.sock_write, zhttp_read: self.zhttp_read || other.zhttp_read, zhttp_write: self.zhttp_write || other.zhttp_write, zhttp_write_to: self.zhttp_write_to || other.zhttp_write_to, timeout, } } } #[derive(Debug, PartialEq, Clone, Copy)] pub enum ServerState { // call: start Ready, // call: process, apply_zhttp_response // next: Connected, Ready, Finished Connected, // connection should be closed Finished, } #[derive(Debug)] pub enum ServerError { Io(io::Error), Utf8(str::Utf8Error), Http(http1::ServerError), WebSocket(websocket::Error), InvalidWebSocketRequest, BadMessage, BufferExceeded, BadFrame, } impl From for ServerError { fn from(e: io::Error) -> Self { Self::Io(e) } } impl From for ServerError { fn from(e: str::Utf8Error) -> Self { Self::Utf8(e) } } impl From for ServerError { fn from(e: http1::ServerError) -> Self { Self::Http(e) } } impl From for ServerError { fn from(e: websocket::Error) -> Self { Self::WebSocket(e) } } // our own range-like struct that supports copying #[derive(Clone, Copy)] struct Range { start: usize, end: usize, } fn slice_to_range>(base: &[u8], s: T) -> Range { let sref = s.as_ref(); let start = (sref.as_ptr() as usize) - (base.as_ptr() as usize); Range { start, end: start + sref.len(), } } fn range_to_slice(base: &[u8], range: Range) -> &[u8] { &base[range.start..range.end] } unsafe fn range_to_str_unchecked(base: &[u8], range: Range) -> &str { str::from_utf8_unchecked(&base[range.start..range.end]) } #[derive(Clone, Copy)] struct HeaderRanges { name: Range, value: Range, } const EMPTY_HEADER_RANGES: HeaderRanges = HeaderRanges { name: Range { start: 0, end: 0 }, value: Range { start: 0, end: 0 }, }; #[derive(Clone, Copy)] struct RequestHeaderRanges { method: Range, uri: Range, headers: [HeaderRanges; http1::HEADERS_MAX], headers_count: usize, } #[derive(Clone, Copy)] struct MessageItem { mtype: u8, avail: usize, } pub struct MessageTracker { items: VecDeque, last_partial: bool, } impl MessageTracker { pub fn new(max_messages: usize) -> Self { Self { items: VecDeque::with_capacity(max_messages), last_partial: false, } } pub fn clear(&mut self) { self.items.clear(); self.last_partial = false; } pub fn in_progress(&self) -> bool { self.last_partial } pub fn start(&mut self, mtype: u8) -> Result<(), ()> { if self.last_partial || self.items.len() == self.items.capacity() { return Err(()); } self.items.push_back(MessageItem { mtype, avail: 0 }); self.last_partial = true; Ok(()) } pub fn extend(&mut self, amt: usize) { assert_eq!(self.last_partial, true); self.items.back_mut().unwrap().avail += amt; } pub fn done(&mut self) { self.last_partial = false; } // type, avail, done pub fn current(&self) -> Option<(u8, usize, bool)> { if self.items.len() > 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, true)) } else if self.items.len() == 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, !self.last_partial)) } else { None } } pub fn consumed(&mut self, amt: usize, done: bool) { assert!(amt <= self.items[0].avail); self.items[0].avail -= amt; if done { assert_eq!(self.items[0].avail, 0); self.items.pop_front().unwrap(); } } } struct ServerProcessArgs<'a, S, Z> where S: Read + Write + Shutdown, Z: ZhttpSender, { now: Instant, instance: &'a str, sock: &'a mut S, zsender: &'a mut Z, packet_buf: &'a mut [u8], tmp_buf: &'a mut [u8], } #[derive(Debug, PartialEq)] enum ServerReqState { Ready, Active, ShuttingDown, Finishing, Finished, } pub struct ServerReqConnection { id: ArrayString<[u8; 32]>, peer_addr: Option, secure: bool, timeout: Duration, state: ServerReqState, protocol: http1::ServerProtocol, exp_time: Option, req: Option, buf1: RingBuffer, buf2: RingBuffer, body_buf: Buffer, cont: [u8; 32], cont_len: usize, cont_left: usize, pending_msg: Option, sock_readable: bool, } impl ServerReqConnection { pub fn new( now: Instant, peer_addr: Option, secure: bool, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, timeout: Duration, ) -> Self { let buf1 = RingBuffer::new(buffer_size, rb_tmp); let buf2 = RingBuffer::new(buffer_size, rb_tmp); let body_buf = Buffer::new(body_buffer_size); Self { id: ArrayString::new(), peer_addr, secure, timeout, state: ServerReqState::Ready, protocol: http1::ServerProtocol::new(), exp_time: Some(now + timeout), req: None, buf1, buf2, body_buf, cont: [0; 32], cont_len: 0, cont_left: 0, pending_msg: None, sock_readable: true, } } fn reset(&mut self, now: Instant) { // note: buf1 is not cleared as there may be data to read self.state = ServerReqState::Ready; self.protocol = http1::ServerProtocol::new(); self.exp_time = Some(now + self.timeout); self.req = None; self.buf2.clear(); self.body_buf.clear(); self.pending_msg = None; self.sock_readable = true; } pub fn state(&self) -> ServerState { match self.state { ServerReqState::Ready => ServerState::Ready, ServerReqState::Finished => ServerState::Finished, _ => ServerState::Connected, } } pub fn start(&mut self, id: &str) { self.id = ArrayString::from_str(id).unwrap(); self.state = ServerReqState::Active; } pub fn set_sock_readable(&mut self) { self.sock_readable = true; } pub fn process( &mut self, now: Instant, sock: &mut S, zsender: &mut Z, packet_buf: &mut [u8], ) -> Result where S: Read + Write + Shutdown, Z: ZhttpSender, { loop { let args = ServerProcessArgs { now, instance: "", sock, zsender, packet_buf, tmp_buf: &mut [0; 0], }; if let Some(r) = self.process_step(args) { if let Err(e) = &r { match self.state { ServerReqState::Finishing | ServerReqState::Finished => {} _ => { debug!("conn {}: error: {:?}", self.id, e); self.state = ServerReqState::Finishing; continue; } } } return r; } } } fn try_recv(&mut self, sock: &mut dyn io::Read) -> Result { if self.buf1.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } if !self.sock_readable { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = match self.buf1.write_from(sock) { Ok(size) => size, Err(e) => { if e.kind() == io::ErrorKind::WouldBlock { self.sock_readable = false; } return Err(e); } }; let closed = size == 0; if closed { self.state = ServerReqState::ShuttingDown; } Ok(closed) } fn after_request(&mut self, args: ServerProcessArgs<'_, S, Z>) -> Result<(), ServerError> where S: Read + Write + Shutdown, Z: ZhttpSender, { let proto = &mut self.protocol; let hbuf = self.buf2.read_buf(); let ranges = self.req.unwrap(); let method = unsafe { range_to_str_unchecked(hbuf, ranges.method) }; let path = unsafe { range_to_str_unchecked(hbuf, ranges.uri) }; let mut headers = [httparse::EMPTY_HEADER; http1::HEADERS_MAX]; for (i, h) in ranges.headers.iter().enumerate() { headers[i].name = unsafe { range_to_str_unchecked(hbuf, h.name) }; headers[i].value = range_to_slice(hbuf, h.value); } let headers = &headers[..ranges.headers_count]; let mut websocket = false; for h in headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") && h.value == b"websocket" { websocket = true; break; } } if websocket { // header consumed self.buf2.clear(); // body sent self.body_buf.clear(); let mut hbuf = io::Cursor::new(self.buf2.write_buf()); let headers = &[http1::Header { name: "Content-Type", value: b"text/plain", }]; let body = "WebSockets not supported on req mode interface.\n"; if let Err(e) = proto.send_response( &mut hbuf, 400, "Bad Request", headers, http1::BodySize::Known(body.len()), ) { return Err(e.into()); } let size = hbuf.position() as usize; self.buf2.write_commit(size); if let Err(e) = self.body_buf.write_all(body.as_bytes()) { return Err(ServerError::Io(e)); } return Ok(()); } let ids = [zhttppacket::Id { id: self.id.as_bytes(), seq: None, }]; let msg = match make_zhttp_request( "", &ids, method, path, headers, self.body_buf.read_buf(), false, Mode::HttpReq, 0, self.peer_addr, self.secure, args.packet_buf, ) { Ok(msg) => msg, Err(e) => return Err(e.into()), }; // header and body consumed self.buf2.clear(); self.body_buf.clear(); self.pending_msg = Some(msg); Ok(()) } fn process_step( &mut self, args: ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { // check expiration if not already shutting down match self.state { ServerReqState::Finishing | ServerReqState::Finished => {} _ => { if self.exp_time.is_some() && args.now >= self.exp_time.unwrap() { self.state = ServerReqState::Finishing; } } } match self.state { ServerReqState::Active => { return self.process_http(args); } ServerReqState::ShuttingDown => { match args.sock.shutdown() { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::WouldBlock => { let mut want = Want::nothing(); want.sock_read = true; want.sock_write = true; want.timeout = self.exp_time; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), } self.state = ServerReqState::Finishing; return None; } ServerReqState::Finishing => { self.state = ServerReqState::Finished; return None; } ServerReqState::Ready | ServerReqState::Finished => { return Some(Ok(Want::nothing())); } } } fn process_http( &mut self, args: ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let mut want = Want::nothing(); want.sock_read = true; want.timeout = self.exp_time; // always read if possible, to detect disconnects match self.try_recv(args.sock) { Ok(closed) => { if closed { return None; } } Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} // ok Err(e) if e.kind() == io::ErrorKind::WriteZero => want.sock_read = false, Err(e) => return Some(Err(e.into())), } let proto = &mut self.protocol; match proto.state() { http1::ServerState::ReceivingRequest => { self.buf1.align(); let mut hbuf = io::Cursor::new(self.buf1.read_buf()); let mut headers = [httparse::EMPTY_HEADER; http1::HEADERS_MAX]; let req = match proto.recv_request(&mut hbuf, &mut headers) { Some(Ok(req)) => req, Some(Err(e)) => return Some(Err(e.into())), None => match self.try_recv(args.sock) { Ok(_) => return None, Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Some(Err(ServerError::BufferExceeded)); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_read = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }, }; let hsize = hbuf.position() as usize; let host = get_host(req.headers); let scheme = if self.secure { "https" } else { "http" }; debug!( "conn {}: request: {} {}://{}{}", self.id, req.method, scheme, host, req.uri ); let hbuf = self.buf1.read_buf(); let mut ranges = RequestHeaderRanges { method: slice_to_range(hbuf, req.method), uri: slice_to_range(hbuf, req.uri), headers: [EMPTY_HEADER_RANGES; http1::HEADERS_MAX], headers_count: req.headers.len(), }; for (i, h) in req.headers.iter().enumerate() { ranges.headers[i].name = slice_to_range(hbuf, h.name); ranges.headers[i].value = slice_to_range(hbuf, h.value); } self.req = Some(ranges); // move header data to buf2 if let Err(e) = self.buf2.write_all(&hbuf[..hsize]) { return Some(Err(e.into())); } if req.expect_100 { let mut cont = io::Cursor::new(&mut self.cont[..]); if let Err(e) = proto.send_100_continue(&mut cont) { return Some(Err(e.into())); } self.cont_len = cont.position() as usize; self.cont_left = self.cont_len; } self.buf1.read_commit(hsize); if proto.state() == http1::ServerState::AwaitingResponse { if let Err(e) = self.after_request(args) { return Some(Err(e)); } } } http1::ServerState::ReceivingBody => { if self.cont_left > 0 { let pos = self.cont_len - self.cont_left; let size = match args.sock.write(&self.cont[pos..self.cont_len]) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.cont_left -= size; return None; } self.buf1.align(); let mut buf = io::Cursor::new(self.buf1.read_buf()); let mut headers = [httparse::EMPTY_HEADER; http1::HEADERS_MAX]; let (size, _) = match proto.recv_body(&mut buf, self.body_buf.write_buf(), &mut headers) { Ok((size, headers)) => (size, headers), Err(e) => return Some(Err(e.into())), }; let read_size = buf.position() as usize; if proto.state() == http1::ServerState::ReceivingBody && read_size == 0 { match self.try_recv(args.sock) { Ok(_) => return None, Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Some(Err(ServerError::BufferExceeded)); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_read = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), } } self.buf1.read_commit(read_size); self.body_buf.write_commit(size); if proto.state() == http1::ServerState::ReceivingBody && self.body_buf.write_avail() == 0 { return Some(Err(ServerError::BufferExceeded)); } if proto.state() == http1::ServerState::AwaitingResponse { if let Err(e) = self.after_request(args) { return Some(Err(e)); } } } http1::ServerState::AwaitingResponse => { if let Some(msg) = self.pending_msg.take() { match args.zsender.send(msg) { Ok(()) => {} Err(zhttpsocket::SendError::Full(msg)) => { self.pending_msg = Some(msg); want.zhttp_write = true; return Some(Ok(want)); } Err(zhttpsocket::SendError::Io(e)) => return Some(Err(e.into())), } } else { want.zhttp_read = true; return Some(Ok(want)); } } http1::ServerState::SendingBody => { if self.buf2.read_avail() > 0 { let size = match args.sock.write(self.buf2.read_buf()) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.buf2.read_commit(size); return None; } let size = match proto.send_body(args.sock, &[self.body_buf.read_buf()], true, None) { Ok(size) => size, Err(http1::ServerError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.body_buf.read_commit(size); } http1::ServerState::Finished => { debug!("conn {}: finished", self.id); if proto.is_persistent() { self.reset(args.now); } else { self.state = ServerReqState::ShuttingDown; } } } None } pub fn apply_zhttp_response( &mut self, zresp: &zhttppacket::Response, ) -> Result<(), ServerError> { let proto = &mut self.protocol; if proto.state() != http1::ServerState::AwaitingResponse || self.pending_msg.is_some() { // not expecting anything return Ok(()); } match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => { let mut hbuf = io::Cursor::new(self.buf2.write_buf()); let mut headers = [http1::EMPTY_HEADER; http1::HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } if let Err(e) = proto.send_response( &mut hbuf, rdata.code, rdata.reason, &headers[..headers_len], http1::BodySize::Known(rdata.body.len()), ) { self.state = ServerReqState::Finishing; return Err(e.into()); } let size = hbuf.position() as usize; self.buf2.write_commit(size); if let Err(e) = self.body_buf.write_all(&rdata.body) { self.state = ServerReqState::Finishing; return Err(ServerError::Io(e)); } } _ => debug!( "conn {}: unexpected packet in req mode: {}", self.id, zresp.ptype_str ), } Ok(()) } } enum ServerProtocol { Http(http1::ServerProtocol), WebSocket(websocket::Protocol), } #[derive(Debug, PartialEq)] enum ServerStreamState { Ready, Active, Paused, ShuttingDown, Finishing, Finished, } struct ServerStreamSharedDataInner { to_addr: Option>, out_seq: u32, } pub struct AddrRef<'a> { s: Ref<'a, ServerStreamSharedDataInner>, } impl<'a> AddrRef<'a> { pub fn get(&self) -> Option<&[u8]> { match &self.s.to_addr { Some(addr) => Some(addr.as_ref()), None => None, } } } pub struct ServerStreamSharedData { inner: RefCell, } impl ServerStreamSharedData { pub fn new() -> Self { Self { inner: RefCell::new(ServerStreamSharedDataInner { to_addr: None, out_seq: 0, }), } } fn reset(&self) { let s = &mut *self.inner.borrow_mut(); s.to_addr = None; s.out_seq = 0; } fn set_to_addr(&self, addr: Option>) { let s = &mut *self.inner.borrow_mut(); s.to_addr = addr; } pub fn to_addr(&self) -> AddrRef { AddrRef { s: self.inner.borrow(), } } pub fn out_seq(&self) -> u32 { self.inner.borrow().out_seq } pub fn inc_out_seq(&self) { let s = &mut *self.inner.borrow_mut(); s.out_seq += 1; } } struct ServerStreamData { id: ArrayString<[u8; 32]>, peer_addr: Option, secure: bool, client_timeout: Duration, state: ServerStreamState, client_exp_time: Option, zhttp_exp_time: Option, expect_100: bool, cont: [u8; 32], cont_len: usize, cont_left: usize, websocket: bool, ws_accept: Option>, // base64_encode(sha1_hash) = 28 bytes in_seq: u32, in_credits: u32, out_credits: u32, resp_header_left: usize, resp_body_done: bool, ws_in_tracker: MessageTracker, in_overflow_allow: usize, sock_readable: bool, pending_msg: Option, handoff_requested: bool, } pub struct ServerStreamConnection { d: ServerStreamData, shared: arena::Rc, protocol: ServerProtocol, buf1: RingBuffer, buf2: RingBuffer, in_overflow: Option, } impl ServerStreamConnection { pub fn new( now: Instant, peer_addr: Option, secure: bool, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, timeout: Duration, shared: arena::Rc, ) -> Self { let buf1 = RingBuffer::new(buffer_size, &rb_tmp); let buf2 = RingBuffer::new(buffer_size, &rb_tmp); let ws_in_tracker = MessageTracker::new(messages_max); let mut s = Self { d: ServerStreamData { id: ArrayString::new(), peer_addr, secure, client_timeout: timeout, state: ServerStreamState::Ready, client_exp_time: None, zhttp_exp_time: None, expect_100: false, cont: [0; 32], cont_len: 0, cont_left: 0, websocket: false, ws_accept: None, in_seq: 0, in_credits: 0, out_credits: 0, resp_header_left: 0, resp_body_done: false, ws_in_tracker, in_overflow_allow: 0, sock_readable: true, pending_msg: None, handoff_requested: false, }, shared, protocol: ServerProtocol::Http(http1::ServerProtocol::new()), buf1, buf2, in_overflow: None, }; Self::refresh_client_timeout(&mut s.d, now); s } fn reset(&mut self, now: Instant) { // note: buf1 is not cleared as there may be data to read self.d.state = ServerStreamState::Ready; self.d.zhttp_exp_time = None; self.d.websocket = false; self.d.ws_accept = None; self.d.in_seq = 0; self.d.in_credits = 0; self.d.out_credits = 0; self.d.resp_header_left = 0; self.d.resp_body_done = false; self.d.ws_in_tracker.clear(); self.d.in_overflow_allow = 0; self.d.sock_readable = true; self.d.pending_msg = None; self.d.handoff_requested = false; self.shared.get().reset(); Self::refresh_client_timeout(&mut self.d, now); self.protocol = ServerProtocol::Http(http1::ServerProtocol::new()); self.buf2.clear(); } pub fn state(&self) -> ServerState { match self.d.state { ServerStreamState::Ready => ServerState::Ready, ServerStreamState::Finished => ServerState::Finished, _ => ServerState::Connected, } } pub fn start(&mut self, id: &str) { self.d.id = ArrayString::from_str(id).unwrap(); self.d.state = ServerStreamState::Active; } pub fn set_sock_readable(&mut self) { self.d.sock_readable = true; } pub fn process( &mut self, now: Instant, instance: &str, sock: &mut S, zsender: &mut Z, packet_buf: &mut [u8], tmp_buf: &mut [u8], ) -> Result where S: Read + Write + Shutdown, Z: ZhttpSender, { loop { let args = ServerProcessArgs { now, instance, sock, zsender, packet_buf, tmp_buf, }; if let Some(r) = self.process_step(args) { if let Err(e) = &r { match self.d.state { ServerStreamState::Finishing | ServerStreamState::Finished => {} _ => { debug!("conn {}: error: {:?}", self.d.id, e); self.d.state = ServerStreamState::Finishing; continue; } } } return r; } } } fn try_recv(&mut self, sock: &mut dyn io::Read) -> Result { if self.buf1.write_avail() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } if !self.d.sock_readable { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = match self.buf1.write_from(sock) { Ok(size) => size, Err(e) => { if e.kind() == io::ErrorKind::WouldBlock { self.d.sock_readable = false; } return Err(e); } }; let closed = size == 0; if closed { self.d.state = ServerStreamState::ShuttingDown; } Ok(closed) } fn zsend( d: &mut ServerStreamData, shared: &ServerStreamSharedData, args: &mut ServerProcessArgs<'_, S, Z>, zreq: zhttppacket::Request, ) -> Result<(), io::Error> where S: Read + Write + Shutdown, Z: ZhttpSender, { if !args.zsender.can_send_to() { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let msg = { let mut zreq = zreq; let ids = [zhttppacket::Id { id: d.id.as_bytes(), seq: Some(shared.out_seq()), }]; zreq.from = args.instance.as_bytes(); zreq.ids = &ids; zreq.multi = true; let size = zreq.serialize(args.packet_buf)?; zmq::Message::from(&args.packet_buf[..size]) }; match args.zsender.send_to(shared.to_addr().get().unwrap(), msg) { Ok(()) => {} Err(zhttpsocket::SendError::Full(_)) => { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Err(zhttpsocket::SendError::Io(e)) => return Err(e), } shared.inc_out_seq(); Ok(()) } fn refresh_client_timeout(d: &mut ServerStreamData, now: Instant) { d.client_exp_time = Some(now + d.client_timeout); } fn refresh_zhttp_timeout(d: &mut ServerStreamData, now: Instant) { d.zhttp_exp_time = Some(now + ZHTTP_SESSION_TIMEOUT); } fn timeout(d: &ServerStreamData) -> Option { if d.client_exp_time.is_some() && d.zhttp_exp_time.is_some() { Some(cmp::min( d.client_exp_time.unwrap(), d.zhttp_exp_time.unwrap(), )) } else if d.client_exp_time.is_some() && d.zhttp_exp_time.is_none() { d.client_exp_time } else if d.client_exp_time.is_none() && d.zhttp_exp_time.is_some() { d.zhttp_exp_time } else { None } } fn send_resp_header( &mut self, args: ServerProcessArgs<'_, S, Z>, mut want: Want, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let size = match args .sock .write(&self.buf2.read_buf()[..self.d.resp_header_left]) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.buf2.read_commit(size); self.d.in_overflow_allow -= size; if let Some(overflow) = &mut self.in_overflow { let osize = match self.buf2.write(overflow.read_buf()) { Ok(size) => size, Err(e) => return Some(Err(e.into())), }; overflow.read_commit(osize); } self.d.resp_header_left -= size; if self.d.resp_header_left == 0 { if let Some(overflow) = &self.in_overflow { assert_eq!(overflow.read_avail(), 0); self.in_overflow = None; } Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); } None } fn accept_body(&mut self, body: &[u8]) -> Result<(), ServerError> { if self.d.resp_header_left > 0 { let have_overflow = if let Some(overflow) = &mut self.in_overflow { overflow.read_avail() > 0 } else { false }; let accepted = if !have_overflow { self.buf2.write(body)? } else { 0 }; if accepted < body.len() { debug!( "conn {}: overflowing {} bytes", self.d.id, body.len() - accepted ); self.in_overflow = Some(Buffer::new(self.d.in_overflow_allow)); let overflow = self.in_overflow.as_mut().unwrap(); overflow.write_all(&body[accepted..])?; } } else { self.buf2.write_all(body)?; } Ok(()) } fn process_step( &mut self, mut args: ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { // check expiration if not already shutting down match self.d.state { ServerStreamState::Finishing | ServerStreamState::Finished => {} _ => { let exp_time = Self::timeout(&self.d); if exp_time.is_some() && args.now >= exp_time.unwrap() { debug!("conn {}: timed out", self.d.id); self.d.state = ServerStreamState::Finishing; } } } match self.d.state { ServerStreamState::Active => { if self.d.out_credits > 0 { if !args.zsender.can_send_to() { let mut want = Want::nothing(); want.zhttp_write_to = true; return Some(Ok(want)); } let zreq = zhttppacket::Request::new_credit(b"", &[], self.d.out_credits); if let Err(e) = Self::zsend(&mut self.d, self.shared.get(), &mut args, zreq) { return Some(Err(e.into())); } self.d.out_credits = 0; } if self.d.handoff_requested && self.buf2.read_avail() == 0 { if !args.zsender.can_send_to() { let mut want = Want::nothing(); want.zhttp_write_to = true; return Some(Ok(want)); } let zreq = zhttppacket::Request::new_handoff_proceed(b"", &[]); if let Err(e) = Self::zsend(&mut self.d, self.shared.get(), &mut args, zreq) { return Some(Err(e.into())); } self.d.state = ServerStreamState::Paused; self.shared.get().set_to_addr(None); self.d.handoff_requested = false; return None; } match &self.protocol { ServerProtocol::Http(_) => { return self.process_http(args); } ServerProtocol::WebSocket(_) => { return self.process_websocket(args); } } } ServerStreamState::Paused => { let mut want = Want::nothing(); want.zhttp_read = true; want.timeout = Self::timeout(&self.d); return Some(Ok(want)); } ServerStreamState::ShuttingDown => { match args.sock.shutdown() { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::WouldBlock => { let mut want = Want::nothing(); want.sock_read = true; want.sock_write = true; want.timeout = Self::timeout(&self.d); return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), } self.d.state = ServerStreamState::Finishing; return None; } ServerStreamState::Finishing => { if self.shared.get().to_addr().get().is_some() { if !args.zsender.can_send_to() { let mut want = Want::nothing(); want.zhttp_write_to = true; return Some(Ok(want)); } let zreq = zhttppacket::Request::new_cancel(b"", &[]); if let Err(e) = Self::zsend(&mut self.d, self.shared.get(), &mut args, zreq) { return Some(Err(e.into())); } } self.d.state = ServerStreamState::Finished; return None; } ServerStreamState::Ready | ServerStreamState::Finished => { return Some(Ok(Want::nothing())); } } } fn process_http( &mut self, mut args: ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let mut want = Want::nothing(); want.sock_read = true; want.zhttp_read = true; want.timeout = Self::timeout(&self.d); // always read if possible, to detect disconnects match self.try_recv(args.sock) { Ok(closed) => { if closed { return None; } } Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} // ok Err(e) if e.kind() == io::ErrorKind::WriteZero => want.sock_read = false, Err(e) => return Some(Err(e.into())), } let proto = match &mut self.protocol { ServerProtocol::Http(proto) => proto, _ => unreachable!(), }; if let Some(msg) = self.d.pending_msg.take() { match args.zsender.send(msg) { Ok(()) => {} Err(zhttpsocket::SendError::Full(msg)) => { self.d.pending_msg = Some(msg); want.zhttp_write = true; return Some(Ok(want)); } Err(zhttpsocket::SendError::Io(e)) => return Some(Err(e.into())), } self.shared.get().inc_out_seq(); Self::refresh_zhttp_timeout(&mut self.d, args.now); if self.d.expect_100 { let mut cont = io::Cursor::new(&mut self.d.cont[..]); if let Err(e) = proto.send_100_continue(&mut cont) { return Some(Err(e.into())); } self.d.cont_len = cont.position() as usize; self.d.cont_left = self.d.cont_len; } return None; } match proto.state() { http1::ServerState::ReceivingRequest => { self.buf1.align(); let mut hbuf = io::Cursor::new(self.buf1.read_buf()); let mut headers = [httparse::EMPTY_HEADER; http1::HEADERS_MAX]; let req = match proto.recv_request(&mut hbuf, &mut headers) { Some(Ok(req)) => req, Some(Err(e)) => return Some(Err(e.into())), None => match self.try_recv(args.sock) { Ok(_) => return None, Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Some(Err(ServerError::BufferExceeded)); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_read = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }, }; let hsize = hbuf.position() as usize; Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); let mut ws_key = None; for h in req.headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") { if str::from_utf8(h.value).unwrap() == "websocket" { self.d.websocket = true; } } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = Some(h.value); } } let host = get_host(req.headers); let scheme = if self.d.websocket { if self.d.secure { "wss" } else { "ws" } } else { if self.d.secure { "https" } else { "http" } }; debug!( "conn {}: request: {} {}://{}{}", self.d.id, req.method, scheme, host, req.uri ); if self.d.websocket { if req.method != "GET" || req.body_size != http1::BodySize::NoBody || ws_key.is_none() { return Some(Err(ServerError::InvalidWebSocketRequest)); } let ws_key = ws_key.unwrap(); if self.d.ws_accept.is_none() { let mut input = [0; WS_HASH_INPUT_MAX]; let input_len = ws_key.len() + websocket::WS_GUID.len(); if input_len > WS_HASH_INPUT_MAX { return Some(Err(ServerError::InvalidWebSocketRequest)); } &mut input[..ws_key.len()].copy_from_slice(ws_key); &mut input[ws_key.len()..input_len] .copy_from_slice(&websocket::WS_GUID.as_bytes()); let input = &input[..input_len]; let digest = sha1::Sha1::from(input).digest(); let mut output = [0; 28]; let size = base64::encode_config_slice( &digest.bytes(), base64::STANDARD, &mut output, ); let output = str::from_utf8(&output[..size]).unwrap(); self.d.ws_accept = Some(ArrayString::from_str(output).unwrap()); } } let ids = [zhttppacket::Id { id: self.d.id.as_bytes(), seq: Some(self.shared.get().out_seq()), }]; let (mode, more) = if self.d.websocket { (Mode::WebSocket, false) } else { let more = match req.body_size { http1::BodySize::NoBody => false, http1::BodySize::Known(x) => x > 0, http1::BodySize::Unknown => true, }; (Mode::HttpStream, more) }; let msg = match make_zhttp_request( args.instance, &ids, req.method, &req.uri, &req.headers, b"", more, mode, self.buf2.capacity() as u32, self.d.peer_addr, self.d.secure, args.packet_buf, ) { Ok(msg) => msg, Err(e) => return Some(Err(e.into())), }; self.d.expect_100 = req.expect_100; self.buf1.read_commit(hsize); self.d.pending_msg = Some(msg); } http1::ServerState::ReceivingBody => { if self.d.cont_left > 0 { let pos = self.d.cont_len - self.d.cont_left; let size = match args.sock.write(&self.d.cont[pos..self.d.cont_len]) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.d.cont_left -= size; return None; } if self.shared.get().to_addr().get().is_none() || self.d.in_credits == 0 { return Some(Ok(want)); } if !args.zsender.can_send_to() { want.zhttp_write_to = true; return Some(Ok(want)); } self.buf1.align(); let mut buf = io::Cursor::new(self.buf1.read_buf()); let mut headers = [httparse::EMPTY_HEADER; http1::HEADERS_MAX]; // pull tmp_buf out of args so we can borrow it and still pass args along let tmp_buf = args.tmp_buf; args.tmp_buf = &mut [0; 0]; let max_read = cmp::min(self.d.in_credits as usize, tmp_buf.len()); let (size, _) = match proto.recv_body(&mut buf, &mut tmp_buf[..max_read], &mut headers) { Ok((size, headers)) => (size, headers), Err(e) => return Some(Err(e.into())), }; let read_size = buf.position() as usize; self.buf1.read_commit(read_size); if read_size > 0 { Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); } if proto.state() == http1::ServerState::ReceivingBody && read_size == 0 { match self.try_recv(args.sock) { Ok(_) => return None, Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Some(Err(ServerError::BufferExceeded)); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_read = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), } } let body = &tmp_buf[..size]; self.d.in_credits -= size as u32; let mut rdata = zhttppacket::RequestData::new(); rdata.body = body; if proto.state() == http1::ServerState::ReceivingBody { rdata.more = true; } let zreq = zhttppacket::Request::new_data(b"", &[], rdata); if let Err(e) = Self::zsend(&mut self.d, self.shared.get(), &mut args, zreq) { return Some(Err(e.into())); } } http1::ServerState::AwaitingResponse => { return Some(Ok(want)); } http1::ServerState::SendingBody => { if self.d.resp_header_left > 0 { return self.send_resp_header(args, want); } if self.buf2.read_avail() == 0 && !self.d.resp_body_done { return Some(Ok(want)); } let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = self.buf2.get_ref_vectored(&mut buf_arr); let size = match proto.send_body(args.sock, bufs, self.d.resp_body_done, None) { Ok(size) => size, Err(http1::ServerError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_write = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }; self.buf2.read_commit(size); if size > 0 { Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); } if size > 0 && !self.d.resp_body_done { self.d.out_credits += size as u32; } } http1::ServerState::Finished => { debug!("conn {}: finished", self.d.id); if proto.is_persistent() { self.reset(args.now); } else { // don't send cancel self.shared.get().set_to_addr(None); self.d.state = ServerStreamState::ShuttingDown; } } } None } fn process_websocket( &mut self, mut args: ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let mut want = Want::nothing(); want.sock_read = true; want.zhttp_read = true; want.timeout = Self::timeout(&self.d); // always read if possible, to detect disconnects match self.try_recv(args.sock) { Ok(closed) => { if closed { return None; } } Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} // ok Err(e) if e.kind() == io::ErrorKind::WriteZero => want.sock_read = false, Err(e) => return Some(Err(e.into())), } if self.d.resp_header_left > 0 { return self.send_resp_header(args, want); } let proto = match &mut self.protocol { ServerProtocol::WebSocket(proto) => proto, _ => unreachable!(), }; let tmp_buf = args.tmp_buf; args.tmp_buf = &mut [0; 0]; match proto.state() { websocket::State::Connected => { let r = self.recv_frames(&mut args, tmp_buf); if let Some(Err(e)) = r { return Some(Err(e)); } if self.d.state == ServerStreamState::Finished { return r; } let w = self.send_frames(&mut args); if let Some(Err(e)) = w { return Some(Err(e)); } if r.is_none() || w.is_none() { return None; } let r = r.unwrap().unwrap(); let w = w.unwrap().unwrap(); Some(Ok(r.merge(&w.merge(&want)))) } websocket::State::PeerClosed => { let w = self.send_frames(&mut args); if let Some(Err(e)) = w { return Some(Err(e)); } if w.is_none() { return None; } let w = w.unwrap().unwrap(); Some(Ok(w.merge(&want))) } websocket::State::Closing => { let r = self.recv_frames(&mut args, tmp_buf); if let Some(Err(e)) = r { return Some(Err(e)); } if r.is_none() { return None; } let r = r.unwrap().unwrap(); Some(Ok(r.merge(&want))) } websocket::State::Finished => { // don't send cancel self.shared.get().set_to_addr(None); self.d.state = ServerStreamState::ShuttingDown; None } } } fn recv_frames( &mut self, args: &mut ServerProcessArgs<'_, S, Z>, tmp_buf: &mut [u8], ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let proto = match &mut self.protocol { ServerProtocol::WebSocket(proto) => proto, _ => unreachable!(), }; let mut want = Want::nothing(); want.zhttp_read = true; want.timeout = Self::timeout(&self.d); if self.d.in_credits == 0 { return Some(Ok(want)); } if !args.zsender.can_send_to() { want.zhttp_write_to = true; return Some(Ok(want)); } let max_read = cmp::min(self.d.in_credits as usize, tmp_buf.len()); self.buf1.align(); match proto.recv_message_content(&mut self.buf1, &mut tmp_buf[..max_read]) { Some(Ok((opcode, size, end))) => { let body = &tmp_buf[..size]; let zreq = match opcode { websocket::OPCODE_TEXT | websocket::OPCODE_BINARY => { if body.is_empty() && !end { // process again instead of sending empty message return None; } let mut data = zhttppacket::RequestData::new(); data.body = body; data.content_type = if opcode == websocket::OPCODE_TEXT { Some(zhttppacket::ContentType::Text) } else { Some(zhttppacket::ContentType::Binary) }; data.more = !end; zhttppacket::Request::new_data(b"", &[], data) } websocket::OPCODE_CLOSE => { let status = if body.len() >= 2 { let mut arr = [0; 2]; &mut arr[..].copy_from_slice(&body[..2]); let code = u16::from_be_bytes(arr); let reason = match str::from_utf8(&body[2..]) { Ok(reason) => reason, Err(e) => return Some(Err(e.into())), }; Some((code, reason)) } else { None }; zhttppacket::Request::new_close(b"", &[], status) } websocket::OPCODE_PING => zhttppacket::Request::new_ping(b"", &[], body), websocket::OPCODE_PONG => zhttppacket::Request::new_pong(b"", &[], body), opcode => { debug!( "conn {}: unsupported websocket opcode: {}", self.d.id, opcode ); return Some(Err(ServerError::BadFrame)); } }; self.d.in_credits -= size as u32; if let Err(e) = Self::zsend(&mut self.d, self.shared.get(), args, zreq) { return Some(Err(e.into())); } Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); } Some(Err(e)) => return Some(Err(e.into())), None => match self.try_recv(args.sock) { Ok(_) => return None, Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Some(Err(ServerError::BufferExceeded)); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { want.sock_read = true; return Some(Ok(want)); } Err(e) => return Some(Err(e.into())), }, } None } fn send_frames( &mut self, args: &mut ServerProcessArgs<'_, S, Z>, ) -> Option> where S: Read + Write + Shutdown, Z: ZhttpSender, { let proto = match &mut self.protocol { ServerProtocol::WebSocket(proto) => proto, _ => unreachable!(), }; let mut want = Want::nothing(); want.zhttp_read = true; want.timeout = Self::timeout(&self.d); if let Some((mtype, avail, done)) = self.d.ws_in_tracker.current() { if !proto.is_sending_message() { proto.send_message_start(mtype, None); } if avail == 0 && !done { return Some(Ok(want)); } let mut buf_arr = [&b""[..]; VECTORED_MAX - 1]; let bufs = self.buf2.get_ref_vectored(&mut buf_arr).limit(avail); let (size, done) = match proto.send_message_content(args.sock, bufs, done) { Ok(r) => r, Err(e) => return Some(Err(e.into())), }; if size == 0 && !done { want.sock_write = true; return Some(Ok(want)); } self.buf2.read_commit(size); self.d.ws_in_tracker.consumed(size, done); Self::refresh_client_timeout(&mut self.d, args.now); want.timeout = Self::timeout(&self.d); if proto.state() == websocket::State::Connected || proto.state() == websocket::State::PeerClosed { self.d.out_credits += size as u32; } None } else { Some(Ok(want)) } } pub fn apply_zhttp_response( &mut self, now: Instant, zresp: &zhttppacket::Response, seq: Option, ) -> Result<(), ServerError> { if zresp.ids.len() == 0 { return Err(ServerError::BadMessage); } if let Some(seq) = seq { if seq != self.d.in_seq { debug!( "conn {}: bad seq (expected {}, got {}), skipping", self.d.id, self.d.in_seq, seq ); return Err(ServerError::BadMessage); } self.d.in_seq += 1; } match self.d.state { ServerStreamState::Ready | ServerStreamState::ShuttingDown | ServerStreamState::Finishing | ServerStreamState::Finished => { debug!( "conn {}: unexpected message while in state {:?}", self.d.id, self.d.state ); return Err(ServerError::BadMessage); } ServerStreamState::Active => {} ServerStreamState::Paused => self.d.state = ServerStreamState::Active, } if self.d.handoff_requested { debug!( "conn {}: unexpected message after handoff requested", self.d.id ); return Err(ServerError::BadMessage); } let mut addr = ArrayVec::new(); if addr.try_extend_from_slice(zresp.from).is_err() { return Err(ServerError::BadMessage); } self.shared.get().set_to_addr(Some(addr)); Self::refresh_zhttp_timeout(&mut self.d, now); match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => { match &mut self.protocol { ServerProtocol::Http(proto) => match proto.state() { http1::ServerState::AwaitingResponse | http1::ServerState::ReceivingBody => { let mut hbuf = io::Cursor::new(self.buf2.write_buf()); let mut headers = [http1::EMPTY_HEADER; http1::HEADERS_MAX]; let mut headers_len = 0; let mut body_size = http1::BodySize::Unknown; for h in rdata.headers.iter() { if self.d.websocket { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") { continue; } } else { if h.name.eq_ignore_ascii_case("Content-Length") { let s = match str::from_utf8(h.value) { Ok(s) => s, Err(e) => { self.d.state = ServerStreamState::Finishing; return Err(e.into()); } }; let clen: usize = match s.parse() { Ok(clen) => clen, Err(_) => { self.d.state = ServerStreamState::Finishing; return Err(io::Error::from( io::ErrorKind::InvalidInput, ) .into()); } }; body_size = http1::BodySize::Known(clen); } } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } if body_size == http1::BodySize::Unknown && !rdata.more { body_size = http1::BodySize::Known(rdata.body.len()); } if self.d.websocket { headers[headers_len] = http1::Header { name: "Upgrade", value: b"websocket", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Connection", value: b"Upgrade", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Sec-WebSocket-Accept", value: self.d.ws_accept.as_ref().unwrap().as_bytes(), }; headers_len += 1; } if let Err(e) = proto.send_response( &mut hbuf, rdata.code, rdata.reason, &headers[..headers_len], body_size, ) { self.d.state = ServerStreamState::Finishing; return Err(e.into()); } let pos = hbuf.position() as usize; self.buf2.write_commit(pos); self.d.resp_header_left = pos; self.d.in_overflow_allow = pos; if self.d.websocket && rdata.credits == 0 { // workaround for pushpin-proxy, which doesn't // send credits on websocket accept let credits = self.buf1.capacity() as u32; debug!("no credits in websocket accept, assuming {}", credits); self.d.in_credits = credits; } else { self.d.in_credits = rdata.credits; } if self.d.websocket { self.protocol = ServerProtocol::WebSocket(websocket::Protocol::new()); } else { if let Err(e) = self.accept_body(rdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.resp_body_done = !rdata.more; } } http1::ServerState::SendingBody => { if let Err(e) = self.accept_body(rdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.in_credits += rdata.credits; self.d.resp_body_done = !rdata.more; } _ => {} }, ServerProtocol::WebSocket(proto) => match proto.state() { websocket::State::Connected | websocket::State::PeerClosed => { if let Err(e) = self.accept_body(rdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.in_credits += rdata.credits; let opcode = match &rdata.content_type { Some(zhttppacket::ContentType::Binary) => websocket::OPCODE_BINARY, _ => websocket::OPCODE_TEXT, }; if !self.d.ws_in_tracker.in_progress() { if self.d.ws_in_tracker.start(opcode).is_err() { self.d.state = ServerStreamState::Finishing; return Err(ServerError::BufferExceeded); } } self.d.ws_in_tracker.extend(rdata.body.len()); if !rdata.more { self.d.ws_in_tracker.done(); } } _ => {} }, } } zhttppacket::ResponsePacket::Error(edata) => { debug!( "conn {}: zhttp error condition={}", self.d.id, edata.condition ); match &mut self.protocol { ServerProtocol::Http(proto) => match proto.state() { http1::ServerState::AwaitingResponse => { if self.d.websocket && edata.condition == "rejected" { let rdata = edata.rejected_info.as_ref().unwrap(); let mut hbuf = io::Cursor::new(self.buf2.write_buf()); let mut headers = [http1::EMPTY_HEADER; http1::HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { if self.d.websocket { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") { continue; } } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } if let Err(e) = proto.send_response( &mut hbuf, rdata.code, rdata.reason, &headers[..headers_len], http1::BodySize::Known(rdata.body.len()), ) { self.d.state = ServerStreamState::Finishing; return Err(e.into()); } let pos = hbuf.position() as usize; self.buf2.write_commit(pos); self.d.resp_header_left = pos; self.d.in_overflow_allow = pos; if let Err(e) = self.accept_body(rdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.resp_body_done = true; return Ok(()); } } _ => {} }, _ => {} } self.d.state = ServerStreamState::Finished; } zhttppacket::ResponsePacket::Credit(cdata) => { self.d.in_credits += cdata.credits; } zhttppacket::ResponsePacket::KeepAlive => {} zhttppacket::ResponsePacket::Cancel => { self.d.state = ServerStreamState::Finished; } zhttppacket::ResponsePacket::HandoffStart => { self.d.handoff_requested = true; } zhttppacket::ResponsePacket::Close(cdata) => match &mut self.protocol { ServerProtocol::WebSocket(proto) => match proto.state() { websocket::State::Connected | websocket::State::PeerClosed => { let (code, reason) = match cdata.status { Some(v) => v, None => (1000, ""), }; let arr: [u8; 2] = code.to_be_bytes(); if let Err(e) = self.accept_body(&arr) { self.d.state = ServerStreamState::Finishing; return Err(e); } if let Err(e) = self.accept_body(reason.as_bytes()) { self.d.state = ServerStreamState::Finishing; return Err(e); } if self.d.ws_in_tracker.start(websocket::OPCODE_CLOSE).is_err() { self.d.state = ServerStreamState::Finishing; return Err(ServerError::BadFrame); } self.d.ws_in_tracker.extend(arr.len() + reason.len()); self.d.ws_in_tracker.done(); } _ => {} }, _ => {} }, zhttppacket::ResponsePacket::Ping(pdata) => match &mut self.protocol { ServerProtocol::WebSocket(proto) => match proto.state() { websocket::State::Connected | websocket::State::PeerClosed => { if let Err(e) = self.accept_body(pdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.in_credits += pdata.credits; if self.d.ws_in_tracker.start(websocket::OPCODE_PING).is_err() { self.d.state = ServerStreamState::Finishing; return Err(ServerError::BadFrame); } self.d.ws_in_tracker.extend(pdata.body.len()); self.d.ws_in_tracker.done(); } _ => {} }, _ => {} }, zhttppacket::ResponsePacket::Pong(pdata) => match &mut self.protocol { ServerProtocol::WebSocket(proto) => match proto.state() { websocket::State::Connected | websocket::State::PeerClosed => { if let Err(e) = self.accept_body(pdata.body) { self.d.state = ServerStreamState::Finishing; return Err(e); } self.d.in_credits += pdata.credits; if self.d.ws_in_tracker.start(websocket::OPCODE_PONG).is_err() { self.d.state = ServerStreamState::Finishing; return Err(ServerError::BadFrame); } self.d.ws_in_tracker.extend(pdata.body.len()); self.d.ws_in_tracker.done(); } _ => {} }, _ => {} }, _ => debug!("conn {}: unsupported type: {}", self.d.id, zresp.ptype_str), } Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::buffer::TmpBuffer; use std::mem; use std::rc::Rc; struct FakeSock { inbuf: Vec, outbuf: Vec, out_allow: usize, } impl FakeSock { fn new() -> Self { Self { inbuf: Vec::new(), outbuf: Vec::new(), out_allow: 0, } } fn add_readable(&mut self, buf: &[u8]) { self.inbuf.extend_from_slice(buf); } fn take_writable(&mut self) -> Vec { self.outbuf.split_off(0) } fn allow_write(&mut self, size: usize) { self.out_allow += size; } } impl Read for FakeSock { fn read(&mut self, buf: &mut [u8]) -> Result { if self.inbuf.is_empty() { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.inbuf.len()); &mut buf[..size].copy_from_slice(&self.inbuf[..size]); let mut rest = self.inbuf.split_off(size); mem::swap(&mut self.inbuf, &mut rest); Ok(size) } } impl Write for FakeSock { fn write(&mut self, buf: &[u8]) -> Result { if buf.len() > 0 && self.out_allow == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf); self.out_allow -= size; Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { if self.out_allow == 0 { break; } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf.as_ref()); self.out_allow -= size; total += buf.len(); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } impl Shutdown for FakeSock { fn shutdown(&mut self) -> Result<(), io::Error> { Ok(()) } } struct FakeSender { msgs: Vec<(Option, zmq::Message)>, allow: usize, } impl FakeSender { fn new() -> Self { Self { msgs: Vec::new(), allow: 0, } } fn take(&mut self) -> (Option, zmq::Message) { self.msgs.remove(0) } fn allow(&mut self, size: usize) { self.allow += size; } } impl ZhttpSender for FakeSender { fn can_send_to(&self) -> bool { self.allow > 0 } fn send(&mut self, message: zmq::Message) -> Result<(), zhttpsocket::SendError> { if self.allow == 0 { return Err(zhttpsocket::SendError::Full(message)); } self.msgs.push((None, message)); self.allow -= 1; Ok(()) } fn send_to( &mut self, addr: &[u8], message: zmq::Message, ) -> Result<(), zhttpsocket::SendError> { if self.allow == 0 { return Err(zhttpsocket::SendError::Full(message)); } self.msgs .push((Some(String::from_utf8(addr.to_vec()).unwrap()), message)); self.allow -= 1; Ok(()) } } #[test] fn message_tracker() { let mut t = MessageTracker::new(2); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), None); t.start(websocket::OPCODE_TEXT).unwrap(); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, false))); t.extend(5); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 5, false))); t.consumed(2, false); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, false))); t.done(); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, true))); t.consumed(3, true); assert_eq!(t.current(), None); for _ in 0..t.items.capacity() { t.start(websocket::OPCODE_TEXT).unwrap(); t.done(); } let r = t.start(websocket::OPCODE_TEXT); assert!(r.is_err()); } #[test] fn server_req_without_body() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let mut packet_buf = vec![0; 2048]; let timeout = Duration::from_millis(5_000); let mut c = ServerReqConnection::new( Instant::now(), None, false, buffer_size, buffer_size, &rb_tmp, timeout, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (_, buf) = sender.take(); let buf = &buf[..]; let expected = concat!( "T148:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,23:http://example.com/path,7:headers,52:22:4:Host,11:exa", "mple.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: None, }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(&zresp).unwrap(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); sock.allow_write(1024); c.process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Finished); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_with_body() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let mut packet_buf = vec![0; 2048]; let timeout = Duration::from_millis(5_000); let mut c = ServerReqConnection::new( Instant::now(), None, false, buffer_size, buffer_size, &rb_tmp, timeout, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "Connection: close\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (_, buf) = sender.take(); let buf = &buf[..]; let expected = concat!( "T191:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,4:POST,3:u", "ri,23:http://example.com/path,7:headers,78:22:4:Host,11:ex", "ample.com,]22:14:Content-Length,1:6,]22:10:Connection,5:cl", "ose,]]4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: None, }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(&zresp).unwrap(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); sock.allow_write(1024); c.process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Finished); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_timeout() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let now = Instant::now(); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let mut packet_buf = vec![0; 2048]; let timeout = Duration::from_millis(5_000); let mut c = ServerReqConnection::new(now, None, false, buffer_size, buffer_size, &rb_tmp, timeout); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process(now, &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); assert_eq!(want.timeout, Some(now + timeout)); c.process(now + timeout, &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Finished); } #[test] fn server_req_pipeline() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let mut packet_buf = vec![0; 2048]; let timeout = Duration::from_millis(5_000); let mut c = ServerReqConnection::new( Instant::now(), None, false, buffer_size, buffer_size, &rb_tmp, timeout, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "GET /path1 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", "GET /path2 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (_, buf) = sender.take(); let buf = &buf[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path1,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: None, }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(&zresp).unwrap(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); sock.allow_write(1024); c.process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Ready); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (_, buf) = sender.take(); let buf = &buf[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path2,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: None, }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(&zresp).unwrap(); c.process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Ready); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_secure() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let mut packet_buf = vec![0; 2048]; let timeout = Duration::from_millis(5_000); let mut c = ServerReqConnection::new( Instant::now(), None, true, buffer_size, buffer_size, &rb_tmp, timeout, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (_, buf) = sender.take(); let buf = &buf[..]; let expected = concat!( "T149:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:https://example.com/path,7:headers,52:22:4:Host,11:ex", "ample.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: None, }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(&zresp).unwrap(); let want = c .process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); sock.allow_write(1024); c.process(Instant::now(), &mut sock, &mut sender, &mut packet_buf) .unwrap(); assert_eq!(c.state(), ServerState::Finished); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_without_body() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let messages_max = 10; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut packet_buf = vec![0; buffer_size * 2]; let mut tmp_buf = vec![0; buffer_size]; let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(ServerStreamSharedData::new(), &shared_mem).unwrap(); let mut c = ServerStreamConnection::new( Instant::now(), None, false, buffer_size, messages_max, &rb_tmp, timeout, shared, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (addr, buf) = sender.take(); let buf = &buf[..]; assert_eq!(addr, None); let expected = concat!( "T179:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: Some(0), }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); let data = sock.take_writable(); assert!(data.is_empty()); sock.allow_write(1024); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Ready); assert_eq!(want.sock_write, false); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_with_body() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let messages_max = 10; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut packet_buf = vec![0; buffer_size * 2]; let mut tmp_buf = vec![0; buffer_size]; let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(ServerStreamSharedData::new(), &shared_mem).unwrap(); let mut c = ServerStreamConnection::new( Instant::now(), None, false, buffer_size, messages_max, &rb_tmp, timeout, shared, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(2); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (addr, buf) = sender.take(); let buf = &buf[..]; assert_eq!(addr, None); let expected = concat!( "T220:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: Some(0), }]; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Credit(zhttppacket::CreditData { credits: buffer_size as u32, }), ptype_str: "credit", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (addr, buf) = sender.take(); let buf = &buf[..]; assert_eq!(addr, Some(String::from("handler"))); let expected = concat!( "T74:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: Some(1), }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); let data = sock.take_writable(); assert!(data.is_empty()); sock.allow_write(1024); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Ready); assert_eq!(want.sock_write, false); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_chunked() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let messages_max = 10; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut packet_buf = vec![0; buffer_size * 2]; let mut tmp_buf = vec![0; buffer_size]; let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(ServerStreamSharedData::new(), &shared_mem).unwrap(); let mut c = ServerStreamConnection::new( Instant::now(), None, false, buffer_size, messages_max, &rb_tmp, timeout, shared, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(1); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (addr, buf) = sender.take(); let buf = &buf[..]; assert_eq!(addr, None); let expected = concat!( "T179:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: Some(0), }]; let rdata = zhttppacket::ResponseData { credits: 0, more: true, code: 200, reason: "OK", headers: &[zhttppacket::Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"", }; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let ids = [zhttppacket::Id { id: b"1", seq: Some(1), }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 200, reason: "OK", headers: &[], content_type: None, body: b"hello\n", }; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); let data = sock.take_writable(); assert!(data.is_empty()); sock.allow_write(1024); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Ready); assert_eq!(want.sock_write, false); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\n", "hello\n", "\r\n", "0\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_early_response() { let mut sock = FakeSock::new(); let mut sender = FakeSender::new(); let buffer_size = 1024; let messages_max = 10; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut packet_buf = vec![0; buffer_size * 2]; let mut tmp_buf = vec![0; buffer_size]; let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(ServerStreamSharedData::new(), &shared_mem).unwrap(); let mut c = ServerStreamConnection::new( Instant::now(), None, false, buffer_size, messages_max, &rb_tmp, timeout, shared, ); c.start("1"); assert_eq!(c.state(), ServerState::Connected); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_read, true); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.add_readable(req_data); c.set_sock_readable(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_write, true); assert_eq!(sender.msgs.len(), 0); sender.allow(2); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.zhttp_read, true); assert_eq!(sender.msgs.len(), 1); let (addr, buf) = sender.take(); let buf = &buf[..]; assert_eq!(addr, None); let expected = concat!( "T220:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let ids = [zhttppacket::Id { id: b"1", seq: Some(0), }]; let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: 400, reason: "Bad Request", headers: &[ zhttppacket::Header { name: "Content-Type", value: b"text/plain", }, zhttppacket::Header { name: "Content-Length", value: b"18", }, ], content_type: None, body: b"stopping this now\n", }; let zresp = zhttppacket::Response { from: b"handler", ids: &ids, multi: false, ptype: zhttppacket::ResponsePacket::Data(rdata), ptype_str: "", }; c.apply_zhttp_response(Instant::now(), &zresp, ids[0].seq) .unwrap(); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Connected); assert_eq!(want.sock_write, true); let data = sock.take_writable(); assert!(data.is_empty()); sock.allow_write(1024); let want = c .process( Instant::now(), "test", &mut sock, &mut sender, &mut packet_buf, &mut tmp_buf, ) .unwrap(); assert_eq!(c.state(), ServerState::Finished); assert_eq!(want.sock_write, false); let data = sock.take_writable(); let expected = concat!( "HTTP/1.1 400 Bad Request\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 18\r\n", "\r\n", "stopping this now\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } condure-1.3.1/src/event.rs000064400000000000000000000522300000000000000135220ustar 00000000000000/* * Copyright (C) 2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::list; use mio::event::Source; use mio::{Events, Interest, Poll, Token, Waker}; use slab::Slab; use std::cell::{Cell, RefCell}; use std::io; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Duration; const EVENTS_MAX: usize = 1024; const LOCAL_BUDGET: u32 = 10; pub type Readiness = Option; pub trait ReadinessExt { fn contains_any(&self, readiness: Interest) -> bool; fn merge(&mut self, readiness: Interest); } impl ReadinessExt for Readiness { fn contains_any(&self, readiness: Interest) -> bool { match *self { Some(cur) => { (readiness.is_readable() && cur.is_readable()) || (readiness.is_writable() && cur.is_writable()) } None => false, } } fn merge(&mut self, readiness: Interest) { match *self { Some(cur) => *self = Some(cur.add(readiness)), None => *self = Some(readiness), } } } struct SourceItem { subtoken: Token, interests: Interest, readiness: Readiness, } struct RegisteredSources { nodes: Slab>, ready: list::List, } struct LocalSources { registered_sources: RefCell, } impl LocalSources { fn new(max_sources: usize) -> Self { Self { registered_sources: RefCell::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !(item.interests.is_readable() && readiness.is_readable()) && !(item.interests.is_writable() && readiness.is_writable()) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.borrow(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.borrow_mut(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct SyncSources { registered_sources: Mutex, waker: Waker, } impl SyncSources { fn new(max_sources: usize, waker: Waker) -> Self { Self { registered_sources: Mutex::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), waker, } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !(item.interests.is_readable() && readiness.is_readable()) && !(item.interests.is_writable() && readiness.is_writable()) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { let need_wake = sources.ready.is_empty(); sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); if need_wake { self.waker.wake()?; } } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.lock().unwrap(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.lock().unwrap(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct CustomSources { local: Rc, sync: Arc, next_local_only: Cell, } impl CustomSources { fn new(poll: &Poll, token: Token, max_sources: usize) -> Result { let waker = Waker::new(poll.registry(), token)?; Ok(Self { local: Rc::new(LocalSources::new(max_sources)), sync: Arc::new(SyncSources::new(max_sources, waker)), next_local_only: Cell::new(false), }) } fn set_next_local_only(&self, enabled: bool) { self.next_local_only.set(enabled); } fn register_local( &self, registration: &LocalRegistration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if reg.data.is_none() { let key = self.local.register(subtoken, interests)?; reg.data = Some((key, self.local.clone())); if let Some(readiness) = reg.readiness { self.local.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister_local(&self, registration: &LocalRegistration) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if let Some((key, _)) = reg.data { self.local.deregister(key)?; reg.data = None; } Ok(()) } fn register( &self, registration: &Registration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if reg.data.is_none() { let key = self.sync.register(subtoken, interests)?; reg.data = Some((key, self.sync.clone())); if let Some(readiness) = reg.readiness { self.sync.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister(&self, registration: &Registration) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if let Some((key, _)) = reg.data { self.sync.deregister(key)?; reg.data = None; } Ok(()) } fn has_local_events(&self) -> bool { self.local.has_events() } fn has_events(&self) -> bool { if self.local.has_events() { return true; } if self.next_local_only.get() { return false; } self.sync.has_events() } fn next_event(&self) -> Option<(Token, Interest)> { if let Some(e) = self.local.next_event() { return Some(e); } if self.next_local_only.get() { return None; } if let Some(e) = self.sync.next_event() { return Some(e); } None } } struct RegistrationInner { data: Option<(usize, Arc)>, readiness: Readiness, } pub struct Registration { inner: Arc>, } impl Registration { pub fn new() -> (Self, SetReadiness) { let reg = Arc::new(Mutex::new(RegistrationInner { data: None, readiness: None, })); let registration = Self { inner: reg.clone() }; let set_readiness = SetReadiness { inner: reg }; (registration, set_readiness) } } impl Drop for Registration { fn drop(&mut self) { let mut reg = self.inner.lock().unwrap(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct SetReadiness { inner: Arc>, } impl SetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.inner.lock().unwrap(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } struct LocalRegistrationData { data: Option<(usize, Rc)>, readiness: Readiness, } pub struct LocalRegistrationEntry { data: RefCell, } pub struct LocalRegistration { entry: arena::Rc, } impl LocalRegistration { pub fn new(memory: &Rc>) -> (Self, LocalSetReadiness) { let reg = arena::Rc::new( LocalRegistrationEntry { data: RefCell::new(LocalRegistrationData { data: None, readiness: None, }), }, memory, ) .unwrap(); let registration = Self { entry: arena::Rc::clone(®), }; let set_readiness = LocalSetReadiness { entry: reg }; (registration, set_readiness) } } impl Drop for LocalRegistration { fn drop(&mut self) { let mut reg = self.entry.get().data.borrow_mut(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct LocalSetReadiness { entry: arena::Rc, } impl LocalSetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.entry.get().data.borrow_mut(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } #[derive(Debug, PartialEq)] pub struct Event { token: Token, readiness: Interest, } impl Event { pub fn token(&self) -> Token { self.token } pub fn readiness(&self) -> Interest { self.readiness } pub fn is_readable(&self) -> bool { self.readiness.is_readable() } pub fn is_writable(&self) -> bool { self.readiness.is_writable() } } pub struct Poller { poll: Poll, events: Events, custom_sources: CustomSources, local_registration_memory: Rc>, local_budget: u32, } impl Poller { pub fn new(max_custom_sources: usize) -> Result { let poll = Poll::new()?; let events = Events::with_capacity(EVENTS_MAX); let custom_sources = CustomSources::new(&poll, Token(0), max_custom_sources)?; Ok(Self { poll, events, custom_sources, local_registration_memory: Rc::new(arena::RcMemory::new(max_custom_sources)), local_budget: LOCAL_BUDGET, }) } pub fn register( &self, source: &mut S, token: Token, interests: Interest, ) -> Result<(), io::Error> where S: Source + ?Sized, { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.poll.registry().register(source, token, interests) } pub fn deregister(&self, source: &mut S) -> Result<(), io::Error> where S: Source + ?Sized, { self.poll.registry().deregister(source) } pub fn register_custom( &self, registration: &Registration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources.register(registration, token, interests) } pub fn deregister_custom(&self, registration: &Registration) -> Result<(), io::Error> { self.custom_sources.deregister(registration) } pub fn local_registration_memory(&self) -> &Rc> { &self.local_registration_memory } pub fn register_custom_local( &self, registration: &LocalRegistration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources .register_local(registration, token, interests) } pub fn deregister_custom_local( &self, registration: &LocalRegistration, ) -> Result<(), io::Error> { self.custom_sources.deregister_local(registration) } pub fn poll(&mut self, timeout: Option) -> Result<(), io::Error> { if self.custom_sources.has_local_events() && self.local_budget > 0 { self.local_budget -= 1; self.custom_sources.set_next_local_only(true); self.events.clear(); // don't reread previous mio events return Ok(()); } self.local_budget = LOCAL_BUDGET; self.custom_sources.set_next_local_only(false); let timeout = if self.custom_sources.has_events() { Some(Duration::from_millis(0)) } else { timeout }; self.poll.poll(&mut self.events, timeout) } pub fn iter_events(&self) -> EventsIterator<'_, '_> { EventsIterator { events: self.events.iter(), custom_sources: &self.custom_sources, custom_left: EVENTS_MAX, } } } pub struct EventsIterator<'a, 'b> { events: mio::event::Iter<'b>, custom_sources: &'a CustomSources, custom_left: usize, } impl Iterator for EventsIterator<'_, '_> { type Item = Event; fn next(&mut self) -> Option { while let Some(event) = self.events.next() { if event.token() == Token(0) { continue; } let mut readiness = None; if event.is_readable() { readiness.merge(Interest::READABLE); } if event.is_writable() { readiness.merge(Interest::WRITABLE); } if let Some(readiness) = readiness { return Some(Event { token: event.token(), readiness, }); } } if self.custom_left > 0 { self.custom_left -= 1; if let Some((token, readiness)) = self.custom_sources.next_event() { return Some(Event { token, readiness }); } } None } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; #[test] fn test_readiness() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); 'poll: loop { poll.poll(&mut events, None).unwrap(); for event in &events { if event.token() == token { break 'poll; } } } assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_early() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sr.set_readiness(Interest::READABLE).unwrap(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); let event = events.iter().next(); assert!(event.is_some()); let event = event.unwrap(); assert_eq!(event.token(), token); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_local() { let poller = Poller::new(1).unwrap(); let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = LocalRegistration::new(poller.local_registration_memory()); sources .register_local(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_poller() { let token = Token(123); let mut poller = Poller::new(1).unwrap(); assert_eq!(poller.iter_events().next(), None); let (reg, sr) = Registration::new(); poller .register_custom(®, token, Interest::READABLE) .unwrap(); poller.poll(Some(Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); sr.set_readiness(Interest::READABLE).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), token); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); } } condure-1.3.1/src/executor.rs000064400000000000000000000364770000000000000142560ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::list; use crate::waker; use slab::Slab; use std::cell::RefCell; use std::future::Future; use std::io; use std::pin::Pin; use std::rc::{Rc, Weak}; use std::task::{Context, Poll, Waker}; use std::time::Duration; thread_local! { static EXECUTOR: RefCell>> = RefCell::new(None); } struct TaskWaker { tasks: Weak, task_id: usize, } impl waker::RcWake for TaskWaker { fn wake(self: Rc) { if let Some(tasks) = self.tasks.upgrade() { tasks.wake(self.task_id); } } } fn poll_fut(fut: &mut Pin>>, waker: Waker) -> bool { // convert from Pin to Pin<&mut> let fut: Pin<&mut dyn Future> = fut.as_mut(); let mut cx = Context::from_waker(&waker); match fut.poll(&mut cx) { Poll::Ready(_) => true, Poll::Pending => false, } } struct Task { fut: Option>>>, processing: bool, } struct TasksData { nodes: Slab>, next: list::List, wakers: Vec>, } struct Tasks { data: RefCell, pre_poll: RefCell>>, } impl Tasks { fn new(max: usize) -> Rc { let data = TasksData { nodes: Slab::with_capacity(max), next: list::List::default(), wakers: Vec::with_capacity(max), }; let tasks = Rc::new(Self { data: RefCell::new(data), pre_poll: RefCell::new(None), }); { let data = &mut *tasks.data.borrow_mut(); for task_id in 0..data.nodes.capacity() { data.wakers.push(Rc::new(TaskWaker { tasks: Rc::downgrade(&tasks), task_id, })); } } tasks } fn is_empty(&self) -> bool { self.data.borrow().nodes.is_empty() } fn have_next(&self) -> bool { !self.data.borrow().next.is_empty() } fn add(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let data = &mut *self.data.borrow_mut(); if data.nodes.len() == data.nodes.capacity() { return Err(()); } let entry = data.nodes.vacant_entry(); let nkey = entry.key(); let task = Task { fut: Some(Box::pin(fut)), processing: false, }; entry.insert(list::Node::new(task)); data.next.push_back(&mut data.nodes, nkey); Ok(()) } fn remove(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; // drop the future. this should cause it to drop any owned wakers task.fut = None; // at this point, we should be the only remaining owner assert_eq!(Rc::strong_count(&data.wakers[nkey]), 1); data.next.remove(&mut data.nodes, nkey); data.nodes.remove(nkey); } fn take_next_list(&self) -> list::List { let data = &mut *self.data.borrow_mut(); let mut l = list::List::default(); l.concat(&mut data.nodes, &mut data.next); let mut cur = l.head; while let Some(nkey) = cur { let node = &mut data.nodes[nkey]; node.value.processing = true; cur = node.next; } l } fn take_task( &self, l: &mut list::List, ) -> Option<(usize, Pin>>, Waker)> { let nkey = match l.head { Some(nkey) => nkey, None => return None, }; let data = &mut *self.data.borrow_mut(); l.remove(&mut data.nodes, nkey); let task = &mut data.nodes[nkey].value; // both of these are cheap let fut = task.fut.take().unwrap(); let waker = waker::into_std(data.wakers[nkey].clone()); task.processing = false; Some((nkey, fut, waker)) } fn process_next(&self) { let mut l = self.take_next_list(); loop { match self.take_task(&mut l) { Some((task_id, mut fut, waker)) => { self.pre_poll(); let done = poll_fut(&mut fut, waker); // take_task() took the future out of the task, so we // could poll it without having to maintain a borrow of // the tasks set. we'll put it back now self.set_fut(task_id, fut); if done { self.remove(task_id); } } None => break, } } } fn set_fut(&self, task_id: usize, fut: Pin>>) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; task.fut = Some(fut); } fn wake(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let node = &mut data.nodes[nkey]; // if the task is already in the processing list, don't do anything if node.value.processing { return; } // if the task is already queued up in the next list, don't do anything if node.prev.is_some() || data.next.head == Some(nkey) { return; } data.next.push_back(&mut data.nodes, nkey); } fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { *self.pre_poll.borrow_mut() = Some(Box::new(pre_poll_fn)); } fn pre_poll(&self) { let pre_poll = &mut *self.pre_poll.borrow_mut(); if let Some(f) = pre_poll { f(); } } } pub struct Executor { tasks: Rc, } impl Executor { pub fn new(tasks_max: usize) -> Self { let tasks = Tasks::new(tasks_max); EXECUTOR.with(|ex| { if ex.borrow().is_some() { panic!("thread already has an Executor"); } ex.replace(Some(Rc::downgrade(&tasks))); }); Self { tasks } } pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { self.tasks.add(fut) } pub fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { self.tasks.set_pre_poll(pre_poll_fn); } pub fn have_tasks(&self) -> bool { !self.tasks.is_empty() } pub fn run_until_stalled(&self) { while self.tasks.have_next() { self.tasks.process_next() } } pub fn run(&self, mut park: F) -> Result<(), io::Error> where F: FnMut(Option) -> Result<(), io::Error>, { loop { self.tasks.process_next(); if !self.have_tasks() { break; } let timeout = if self.tasks.have_next() { Some(Duration::from_millis(0)) } else { None }; park(timeout)?; } Ok(()) } pub fn current() -> Option { EXECUTOR.with(|ex| match &mut *ex.borrow_mut() { Some(tasks) => Some(Self { tasks: tasks.upgrade().unwrap(), }), None => None, }) } pub fn spawner(&self) -> Spawner { Spawner { tasks: Rc::downgrade(&self.tasks), } } } impl Drop for Executor { fn drop(&mut self) { EXECUTOR.with(|ex| { if Rc::strong_count(&self.tasks) == 1 { ex.replace(None); } }); } } pub struct Spawner { tasks: Weak, } impl Spawner { pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let tasks = match self.tasks.upgrade() { Some(tasks) => tasks, None => return Err(()), }; let ex = Executor { tasks }; ex.spawn(fut) } } #[cfg(test)] mod tests { use super::*; use std::cell::Cell; use std::mem; struct TestFutureData { ready: bool, waker: Option, } struct TestFuture { data: Rc>, } impl TestFuture { fn new() -> Self { let data = TestFutureData { ready: false, waker: None, }; Self { data: Rc::new(RefCell::new(data)), } } fn handle(&self) -> TestHandle { TestHandle { data: Rc::clone(&self.data), } } } impl Future for TestFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut data = self.data.borrow_mut(); match data.ready { true => Poll::Ready(()), false => { data.waker = Some(cx.waker().clone()); Poll::Pending } } } } struct TestHandle { data: Rc>, } impl TestHandle { fn set_ready(&self) { let data = &mut *self.data.borrow_mut(); data.ready = true; if let Some(waker) = data.waker.take() { waker.wake(); } } } struct EarlyWakeFuture { done: bool, } impl EarlyWakeFuture { fn new() -> Self { Self { done: false } } } impl Future for EarlyWakeFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if !self.done { self.done = true; cx.waker().wake_by_ref(); return Poll::Pending; } Poll::Ready(()) } } #[test] fn test_executor_step() { let executor = Executor::new(1); let fut1 = TestFuture::new(); let fut2 = TestFuture::new(); let handle1 = fut1.handle(); let handle2 = fut2.handle(); let started = Rc::new(Cell::new(false)); let fut1_done = Rc::new(Cell::new(false)); let finishing = Rc::new(Cell::new(false)); { let started = Rc::clone(&started); let fut1_done = Rc::clone(&fut1_done); let finishing = Rc::clone(&finishing); executor .spawn(async move { started.set(true); fut1.await; fut1_done.set(true); fut2.await; finishing.set(true); }) .unwrap(); } // not started yet, no progress assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), false); executor.run_until_stalled(); // started, but fut1 not ready assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), true); assert_eq!(fut1_done.get(), false); handle1.set_ready(); executor.run_until_stalled(); // fut1 finished assert_eq!(executor.have_tasks(), true); assert_eq!(fut1_done.get(), true); assert_eq!(finishing.get(), false); handle2.set_ready(); executor.run_until_stalled(); // fut2 finished, and thus the task finished assert_eq!(finishing.get(), true); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_run() { let executor = Executor::new(1); let fut = TestFuture::new(); let handle = fut.handle(); executor .spawn(async move { fut.await; }) .unwrap(); executor .run(|_| { handle.set_ready(); Ok(()) }) .unwrap(); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_spawn_error() { let executor = Executor::new(1); assert!(executor.spawn(async {}).is_ok()); assert!(executor.spawn(async {}).is_err()); } #[test] fn test_executor_current() { assert!(Executor::current().is_none()); let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor .spawn(async move { Executor::current() .unwrap() .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); let current = Executor::current().unwrap(); assert_eq!(executor.have_tasks(), false); assert!(current.spawn(async {}).is_ok()); assert_eq!(executor.have_tasks(), true); mem::drop(executor); assert!(Executor::current().is_some()); mem::drop(current); assert!(Executor::current().is_none()); } #[test] fn test_executor_spawner() { let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); let spawner = executor.spawner(); executor .spawn(async move { spawner .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } #[test] fn test_executor_early_wake() { let executor = Executor::new(1); let fut = EarlyWakeFuture::new(); executor .spawn(async move { fut.await; }) .unwrap(); let mut park_count = 0; executor .run(|_| { park_count += 1; Ok(()) }) .unwrap(); assert_eq!(park_count, 1); } #[test] fn test_executor_pre_poll() { let executor = Executor::new(1); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor.set_pre_poll(move || { flag.set(true); }); } executor.spawn(async {}).unwrap(); assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } } condure-1.3.1/src/future.rs000064400000000000000000001517070000000000000137240ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::channel; use crate::event::ReadinessExt; use crate::reactor::{CustomEvented, FdEvented, IoEvented, Reactor, Registration, TimerEvented}; use crate::shuffle::shuffle; use crate::zmq::{MultipartHeader, ZmqSocket}; use mio; use mio::net::{TcpListener, TcpStream}; use std::cell::Cell; use std::future::Future; use std::io; use std::io::{Read, Write}; use std::net::SocketAddr; use std::pin::Pin; use std::sync::mpsc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; pub const REGISTRATIONS_PER_CHANNEL: usize = 1; // 1 for the zmq fd, and potentially 1 for the retry timer pub const REGISTRATIONS_PER_ZMQSOCKET: usize = 2; fn range_unordered(dest: &mut [usize]) -> &[usize] { for i in 0..dest.len() { dest[i] = i; } shuffle(dest); dest } fn map_poll(p: Pin<&mut T>, cx: &mut Context, map_func: M, wrap_func: W) -> Poll where F: Future, M: FnOnce(&mut T) -> &mut F, W: FnOnce(F::Output) -> V, { let f = unsafe { p.map_unchecked_mut(map_func) }; match f.poll(cx) { Poll::Ready(v) => Poll::Ready(wrap_func(v)), Poll::Pending => Poll::Pending, } } pub enum Select2 { R1(O1), R2(O2), } pub struct Select2Future { f1: F1, f2: F2, } impl Future for Select2Future where F1: Future, F2: Future, { type Output = Select2; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 2]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select2::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select2::R2(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub enum Select3 { R1(O1), R2(O2), R3(O3), } pub struct Select3Future { f1: F1, f2: F2, f3: F3, } impl Future for Select3Future where F1: Future, F2: Future, F3: Future, { type Output = Select3; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 3]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select3::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select3::R2(v)), 2 => map_poll(s, cx, |s| &mut s.f3, |v| Select3::R3(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub enum Select4 { R1(O1), R2(O2), R3(O3), R4(O4), } pub struct Select4Future { f1: F1, f2: F2, f3: F3, f4: F4, } impl Future for Select4Future where F1: Future, F2: Future, F3: Future, F4: Future, { type Output = Select4; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 4]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select4::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select4::R2(v)), 2 => map_poll(s, cx, |s| &mut s.f3, |v| Select4::R3(v)), 3 => map_poll(s, cx, |s| &mut s.f4, |v| Select4::R4(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub enum Select5 { R1(O1), R2(O2), R3(O3), R4(O4), R5(O5), } pub struct Select5Future { f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, } impl Future for Select5Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, { type Output = Select5; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 5]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select5::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select5::R2(v)), 2 => map_poll(s, cx, |s| &mut s.f3, |v| Select5::R3(v)), 3 => map_poll(s, cx, |s| &mut s.f4, |v| Select5::R4(v)), 4 => map_poll(s, cx, |s| &mut s.f5, |v| Select5::R5(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub enum Select6 { R1(O1), R2(O2), R3(O3), R4(O4), R5(O5), R6(O6), } pub struct Select6Future { f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, f6: F6, } impl Future for Select6Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, F6: Future, { type Output = Select6; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 6]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select6::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select6::R2(v)), 2 => map_poll(s, cx, |s| &mut s.f3, |v| Select6::R3(v)), 3 => map_poll(s, cx, |s| &mut s.f4, |v| Select6::R4(v)), 4 => map_poll(s, cx, |s| &mut s.f5, |v| Select6::R5(v)), 5 => map_poll(s, cx, |s| &mut s.f6, |v| Select6::R6(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub enum Select9 { R1(O1), R2(O2), R3(O3), R4(O4), R5(O5), R6(O6), R7(O7), R8(O8), R9(O9), } pub struct Select9Future { f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, f6: F6, f7: F7, f8: F8, f9: F9, } impl Future for Select9Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, F6: Future, F7: Future, F8: Future, F9: Future, { type Output = Select9< F1::Output, F2::Output, F3::Output, F4::Output, F5::Output, F6::Output, F7::Output, F8::Output, F9::Output, >; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; 9]; for i in range_unordered(&mut indexes) { let s = self.as_mut(); let p = match i { 0 => map_poll(s, cx, |s| &mut s.f1, |v| Select9::R1(v)), 1 => map_poll(s, cx, |s| &mut s.f2, |v| Select9::R2(v)), 2 => map_poll(s, cx, |s| &mut s.f3, |v| Select9::R3(v)), 3 => map_poll(s, cx, |s| &mut s.f4, |v| Select9::R4(v)), 4 => map_poll(s, cx, |s| &mut s.f5, |v| Select9::R5(v)), 5 => map_poll(s, cx, |s| &mut s.f6, |v| Select9::R6(v)), 6 => map_poll(s, cx, |s| &mut s.f7, |v| Select9::R7(v)), 7 => map_poll(s, cx, |s| &mut s.f8, |v| Select9::R8(v)), 8 => map_poll(s, cx, |s| &mut s.f9, |v| Select9::R9(v)), _ => unreachable!(), }; if p.is_ready() { return p; } } Poll::Pending } } pub fn select_2(f1: F1, f2: F2) -> Select2Future where F1: Future, F2: Future, { Select2Future { f1, f2 } } pub fn select_3(f1: F1, f2: F2, f3: F3) -> Select3Future where F1: Future, F2: Future, F3: Future, { Select3Future { f1, f2, f3 } } pub fn select_4(f1: F1, f2: F2, f3: F3, f4: F4) -> Select4Future where F1: Future, F2: Future, F3: Future, F4: Future, { Select4Future { f1, f2, f3, f4 } } pub fn select_5( f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, ) -> Select5Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, { Select5Future { f1, f2, f3, f4, f5 } } pub fn select_6( f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, f6: F6, ) -> Select6Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, F6: Future, { Select6Future { f1, f2, f3, f4, f5, f6, } } pub fn select_9( f1: F1, f2: F2, f3: F3, f4: F4, f5: F5, f6: F6, f7: F7, f8: F8, f9: F9, ) -> Select9Future where F1: Future, F2: Future, F3: Future, F4: Future, F5: Future, F6: Future, F7: Future, F8: Future, F9: Future, { Select9Future { f1, f2, f3, f4, f5, f6, f7, f8, f9, } } pub struct SelectSliceFuture<'a, F> { futures: &'a mut [F], scratch: &'a mut Vec, } impl Future for SelectSliceFuture<'_, F> where F: Future, { type Output = (usize, F::Output); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let s = &mut *self; let indexes = &mut s.scratch; indexes.resize(s.futures.len(), 0); for i in range_unordered(&mut indexes[..s.futures.len()]) { let f = unsafe { Pin::new_unchecked(&mut s.futures[*i]) }; if let Poll::Ready(v) = f.poll(cx) { return Poll::Ready((*i, v)); } } Poll::Pending } } pub fn select_slice<'a, F, O>( futures: &'a mut [F], scratch: &'a mut Vec, ) -> SelectSliceFuture<'a, F> where F: Future, { if futures.len() > scratch.capacity() { panic!( "select_slice scratch is not large enough: {}, need {}", scratch.capacity(), futures.len() ); } SelectSliceFuture { futures, scratch } } pub struct SelectOptionFuture { fut: Option, } impl Future for SelectOptionFuture where F: Future, { type Output = O; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.fut) }; match f.as_pin_mut() { Some(f) => f.poll(cx), None => Poll::Pending, } } } pub fn select_option(fut: Option) -> SelectOptionFuture where F: Future, { SelectOptionFuture { fut } } pub struct SelectOptionRefFuture<'a, F> { fut: Option<&'a mut F>, } impl<'a, F, O> Future for SelectOptionRefFuture<'a, F> where F: Future, { type Output = O; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match self.fut.as_mut() { Some(f) => { let f: Pin<&mut F> = unsafe { Pin::new_unchecked(*f) }; f.poll(cx) } None => Poll::Pending, } } } pub fn select_option_ref(fut: Option<&mut F>) -> SelectOptionRefFuture where F: Future, { SelectOptionRefFuture { fut } } #[track_caller] fn get_reactor() -> Reactor { Reactor::current().expect("no reactor in thread") } pub struct AsyncSender { evented: CustomEvented, inner: channel::Sender, } impl AsyncSender { pub fn new(s: channel::Sender) -> Self { let evented = CustomEvented::new( s.get_write_registration(), mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); // assume we can write, unless can_send() returns false. note that // if can_send() returns true, it doesn't mean we can actually write evented.registration().set_ready(s.can_send()); Self { evented, inner: s } } pub fn is_writable(&self) -> bool { self.evented.registration().is_ready() } pub fn wait_writable<'a>(&'a self) -> WaitWritableFuture<'a, T> { WaitWritableFuture { s: self } } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { match self.inner.try_send(t) { Ok(_) => { // if can_send() returns false, then we know we can't write if !self.inner.can_send() { self.evented.registration().set_ready(false); } Ok(()) } Err(mpsc::TrySendError::Full(t)) => { self.evented.registration().set_ready(false); Err(mpsc::TrySendError::Full(t)) } Err(mpsc::TrySendError::Disconnected(t)) => Err(mpsc::TrySendError::Disconnected(t)), } } pub fn send<'a>(&'a self, t: T) -> SendFuture<'a, T> { SendFuture { s: self, t: Some(t), } } } pub struct AsyncReceiver { evented: CustomEvented, inner: channel::Receiver, } impl AsyncReceiver { pub fn new(r: channel::Receiver) -> Self { let evented = CustomEvented::new( r.get_read_registration(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: r } } pub fn recv<'a>(&'a self) -> RecvFuture<'a, T> { RecvFuture { r: self } } } pub struct AsyncLocalSender { evented: CustomEvented, inner: channel::LocalSender, } impl AsyncLocalSender { pub fn new(s: channel::LocalSender) -> Self { let evented = CustomEvented::new_local( s.get_write_registration(), mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: s } } pub fn send<'a>(&'a self, t: T) -> LocalSendFuture<'a, T> { LocalSendFuture { s: self, t: Some(t), } } } pub struct AsyncLocalReceiver { evented: CustomEvented, inner: channel::LocalReceiver, } impl AsyncLocalReceiver { pub fn new(r: channel::LocalReceiver) -> Self { let evented = CustomEvented::new_local( r.get_read_registration(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: r } } pub fn recv<'a>(&'a self) -> LocalRecvFuture<'a, T> { LocalRecvFuture { r: self } } } pub struct AsyncTcpListener { evented: IoEvented, } impl AsyncTcpListener { pub fn new(l: TcpListener) -> Self { let evented = IoEvented::new(l, mio::Interest::READABLE, &get_reactor()).unwrap(); evented.registration().set_ready(true); Self { evented } } pub fn bind(addr: SocketAddr) -> Result { let listener = TcpListener::bind(addr)?; Ok(Self::new(listener)) } pub fn local_addr(&self) -> Result { self.evented.io().local_addr() } pub fn accept<'a>(&'a self) -> AcceptFuture<'a> { AcceptFuture { l: self } } } pub struct AsyncTcpStream { evented: IoEvented, } impl AsyncTcpStream { pub fn new(s: TcpStream) -> Self { let evented = IoEvented::new( s, mio::Interest::READABLE | mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); // when constructing via new(), assume I/O operations are ready to be // attempted evented.registration().set_ready(true); Self { evented } } pub async fn connect<'a>(addr: SocketAddr) -> Result { let stream = TcpStream::connect(addr)?; let mut stream = Self::new(stream); // when constructing via connect(), the ready state should start out // false because we need to wait for a writability indication stream.evented.registration().set_ready(false); let fut = TcpConnectFuture { s: &mut stream }; fut.await?; Ok(stream) } pub fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> TcpReadFuture<'a> { TcpReadFuture { s: self, buf } } pub fn write<'a>(&'a mut self, buf: &'a [u8]) -> TcpWriteFuture<'a> { TcpWriteFuture { s: self, buf, pos: 0, } } } pub struct AsyncSleep { evented: TimerEvented, } impl AsyncSleep { pub fn new(expires: Instant) -> Self { let evented = TimerEvented::new(expires, &get_reactor()).unwrap(); evented.registration().set_ready(true); Self { evented } } pub fn sleep<'a>(&'a mut self) -> SleepFuture<'a> { SleepFuture { s: self } } } pub struct AsyncZmqSocket { evented: FdEvented, inner: ZmqSocket, timeout: Cell>, } impl AsyncZmqSocket { pub fn new(s: ZmqSocket) -> Self { let evented = FdEvented::new( s.inner().get_fd().unwrap(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); // zmq events are used for readiness, and registration readiness is // used to tell us when to call update_events(). we'll call that // below, so registration readiness can start out false evented.registration().set_ready(false); s.update_events(); Self { evented, inner: s, timeout: Cell::new(None), } } pub fn inner(&self) -> &ZmqSocket { &self.inner } pub fn set_retry_timeout(&self, timeout: Option) { self.timeout.set(timeout); } pub fn send<'a>(&'a self, msg: zmq::Message) -> ZmqSendFuture<'a> { ZmqSendFuture { s: self, msg } } pub fn send_to<'a>( &'a self, header: MultipartHeader, content: zmq::Message, ) -> ZmqSendToFuture<'a> { ZmqSendToFuture { s: self, header, content, timer_evented: None, } } pub fn recv<'a>(&'a self) -> ZmqRecvFuture<'a> { ZmqRecvFuture { s: self } } pub fn recv_routed<'a>(&'a self) -> ZmqRecvRoutedFuture<'a> { ZmqRecvRoutedFuture { s: self } } } pub struct EventWaiter<'a> { registration: &'a Registration, } impl<'a> EventWaiter<'a> { pub fn new(registration: &'a Registration) -> Self { Self { registration } } pub fn wait(&'a self, interest: mio::Interest) -> WaitFuture<'a> { WaitFuture { w: self, interest } } } pub struct WaitWritableFuture<'a, T> { s: &'a AsyncSender, } impl Future for WaitWritableFuture<'_, T> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::WRITABLE); // if can_send() returns false, then we know we can't write. this // check prevents spurious wakups of a rendezvous channel from // indicating writability when the channel is not actually writable if !f.s.inner.can_send() { f.s.evented.registration().set_ready(false); } if !f.s.evented.registration().is_ready() { return Poll::Pending; } Poll::Ready(()) } } impl Drop for WaitWritableFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct SendFuture<'a, T> { s: &'a AsyncSender, t: Option, } impl Future for SendFuture<'_, T> where T: Unpin, { type Output = Result<(), mpsc::SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } let t = f.t.take().unwrap(); // try_send will update the registration readiness, so we don't need // to do that here match f.s.try_send(t) { Ok(()) => Poll::Ready(Ok(())), Err(mpsc::TrySendError::Full(t)) => { f.t = Some(t); Poll::Pending } Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), } } } impl Drop for SendFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct RecvFuture<'a, T> { r: &'a AsyncReceiver, } impl Future for RecvFuture<'_, T> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.r.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if !f.r.evented.registration().is_ready() { return Poll::Pending; } if !f.r.evented.registration().pull_from_budget() { return Poll::Pending; } match f.r.inner.try_recv() { Ok(v) => Poll::Ready(Ok(v)), Err(mpsc::TryRecvError::Empty) => { f.r.evented.registration().set_ready(false); Poll::Pending } Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), } } } impl Drop for RecvFuture<'_, T> { fn drop(&mut self) { self.r.evented.registration().clear_waker(); } } pub struct LocalSendFuture<'a, T> { s: &'a AsyncLocalSender, t: Option, } impl Future for LocalSendFuture<'_, T> where T: Unpin, { type Output = Result<(), mpsc::SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } let t = f.t.take().unwrap(); match f.s.inner.try_send(t) { Ok(()) => Poll::Ready(Ok(())), Err(mpsc::TrySendError::Full(t)) => { f.s.evented.registration().set_ready(false); f.t = Some(t); Poll::Pending } Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), } } } impl Drop for LocalSendFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct LocalRecvFuture<'a, T> { r: &'a AsyncLocalReceiver, } impl Future for LocalRecvFuture<'_, T> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.r.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if !f.r.evented.registration().is_ready() { return Poll::Pending; } if !f.r.evented.registration().pull_from_budget() { return Poll::Pending; } match f.r.inner.try_recv() { Ok(v) => Poll::Ready(Ok(v)), Err(mpsc::TryRecvError::Empty) => { f.r.evented.registration().set_ready(false); Poll::Pending } Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), } } } impl Drop for LocalRecvFuture<'_, T> { fn drop(&mut self) { self.r.evented.registration().clear_waker(); } } pub struct AcceptFuture<'a> { l: &'a AsyncTcpListener, } impl Future for AcceptFuture<'_> { type Output = Result<(TcpStream, SocketAddr), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.l.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if !f.l.evented.registration().is_ready() { return Poll::Pending; } if !f.l.evented.registration().pull_from_budget() { return Poll::Pending; } match f.l.evented.io().accept() { Ok((stream, peer_addr)) => Poll::Ready(Ok((stream, peer_addr))), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.l.evented.registration().set_ready(false); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for AcceptFuture<'_> { fn drop(&mut self) { self.l.evented.registration().clear_waker(); } } pub struct TcpConnectFuture<'a> { s: &'a mut AsyncTcpStream, } impl Future for TcpConnectFuture<'_> { type Output = Result<(), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } let maybe_error = match f.s.evented.io().take_error() { Ok(me) => me, Err(e) => return Poll::Ready(Err(e)), }; if let Some(e) = maybe_error { return Poll::Ready(Err(e)); } Poll::Ready(Ok(())) } } impl Drop for TcpConnectFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct TcpReadFuture<'a> { s: &'a mut AsyncTcpStream, buf: &'a mut [u8], } impl Future for TcpReadFuture<'_> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if !f .s .evented .registration() .readiness() .contains_any(mio::Interest::READABLE) { return Poll::Pending; } if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } match f.s.evented.io().read(f.buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.s.evented .registration() .clear_readiness(mio::Interest::READABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for TcpReadFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct TcpWriteFuture<'a> { s: &'a mut AsyncTcpStream, buf: &'a [u8], pos: usize, } impl Future for TcpWriteFuture<'_> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::WRITABLE); if !f .s .evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) { return Poll::Pending; } // try to write all the data before producing a result, the same as // what a blocking write would do loop { if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } match f.s.evented.io().write(&f.buf[f.pos..]) { Ok(size) => { f.pos += size; if f.pos >= f.buf.len() { break; } } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.s.evented .registration() .clear_readiness(mio::Interest::WRITABLE); return Poll::Pending; } Err(e) => return Poll::Ready(Err(e)), } } Poll::Ready(Ok(f.buf.len())) } } impl Drop for TcpWriteFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct SleepFuture<'a> { s: &'a mut AsyncSleep, } impl Future for SleepFuture<'_> { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } let now = get_reactor().now(); if now >= f.s.evented.expires() { Poll::Ready(()) } else { f.s.evented.registration().set_ready(false); Poll::Pending } } } impl Drop for SleepFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub async fn sleep(duration: Duration) { let now = get_reactor().now(); AsyncSleep::new(now + duration).sleep().await } pub struct ZmqSendFuture<'a> { s: &'a AsyncZmqSocket, msg: zmq::Message, } impl Future for ZmqSendFuture<'_> { type Output = Result<(), zmq::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if f.s.evented.registration().is_ready() { f.s.inner.update_events(); f.s.evented.registration().set_ready(false); } if !f.s.inner.events().contains(zmq::POLLOUT) { return Poll::Pending; } // NOTE: when rust-zmq allows resending messages we can // avoid this copy let msg = zmq::Message::from(&f.msg[..]); match f.s.inner.send(msg, zmq::DONTWAIT) { Ok(()) => Poll::Ready(Ok(())), Err(zmq::Error::EAGAIN) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } } impl Drop for ZmqSendFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct ZmqSendToFuture<'a> { s: &'a AsyncZmqSocket, header: MultipartHeader, content: zmq::Message, timer_evented: Option, } impl Future for ZmqSendToFuture<'_> { type Output = Result<(), zmq::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let reactor = f.s.evented.registration().reactor(); if let Some(timer_evented) = &f.timer_evented { timer_evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if reactor.now() < timer_evented.expires() { timer_evented.registration().set_ready(false); return Poll::Pending; } f.timer_evented = None; } assert!(f.timer_evented.is_none()); f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if f.s.evented.registration().is_ready() { f.s.inner.update_events(); f.s.evented.registration().set_ready(false); } if !f.s.inner.events().contains(zmq::POLLOUT) { return Poll::Pending; } // NOTE: when rust-zmq allows resending messages we can // avoid this copy let content = zmq::Message::from(&f.content[..]); match f.s.inner.send_to(&f.header, content, zmq::DONTWAIT) { Ok(()) => Poll::Ready(Ok(())), Err(zmq::Error::EAGAIN) => { if let Some(timeout) = f.s.timeout.get() { let expires = reactor.now() + timeout; let timer_evented = TimerEvented::new(expires, &reactor).unwrap(); f.s.evented.registration().clear_waker(); timer_evented.registration().set_ready(true); timer_evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); f.timer_evented = Some(timer_evented); } Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for ZmqSendToFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct ZmqRecvFuture<'a> { s: &'a AsyncZmqSocket, } impl Future for ZmqRecvFuture<'_> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if f.s.evented.registration().is_ready() { f.s.inner.update_events(); f.s.evented.registration().set_ready(false); } if !f.s.inner.events().contains(zmq::POLLIN) { return Poll::Pending; } match f.s.inner.recv(zmq::DONTWAIT) { Ok(msg) => Poll::Ready(Ok(msg)), Err(zmq::Error::EAGAIN) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } } impl Drop for ZmqRecvFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct ZmqRecvRoutedFuture<'a> { s: &'a AsyncZmqSocket, } impl Future for ZmqRecvRoutedFuture<'_> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker().clone(), mio::Interest::READABLE); if f.s.evented.registration().is_ready() { f.s.inner.update_events(); f.s.evented.registration().set_ready(false); } if !f.s.inner.events().contains(zmq::POLLIN) { return Poll::Pending; } match f.s.inner.recv_routed(zmq::DONTWAIT) { Ok(msg) => Poll::Ready(Ok(msg)), Err(zmq::Error::EAGAIN) => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } } impl Drop for ZmqRecvRoutedFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct WaitFuture<'a> { w: &'a EventWaiter<'a>, interest: mio::Interest, } impl Future for WaitFuture<'_> { type Output = mio::Interest; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.w.registration.set_waker(cx.waker().clone(), f.interest); if !f.w.registration.readiness().contains_any(f.interest) { return Poll::Pending; } let readiness = f.w.registration.readiness().unwrap(); // mask with the interest let readable = readiness.is_readable() && f.interest.is_readable(); let writable = readiness.is_writable() && f.interest.is_writable(); let readiness = if readable && writable { mio::Interest::READABLE | mio::Interest::WRITABLE } else if readable { mio::Interest::READABLE } else { mio::Interest::WRITABLE }; Poll::Ready(readiness) } } impl Drop for WaitFuture<'_> { fn drop(&mut self) { self.w.registration.clear_waker(); } } pub async fn event_wait(registration: &Registration, interest: mio::Interest) -> mio::Interest { EventWaiter::new(registration).wait(interest).await } #[cfg(test)] mod tests { use super::*; use crate::executor::Executor; use crate::zmq::SpecInfo; use std::mem; use std::rc::Rc; use std::str; use std::task::Context; use std::thread; struct PollFuture<'a, F> { fut: &'a mut F, } impl Future for PollFuture<'_, F> where F: Future, { type Output = Poll; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = unsafe { self.as_mut().map_unchecked_mut(|s| s.fut) }; Poll::Ready(f.poll(cx)) } } fn poll_fut_async<'a, F>(fut: &'a mut F) -> PollFuture<'a, F> { PollFuture { fut } } #[test] fn test_channel_send_bound0() { let executor = Executor::new(2); let reactor = Reactor::new(2); let (s, r) = channel::channel::(0); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { s.send(1).await.unwrap(); assert_eq!(s.is_writable(), false); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_channel_send_bound1() { let executor = Executor::new(1); let reactor = Reactor::new(2); let (s, r) = channel::channel::(1); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { s.send(1).await.unwrap(); assert_eq!(s.is_writable(), true); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), false); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_channel_recv() { let executor = Executor::new(2); let reactor = Reactor::new(2); let (s, r) = channel::channel::(0); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { s.send(1).await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_channel_writable() { let executor = Executor::new(1); let reactor = Reactor::new(1); let (s, r) = channel::channel::(0); let s = AsyncSender::new(s); executor .spawn(async move { assert_eq!(s.is_writable(), false); s.wait_writable().await; }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); // attempting to receive on a rendezvous channel will make the // sender writable assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_local_channel() { let executor = Executor::new(2); let reactor = Reactor::new(2); let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { s.send(1).await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_tcpstream() { let executor = Executor::new(2); // 2 tasks let reactor = Reactor::new(3); // 3 registrations let spawner = executor.spawner(); executor .spawn(async move { let addr = "127.0.0.1:0".parse().unwrap(); let listener = AsyncTcpListener::bind(addr).expect("failed to bind"); let addr = listener.local_addr().unwrap(); spawner .spawn(async move { let mut stream = AsyncTcpStream::connect(addr).await.unwrap(); let size = stream.write("hello".as_bytes()).await.unwrap(); assert_eq!(size, 5); }) .unwrap(); let (stream, _) = listener.accept().await.unwrap(); let mut stream = AsyncTcpStream::new(stream); let mut resp = [0u8; 1024]; let mut resp = io::Cursor::new(&mut resp[..]); loop { let mut buf = [0; 1024]; let size = stream.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.write(&buf[..size]).unwrap(); } let size = resp.position() as usize; let resp = str::from_utf8(&resp.get_ref()[..size]).unwrap(); assert_eq!(resp, "hello"); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_zmq() { let executor = Executor::new(2); let reactor = Reactor::new(2); let spec = "inproc://futures::tests::test_zmq"; let context = zmq::Context::new(); let s = AsyncZmqSocket::new(ZmqSocket::new(&context, zmq::PUSH)); let r = AsyncZmqSocket::new(ZmqSocket::new(&context, zmq::PULL)); s.inner().inner().bind(spec).unwrap(); s.inner().inner().set_sndhwm(1).unwrap(); executor .spawn(async move { s.send(zmq::Message::from(&b"1"[..])).await.unwrap(); s.send(zmq::Message::from(&b"2"[..])).await.unwrap(); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); r.inner().inner().connect(spec).unwrap(); executor .spawn(async move { assert_eq!(r.recv().await, Ok(zmq::Message::from(&b"1"[..]))); assert_eq!(r.recv().await, Ok(zmq::Message::from(&b"2"[..]))); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_zmq_routable() { let executor = Executor::new(2); let reactor = Reactor::new(2); let spec = "inproc://futures::tests::test_zmq_routable"; let context = zmq::Context::new(); let s = AsyncZmqSocket::new(ZmqSocket::new(&context, zmq::DEALER)); let r = AsyncZmqSocket::new(ZmqSocket::new(&context, zmq::ROUTER)); s.inner().inner().bind(spec).unwrap(); s.inner().inner().set_sndhwm(1).unwrap(); executor .spawn(async move { s.send_to(MultipartHeader::new(), zmq::Message::from(&b"1"[..])) .await .unwrap(); s.send_to(MultipartHeader::new(), zmq::Message::from(&b"2"[..])) .await .unwrap(); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); r.inner().inner().connect(spec).unwrap(); executor .spawn(async move { assert_eq!(r.recv_routed().await, Ok(zmq::Message::from(&b"1"[..]))); assert_eq!(r.recv_routed().await, Ok(zmq::Message::from(&b"2"[..]))); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn test_zmq_retry_timeout() { let executor = Executor::new(2); let reactor = Reactor::new(2); let spec = "inproc://futures::tests::test_zmq_routable"; let context = zmq::Context::new(); let s = ZmqSocket::new(&context, zmq::ROUTER); let r1 = ZmqSocket::new(&context, zmq::ROUTER); let r2 = ZmqSocket::new(&context, zmq::ROUTER); s.inner().set_sndhwm(1).unwrap(); s.inner().set_router_mandatory(true).unwrap(); s.apply_specs(&[SpecInfo { spec: spec.to_string(), bind: true, ipc_file_mode: 0, }]) .unwrap(); r1.inner().set_rcvhwm(1).unwrap(); r1.inner().set_identity(b"test1").unwrap(); r1.apply_specs(&[SpecInfo { spec: spec.to_string(), bind: false, ipc_file_mode: 0, }]) .unwrap(); r2.inner().set_rcvhwm(1).unwrap(); r2.inner().set_identity(b"test2").unwrap(); r2.apply_specs(&[SpecInfo { spec: spec.to_string(), bind: false, ipc_file_mode: 0, }]) .unwrap(); // ensure both peers are connected loop { let mut h = MultipartHeader::new(); h.push(zmq::Message::from(&b"test1"[..])); match s.send_to(&h, zmq::Message::from(&b"1"[..]), 0) { Ok(()) => break, Err(zmq::Error::EHOSTUNREACH) => thread::sleep(Duration::from_millis(10)), Err(e) => panic!("{}", e), } } loop { let mut h = MultipartHeader::new(); h.push(zmq::Message::from(&b"test2"[..])); match s.send_to(&h, zmq::Message::from(&b"1"[..]), 0) { Ok(()) => break, Err(zmq::Error::EHOSTUNREACH) => thread::sleep(Duration::from_millis(10)), Err(e) => panic!("{}", e), } } // we can clear out r1 assert_eq!(r1.recv_routed(0), Ok(zmq::Message::from(&b"1"[..]))); // wrap in Rc so the inproc sender is not dropped until after the // messages have been received let s = Rc::new(AsyncZmqSocket::new(s)); s.set_retry_timeout(Some(Duration::from_millis(0))); { let s = s.clone(); executor .spawn(async move { // second write will succeed immediately let mut h = MultipartHeader::new(); h.push(zmq::Message::from(&b"test2"[..])); s.send_to(h, zmq::Message::from(&b"2"[..])).await.unwrap(); // third write will block let mut h = MultipartHeader::new(); h.push(zmq::Message::from(&b"test2"[..])); let mut fut = s.send_to(h, zmq::Message::from(&b"3"[..])); assert_eq!(poll_fut_async(&mut fut).await, Poll::Pending); assert_eq!(fut.timer_evented.is_some(), true); fut.await.unwrap(); }) .unwrap(); } executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); // this will allow the third write to go through assert_eq!(r2.recv_routed(0), Ok(zmq::Message::from(&b"1"[..]))); executor.run(|timeout| reactor.poll(timeout)).unwrap(); assert_eq!(r2.recv_routed(0), Ok(zmq::Message::from(&b"2"[..]))); assert_eq!(r2.recv_routed(0), Ok(zmq::Message::from(&b"3"[..]))); } #[test] fn test_budget_unlimited() { let executor = Executor::new(1); let reactor = Reactor::new(1); let (s, r) = channel::channel::(3); s.send(1).unwrap(); s.send(2).unwrap(); s.send(3).unwrap(); mem::drop(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Ok(2)); assert_eq!(r.recv().await, Ok(3)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); let mut park_count = 0; executor .run(|timeout| { park_count += 1; reactor.poll(timeout) }) .unwrap(); assert_eq!(park_count, 0); } #[test] fn test_budget_1() { let executor = Executor::new(1); let reactor = Reactor::new(1); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(1)); }); } let (s, r) = channel::channel::(3); s.send(1).unwrap(); s.send(2).unwrap(); s.send(3).unwrap(); mem::drop(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Ok(2)); assert_eq!(r.recv().await, Ok(3)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); let mut park_count = 0; executor .run(|timeout| { park_count += 1; reactor.poll(timeout) }) .unwrap(); assert_eq!(park_count, 3); } #[test] fn test_sleep() { let now = Instant::now(); let executor = Executor::new(1); let reactor = Reactor::new_with_time(1, now); executor.spawn(sleep(Duration::from_millis(100))).unwrap(); executor.run_until_stalled(); reactor .poll_nonblocking(now + Duration::from_millis(200)) .unwrap(); executor.run(|_| Ok(())).unwrap(); } #[test] fn test_sleep_ready() { let now = Instant::now(); let executor = Executor::new(1); let _reactor = Reactor::new_with_time(1, now); executor.spawn(sleep(Duration::from_millis(0))).unwrap(); executor.run(|_| Ok(())).unwrap(); } #[test] fn test_event_wait() { let executor = Executor::new(2); let reactor = Reactor::new(2); let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); let s = AsyncLocalSender::new(s); executor .spawn(async move { let reactor = Reactor::current().unwrap(); let reg = reactor .register_custom_local(r.get_read_registration(), mio::Interest::READABLE) .unwrap(); assert_eq!( event_wait(®, mio::Interest::READABLE).await, mio::Interest::READABLE ); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { s.send(1).await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } } condure-1.3.1/src/http1.rs000064400000000000000000002354250000000000000134520ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::buffer::{write_vectored_offset, LimitBufs, VECTORED_MAX}; use std::cmp; use std::convert::TryFrom; use std::io; use std::io::{Read, Write}; use std::str; pub const HEADERS_MAX: usize = 32; const CHUNK_SIZE_MAX: usize = 0xffff; const CHUNK_HEADER_SIZE_MAX: usize = 6; // ffff\r\n const CHUNK_FOOTER: &[u8] = b"\r\n"; fn parse_as_int(src: &[u8]) -> Result { let int_str = str::from_utf8(src); let int_str = match int_str { Ok(int_str) => int_str, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidData)); } }; let x = int_str.parse(); let x = match x { Ok(x) => x, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidData)); } }; Ok(x) } fn header_contains_param(value: &[u8], param: &[u8], ignore_case: bool) -> bool { let param_str = str::from_utf8(param); let param_str = match param_str { Ok(param_str) => param_str, Err(_) => { return false; } }; for part in value.split(|b| *b == b',') { let part_str = str::from_utf8(part); let part_str = match part_str { Ok(part_str) => part_str, Err(_) => { continue; } }; let part_str = part_str.trim(); if ignore_case { if part_str.eq_ignore_ascii_case(param_str) { return true; } } else { if part_str == param_str { return true; } } } false } #[derive(Debug, PartialEq, Clone, Copy)] struct Chunk { header: [u8; CHUNK_HEADER_SIZE_MAX], header_len: usize, size: usize, sent: usize, } // writes src to dest as chunks. current chunk state is passed in fn write_chunk( content: &[&[u8]], footer: &[u8], dest: &mut dyn Write, chunk: &mut Option, max_size: usize, ) -> Result { assert!(max_size <= CHUNK_SIZE_MAX); let mut content_len = 0; for buf in content.iter() { content_len += buf.len(); } if chunk.is_none() { let size = cmp::min(content_len, max_size); let mut h = [0; CHUNK_HEADER_SIZE_MAX]; let h_len = { let mut c = io::Cursor::new(&mut h[..]); write!(&mut c, "{:x}\r\n", size).unwrap(); c.position() as usize }; *chunk = Some(Chunk { header: h, header_len: h_len, size, sent: 0, }); } let chunkv = chunk.as_mut().unwrap(); let cheader = &chunkv.header[..chunkv.header_len]; let data_size = chunkv.size; let total = cheader.len() + data_size + footer.len(); let mut content_arr = [&b""[..]; VECTORED_MAX - 2]; for (i, buf) in content.iter().enumerate() { content_arr[i] = buf; } let trim_content = (&mut content_arr[..content.len()]).limit(data_size); let mut out_arr = [&b""[..]; VECTORED_MAX]; let mut out_arr_len = 0; out_arr[0] = cheader; out_arr_len += 1; for buf in trim_content.iter() { out_arr[out_arr_len] = buf; out_arr_len += 1; } out_arr[out_arr_len] = footer; out_arr_len += 1; let size = write_vectored_offset(dest, &out_arr[..out_arr_len], chunkv.sent)?; chunkv.sent += size; if chunkv.sent < total { return Ok(0); } *chunk = None; Ok(data_size) } #[cfg(test)] pub fn write_headers(writer: &mut dyn io::Write, headers: &[Header]) -> Result<(), io::Error> { for h in headers.iter() { write!(writer, "{}: ", h.name)?; writer.write(h.value)?; writer.write(b"\r\n")?; } writer.write(b"\r\n")?; Ok(()) } #[derive(Clone, Copy)] pub struct Header<'a> { pub name: &'a str, pub value: &'a [u8], } pub const EMPTY_HEADER: Header<'static> = Header { name: "", value: b"", }; #[derive(Debug, PartialEq, Clone, Copy)] pub enum BodySize { NoBody, Known(usize), Unknown, } #[derive(Debug, PartialEq)] pub struct Request<'buf, 'headers> { pub method: &'buf str, pub uri: &'buf str, pub headers: &'headers [httparse::Header<'buf>], pub body_size: BodySize, pub expect_100: bool, } #[derive(Debug, PartialEq, Clone, Copy)] pub enum ServerState { // call: recv_request // next: ReceivingRequest, ReceivingBody, AwaitingResponse ReceivingRequest, // call: recv_body // next: ReceivingBody, AwaitingResponse ReceivingBody, // call: send_response // next: SendingBody AwaitingResponse, // call: send_body // next: SendingBody, Finished SendingBody, // request/response has completed Finished, } #[derive(Debug)] pub enum ServerError { ParseError(httparse::Error), InvalidContentLength, UnsupportedTransferEncoding, Io(io::Error), InvalidChunkSize, ChunkTooLarge, InvalidChunkSuffix, } impl From for ServerError { fn from(e: io::Error) -> Self { Self::Io(e) } } pub struct ServerProtocol { state: ServerState, ver_min: u8, body_size: BodySize, chunk_left: Option, chunk_size: usize, persistent: bool, chunked: bool, sending_chunk: Option, } impl<'buf, 'headers> ServerProtocol { pub fn new() -> Self { Self { state: ServerState::ReceivingRequest, ver_min: 0, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, persistent: false, chunked: false, sending_chunk: None, } } pub fn state(&self) -> ServerState { self.state } pub fn is_persistent(&self) -> bool { self.persistent } pub fn recv_request( &mut self, rbuf: &mut io::Cursor<&'buf [u8]>, headers: &'headers mut [httparse::Header<'buf>], ) -> Option, ServerError>> { assert_eq!(self.state, ServerState::ReceivingRequest); let mut req = httparse::Request::new(headers); let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; let size = match req.parse(buf) { Ok(httparse::Status::Complete(size)) => size, Ok(httparse::Status::Partial) => return None, Err(e) => return Some(Err(ServerError::ParseError(e))), }; let version = req.version.unwrap(); let mut content_len = None; let mut chunked = false; let mut keep_alive = false; let mut close = false; let mut expect_100 = false; for i in 0..req.headers.len() { let h = req.headers[i]; if h.name.eq_ignore_ascii_case("Content-Length") { let len = parse_as_int(h.value); let len = match len { Ok(len) => len, Err(_) => { return Some(Err(ServerError::InvalidContentLength)); } }; content_len = Some(len); } else if h.name.eq_ignore_ascii_case("Transfer-Encoding") { if h.value == b"chunked" { chunked = true; } else { // unknown transfer encoding return Some(Err(ServerError::UnsupportedTransferEncoding)); } } else if h.name.eq_ignore_ascii_case("Connection") { if !keep_alive && header_contains_param(h.value, b"keep-alive", true) { keep_alive = true; } if !close && header_contains_param(h.value, b"close", false) { close = true; } } else if h.name.eq_ignore_ascii_case("Expect") { if header_contains_param(h.value, b"100-continue", false) && version >= 1 { expect_100 = true; } } } self.ver_min = version; if chunked { self.body_size = BodySize::Unknown; } else if let Some(len) = content_len { self.body_size = BodySize::Known(len); self.chunk_left = Some(len); } else { self.body_size = BodySize::NoBody; } if version >= 1 { self.persistent = !close; } else { self.persistent = keep_alive && !close; } rbuf.set_position(rbuf.position() + (size as u64)); self.state = match self.body_size { BodySize::Unknown | BodySize::Known(_) => ServerState::ReceivingBody, BodySize::NoBody => ServerState::AwaitingResponse, }; Some(Ok(Request { method: req.method.unwrap(), uri: req.path.unwrap(), headers: req.headers, body_size: self.body_size, expect_100: expect_100 && self.body_size != BodySize::NoBody, })) } pub fn recv_body( &mut self, rbuf: &mut io::Cursor<&'buf [u8]>, dest: &mut [u8], headers: &'headers mut [httparse::Header<'buf>], ) -> Result<(usize, Option<&'headers [httparse::Header<'buf>]>), ServerError> { assert_eq!(self.state, ServerState::ReceivingBody); match self.body_size { BodySize::Known(_) => { let mut chunk_left = self.chunk_left.unwrap(); let read_size = cmp::min(chunk_left, dest.len()); // rbuf holds body as-is let size = rbuf.read(&mut dest[..read_size])?; chunk_left -= size; if chunk_left == 0 { self.chunk_left = None; self.state = ServerState::AwaitingResponse; } else { self.chunk_left = Some(chunk_left); } Ok((size, None)) } BodySize::Unknown => { if self.chunk_left.is_none() { let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; match httparse::parse_chunk_size(buf) { Ok(httparse::Status::Complete((pos, size))) => { let size = match u32::try_from(size) { Ok(size) => size, Err(_) => return Err(ServerError::ChunkTooLarge), }; let size = size as usize; rbuf.set_position(rbuf.position() + (pos as u64)); self.chunk_left = Some(size); self.chunk_size = size; } Ok(httparse::Status::Partial) => { return Ok((0, None)); } Err(_) => { return Err(ServerError::InvalidChunkSize); } } } let mut chunk_left = self.chunk_left.unwrap(); let size; if chunk_left > 0 { let read_size = cmp::min(chunk_left, dest.len()); size = rbuf.read(&mut dest[..read_size])?; chunk_left -= size; self.chunk_left = Some(chunk_left); } else { size = 0; } let mut trailing_headers = None; if chunk_left == 0 { let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; if self.chunk_size == 0 { // trailing headers match httparse::parse_headers(buf, headers) { Ok(httparse::Status::Complete((pos, headers))) => { rbuf.set_position(rbuf.position() + (pos as u64)); trailing_headers = Some(headers); } Ok(httparse::Status::Partial) => { return Ok((size, None)); } Err(e) => { return Err(ServerError::ParseError(e)); } } self.state = ServerState::AwaitingResponse; } else { if buf.len() < 2 { return Ok((size, None)); } if &buf[..2] != b"\r\n" { return Err(ServerError::InvalidChunkSuffix); } rbuf.set_position(rbuf.position() + 2); } self.chunk_left = None; self.chunk_size = 0; } Ok((size, trailing_headers)) } BodySize::NoBody => unreachable!(), } } pub fn send_100_continue(&mut self, writer: &mut dyn Write) -> Result<(), ServerError> { writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")?; Ok(()) } pub fn send_response( &mut self, writer: &mut dyn Write, code: u32, reason: &str, headers: &[Header], body_size: BodySize, ) -> Result<(), ServerError> { assert!( self.state == ServerState::AwaitingResponse || self.state == ServerState::ReceivingBody ); if self.state == ServerState::ReceivingBody { // when responding early, input stream may be broken self.persistent = false; } let mut body_size = body_size; // certain responses have no body match code { 100..=199 | 204 | 304 => { body_size = BodySize::NoBody; } _ => {} } let chunked = if body_size == BodySize::Unknown && self.ver_min >= 1 { true } else { false }; if self.ver_min >= 1 { writer.write(b"HTTP/1.1 ")?; } else { writer.write(b"HTTP/1.0 ")?; } write!(writer, "{} {}\r\n", code, reason)?; for h in headers.iter() { // we'll override these headers if (h.name.eq_ignore_ascii_case("Connection") && code != 101) || h.name.eq_ignore_ascii_case("Content-Length") || h.name.eq_ignore_ascii_case("Transfer-Encoding") { continue; } write!(writer, "{}: ", h.name)?; writer.write(h.value)?; writer.write(b"\r\n")?; } // Connection header if self.persistent && self.ver_min == 0 { writer.write(b"Connection: keep-alive\r\n")?; } else if !self.persistent && self.ver_min >= 1 { writer.write(b"Connection: close\r\n")?; } if chunked { writer.write(b"Connection: Transfer-Encoding\r\n")?; } // Content-Length header if let BodySize::Known(x) = body_size { write!(writer, "Content-Length: {}\r\n", x)?; } // Transfer-Encoding header if chunked { writer.write(b"Transfer-Encoding: chunked\r\n")?; } writer.write(b"\r\n")?; self.state = ServerState::SendingBody; self.body_size = body_size; self.chunked = chunked; Ok(()) } pub fn send_body( &mut self, writer: &mut dyn Write, src: &[&[u8]], end: bool, headers: Option<&[u8]>, ) -> Result { assert_eq!(self.state, ServerState::SendingBody); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if let BodySize::NoBody = self.body_size { // ignore the data if end { self.state = ServerState::Finished; } return Ok(src_len); } if !self.chunked { let size = write_vectored_offset(writer, src, 0)?; if end && size >= src_len { self.state = ServerState::Finished; } return Ok(size); } // chunked let mut content_written = 0; if src_len > 0 { content_written = write_chunk( src, CHUNK_FOOTER, writer, &mut self.sending_chunk, CHUNK_SIZE_MAX, )?; } // if all content is written then we can send the closing chunk if end && content_written >= src_len { let footer = if let Some(headers) = headers { headers } else { CHUNK_FOOTER }; write_chunk( &[b""], footer, writer, &mut self.sending_chunk, CHUNK_SIZE_MAX, )?; if self.sending_chunk.is_none() { self.state = ServerState::Finished; } } Ok(content_written) } } #[cfg(test)] mod tests { use super::*; use std::mem; struct MyBuffer { data: Vec, max: usize, allow_partial: bool, } impl MyBuffer { fn new(cap: usize, allow_partial: bool) -> Self { Self { data: Vec::new(), max: cap, allow_partial, } } } impl Write for MyBuffer { fn write(&mut self, buf: &[u8]) -> Result { let size = cmp::min(buf.len(), self.max - self.data.len()); if (size == 0 && !buf.is_empty()) || (size < buf.len() && !self.allow_partial) { return Err(io::Error::from(io::ErrorKind::WriteZero)); } self.data.extend_from_slice(&buf[..size]); Ok(size) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { let size = match self.write(buf.as_ref()) { Ok(size) => size, Err(e) => { if e.kind() == io::ErrorKind::WriteZero && total > 0 { return Ok(total); } return Err(e); } }; total += size; if size < buf.len() { break; } } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } struct TestRequest { pub method: String, pub uri: String, pub headers: Vec<(String, Vec)>, pub body: Vec, pub trailing_headers: Vec<(String, Vec)>, pub persistent: bool, } impl TestRequest { fn new() -> Self { Self { method: String::new(), uri: String::new(), headers: Vec::new(), body: Vec::new(), trailing_headers: Vec::new(), persistent: false, } } } struct TestResponse { pub code: u32, pub reason: String, pub headers: Vec<(String, Vec)>, pub body: Vec, pub chunked: bool, pub trailing_headers: Vec<(String, Vec)>, } impl TestResponse { fn new() -> Self { Self { code: 0, reason: String::new(), headers: Vec::new(), body: Vec::new(), chunked: false, trailing_headers: Vec::new(), } } } fn read_req(p: &mut ServerProtocol, src: &[u8], read_size: usize) -> TestRequest { const READ_SIZE_MAX: usize = 1024; const LOOPS_MAX: u32 = 20; assert!(read_size <= READ_SIZE_MAX); assert_eq!(p.state(), ServerState::ReceivingRequest); let mut rbuf = io::Cursor::new(src); let mut result = TestRequest::new(); assert_eq!(p.state(), ServerState::ReceivingRequest); let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let req = p.recv_request(&mut rbuf, &mut headers).unwrap().unwrap(); result.method = String::from(req.method); result.uri = String::from(req.uri); for h in req.headers { let name = String::from(h.name); let value = Vec::from(h.value); result.headers.push((name, value)); } for _ in 0..LOOPS_MAX { if p.state() != ServerState::ReceivingBody { break; } let mut buf = [0; READ_SIZE_MAX]; let (size, trailing_headers) = p .recv_body(&mut rbuf, &mut buf[..read_size], &mut headers) .unwrap(); result.body.extend_from_slice(&buf[..size]); if let Some(trailing_headers) = trailing_headers { for h in trailing_headers { let name = String::from(h.name); let value = Vec::from(h.value); result.trailing_headers.push((name, value)); } } } result.persistent = p.is_persistent(); assert_eq!(p.state(), ServerState::AwaitingResponse); return result; } fn write_resp(p: &mut ServerProtocol, resp: TestResponse, write_size: usize) -> Vec { const WRITE_SIZE_MAX: usize = 1024; const LOOPS_MAX: u32 = 20; assert!(write_size <= WRITE_SIZE_MAX); assert_eq!(p.state(), ServerState::AwaitingResponse); let mut header_out = [0; 1024]; let mut wbuf = io::Cursor::new(&mut header_out[..]); let mut headers = Vec::new(); for h in resp.headers.iter() { headers.push(Header { name: &h.0, value: &h.1, }); } let body_size = if resp.chunked { BodySize::Unknown } else { BodySize::Known(resp.body.len()) }; p.send_response(&mut wbuf, resp.code, &resp.reason, &headers, body_size) .unwrap(); let size = wbuf.position() as usize; let header_out = &header_out[..size]; let mut body_out = MyBuffer::new(0, true); let mut sent = 0; for _ in 0..LOOPS_MAX { if p.state() != ServerState::SendingBody { break; } body_out.max += write_size; let size = match p.send_body(&mut body_out, &[&resp.body[sent..]], true, None) { Ok(size) => size, Err(ServerError::Io(e)) if e.kind() == io::ErrorKind::WriteZero => 0, Err(_) => panic!("send_body failed"), }; sent += size; } assert_eq!(p.state(), ServerState::Finished); let mut out = Vec::new(); out.extend_from_slice(header_out); out.append(&mut body_out.data); out } #[test] fn test_parse_as_int() { // invalid utf8 assert!(parse_as_int(b"\xa0\xa1").is_err()); // not an integer assert!(parse_as_int(b"bogus").is_err()); // not a non-negative integer assert!(parse_as_int(b"-123").is_err()); // success assert_eq!(parse_as_int(b"0").unwrap(), 0); assert_eq!(parse_as_int(b"123").unwrap(), 123); } #[test] fn test_header_contains_param() { // param invalid utf8 assert_eq!(header_contains_param(b"", b"\xa0\xa1", false), false); // skip invalid utf8 part assert_eq!(header_contains_param(b"\xa0\xa1,a", b"a", false), true); // not found assert_eq!(header_contains_param(b"", b"a", false), false); assert_eq!(header_contains_param(b"a", b"b", false), false); assert_eq!(header_contains_param(b"a,b", b"c", false), false); // success assert_eq!(header_contains_param(b"a", b"a", false), true); assert_eq!(header_contains_param(b"a,b", b"a", false), true); assert_eq!(header_contains_param(b"a,b", b"b", false), true); assert_eq!(header_contains_param(b" a ,b", b"a", false), true); assert_eq!(header_contains_param(b"a, b ", b"b", false), true); assert_eq!(header_contains_param(b"A", b"a", true), true); } #[test] fn test_write_chunk() { struct Test { name: &'static str, write_space: usize, data: &'static [&'static [u8]], footer: &'static str, chunk: Option, max_size: usize, result: Result, chunk_after: Option, written: &'static str, } let tests = [ Test { name: "new-partial", write_space: 2, data: &[b"hello"], footer: "\r\n", chunk: None, max_size: CHUNK_SIZE_MAX, result: Ok(0), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), written: "5\r", }, Test { name: "resume-partial", write_space: 2, data: &[b"hello"], footer: "\r\n", chunk: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), max_size: CHUNK_SIZE_MAX, result: Ok(0), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 4, }), written: "\nh", }, Test { name: "error", write_space: 0, data: &[b"hello"], footer: "\r\n", chunk: None, max_size: CHUNK_SIZE_MAX, result: Err(io::Error::from(io::ErrorKind::WriteZero)), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 0, }), written: "", }, Test { name: "complete", write_space: 1024, data: &[b"hello"], footer: "\r\n", chunk: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 4, }), max_size: CHUNK_SIZE_MAX, result: Ok(5), chunk_after: None, written: "ello\r\n", }, Test { name: "partial-content", write_space: 1024, data: &[b"hel", b"lo world"], footer: "\r\n", chunk: Some(Chunk { header: [b'7', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 7, sent: 0, }), max_size: 7, result: Ok(7), chunk_after: None, written: "7\r\nhello w\r\n", }, ]; for test in tests.iter() { let mut w = MyBuffer::new(test.write_space, true); let mut chunk = test.chunk.clone(); let r = write_chunk( test.data, test.footer.as_bytes(), &mut w, &mut chunk, test.max_size, ); match r { Ok(size) => { let expected = match &test.result { Ok(size) => size, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected, "test={}", test.name); } Err(_) => { assert!(test.result.is_err(), "test={}", test.name); } } assert_eq!(chunk, test.chunk_after, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_write_headers() { struct Test<'buf, 'headers> { name: &'static str, write_space: usize, headers: &'headers [Header<'buf>], err: bool, written: &'static str, } let tests = [ Test { name: "cant-write-header-name", write_space: 2, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A", }, Test { name: "cant-write-header-value", write_space: 3, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: ", }, Test { name: "cant-write-header-eol", write_space: 4, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: a", }, Test { name: "cant-write-eol", write_space: 13, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: a\r\nB: b\r\n", }, Test { name: "success", write_space: 1024, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: false, written: "A: a\r\nB: b\r\n\r\n", }, ]; for test in tests.iter() { let mut w = MyBuffer::new(test.write_space, false); let r = write_headers(&mut w, test.headers); assert_eq!(r.is_err(), test.err, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_recv_request() { struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, result: Option, ServerError>>, state: ServerState, ver_min: u8, chunk_left: Option, persistent: bool, rbuf_position: u64, } let tests = [ Test { name: "partial", data: "G", result: None, state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "parse-error", data: "G\n", result: Some(Err(ServerError::ParseError(httparse::Error::Token))), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "invalid-content-length", data: "GET / HTTP/1.0\r\nContent-Length: a\r\n\r\n", result: Some(Err(ServerError::InvalidContentLength)), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "unsupported-transfer-encoding", data: "GET / HTTP/1.0\r\nTransfer-Encoding: bogus\r\n\r\n", result: Some(Err(ServerError::UnsupportedTransferEncoding)), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "no-body", data: "GET / HTTP/1.0\r\nFoo: Bar\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Foo", value: b"Bar", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 28, }, Test { name: "body-size-known", data: "GET / HTTP/1.0\r\nContent-Length: 42\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Content-Length", value: b"42", }], body_size: BodySize::Known(42), expect_100: false, })), state: ServerState::ReceivingBody, ver_min: 0, chunk_left: Some(42), persistent: false, rbuf_position: 38, }, Test { name: "body-size-unknown", data: "GET / HTTP/1.0\r\nTransfer-Encoding: chunked\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Transfer-Encoding", value: b"chunked", }], body_size: BodySize::Unknown, expect_100: false, })), state: ServerState::ReceivingBody, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 46, }, Test { name: "1.0-persistent", data: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Connection", value: b"keep-alive", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 0, chunk_left: None, persistent: true, rbuf_position: 42, }, Test { name: "1.1-persistent", data: "GET / HTTP/1.1\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 1, chunk_left: None, persistent: true, rbuf_position: 18, }, Test { name: "1.1-non-persistent", data: "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Connection", value: b"close", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 1, chunk_left: None, persistent: false, rbuf_position: 37, }, Test { name: "expect-100", data: "POST / HTTP/1.1\r\nContent-Length: 10\r\nExpect: 100-continue\r\n\r\n", result: Some(Ok(Request { method: "POST", uri: "/", headers: &[ httparse::Header { name: "Content-Length", value: b"10", }, httparse::Header { name: "Expect", value: b"100-continue", }, ], body_size: BodySize::Known(10), expect_100: true, })), state: ServerState::ReceivingBody, ver_min: 1, chunk_left: Some(10), persistent: true, rbuf_position: 61, }, ]; for test in tests.iter() { let mut p = ServerProtocol::new(); let mut c = io::Cursor::new(test.data.as_bytes()); let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let r = p.recv_request(&mut c, &mut headers); match r { None => { assert!(test.result.is_none(), "test={}", test.name); } Some(Ok(req)) => { let expected = match &test.result { Some(Ok(req)) => req, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(req, *expected, "test={}", test.name); } Some(Err(e)) => { let expected = match &test.result { Some(Err(e)) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.ver_min, test.ver_min, "test={}", test.name); assert_eq!(p.chunk_left, test.chunk_left, "test={}", test.name); assert_eq!(p.is_persistent(), test.persistent, "test={}", test.name); assert_eq!(c.position(), test.rbuf_position, "test={}", test.name); } } #[test] fn test_recv_body() { struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, body_size: BodySize, chunk_left: Option, chunk_size: usize, result: Result<(usize, Option<&'headers [httparse::Header<'buf>]>), ServerError>, state: ServerState, chunk_left_after: Option, chunk_size_after: usize, rbuf_position: u64, dest_data: &'static str, } let tests = [ Test { name: "partial", data: "hel", body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, result: Ok((3, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(2), chunk_size_after: 0, rbuf_position: 3, dest_data: "hel", }, Test { name: "complete", data: "hello", body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, result: Ok((5, None)), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "hello", }, Test { name: "chunked-header-partial", data: "5", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((0, None)), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-parse-error", data: "z", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(ServerError::InvalidChunkSize), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-too-large", data: "ffffffffff\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(ServerError::ChunkTooLarge), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-ok", data: "5\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((0, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 3, dest_data: "", }, Test { name: "chunked-content-partial", data: "5\r\nhel", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((3, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(2), chunk_size_after: 5, rbuf_position: 6, dest_data: "hel", }, Test { name: "chunked-footer-partial-full-none", data: "5\r\nhello", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((5, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-full-r", data: "5\r\nhello\r", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((5, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-mid-r", data: "\r", body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, result: Ok((0, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-footer-parse-error", data: "5\r\nhelloXX", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(ServerError::InvalidChunkSuffix), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "", }, Test { name: "chunked-complete-full", data: "5\r\nhello\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((5, None)), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 10, dest_data: "hello", }, Test { name: "chunked-complete-mid", data: "lo\r\n", body_size: BodySize::Unknown, chunk_left: Some(2), chunk_size: 5, result: Ok((2, None)), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 4, dest_data: "lo", }, Test { name: "chunked-complete-end", data: "\r\n", body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, result: Ok((0, None)), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 2, dest_data: "", }, Test { name: "chunked-empty", data: "0\r\n\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((0, Some(&[]))), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "", }, Test { name: "trailing-headers-partial", data: "0\r\nhelloXX", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok((0, None)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-parse-error", data: "0\r\nhelloXX\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(ServerError::ParseError(httparse::Error::Token)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-complete", data: "0\r\nFoo: Bar\r\n\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(( 0, Some(&[httparse::Header { name: "Foo", value: b"Bar", }]), )), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 15, dest_data: "", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::ReceivingBody, ver_min: 0, body_size: test.body_size, chunk_left: test.chunk_left, chunk_size: test.chunk_size, persistent: false, chunked: test.body_size == BodySize::Unknown, sending_chunk: None, }; let mut c = io::Cursor::new(test.data.as_bytes()); let mut dest = [0; 1024]; let mut dest_size = 0; let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let r = p.recv_body(&mut c, &mut dest, &mut headers); match r { Ok((size, headers)) => { let (expected_size, expected_headers) = match &test.result { Ok((size, headers)) => (size, headers), _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected_size, "test={}", test.name); assert_eq!(headers, *expected_headers, "test={}", test.name); dest_size = size; } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.chunk_left, test.chunk_left_after, "test={}", test.name); assert_eq!(p.chunk_size, test.chunk_size_after, "test={}", test.name); assert_eq!(c.position(), test.rbuf_position, "test={}", test.name); assert_eq!( str::from_utf8(&dest[..dest_size]).unwrap(), test.dest_data, "test={}", test.name ); } } #[test] fn test_send_response() { struct Test<'buf, 'headers> { name: &'static str, write_space: usize, code: u32, reason: &'static str, headers: &'headers [Header<'buf>], body_size: BodySize, ver_min: u8, persistent: bool, result: Result<(), ServerError>, state: ServerState, body_size_after: BodySize, chunked: bool, written: &'static str, } let tests = [ Test { name: "cant-write-1.1", write_space: 5, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 1, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "", }, Test { name: "cant-write-1.0", write_space: 5, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "", }, Test { name: "cant-write-status-line", write_space: 12, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200", }, Test { name: "cant-write-header-name", write_space: 20, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo", }, Test { name: "cant-write-header-value", write_space: 24, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: ", }, Test { name: "cant-write-header-eol", write_space: 26, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: Bar", }, Test { name: "cant-write-keep-alive", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: true, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "cant-write-close", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 1, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n", }, Test { name: "cant-write-transfer-encoding", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n", }, Test { name: "cant-write-content-length", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "cant-write-te-chunked", write_space: 50, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: true, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\n", }, Test { name: "cant-write-eol", write_space: 18, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "exclude-headers", write_space: 1024, code: 200, reason: "OK", headers: &[ Header { name: "Connection", value: b"X" }, Header { name: "Foo", value: b"Bar" }, Header { name: "Content-Length", value: b"X" }, Header { name: "Transfer-Encoding", value: b"X" }, ], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: Bar\r\n\r\n", }, Test { name: "exclude-headers-101", write_space: 1024, code: 101, reason: "Switching Protocols", headers: &[ Header { name: "Connection", value: b"X" }, Header { name: "Foo", value: b"Bar" }, Header { name: "Content-Length", value: b"X" }, Header { name: "Transfer-Encoding", value: b"X" }, ], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 101 Switching Protocols\r\nConnection: X\r\nFoo: Bar\r\n\r\n", }, Test { name: "1.0-no-body", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.0-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(42), ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, written: "HTTP/1.0 200 OK\r\nContent-Length: 42\r\n\r\n", }, Test { name: "1.0-no-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.1-no-body", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n\r\n", }, Test { name: "1.1-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(42), ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, written: "HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n", }, Test { name: "1.1-no-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: true, written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\nTransfer-Encoding: chunked\r\n\r\n", }, Test { name: "1.0-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nConnection: keep-alive\r\n\r\n", }, Test { name: "1.0-non-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.1-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n\r\n", }, Test { name: "1.1-non-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", }, Test { name: "force-no-body", write_space: 1024, code: 101, reason: "Switching Protocols", headers: &[], body_size: BodySize::Known(42), ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 101 Switching Protocols\r\n\r\n", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::AwaitingResponse, ver_min: test.ver_min, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, persistent: test.persistent, chunked: false, sending_chunk: None, }; let mut w = MyBuffer::new(test.write_space, false); let r = p.send_response(&mut w, test.code, test.reason, test.headers, test.body_size); match r { Ok(_) => { match &test.result { Ok(_) => {} _ => panic!("result mismatch: test={}", test.name), }; } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.body_size, test.body_size_after, "test={}", test.name); assert_eq!(p.chunked, test.chunked, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_send_body() { struct Test { name: &'static str, write_space: usize, src: &'static str, end: bool, headers: Option<&'static [u8]>, body_size: BodySize, chunked: bool, sending_chunk: Option, result: Result, state: ServerState, sending_chunk_after: Option, written: &'static str, } let tests = [ Test { name: "no-body", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "", }, Test { name: "no-body-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "", }, Test { name: "non-chunked-partial", write_space: 3, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(3), state: ServerState::SendingBody, sending_chunk_after: None, written: "hel", }, Test { name: "non-chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: None, written: "", }, Test { name: "non-chunked", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "hello", }, Test { name: "non-chunked-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "", }, Test { name: "chunked-partial", write_space: 2, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), written: "5\r", }, Test { name: "chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 0, }), written: "", }, Test { name: "chunked-complete", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "5\r\nhello\r\n", }, Test { name: "end-partial", write_space: 2, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 2, }), written: "0\r", }, Test { name: "end-error", write_space: 0, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(ServerError::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 0, }), written: "", }, Test { name: "end-complete", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "0\r\n\r\n", }, Test { name: "end-headers", write_space: 1024, src: "", end: true, headers: Some(b"Foo: Bar\r\n\r\n"), body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "0\r\nFoo: Bar\r\n\r\n", }, Test { name: "content-and-end", write_space: 1024, src: "hello", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(5), state: ServerState::Finished, sending_chunk_after: None, written: "5\r\nhello\r\n0\r\n\r\n", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::SendingBody, ver_min: 0, body_size: test.body_size, chunk_left: None, chunk_size: 0, persistent: false, chunked: test.chunked, sending_chunk: test.sending_chunk, }; let mut w = MyBuffer::new(test.write_space, true); let r = p.send_body(&mut w, &[test.src.as_bytes()], test.end, test.headers); match r { Ok(size) => { let expected_size = match &test.result { Ok(size) => size, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected_size, "test={}", test.name); } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!( p.sending_chunk, test.sending_chunk_after, "test={}", test.name ); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_req() { let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n".as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "GET"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 1); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body.len(), 0); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n", "0\r\n\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n", "0\r\n", "Foo: bar\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 1); assert_eq!(req.trailing_headers[0].0, "Foo"); assert_eq!(req.trailing_headers[0].1, b"bar"); assert_eq!(req.persistent, true); } #[test] fn test_resp() { let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; let mut p = ServerProtocol::new(); read_req(&mut p, data.as_bytes(), 2); let mut resp = TestResponse::new(); resp.code = 200; resp.reason = String::from("OK"); resp.headers = vec![(String::from("Content-Type"), b"text/plain".to_vec())]; resp.body = b"hello\n".to_vec(); let out = write_resp(&mut p, resp, 2); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&out).unwrap(), data); let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; let mut p = ServerProtocol::new(); read_req(&mut p, data.as_bytes(), 2); let mut resp = TestResponse::new(); resp.code = 200; resp.reason = String::from("OK"); resp.headers = vec![(String::from("Content-Type"), b"text/plain".to_vec())]; resp.body = b"hello\n".to_vec(); resp.chunked = true; let out = write_resp(&mut p, resp, 2); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n0\r\n\r\n", ); assert_eq!(str::from_utf8(&out).unwrap(), data); } #[test] fn test_persistent() { // http 1.0 without keep alive let data = concat!("GET /foo HTTP/1.0\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, false); // http 1.0 with keep alive let data = concat!( "GET /foo HTTP/1.0\r\n", "Host: example.com\r\n", "Connection: keep-alive\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, true); // http 1.1 without keep alive let data = concat!( "GET /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, false); // http 1.1 with keep alive let data = concat!("GET /foo HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, true); } } condure-1.3.1/src/lib.rs000064400000000000000000000025760000000000000131570ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ pub mod app; pub mod arena; pub mod buffer; pub mod channel; pub mod connection; pub mod event; pub mod executor; pub mod future; pub mod http1; pub mod list; pub mod listener; pub mod reactor; pub mod server; pub mod shuffle; pub mod timer; pub mod tls; pub mod tnetstring; pub mod waker; pub mod websocket; pub mod zhttppacket; pub mod zhttpsocket; pub mod zmq; use app::Config; use log::info; use std::error::Error; pub fn run(config: &Config) -> Result<(), Box> { info!("starting..."); { let a = match app::App::new(config) { Ok(a) => a, Err(e) => { return Err(e.into()); } }; info!("started"); a.wait_for_term(); info!("stopping..."); } info!("stopped"); Ok(()) } condure-1.3.1/src/list.rs000064400000000000000000000205620000000000000133570ustar 00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::ops::IndexMut; pub struct Node { pub prev: Option, pub next: Option, pub value: T, } impl Node { pub fn new(value: T) -> Self { Self { prev: None, next: None, value: value, } } } #[derive(Default, Clone, Copy)] pub struct List { pub head: Option, pub tail: Option, } impl List { pub fn is_empty(&self) -> bool { !self.head.is_some() } pub fn insert(&mut self, nodes: &mut S, after: Option, key: usize) where S: IndexMut>, { let next = if let Some(pkey) = after { let pn = &mut nodes[pkey]; let next = pn.next; pn.next = Some(key); let n = &mut nodes[key]; n.prev = Some(pkey); next } else { let next = self.head; self.head = Some(key); let n = &mut nodes[key]; n.prev = None; next }; let n = &mut nodes[key]; n.next = next; if let Some(nkey) = next { let nn = &mut nodes[nkey]; nn.prev = Some(key); } else { self.tail = Some(key); } } pub fn remove(&mut self, nodes: &mut S, key: usize) where S: IndexMut>, { let n = &mut nodes[key]; let prev = n.prev.take(); let next = n.next.take(); if let Some(pkey) = prev { let pn = &mut nodes[pkey]; pn.next = next; } if let Some(nkey) = next { let nn = &mut nodes[nkey]; nn.prev = prev; } if let Some(hkey) = self.head { if hkey == key { self.head = next; } } if let Some(tkey) = self.tail { if tkey == key { self.tail = prev; } } } pub fn pop_front(&mut self, nodes: &mut S) -> Option where S: IndexMut>, { match self.head { Some(key) => { self.remove(nodes, key); Some(key) } None => None, } } pub fn push_back(&mut self, nodes: &mut S, key: usize) where S: IndexMut>, { self.insert(nodes, self.tail, key); } pub fn concat(&mut self, nodes: &mut S, other: &mut Self) where S: IndexMut>, { if other.is_empty() { // nothing to do return; } // other is non-empty so this is guaranteed to succeed let hkey = other.head.unwrap(); let next = nodes[hkey].next; // since we're inserting after the tail, this will set next=None self.insert(nodes, self.tail, hkey); // restore the node's next key nodes[hkey].next = next; self.tail = other.tail; other.head = None; other.tail = None; } } #[cfg(test)] mod tests { use super::*; use slab::Slab; #[test] fn test_list_push_pop() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); let n2 = nodes.insert(Node::new("n2")); let n3 = nodes.insert(Node::new("n3")); // prevent unused warning on data field assert_eq!(nodes[n1].value, "n1"); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); assert_eq!(nodes[n2].prev, None); assert_eq!(nodes[n2].next, None); assert_eq!(nodes[n3].prev, None); assert_eq!(nodes[n3].next, None); let mut l = List::default(); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(l.pop_front(&mut nodes), None); l.push_back(&mut nodes, n1); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n1)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); l.push_back(&mut nodes, n2); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n2)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, None); l.push_back(&mut nodes, n3); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, Some(n3)); assert_eq!(nodes[n3].prev, Some(n2)); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n1)); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n2)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n2].prev, None); assert_eq!(nodes[n2].next, Some(n3)); assert_eq!(nodes[n3].prev, Some(n2)); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n2)); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n3)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n3].prev, None); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n3)); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(l.pop_front(&mut nodes), None); } #[test] fn test_remove() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); let mut l = List::default(); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); l.push_back(&mut nodes, n1); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n1)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); l.remove(&mut nodes, n1); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); // already removed l.remove(&mut nodes, n1); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); } #[test] fn test_list_concat() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); let n2 = nodes.insert(Node::new("n2")); let mut a = List::default(); let mut b = List::default(); a.concat(&mut nodes, &mut b); assert_eq!(a.is_empty(), true); assert_eq!(a.head, None); assert_eq!(a.tail, None); assert_eq!(b.is_empty(), true); assert_eq!(b.head, None); assert_eq!(b.tail, None); a.push_back(&mut nodes, n1); b.push_back(&mut nodes, n2); a.concat(&mut nodes, &mut b); assert_eq!(a.is_empty(), false); assert_eq!(a.head, Some(n1)); assert_eq!(a.tail, Some(n2)); assert_eq!(b.is_empty(), true); assert_eq!(b.head, None); assert_eq!(b.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, None); } } condure-1.3.1/src/listener.rs000064400000000000000000000172330000000000000142320ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena::recycle_vec; use crate::channel; use crate::executor::Executor; use crate::future::{ select_2, select_slice, AcceptFuture, AsyncReceiver, AsyncSender, AsyncTcpListener, Select2, WaitWritableFuture, }; use crate::reactor::Reactor; use log::{debug, error}; use mio; use mio::net::{TcpListener, TcpStream}; use std::cmp; use std::net::SocketAddr; use std::sync::mpsc; use std::thread; const REACTOR_REGISTRATIONS_MAX: usize = 128; const EXECUTOR_TASKS_MAX: usize = 1; pub struct Listener { thread: Option>, stop: channel::Sender<()>, } impl Listener { pub fn new( listeners: Vec, senders: Vec>, ) -> Listener { let (s, r) = channel::channel(1); let thread = thread::spawn(move || { let reactor = Reactor::new(REACTOR_REGISTRATIONS_MAX); let executor = Executor::new(EXECUTOR_TASKS_MAX); executor.spawn(Self::run(r, listeners, senders)).unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); }); Self { thread: Some(thread), stop: s, } } async fn run( stop: channel::Receiver<()>, listeners: Vec, senders: Vec>, ) { let stop = AsyncReceiver::new(stop); let mut listeners: Vec = listeners .into_iter() .map(|l| AsyncTcpListener::new(l)) .collect(); let mut senders: Vec> = senders.into_iter().map(|s| AsyncSender::new(s)).collect(); let mut listeners_pos = 0; let mut senders_pos = 0; let mut sender_tasks_mem: Vec> = Vec::with_capacity(senders.len()); let mut listener_tasks_mem: Vec = Vec::with_capacity(listeners.len()); let mut slice_scratch = Vec::with_capacity(cmp::max(senders.len(), listeners.len())); let mut stop_recv = stop.recv(); 'accept: loop { // wait for a sender to become writable let mut sender_tasks = recycle_vec(sender_tasks_mem); for s in senders.iter_mut() { sender_tasks.push(s.wait_writable()); } let result = select_2( &mut stop_recv, select_slice(&mut sender_tasks, &mut slice_scratch), ) .await; sender_tasks_mem = recycle_vec(sender_tasks); match result { Select2::R1(_) => break, Select2::R2(_) => {} } // accept a connection let mut listener_tasks = recycle_vec(listener_tasks_mem); let (b, a) = listeners.split_at_mut(listeners_pos); for l in a.iter_mut().chain(b.iter_mut()) { listener_tasks.push(l.accept()); } let (pos, stream, peer_addr) = loop { match select_2( &mut stop_recv, select_slice(&mut listener_tasks, &mut slice_scratch), ) .await { Select2::R1(_) => break 'accept, Select2::R2((pos, result)) => match result { Ok((stream, peer_addr)) => break (pos, stream, peer_addr), Err(e) => error!("accept error: {:?}", e), }, } }; listener_tasks_mem = recycle_vec(listener_tasks); let pos = (listeners_pos + pos) % listeners.len(); debug!("accepted connection from {}", peer_addr); listeners_pos = (pos + 1) % listeners.len(); // write connection to sender let mut pending_sock = Some((pos, stream, peer_addr)); for _ in 0..senders.len() { let sender = &mut senders[senders_pos]; if !sender.is_writable() { senders_pos = (senders_pos + 1) % senders.len(); continue; } let s = pending_sock.take().unwrap(); match sender.try_send(s) { Ok(()) => {} Err(mpsc::TrySendError::Full(s)) => pending_sock = Some(s), Err(mpsc::TrySendError::Disconnected(_)) => { // this could happen during shutdown debug!("receiver disconnected"); } } senders_pos = (senders_pos + 1) % senders.len(); if pending_sock.is_none() { break; } } } } } impl Drop for Listener { fn drop(&mut self) { // this should never fail. receiver won't disconnect unless // we tell it to self.stop.send(()).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] mod tests { use super::*; use crate::event; use std::io::{Read, Write}; use std::mem; use std::sync::mpsc; #[test] fn test_accept() { let mut addrs = Vec::new(); let mut listeners = Vec::new(); let mut senders = Vec::new(); let mut receivers = Vec::new(); for _ in 0..2 { let addr = "127.0.0.1:0".parse().unwrap(); let l = TcpListener::bind(addr).unwrap(); addrs.push(l.local_addr().unwrap()); listeners.push(l); let (sender, receiver) = channel::channel(0); senders.push(sender); receivers.push(receiver); } let _l = Listener::new(listeners, senders); let mut poller = event::Poller::new(1024).unwrap(); let mut client = std::net::TcpStream::connect(&addrs[0]).unwrap(); poller .register_custom( receivers[0].get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); let result = receivers[0].try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); loop { poller.poll(None).unwrap(); let mut done = false; for event in poller.iter_events() { match event.token() { mio::Token(1) => { assert_eq!(event.is_readable(), true); done = true; break; } _ => unreachable!(), } } if done { break; } } let (lnum, mut peer_client, _) = receivers[0].try_recv().unwrap(); assert_eq!(lnum, 0); peer_client.write(b"hello").unwrap(); mem::drop(peer_client); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!(&buf, b"hello"); } } condure-1.3.1/src/main.rs000064400000000000000000000327440000000000000133350ustar 00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use clap::{crate_version, App, Arg}; use condure::app; use log::{error, Level, LevelFilter, Metadata, Record}; use std::error::Error; use std::path::PathBuf; use std::process; use std::time::Duration; // safety values const WORKERS_MAX: usize = 1024; const CONNS_MAX: usize = 10_000_000; struct SimpleLogger; impl log::Log for SimpleLogger { fn enabled(&self, metadata: &Metadata) -> bool { metadata.level() <= Level::Trace } fn log(&self, record: &Record) { if !self.enabled(record.metadata()) { return; } let now = time::now(); let tm_fmt = now.strftime("%Y-%m-%d %H:%M:%S").unwrap(); let millis = now.tm_nsec / 1_000_000; let lname = match record.level() { log::Level::Error => "ERR", log::Level::Warn => "WARN", log::Level::Info => "INFO", log::Level::Debug => "DEBUG", log::Level::Trace => "TRACE", }; println!( "[{}] {}.{:03} [{}] {}", lname, tm_fmt, millis, record.target(), record.args() ); } fn flush(&self) {} } static LOGGER: SimpleLogger = SimpleLogger; struct Args { id: String, workers: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: usize, stream_timeout: usize, listen: Vec, zclient_req_specs: Vec, zclient_stream_specs: Vec, zclient_connect: bool, ipc_file_mode: usize, tls_identities_dir: String, } fn process_args_and_run(args: Args) -> Result<(), Box> { if args.id.is_empty() || args.id.contains(" ") { return Err("failed to parse id: value cannot be empty or contain a space".into()); } if args.workers > WORKERS_MAX { return Err("failed to parse workers: value too large".into()); } if args.req_maxconn + args.stream_maxconn > CONNS_MAX { return Err("total maxconn is too large".into()); } let mut config = app::Config { instance_id: args.id, workers: args.workers, req_maxconn: args.req_maxconn, stream_maxconn: args.stream_maxconn, buffer_size: args.buffer_size, body_buffer_size: args.body_buffer_size, messages_max: args.messages_max, req_timeout: Duration::from_secs(args.req_timeout as u64), stream_timeout: Duration::from_secs(args.stream_timeout as u64), listen: Vec::new(), zclient_req: args.zclient_req_specs, zclient_stream: args.zclient_stream_specs, zclient_connect: args.zclient_connect, ipc_file_mode: args.ipc_file_mode, certs_dir: PathBuf::from(args.tls_identities_dir), }; for v in args.listen.iter() { let mut parts = v.split(','); // there's always a first part let part1 = parts.next().unwrap(); let port_pos = match part1.rfind(':') { Some(pos) => pos + 1, None => 0, }; let port = &part1[port_pos..]; if port.parse::().is_err() { return Err(format!("failed to parse listen: invalid port {}", port).into()); } let addr = if port_pos > 0 { String::from(part1) } else { format!("0.0.0.0:{}", part1) }; let addr = match addr.parse() { Ok(addr) => addr, Err(e) => { return Err(format!("failed to parse listen: {}", e).into()); } }; let mut stream = true; let mut tls = false; let mut default_cert = None; for part in parts { let (k, v) = match part.find('=') { Some(pos) => (&part[..pos], &part[(pos + 1)..]), None => (part, ""), }; match k { "req" => stream = false, "stream" => stream = true, "tls" => tls = true, "default-cert" => default_cert = Some(String::from(v)), _ => return Err(format!("failed to parse listen: invalid param: {}", part).into()), } } config.listen.push(app::ListenConfig { addr, stream, tls, default_cert, }); } condure::run(&config) } fn main() { let matches = App::new("condure") .version(crate_version!()) .about("HTTP/WebSocket connection manager") .arg( Arg::with_name("log-level") .long("log-level") .takes_value(true) .value_name("N") .help("Log level") .default_value("2"), ) .arg( Arg::with_name("id") .long("id") .takes_value(true) .value_name("ID") .help("Instance ID") .default_value("condure"), ) .arg( Arg::with_name("workers") .long("workers") .takes_value(true) .value_name("N") .help("Number of worker threads") .default_value("2"), ) .arg( Arg::with_name("req-maxconn") .long("req-maxconn") .takes_value(true) .value_name("N") .help("Maximum number of concurrent connections in req mode") .default_value("100"), ) .arg( Arg::with_name("stream-maxconn") .long("stream-maxconn") .takes_value(true) .value_name("N") .help("Maximum number of concurrent connections in stream mode") .default_value("10000"), ) .arg( Arg::with_name("buffer-size") .long("buffer-size") .takes_value(true) .value_name("N") .help("Connection buffer size (two buffers per connection)") .default_value("8192"), ) .arg( Arg::with_name("body-buffer-size") .long("body-buffer-size") .takes_value(true) .value_name("N") .help("Body buffer size for connections in req mode") .default_value("100000"), ) .arg( Arg::with_name("messages-max") .long("messages-max") .takes_value(true) .value_name("N") .help("Maximum number of queued WebSocket messages per connection") .default_value("100"), ) .arg( Arg::with_name("req-timeout") .long("req-timeout") .takes_value(true) .value_name("N") .help("Client timeout in req mode (seconds)") .default_value("30"), ) .arg( Arg::with_name("stream-timeout") .long("stream-timeout") .takes_value(true) .value_name("N") .help("Client timeout in stream mode (seconds)") .default_value("1800"), ) .arg( Arg::with_name("listen") .long("listen") .takes_value(true) .value_name("[addr:]port[,params...]") .multiple(true) .help("Port to listen on") .default_value("0.0.0.0:8000,stream"), ) .arg( Arg::with_name("zclient-req") .long("zclient-req") .takes_value(true) .value_name("spec") .multiple(true) .help("ZeroMQ client REQ spec") .default_value("ipc://client"), ) .arg( Arg::with_name("zclient-stream") .long("zclient-stream") .takes_value(true) .value_name("spec-base") .multiple(true) .help("ZeroMQ client PUSH/ROUTER/SUB spec base") .default_value("ipc://client"), ) .arg( Arg::with_name("zclient-connect") .long("zclient-connect") .help("ZeroMQ client sockets should connect instead of bind"), ) .arg( Arg::with_name("ipc-file-mode") .long("ipc-file-mode") .takes_value(true) .value_name("octal") .help("Permissions for ZeroMQ IPC binds"), ) .arg( Arg::with_name("tls-identities-dir") .long("tls-identities-dir") .takes_value(true) .value_name("directory") .help("Directory containing certificates and private keys") .default_value("."), ) .get_matches(); log::set_logger(&LOGGER).unwrap(); log::set_max_level(LevelFilter::Info); let level = matches.value_of("log-level").unwrap(); let level: usize = match level.parse() { Ok(x) => x, Err(e) => { error!("failed to parse log-level: {}", e); process::exit(1); } }; let level = match level { 0 => LevelFilter::Error, 1 => LevelFilter::Warn, 2 => LevelFilter::Info, 3 => LevelFilter::Debug, 4..=core::usize::MAX => LevelFilter::Trace, _ => unreachable!(), }; log::set_max_level(level); let id = matches.value_of("id").unwrap(); let workers = matches.value_of("workers").unwrap(); let workers: usize = match workers.parse() { Ok(x) => x, Err(e) => { error!("failed to parse workers: {}", e); process::exit(1); } }; let req_maxconn = matches.value_of("req-maxconn").unwrap(); let req_maxconn: usize = match req_maxconn.parse() { Ok(x) => x, Err(e) => { error!("failed to parse req-maxconn: {}", e); process::exit(1); } }; let stream_maxconn = matches.value_of("stream-maxconn").unwrap(); let stream_maxconn: usize = match stream_maxconn.parse() { Ok(x) => x, Err(e) => { error!("failed to parse stream-maxconn: {}", e); process::exit(1); } }; let buffer_size = matches.value_of("buffer-size").unwrap(); let buffer_size: usize = match buffer_size.parse() { Ok(x) => x, Err(e) => { error!("failed to parse buffer-size: {}", e); process::exit(1); } }; let body_buffer_size = matches.value_of("body-buffer-size").unwrap(); let body_buffer_size: usize = match body_buffer_size.parse() { Ok(x) => x, Err(e) => { error!("failed to parse body-buffer-size: {}", e); process::exit(1); } }; let messages_max = matches.value_of("messages-max").unwrap(); let messages_max: usize = match messages_max.parse() { Ok(x) => x, Err(e) => { error!("failed to parse messages-max: {}", e); process::exit(1); } }; let req_timeout = matches.value_of("req-timeout").unwrap(); let req_timeout: usize = match req_timeout.parse() { Ok(x) => x, Err(e) => { error!("failed to parse req-timeout: {}", e); process::exit(1); } }; let stream_timeout = matches.value_of("stream-timeout").unwrap(); let stream_timeout: usize = match stream_timeout.parse() { Ok(x) => x, Err(e) => { error!("failed to parse stream-timeout: {}", e); process::exit(1); } }; let listen = matches .values_of("listen") .unwrap() .map(String::from) .collect(); let zclient_req_specs = matches .values_of("zclient-req") .unwrap() .map(String::from) .collect(); let zclient_stream_specs = matches .values_of("zclient-stream") .unwrap() .map(String::from) .collect(); let zclient_connect = matches.is_present("zclient-connect"); let ipc_file_mode = matches.value_of("ipc-file-mode").unwrap_or("0"); let ipc_file_mode: usize = match ipc_file_mode.parse() { Ok(x) => x, Err(e) => { error!("failed to parse ipc-file-mode: {}", e); process::exit(1); } }; let tls_identities_dir = matches.value_of("tls-identities-dir").unwrap(); let args = Args { id: id.to_string(), workers, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, messages_max, req_timeout, stream_timeout, listen, zclient_req_specs, zclient_stream_specs, zclient_connect, ipc_file_mode, tls_identities_dir: tls_identities_dir.to_string(), }; if let Err(e) = process_args_and_run(args) { error!("{}", e); process::exit(1); } } condure-1.3.1/src/reactor.rs000064400000000000000000000546330000000000000140510ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::event; use crate::event::ReadinessExt; use crate::timer::TimerWheel; use mio; use slab::Slab; use std::cell::RefCell; use std::cmp; use std::io; use std::os::unix::io::RawFd; use std::rc::{Rc, Weak}; use std::task::Waker; use std::time::{Duration, Instant}; const TICK_DURATION_MS: u64 = 10; thread_local! { static REACTOR: RefCell>> = RefCell::new(None); } fn duration_to_ticks_round_down(d: Duration) -> u64 { (d.as_millis() / (TICK_DURATION_MS as u128)) as u64 } fn duration_to_ticks_round_up(d: Duration) -> u64 { ((d.as_millis() + (TICK_DURATION_MS as u128) - 1) / (TICK_DURATION_MS as u128)) as u64 } fn ticks_to_duration(t: u64) -> Duration { Duration::from_millis(t * TICK_DURATION_MS) } pub struct Registration { reactor: Weak, key: usize, } impl Registration { pub fn reactor(&self) -> Reactor { let reactor = self.reactor.upgrade().expect("reactor is gone"); Reactor { inner: reactor } } pub fn readiness(&self) -> event::Readiness { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &*reactor.registrations.borrow(); let reg_data = ®istrations[self.key]; reg_data.readiness } pub fn set_readiness(&self, readiness: event::Readiness) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.readiness = readiness; } pub fn clear_readiness(&self, readiness: mio::Interest) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; if let Some(cur) = reg_data.readiness.take() { reg_data.readiness = cur.remove(readiness); } } pub fn is_ready(&self) -> bool { self.readiness().is_some() } pub fn set_ready(&self, ready: bool) { let readiness = if ready { Some(mio::Interest::READABLE) } else { None }; self.set_readiness(readiness); } pub fn set_waker(&self, waker: Waker, interest: mio::Interest) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.waker = Some((waker, interest)); } pub fn clear_waker(&self) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.waker = None; } pub fn deregister_io(&self, source: &mut S) -> Result<(), io::Error> { let reactor = self.reactor.upgrade().expect("reactor is gone"); let poll = &reactor.poll.borrow(); poll.deregister(source) } pub fn pull_from_budget(&self) -> bool { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; if reg_data.waker.is_none() { panic!("pull_from_budget requires a waker to be set"); } let budget = &mut *reactor.budget.borrow_mut(); let ok = match budget { Some(budget) => { if *budget > 0 { *budget -= 1; true } else { false } } None => true, }; // if no budget left, trigger the waker to try again soon if !ok { let (waker, _) = reg_data.waker.take().unwrap(); waker.wake(); } ok } } impl Drop for Registration { fn drop(&mut self) { if let Some(reactor) = self.reactor.upgrade() { let registrations = &mut *reactor.registrations.borrow_mut(); if let Some(timer_key) = registrations[self.key].timer_key { let timer = &mut *reactor.timer.borrow_mut(); timer.wheel.remove(timer_key); } registrations.remove(self.key); } } } struct RegistrationData { readiness: event::Readiness, waker: Option<(Waker, mio::Interest)>, timer_key: Option, } struct TimerData { wheel: TimerWheel, start: Instant, current_ticks: u64, } struct ReactorData { registrations: RefCell>, poll: RefCell, timer: RefCell, budget: RefCell>, } #[derive(Clone)] pub struct Reactor { inner: Rc, } impl Reactor { pub fn new(registrations_max: usize) -> Self { Self::new_with_time(registrations_max, Instant::now()) } pub fn new_with_time(registrations_max: usize, start_time: Instant) -> Self { let timer_data = TimerData { wheel: TimerWheel::new(registrations_max), start: start_time, current_ticks: 0, }; let inner = Rc::new(ReactorData { registrations: RefCell::new(Slab::with_capacity(registrations_max)), poll: RefCell::new(event::Poller::new(registrations_max).unwrap()), timer: RefCell::new(timer_data), budget: RefCell::new(None), }); REACTOR.with(|r| { if r.borrow().is_some() { panic!("thread already has a Reactor"); } r.replace(Some(Rc::downgrade(&inner))); }); Self { inner } } pub fn register_io( &self, source: &mut S, interest: mio::Interest, ) -> Result where S: mio::event::Source + ?Sized, { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, }); if let Err(e) = self .inner .poll .borrow() .register(source, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_custom( &self, handle: &event::Registration, interest: mio::Interest, ) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, }); if let Err(e) = self.inner .poll .borrow() .register_custom(handle, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_custom_local( &self, handle: &event::LocalRegistration, interest: mio::Interest, ) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, }); if let Err(e) = self.inner .poll .borrow() .register_custom_local(handle, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_timer(&self, expires: Instant) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, }); let timer = &mut *self.inner.timer.borrow_mut(); let expires_ticks = duration_to_ticks_round_up(expires - timer.start); let timer_key = match timer.wheel.add(expires_ticks, key) { Ok(timer_key) => timer_key, Err(_) => { registrations.remove(key); return Err(io::Error::from(io::ErrorKind::Other)); } }; registrations[key].timer_key = Some(timer_key); Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } // we advance time after polling. this way, Reactor::now() is accurate // during task processing. we assume the actual time doesn't change much // between task processing and the next poll pub fn poll(&self, timeout: Option) -> Result<(), io::Error> { self.poll_for_events(self.next_timeout(timeout))?; self.advance_time(Instant::now()); self.process_events(); Ok(()) } // return the timeout that would have been used for a blocking poll pub fn poll_nonblocking(&self, current_time: Instant) -> Result, io::Error> { let timeout = self.next_timeout(None); self.poll_for_events(Some(Duration::from_millis(0)))?; self.advance_time(current_time); self.process_events(); Ok(timeout) } pub fn now(&self) -> Instant { let timer = &*self.inner.timer.borrow(); timer.start + ticks_to_duration(timer.current_ticks) } pub fn set_budget(&self, budget: Option) { *self.inner.budget.borrow_mut() = budget; } pub fn current() -> Option { REACTOR.with(|r| match &mut *r.borrow_mut() { Some(inner) => Some(Self { inner: inner.upgrade().unwrap(), }), None => None, }) } pub fn local_registration_memory(&self) -> Rc> { self.inner.poll.borrow().local_registration_memory().clone() } fn next_timeout(&self, user_timeout: Option) -> Option { let timer = &mut *self.inner.timer.borrow_mut(); let timer_timeout = match timer.wheel.timeout() { Some(ticks) => Some(ticks_to_duration(ticks)), None => None, }; match user_timeout { Some(user_timeout) => Some(match timer_timeout { Some(timer_timeout) => cmp::min(user_timeout, timer_timeout), None => user_timeout, }), None => timer_timeout, } } fn poll_for_events(&self, timeout: Option) -> Result<(), io::Error> { let poll = &mut *self.inner.poll.borrow_mut(); poll.poll(timeout) } fn advance_time(&self, current_time: Instant) { let timer = &mut *self.inner.timer.borrow_mut(); timer.current_ticks = duration_to_ticks_round_down(current_time - timer.start); timer.wheel.update(timer.current_ticks); } fn process_events(&self) { let poll = &mut *self.inner.poll.borrow_mut(); let registrations = &mut *self.inner.registrations.borrow_mut(); for event in poll.iter_events() { let key = usize::from(event.token()); assert!(key > 0); let key = key - 1; if let Some(event_reg) = registrations.get_mut(key) { let changed = { let readiness = event_reg.readiness; event_reg.readiness.merge(event.readiness()); readiness != event_reg.readiness }; if changed { if let Some((_, interest)) = &event_reg.waker { let readiness = event.readiness(); if (readiness.is_readable() && interest.is_readable()) || (readiness.is_writable() && interest.is_writable()) { let (waker, _) = event_reg.waker.take().unwrap(); waker.wake(); } } } } } let timer = &mut *self.inner.timer.borrow_mut(); while let Some((_, key)) = timer.wheel.take_expired() { if let Some(event_reg) = registrations.get_mut(key) { event_reg.readiness = Some(mio::Interest::READABLE); event_reg.timer_key = None; if let Some((waker, _)) = event_reg.waker.take() { waker.wake(); } } } } } impl Drop for Reactor { fn drop(&mut self) { REACTOR.with(|r| { if Rc::strong_count(&self.inner) == 1 { r.replace(None); } }); } } pub struct IoEvented { registration: Registration, io: Option, } impl IoEvented { pub fn new(mut io: S, interest: mio::Interest, reactor: &Reactor) -> Result { let registration = reactor.register_io(&mut io, interest)?; Ok(Self { registration, io: Some(io), }) } pub fn registration(&self) -> &Registration { &self.registration } pub fn io(&self) -> &S { &self.io.as_ref().unwrap() } } impl Drop for IoEvented { fn drop(&mut self) { if let Some(mut io) = self.io.take() { self.registration().deregister_io(&mut io).unwrap(); } } } pub struct FdEvented { registration: Registration, fd: RawFd, } impl FdEvented { pub fn new(fd: RawFd, interest: mio::Interest, reactor: &Reactor) -> Result { let registration = reactor.register_io(&mut mio::unix::SourceFd(&fd), interest)?; Ok(Self { registration, fd }) } pub fn registration(&self) -> &Registration { &self.registration } pub fn fd(&self) -> &RawFd { &self.fd } } impl Drop for FdEvented { fn drop(&mut self) { self.registration() .deregister_io(&mut mio::unix::SourceFd(&self.fd)) .unwrap(); } } pub struct CustomEvented { registration: Registration, } impl CustomEvented { pub fn new( event_reg: &event::Registration, interest: mio::Interest, reactor: &Reactor, ) -> Result { let registration = reactor.register_custom(event_reg, interest)?; Ok(Self { registration }) } pub fn new_local( event_reg: &event::LocalRegistration, interest: mio::Interest, reactor: &Reactor, ) -> Result { let registration = reactor.register_custom_local(event_reg, interest)?; Ok(Self { registration }) } pub fn registration(&self) -> &Registration { &self.registration } } pub struct TimerEvented { registration: Registration, expires: Instant, } impl TimerEvented { pub fn new(expires: Instant, reactor: &Reactor) -> Result { let registration = reactor.register_timer(expires)?; Ok(Self { registration, expires, }) } pub fn registration(&self) -> &Registration { &self.registration } pub fn expires(&self) -> Instant { self.expires } } #[cfg(test)] mod tests { use super::*; use crate::waker; use std::cell::Cell; use std::mem; use std::os::unix::io::AsRawFd; use std::rc::Rc; use std::thread; struct TestWaker { waked: Cell, } impl TestWaker { fn new() -> Self { Self { waked: Cell::new(false), } } fn into_std(self: Rc) -> Waker { waker::into_std(self) } fn was_waked(&self) -> bool { self.waked.get() } } impl waker::RcWake for TestWaker { fn wake(self: Rc) { self.waked.set(true); } } #[test] fn test_reactor_io() { let reactor = Reactor::new(1); let addr = "127.0.0.1:0".parse().unwrap(); let listener = mio::net::TcpListener::bind(addr).unwrap(); let evented = IoEvented::new(listener, mio::Interest::READABLE, &reactor).unwrap(); let addr = evented.io().local_addr().unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { std::net::TcpStream::connect(addr).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_fd() { let reactor = Reactor::new(1); let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); let listener = std::net::TcpListener::bind(addr).unwrap(); let evented = FdEvented::new(listener.as_raw_fd(), mio::Interest::READABLE, &reactor).unwrap(); let addr = listener.local_addr().unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { std::net::TcpStream::connect(addr).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_custom() { let reactor = Reactor::new(1); let (reg, sr) = event::Registration::new(); let evented = CustomEvented::new(®, mio::Interest::READABLE, &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { sr.set_readiness(mio::Interest::READABLE).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_timer() { let now = Instant::now(); let reactor = Reactor::new_with_time(1, now); let evented = TimerEvented::new(now + Duration::from_millis(100), &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(waker.was_waked(), false); assert_eq!(reactor.now(), now); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(20)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(100))); assert_eq!(reactor.now(), now + Duration::from_millis(20)); assert_eq!(waker.was_waked(), false); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(40)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(80))); assert_eq!(reactor.now(), now + Duration::from_millis(40)); assert_eq!(waker.was_waked(), false); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(100)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(60))); assert_eq!(waker.was_waked(), true); assert_eq!(reactor.now(), now + Duration::from_millis(100)); } #[test] fn test_reactor_current() { assert!(Reactor::current().is_none()); let reactor = Reactor::new(1); let current = Reactor::current().unwrap(); mem::drop(reactor); assert!(Reactor::current().is_some()); mem::drop(current); assert!(Reactor::current().is_none()); } #[test] fn test_reactor_budget() { let reactor = Reactor::new(1); let (reg, _) = event::Registration::new(); let evented = CustomEvented::new(®, mio::Interest::READABLE, &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(evented.registration().pull_from_budget(), true); assert_eq!(waker.was_waked(), false); reactor.set_budget(Some(0)); assert_eq!(evented.registration().pull_from_budget(), false); assert_eq!(waker.was_waked(), true); let waker = Rc::new(TestWaker::new()); reactor.set_budget(Some(1)); evented .registration() .set_waker(waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(evented.registration().pull_from_budget(), true); assert_eq!(waker.was_waked(), false); assert_eq!(evented.registration().pull_from_budget(), false); assert_eq!(waker.was_waked(), true); } } condure-1.3.1/src/server.rs000064400000000000000000003034230000000000000137120ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::app::ListenConfig; use crate::arena; use crate::buffer::TmpBuffer; use crate::channel; use crate::connection::{ ServerReqConnection, ServerState, ServerStreamConnection, ServerStreamSharedData, Shutdown, Want, ZhttpSender, }; use crate::event; use crate::executor::{Executor, Spawner}; use crate::future::{ event_wait, select_2, select_3, select_4, select_6, select_option, select_option_ref, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver, AsyncSleep, Select2, Select3, Select4, Select6, }; use crate::list; use crate::listener::Listener; use crate::reactor::Reactor; use crate::tls::{IdentityCache, TlsAcceptor, TlsStream}; use crate::tnetstring; use crate::zhttppacket; use crate::zhttpsocket; use crate::zmq::SpecInfo; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, info, warn}; use mio; use mio::net::{TcpListener, TcpSocket, TcpStream}; use mio::unix::SourceFd; use slab::Slab; use std::cell::{Cell, RefCell}; use std::io; use std::io::{Read, Write}; use std::net::SocketAddr; use std::os::unix::io::{FromRawFd, IntoRawFd}; use std::path::Path; use std::rc::Rc; use std::str; use std::str::FromStr; use std::sync::mpsc; use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; const RESP_SENDER_BOUND: usize = 1; // we read and process each response message one at a time, wrapping it in an // rc, and sending it to connections via channels. on the other side of each // channel, the message is received and processed immediately. this means the // max number of messages retained per connection is the channel bound per // connection pub const MSG_RETAINED_PER_CONNECTION_MAX: usize = RESP_SENDER_BOUND; // the max number of messages retained outside of connections is one per // handle we read from (req and stream), in preparation for sending to any // connections pub const MSG_RETAINED_PER_WORKER_MAX: usize = 2; // run x1 // accept_task x2 // req_handle_task x1 // stream_handle_task x1 // keep_alives_task x1 const WORKER_NON_CONNECTION_TASKS_MAX: usize = 10; // note: individual tasks are not (and must not be) capped to this number. // this is because accept_task makes a registration for every connection // task, which means each instance of accept_task could end up making // thousands of registrations. however, such registrations are associated // with the spawning of connection_task, so we can still estimate // registrations relative to the number of tasks const REGISTRATIONS_PER_TASK_MAX: usize = 32; const REACTOR_BUDGET: u32 = 100; const KEEP_ALIVE_TIMEOUT_MS: usize = 45_000; const KEEP_ALIVE_BATCH_MS: usize = 100; const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(KEEP_ALIVE_BATCH_MS as u64); const KEEP_ALIVE_BATCHES: usize = KEEP_ALIVE_TIMEOUT_MS / KEEP_ALIVE_BATCH_MS; const BULK_PACKET_SIZE_MAX: usize = 65_000; fn get_addr_and_offset(msg: &[u8]) -> Result<(&str, usize), ()> { let mut pos = None; for (i, b) in msg.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = match pos { Some(pos) => pos, None => return Err(()), }; let addr = match str::from_utf8(&msg[..pos]) { Ok(addr) => addr, Err(_) => return Err(()), }; Ok((addr, pos + 1)) } fn get_key(id: &[u8]) -> Result { let mut start = None; let mut end = None; for (i, b) in id.iter().enumerate() { if *b == b'-' { if start.is_none() { start = Some(i + 1); } else { end = Some(i); break; } } } let start = match start { Some(start) => start, None => return Err(()), }; let end = match end { Some(end) => end, None => return Err(()), }; let key = match str::from_utf8(&id[start..end]) { Ok(key) => key, Err(_) => return Err(()), }; let key = match key.parse() { Ok(key) => key, Err(_) => return Err(()), }; Ok(key) } fn set_socket_opts(stream: TcpStream) -> TcpStream { if let Err(e) = stream.set_nodelay(true) { error!("set nodelay failed: {:?}", e); } let socket = unsafe { TcpSocket::from_raw_fd(stream.into_raw_fd()) }; if let Err(e) = socket.set_keepalive(true) { error!("set keepalive failed: {:?}", e); } unsafe { TcpStream::from_raw_fd(socket.into_raw_fd()) } } fn local_channel( bound: usize, max_senders: usize, ) -> (channel::LocalSender, channel::LocalReceiver) { let (s, r) = channel::local_channel( bound, max_senders, &Reactor::current().unwrap().local_registration_memory(), ); (s, r) } fn async_local_channel( bound: usize, max_senders: usize, ) -> (AsyncLocalSender, AsyncLocalReceiver) { let (s, r) = local_channel(bound, max_senders); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); (s, r) } fn gen_id(id: usize, ckey: usize, next_cid: &mut u32) -> ArrayString<[u8; 32]> { let mut buf = [0; 32]; let mut c = io::Cursor::new(&mut buf[..]); write!(&mut c, "{}-{}-{:x}", id, ckey, next_cid).unwrap(); let size = c.position() as usize; let s = str::from_utf8(&buf[..size]).unwrap(); *next_cid += 1; ArrayString::from_str(s).unwrap() } impl Shutdown for TcpStream { fn shutdown(&mut self) -> Result<(), io::Error> { Ok(()) } } impl Shutdown for TlsStream { fn shutdown(&mut self) -> Result<(), io::Error> { self.shutdown() } } impl ZhttpSender for channel::LocalSender { fn can_send_to(&self) -> bool { // req mode doesn't use this unimplemented!(); } fn send(&mut self, message: zmq::Message) -> Result<(), zhttpsocket::SendError> { match self.try_send(message) { Ok(()) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(zhttpsocket::SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => Err(zhttpsocket::SendError::Io( io::Error::from(io::ErrorKind::BrokenPipe), )), } } fn send_to( &mut self, _addr: &[u8], _message: zmq::Message, ) -> Result<(), zhttpsocket::SendError> { // req mode doesn't use this unimplemented!(); } } struct StreamLocalSenders { out: channel::LocalSender, out_stream: channel::LocalSender<(ArrayVec<[u8; 64]>, zmq::Message)>, out_stream_can_write: Cell, } impl StreamLocalSenders { fn new( out: channel::LocalSender, out_stream: channel::LocalSender<(ArrayVec<[u8; 64]>, zmq::Message)>, ) -> Self { Self { out, out_stream, out_stream_can_write: Cell::new(true), } } fn set_out_stream_can_write(&self) { self.out_stream_can_write.set(true); } } impl ZhttpSender for StreamLocalSenders { fn can_send_to(&self) -> bool { if self.out_stream_can_write.get() { if self.out_stream.check_send() { return true; } self.out_stream_can_write.set(false); } false } fn send(&mut self, message: zmq::Message) -> Result<(), zhttpsocket::SendError> { match self.out.try_send(message) { Ok(()) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(zhttpsocket::SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => Err(zhttpsocket::SendError::Io( io::Error::from(io::ErrorKind::BrokenPipe), )), } } fn send_to( &mut self, addr: &[u8], message: zmq::Message, ) -> Result<(), zhttpsocket::SendError> { let mut a = ArrayVec::new(); if a.try_extend_from_slice(addr).is_err() { return Err(zhttpsocket::SendError::Io(io::Error::from( io::ErrorKind::InvalidInput, ))); } match self.out_stream.try_send((a, message)) { Ok(()) => Ok(()), Err(mpsc::TrySendError::Full((_, msg))) => Err(zhttpsocket::SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => Err(zhttpsocket::SendError::Io( io::Error::from(io::ErrorKind::BrokenPipe), )), } } } enum ServerConnection { Req(ServerReqConnection, channel::LocalSender), Stream(ServerStreamConnection, StreamLocalSenders), } enum Stream { Plain(TcpStream), Tls(TlsStream), } impl Stream { fn get_tcp(&mut self) -> &mut TcpStream { match self { Stream::Plain(stream) => stream, Stream::Tls(stream) => stream.get_tcp(), } } } struct Connection { id: ArrayString<[u8; 32]>, stream: Stream, conn: ServerConnection, want: Want, timer: Option, zreceiver: channel::LocalReceiver<(arena::Rc, Option)>, } impl Connection { fn new_req( now: Instant, stream: Stream, peer_addr: SocketAddr, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, timeout: Duration, sender: channel::LocalSender, zreceiver: channel::LocalReceiver<(arena::Rc, Option)>, ) -> Self { let secure = match &stream { Stream::Plain(_) => false, Stream::Tls(_) => true, }; Self { id: ArrayString::new(), stream, conn: ServerConnection::Req( ServerReqConnection::new( now, Some(peer_addr), secure, buffer_size, body_buffer_size, rb_tmp, timeout, ), sender, ), want: Want::nothing(), timer: None, zreceiver, } } fn new_stream( now: Instant, stream: Stream, peer_addr: SocketAddr, buffer_size: usize, messages_max: usize, rb_tmp: &Rc, timeout: Duration, senders: StreamLocalSenders, zreceiver: channel::LocalReceiver<(arena::Rc, Option)>, shared: arena::Rc, ) -> Self { let secure = match &stream { Stream::Plain(_) => false, Stream::Tls(_) => true, }; Self { id: ArrayString::new(), stream, conn: ServerConnection::Stream( ServerStreamConnection::new( now, Some(peer_addr), secure, buffer_size, messages_max, rb_tmp, timeout, shared, ), senders, ), want: Want::nothing(), timer: None, zreceiver, } } fn state(&self) -> ServerState { match &self.conn { ServerConnection::Req(conn, _) => conn.state(), ServerConnection::Stream(conn, _) => conn.state(), } } fn get_tcp(&mut self) -> &mut TcpStream { self.stream.get_tcp() } fn set_sock_readable(&mut self) { match &mut self.conn { ServerConnection::Req(conn, _) => conn.set_sock_readable(), ServerConnection::Stream(conn, _) => conn.set_sock_readable(), } } fn set_out_stream_can_write(&self) { match &self.conn { ServerConnection::Req(_, _) => panic!("not stream conn"), ServerConnection::Stream(_, senders) => senders.set_out_stream_can_write(), } } fn start(&mut self, id: &str) { self.id = ArrayString::from_str(id).unwrap(); if let Stream::Tls(stream) = &mut self.stream { stream.set_id(id); } debug!("conn {}: assigning id", self.id); match &mut self.conn { ServerConnection::Req(conn, _) => conn.start(self.id.as_ref()), ServerConnection::Stream(conn, _) => conn.start(self.id.as_ref()), } } fn handle_packet( &mut self, now: Instant, zresp: &zhttppacket::Response, seq: Option, ) -> Result<(), ()> { if !zresp.ptype_str.is_empty() { debug!("conn {}: handle packet: {}", self.id, zresp.ptype_str); } else { debug!("conn {}: handle packet: (data)", self.id); } match &mut self.conn { ServerConnection::Req(conn, _) => { if let Err(e) = conn.apply_zhttp_response(zresp) { debug!("conn {}: apply error {:?}", self.id, e); return Err(()); } } ServerConnection::Stream(conn, _) => { if let Err(e) = conn.apply_zhttp_response(now, zresp, seq) { debug!("conn {}: apply error {:?}", self.id, e); return Err(()); } } } Ok(()) } fn process( &mut self, now: Instant, instance_id: &str, packet_buf: &mut [u8], tmp_buf: &mut [u8], ) -> bool { while let Ok((resp, seq)) = self.zreceiver.try_recv() { // if error, keep going let _ = self.handle_packet(now, resp.get().get(), seq); } match &mut self.stream { Stream::Plain(stream) => Self::process_with_stream( &self.id, &mut self.conn, &mut self.want, stream, now, instance_id, packet_buf, tmp_buf, ), Stream::Tls(stream) => { let done = Self::process_with_stream( &self.id, &mut self.conn, &mut self.want, stream, now, instance_id, packet_buf, tmp_buf, ); // for TLS, wake on all socket events if self.want.sock_read || self.want.sock_write { self.want.sock_read = true; self.want.sock_write = true; } done } } } fn process_with_stream( id: &ArrayString<[u8; 32]>, conn: &mut ServerConnection, want: &mut Want, stream: &mut S, now: Instant, instance_id: &str, packet_buf: &mut [u8], tmp_buf: &mut [u8], ) -> bool { match conn { ServerConnection::Req(conn, sender) => { match conn.process(now, stream, sender, packet_buf) { Ok(w) => *want = w, Err(e) => { debug!("conn {}: process error: {:?}", id, e); return true; } } if conn.state() == ServerState::Finished { return true; } } ServerConnection::Stream(conn, senders) => { match conn.process(now, instance_id, stream, senders, packet_buf, tmp_buf) { Ok(w) => *want = w, Err(e) => { debug!("conn {}: process error: {:?}", id, e); return true; } } if conn.state() == ServerState::Finished { return true; } } } false } } struct BatchKey { addr_index: usize, nkey: usize, } struct BatchGroup<'a, 'b> { addr: &'b [u8], ids: arena::ReusableVecHandle<'b, zhttppacket::Id<'a>>, } impl<'a> BatchGroup<'a, '_> { fn addr(&self) -> &[u8] { self.addr } fn ids(&self) -> &[zhttppacket::Id<'a>] { &*self.ids } } struct Batch { nodes: Slab>, addrs: Vec<(ArrayVec<[u8; 64]>, list::List)>, addr_index: usize, group_ids: arena::ReusableVec, last_group_ckeys: Vec, } impl Batch { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), addrs: Vec::with_capacity(capacity), addr_index: 0, group_ids: arena::ReusableVec::new::(capacity), last_group_ckeys: Vec::with_capacity(capacity), } } fn len(&self) -> usize { self.nodes.len() } fn capacity(&self) -> usize { self.nodes.capacity() } fn is_empty(&self) -> bool { self.nodes.is_empty() } fn clear(&mut self) { self.addrs.clear(); self.nodes.clear(); self.addr_index = 0; } fn add(&mut self, to_addr: &[u8], ckey: usize) -> Result { let mut pos = self.addrs.len(); for (i, a) in self.addrs.iter().enumerate() { if a.0.as_ref() == to_addr { pos = i; } } if pos == self.addrs.len() { // connection limits to_addr to 64 so this is guaranteed to succeed let mut a = ArrayVec::new(); a.try_extend_from_slice(to_addr).unwrap(); self.addrs.push((a, list::List::default())); } if self.nodes.len() == self.nodes.capacity() { return Err(()); } let nkey = self.nodes.insert(list::Node::new(ckey)); self.addrs[pos].1.push_back(&mut self.nodes, nkey); Ok(BatchKey { addr_index: pos, nkey, }) } fn remove(&mut self, key: BatchKey) { self.addrs[key.addr_index] .1 .remove(&mut self.nodes, key.nkey); self.nodes.remove(key.nkey); } fn take_group<'a, 'b: 'a, F>(&'a mut self, get_ids: F) -> Option where F: Fn(usize) -> (&'b [u8], u32), { // find the next addr with items while self.addr_index < self.addrs.len() && self.addrs[self.addr_index].1.is_empty() { self.addr_index += 1; } // if all are empty, we're done if self.addr_index == self.addrs.len() { return None; } let (addr, keys) = &mut self.addrs[self.addr_index]; self.last_group_ckeys.clear(); let mut ids = self.group_ids.get_as_new(); // get ids/seqs while ids.len() < zhttppacket::IDS_MAX { let nkey = match keys.pop_front(&mut self.nodes) { Some(nkey) => nkey, None => break, }; let ckey = self.nodes[nkey].value; self.nodes.remove(nkey); let (id, seq) = get_ids(ckey); self.last_group_ckeys.push(ckey); ids.push(zhttppacket::Id { id, seq: Some(seq) }); } Some(BatchGroup { addr, ids }) } fn last_group_ckeys(&self) -> &[usize] { &self.last_group_ckeys } } enum BatchType { KeepAlive, Cancel, } struct ConnectionItem { id: ArrayString<[u8; 32]>, stop: Option>, zreceiver_sender: Option, Option)>>, shared: Option>, batch_key: Option, } struct ConnectionItems { nodes: Slab>, next_cid: u32, batch: Batch, } impl ConnectionItems { fn new(capacity: usize, batch: Batch) -> Self { Self { nodes: Slab::with_capacity(capacity), next_cid: 0, batch, } } } struct ConnectionsInner { active: list::List, count: usize, max: usize, } struct Connections { items: Rc>, inner: RefCell, } impl Connections { fn new(items: Rc>, max: usize) -> Self { Self { items, inner: RefCell::new(ConnectionsInner { active: list::List::default(), count: 0, max, }), } } fn count(&self) -> usize { self.inner.borrow().count } fn max(&self) -> usize { self.inner.borrow().max } fn add( &self, worker_id: usize, stop: AsyncLocalSender<()>, zreceiver_sender: AsyncLocalSender<(arena::Rc, Option)>, shared: Option>, ) -> Result<(usize, ArrayString<[u8; 32]>), ()> { let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); if items.nodes.len() == items.nodes.capacity() { return Err(()); } let nkey = items.nodes.insert(list::Node::new(ConnectionItem { id: ArrayString::new(), stop: Some(stop), zreceiver_sender: Some(zreceiver_sender), shared, batch_key: None, })); items.nodes[nkey].value.id = gen_id(worker_id, nkey, &mut items.next_cid); c.active.push_back(&mut items.nodes, nkey); c.count += 1; Ok((nkey, items.nodes[nkey].value.id)) } fn remove(&self, ckey: usize) { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } c.active.remove(&mut items.nodes, nkey); c.count -= 1; items.nodes.remove(nkey); } fn regen_id(&self, worker_id: usize, ckey: usize) -> ArrayString<[u8; 32]> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } ci.id = gen_id(worker_id, nkey, &mut items.next_cid); ci.id } fn check_id(&self, ckey: usize, id: &[u8]) -> bool { let nkey = ckey; let items = &*self.items.borrow(); let ci = match items.nodes.get(nkey) { Some(n) => &n.value, None => return false, }; ci.id.as_bytes() == id } fn take_zreceiver_sender( &self, ckey: usize, ) -> Option, Option)>> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; ci.zreceiver_sender.take() } fn set_zreceiver_sender( &self, ckey: usize, sender: AsyncLocalSender<(arena::Rc, Option)>, ) { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; ci.zreceiver_sender = Some(sender); } fn stop_all(&self, about_to_stop: F) where F: Fn(usize), { let items = &mut *self.items.borrow_mut(); let cinner = &*self.inner.borrow_mut(); let mut next = cinner.active.head; while let Some(nkey) = next { let n = &mut items.nodes[nkey]; let ci = &mut n.value; about_to_stop(nkey); ci.stop = None; next = n.next; } } fn items_capacity(&self) -> usize { self.items.borrow().nodes.capacity() } fn is_item_stream(&self, ckey: usize) -> bool { let items = &*self.items.borrow(); match items.nodes.get(ckey) { Some(n) => { let ci = &n.value; ci.shared.is_some() } None => false, } } fn batch_is_empty(&self) -> bool { let items = &*self.items.borrow(); items.batch.is_empty() } fn batch_len(&self) -> usize { let items = &*self.items.borrow(); items.batch.len() } fn batch_capacity(&self) -> usize { let items = &*self.items.borrow(); items.batch.capacity() } fn batch_clear(&self) { let items = &mut *self.items.borrow_mut(); items.batch.clear(); } fn batch_add(&self, ckey: usize) -> Result<(), ()> { let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // only batch connections with known handler addresses let addr_ref = cshared.to_addr(); let addr = match addr_ref.get() { Some(addr) => addr, None => return Err(()), }; let bkey = items.batch.add(addr, ckey)?; ci.batch_key = Some(bkey); Ok(()) } fn next_batch_message( &self, from: &str, btype: BatchType, ) -> Option<(usize, ArrayVec<[u8; 64]>, zmq::Message)> { let items = &mut *self.items.borrow_mut(); let nodes = &mut items.nodes; let batch = &mut items.batch; while !batch.is_empty() { let group = batch .take_group(|ckey| { let ci = &nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); (ci.id.as_bytes(), cshared.out_seq()) }) .unwrap(); let count = group.ids().len(); assert!(count <= zhttppacket::IDS_MAX); let zreq = zhttppacket::Request { from: from.as_bytes(), ids: group.ids(), multi: true, ptype: match btype { BatchType::KeepAlive => zhttppacket::RequestPacket::KeepAlive, BatchType::Cancel => zhttppacket::RequestPacket::Cancel, }, }; let mut data = [0; BULK_PACKET_SIZE_MAX]; let size = match zreq.serialize(&mut data) { Ok(size) => size, Err(e) => { error!( "failed to serialize keep-alive packet with {} ids: {}", zreq.ids.len(), e ); continue; } }; let data = &data[..size]; let mut addr = ArrayVec::<[u8; 64]>::new(); if addr.try_extend_from_slice(group.addr()).is_err() { error!("failed to prepare addr"); continue; } let msg = zmq::Message::from(data); drop(group); for &ckey in batch.last_group_ckeys() { let ci = &mut nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); cshared.inc_out_seq(); ci.batch_key = None; } return Some((count, addr, msg)); } None } } #[derive(Clone)] struct ConnectionOpts { instance_id: Rc, buffer_size: usize, timeout: Duration, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } struct ConnectionReqOpts { body_buffer_size: usize, sender: channel::LocalSender, } struct ConnectionStreamOpts { messages_max: usize, sender: channel::LocalSender, sender_stream: channel::LocalSender<(ArrayVec<[u8; 64]>, zmq::Message)>, stream_shared_mem: Rc>, } enum ConnectionModeOpts { Req(ConnectionReqOpts), Stream(ConnectionStreamOpts), } struct Worker { thread: Option>, stop: Option>, } impl Worker { fn new( instance_id: &str, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, req_acceptor: channel::Receiver<(usize, TcpStream, SocketAddr)>, stream_acceptor: channel::Receiver<(usize, TcpStream, SocketAddr)>, req_acceptor_tls: &Vec<(bool, Option)>, stream_acceptor_tls: &Vec<(bool, Option)>, identities: &Arc, zsockman: &Arc, handle_bound: usize, ) -> Self { debug!("worker {}: starting", id); let (stop, r_stop) = channel::channel(1); let (s_ready, ready) = channel::channel(1); let instance_id = String::from(instance_id); let req_acceptor_tls = req_acceptor_tls.clone(); let stream_acceptor_tls = stream_acceptor_tls.clone(); let identities = Arc::clone(identities); let zsockman = Arc::clone(zsockman); let thread = thread::spawn(move || { let maxconn = req_maxconn + stream_maxconn; // 1 task per connection, plus a handful of supporting tasks let tasks_max = maxconn + WORKER_NON_CONNECTION_TASKS_MAX; let registrations_max = REGISTRATIONS_PER_TASK_MAX * tasks_max; let reactor = Reactor::new(registrations_max); let executor = Executor::new(tasks_max); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(REACTOR_BUDGET)); }); } executor .spawn(Self::run( r_stop, s_ready, instance_id, id, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, messages_max, req_timeout, stream_timeout, req_acceptor, stream_acceptor, req_acceptor_tls, stream_acceptor_tls, identities, zsockman, handle_bound, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("worker {}: stopped", id); }); ready.recv().unwrap(); Self { thread: Some(thread), stop: Some(stop), } } async fn run( stop: channel::Receiver<()>, ready: channel::Sender<()>, instance_id: String, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, req_acceptor: channel::Receiver<(usize, TcpStream, SocketAddr)>, stream_acceptor: channel::Receiver<(usize, TcpStream, SocketAddr)>, req_acceptor_tls: Vec<(bool, Option)>, stream_acceptor_tls: Vec<(bool, Option)>, identities: Arc, zsockman: Arc, handle_bound: usize, ) { let executor = Executor::current().unwrap(); let stop = AsyncReceiver::new(stop); let req_acceptor = AsyncReceiver::new(req_acceptor); let stream_acceptor = AsyncReceiver::new(stream_acceptor); debug!("worker {}: allocating buffers", id); let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); // large enough to fit anything let packet_buf = Rc::new(RefCell::new(vec![0; buffer_size + body_buffer_size + 4096])); // same size as working buffers let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let instance_id = Rc::new(instance_id); let ka_batch = (stream_maxconn + (KEEP_ALIVE_BATCHES - 1)) / KEEP_ALIVE_BATCHES; let batch = Batch::new(ka_batch); let conn_items = Rc::new(RefCell::new(ConnectionItems::new( req_maxconn + stream_maxconn, batch, ))); let req_conns = Rc::new(Connections::new(conn_items.clone(), req_maxconn)); let stream_conns = Rc::new(Connections::new(conn_items.clone(), stream_maxconn)); let (req_accept_stop, r_req_accept_stop) = async_local_channel(1, 1); let (stream_accept_stop, r_stream_accept_stop) = async_local_channel(1, 1); let (req_handle_stop, r_req_handle_stop) = async_local_channel(1, 1); let (stream_handle_stop, r_stream_handle_stop) = async_local_channel(1, 1); let (keep_alives_stop, r_keep_alives_stop) = async_local_channel(1, 1); let (s_req_accept_done, req_accept_done) = async_local_channel(1, 1); let (s_stream_accept_done, stream_accept_done) = async_local_channel(1, 1); let (s_req_handle_done, req_handle_done) = async_local_channel(1, 1); let (s_stream_handle_done, stream_handle_done) = async_local_channel(1, 1); let (s_keep_alives_done, keep_alives_done) = async_local_channel(1, 1); // max_senders is 1 per connection + 1 for the accept task let (zreq_sender, zreq_receiver) = local_channel(handle_bound, req_maxconn + 1); // max_senders is 1 per connection + 1 for the accept task let (zstream_out_sender, zstream_out_receiver) = local_channel(handle_bound, stream_maxconn + 1); // max_senders is 1 per connection + 1 for the accept task + 1 for the handle task let (zstream_out_stream_sender, zstream_out_stream_receiver) = local_channel(handle_bound, stream_maxconn + 2); let zreq_receiver = AsyncLocalReceiver::new(zreq_receiver); let zstream_out_receiver = AsyncLocalReceiver::new(zstream_out_receiver); let zstream_out_stream_receiver = AsyncLocalReceiver::new(zstream_out_stream_receiver); let req_handle = zhttpsocket::AsyncClientReqHandle::new( zsockman.client_req_handle(format!("{}-", id).as_bytes()), ); let stream_handle = zhttpsocket::AsyncClientStreamHandle::new( zsockman.client_stream_handle(format!("{}-", id).as_bytes()), ); let stream_shared_mem = Rc::new(arena::RcMemory::new(stream_maxconn)); executor .spawn(Self::accept_task( "req_accept", id, r_req_accept_stop, s_req_accept_done, req_acceptor, req_acceptor_tls, identities.clone(), executor.spawner(), req_conns.clone(), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: req_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, ConnectionModeOpts::Req(ConnectionReqOpts { body_buffer_size, sender: zreq_sender, }), )) .unwrap(); { let zstream_out_stream_sender = zstream_out_stream_sender .try_clone(&Reactor::current().unwrap().local_registration_memory()) .unwrap(); executor .spawn(Self::accept_task( "stream_accept", id, r_stream_accept_stop, s_stream_accept_done, stream_acceptor, stream_acceptor_tls, identities.clone(), executor.spawner(), stream_conns.clone(), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: stream_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, ConnectionModeOpts::Stream(ConnectionStreamOpts { messages_max, sender: zstream_out_sender, sender_stream: zstream_out_stream_sender, stream_shared_mem, }), )) .unwrap(); } executor .spawn(Self::req_handle_task( id, r_req_handle_stop, s_req_handle_done, zreq_receiver, req_handle, req_maxconn, req_conns.clone(), )) .unwrap(); executor .spawn(Self::stream_handle_task( id, r_stream_handle_stop, s_stream_handle_done, instance_id.clone(), zstream_out_receiver, zstream_out_stream_receiver, stream_handle, stream_maxconn, stream_conns.clone(), )) .unwrap(); executor .spawn(Self::keep_alives_task( id, r_keep_alives_stop, s_keep_alives_done, instance_id.clone(), zstream_out_stream_sender, stream_conns.clone(), )) .unwrap(); debug!("worker {}: started", id); ready.send(()).unwrap(); drop(ready); // wait for stop let _ = stop.recv().await; // stop all tasks drop(req_accept_stop); drop(stream_accept_stop); drop(req_handle_stop); drop(stream_handle_stop); drop(keep_alives_stop); // wait for all to stop let _ = req_accept_done.recv().await; let _ = stream_accept_done.recv().await; let _ = req_handle_done.recv().await; let stream_handle = stream_handle_done.recv().await.unwrap(); let _ = keep_alives_done.recv().await; // send cancels stream_conns.batch_clear(); let mut next_cancel_index = 0; while next_cancel_index < stream_conns.items_capacity() { while stream_conns.batch_len() < stream_conns.batch_capacity() && next_cancel_index < stream_conns.items_capacity() { let key = next_cancel_index; next_cancel_index += 1; if stream_conns.is_item_stream(key) { // ignore errors let _ = stream_conns.batch_add(key); } } while let Some((count, addr, msg)) = stream_conns.next_batch_message(&instance_id, BatchType::Cancel) { debug!("worker {}: sending cancels for {} sessions", id, count); stream_handle.send_to_addr(addr, msg).await.unwrap(); } stream_conns.batch_clear(); } } async fn accept_task( name: &str, id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, acceptor: AsyncReceiver<(usize, TcpStream, SocketAddr)>, acceptor_tls: Vec<(bool, Option)>, identities: Arc, spawner: Spawner, conns: Rc, opts: ConnectionOpts, mode_opts: ConnectionModeOpts, ) { let mut tls_acceptors = Vec::new(); for config in acceptor_tls { if config.0 { let default_cert = config.1.as_ref().map(|s| s.as_str()); tls_acceptors.push(Some(TlsAcceptor::new(&identities, default_cert))); } else { tls_acceptors.push(None); } } let reactor = Reactor::current().unwrap(); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for this task let (s_cdone, cdone) = channel::local_channel( conns.max(), conns.max() + 1, &reactor.local_registration_memory(), ); let cdone = AsyncLocalReceiver::new(cdone); debug!("worker {}: task started: {}", id, name); loop { let acceptor_recv = if conns.count() < conns.max() { Some(acceptor.recv()) } else { None }; let (pos, stream, peer_addr) = match select_3(stop.recv(), cdone.recv(), select_option(acceptor_recv)).await { // stop.recv Select3::R1(_) => break, // cdone.recv Select3::R2(result) => match result { Ok(cid) => { conns.remove(cid); continue; } Err(e) => panic!("cdone channel error: {}", e), }, // acceptor_recv Select3::R3(result) => match result { Ok(ret) => ret, Err(_) => continue, // ignore errors }, }; let stream = set_socket_opts(stream); let stream = match &tls_acceptors[pos] { Some(tls_acceptor) => match tls_acceptor.accept(stream) { Ok(stream) => { debug!("worker {}: tls accept", id); Stream::Tls(stream) } Err(e) => { error!("worker {}: tls accept: {}", id, e); break; } }, None => { debug!("worker {}: plain accept", id); Stream::Plain(stream) } }; let (cstop, r_cstop) = async_local_channel(1, 1); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let (ckey, conn_id, zreceiver, mode_opts, shared) = match &mode_opts { ConnectionModeOpts::Req(req_opts) => { let zreq_sender = req_opts .sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zreq_receiver_sender, zreq_receiver) = local_channel(RESP_SENDER_BOUND, 1); let zreq_receiver_sender = AsyncLocalSender::new(zreq_receiver_sender); let (ckey, conn_id) = conns.add(id, cstop, zreq_receiver_sender, None).unwrap(); debug!( "worker {}: req conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); let mode_opts = ConnectionModeOpts::Req(ConnectionReqOpts { body_buffer_size: req_opts.body_buffer_size, sender: zreq_sender, }); (ckey, conn_id, zreq_receiver, mode_opts, None) } ConnectionModeOpts::Stream(stream_opts) => { let zstream_out_sender = stream_opts .sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let zstream_out_stream_sender = stream_opts .sender_stream .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zstream_receiver_sender, zstream_receiver) = local_channel(RESP_SENDER_BOUND, 1); let zstream_receiver_sender = AsyncLocalSender::new(zstream_receiver_sender); let shared = arena::Rc::new( ServerStreamSharedData::new(), &stream_opts.stream_shared_mem, ) .unwrap(); let (ckey, conn_id) = conns .add( id, cstop, zstream_receiver_sender, Some(arena::Rc::clone(&shared)), ) .unwrap(); debug!( "worker {}: stream conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); let mode_opts = ConnectionModeOpts::Stream(ConnectionStreamOpts { messages_max: stream_opts.messages_max, sender: zstream_out_sender, sender_stream: zstream_out_stream_sender, stream_shared_mem: stream_opts.stream_shared_mem.clone(), }); (ckey, conn_id, zstream_receiver, mode_opts, Some(shared)) } }; if spawner .spawn(Self::connection_task( r_cstop, s_cdone, id, ckey, conn_id, stream, peer_addr, zreceiver, conns.clone(), opts.clone(), mode_opts, shared, )) .is_err() { // this should never happen. we only accept a connection if // we know we can spawn panic!("failed to spawn connection_task"); } } drop(s_cdone); conns.stop_all(|ckey| debug!("worker {}: stopping {}", id, ckey)); while cdone.recv().await.is_ok() {} debug!("worker {}: task stopped: {}", id, name); } async fn req_handle_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, zreq_receiver: AsyncLocalReceiver, req_handle: zhttpsocket::AsyncClientReqHandle, req_maxconn: usize, conns: Rc, ) { let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * req_maxconn); let req_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let req_resp_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); debug!("worker {}: task started: req_handle", id); let mut handle_send = None; 'main: loop { let receiver_recv = if handle_send.is_none() { Some(zreq_receiver.recv()) } else { None }; match select_4( stop.recv(), select_option(receiver_recv), select_option_ref(handle_send.as_mut()), req_handle.recv(), ) .await { // stop.recv Select4::R1(_) => break, // receiver_recv Select4::R2(result) => match result { Ok(msg) => handle_send = Some(req_handle.send(msg)), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send Select4::R3(result) => { handle_send = None; if let Err(e) = result { error!("req send error: {}", e); } } // req_handle.recv Select4::R4(result) => match result { Ok(msg) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ResponseScratch::new()), &req_scratch_mem, ) .unwrap(); let zresp = match zhttppacket::OwnedResponse::parse(msg, 0, scratch) { Ok(zresp) => zresp, Err(e) => { warn!("worker {}: zhttp parse error: {}", id, e); continue; } }; let zresp = arena::Rc::new(zresp, &req_resp_mem).unwrap(); let mut count = 0; for id in zresp.get().get().ids { let key = match get_key(&id.id) { Ok(key) => key, Err(_) => continue, }; if !conns.check_id(key, id.id) { // key found but cid mismatch continue; } if let Some(sender) = conns.take_zreceiver_sender(key) { match select_2( stop.recv(), sender.send((arena::Rc::clone(&zresp), None)), ) .await { Select2::R1(_) => break 'main, Select2::R2(result) => match result { Ok(()) => count += 1, Err(_) => {} }, } // need to re-check for validity after await if conns.check_id(key, id.id) { conns.set_zreceiver_sender(key, sender); } } } debug!("worker {}: queued zmq message for {} conns", id, count); } Err(e) => panic!("worker {}: handle read error {}", id, e), }, } } debug!("worker {}: task stopped: req_handle", id); } async fn stream_handle_task( id: usize, stop: AsyncLocalReceiver<()>, done: AsyncLocalSender, instance_id: Rc, zstream_out_receiver: AsyncLocalReceiver, zstream_out_stream_receiver: AsyncLocalReceiver<(ArrayVec<[u8; 64]>, zmq::Message)>, stream_handle: zhttpsocket::AsyncClientStreamHandle, stream_maxconn: usize, conns: Rc, ) { let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * stream_maxconn); let stream_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let stream_resp_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); debug!("worker {}: task started: stream_handle", id); let mut handle_send_to_any = None; let mut handle_send_to_addr = None; 'main: loop { let receiver_recv = if handle_send_to_any.is_none() { Some(zstream_out_receiver.recv()) } else { None }; let stream_receiver_recv = if handle_send_to_addr.is_none() { Some(zstream_out_stream_receiver.recv()) } else { None }; match select_6( stop.recv(), select_option(receiver_recv), select_option_ref(handle_send_to_any.as_mut()), select_option(stream_receiver_recv), select_option_ref(handle_send_to_addr.as_mut()), stream_handle.recv(), ) .await { // stop.recv Select6::R1(_) => break, // receiver_recv Select6::R2(result) => match result { Ok(msg) => handle_send_to_any = Some(stream_handle.send_to_any(msg)), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send_to_any Select6::R3(result) => { handle_send_to_any = None; if let Err(e) = result { error!("stream out send error: {}", e); } } // stream_receiver_recv Select6::R4(result) => match result { Ok((addr, msg)) => { handle_send_to_addr = Some(stream_handle.send_to_addr(addr, msg)) } Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send_to_addr Select6::R5(result) => { handle_send_to_addr = None; if let Err(e) = result { error!("stream out stream send error: {}", e); } } // stream_handle.recv Select6::R6(result) => match result { Ok(msg) => { let msg_data = &msg.get()[..]; let (addr, offset) = match get_addr_and_offset(msg_data) { Ok(ret) => ret, Err(_) => { warn!("worker {}: packet has unexpected format", id); continue; } }; if addr != &*instance_id { warn!("worker {}: packet not for us", id); continue; } let scratch = arena::Rc::new( RefCell::new(zhttppacket::ResponseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zresp = match zhttppacket::OwnedResponse::parse(msg, offset, scratch) { Ok(zresp) => zresp, Err(e) => { warn!("worker {}: zhttp parse error: {}", id, e); continue; } }; let zresp = arena::Rc::new(zresp, &stream_resp_mem).unwrap(); let mut count = 0; for id in zresp.get().get().ids { let key = match get_key(&id.id) { Ok(key) => key, Err(_) => continue, }; if !conns.check_id(key, id.id) { // key found but cid mismatch continue; } if let Some(sender) = conns.take_zreceiver_sender(key) { match select_2( stop.recv(), sender.send((arena::Rc::clone(&zresp), id.seq)), ) .await { Select2::R1(_) => break 'main, Select2::R2(result) => match result { Ok(()) => count += 1, Err(_) => {} }, } // need to re-check for validity after await if conns.check_id(key, id.id) { conns.set_zreceiver_sender(key, sender); } } } debug!("worker {}: queued zmq message for {} conns", id, count); } Err(e) => panic!("worker {}: handle read error {}", id, e), }, } } drop(handle_send_to_any); drop(handle_send_to_addr); // give the handle back done.send(stream_handle).await.unwrap(); debug!("worker {}: task stopped: stream_handle", id); } async fn connection_task( stop: AsyncLocalReceiver<()>, done: channel::LocalSender, worker_id: usize, ckey: usize, mut cid: ArrayString<[u8; 32]>, mut stream: Stream, peer_addr: SocketAddr, zreceiver: channel::LocalReceiver<(arena::Rc, Option)>, conns: Rc, opts: ConnectionOpts, mode_opts: ConnectionModeOpts, shared: Option>, ) { let done = AsyncLocalSender::new(done); debug!("worker {}: task started: connection-{}", worker_id, ckey); let reactor = Reactor::current().unwrap(); let stream_registration = reactor .register_io( stream.get_tcp(), mio::Interest::READABLE | mio::Interest::WRITABLE, ) .unwrap(); let zreceiver_registration = reactor .register_custom_local(zreceiver.get_read_registration(), mio::Interest::READABLE) .unwrap(); let (zsender1_registration, zsender2_registration, mut c) = { match mode_opts { ConnectionModeOpts::Req(req_opts) => { let zsender_registration = reactor .register_custom_local( req_opts.sender.get_write_registration(), mio::Interest::WRITABLE, ) .unwrap(); let c = Connection::new_req( reactor.now(), stream, peer_addr, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.timeout, req_opts.sender, zreceiver, ); (zsender_registration, None, c) } ConnectionModeOpts::Stream(stream_opts) => { let zsender_registration = reactor .register_custom_local( stream_opts.sender.get_write_registration(), mio::Interest::WRITABLE, ) .unwrap(); let zsender_stream_registration = reactor .register_custom_local( stream_opts.sender_stream.get_write_registration(), mio::Interest::WRITABLE, ) .unwrap(); let c = Connection::new_stream( reactor.now(), stream, peer_addr, opts.buffer_size, stream_opts.messages_max, &opts.rb_tmp, opts.timeout, StreamLocalSenders::new(stream_opts.sender, stream_opts.sender_stream), zreceiver, shared.unwrap(), ); (zsender_registration, Some(zsender_stream_registration), c) } } }; let using_tls = match &c.stream { Stream::Tls(_) => true, _ => false, }; c.start(cid.as_ref()); let mut sleep = None; 'main: loop { debug!("conn {}: process", c.id); if c.process( reactor.now(), &opts.instance_id, &mut *opts.packet_buf.borrow_mut(), &mut *opts.tmp_buf.borrow_mut(), ) { break; } if c.state() == ServerState::Ready { cid = conns.regen_id(worker_id, ckey); c.start(cid.as_ref()); continue; } if let Some(want_exp_time) = c.want.timeout { let mut add = false; if let Some(exp_time) = c.timer { if want_exp_time != exp_time { add = true; } } else { add = true; } if add { sleep = Some(AsyncSleep::new(want_exp_time)); c.timer = Some(want_exp_time); } } else { if c.timer.is_some() { sleep = None; c.timer = None; } } loop { let stream_wait = if c.want.sock_read || c.want.sock_write { let interest = if c.want.sock_read && c.want.sock_write { mio::Interest::READABLE | mio::Interest::WRITABLE } else if c.want.sock_read { mio::Interest::READABLE } else { mio::Interest::WRITABLE }; Some(event_wait(&stream_registration, interest)) } else { None }; // always read zhttp response packets so they can be applied immediately, // even if c.want.zhttp_read is false let zreceiver_wait = event_wait(&zreceiver_registration, mio::Interest::READABLE); let zsender1_wait = if c.want.zhttp_write { Some(event_wait(&zsender1_registration, mio::Interest::WRITABLE)) } else { None }; let zsender2_wait = if let Some(reg) = &zsender2_registration { if c.want.zhttp_write_to { Some(event_wait(reg, mio::Interest::WRITABLE)) } else { None } } else { None }; let sleep = if let Some(sleep) = &mut sleep { Some(sleep.sleep()) } else { None }; match select_6( stop.recv(), select_option(stream_wait), zreceiver_wait, select_option(zsender1_wait), select_option(zsender2_wait), select_option(sleep), ) .await { // stop.recv Select6::R1(_) => break 'main, // stream_wait Select6::R2(readiness) => { stream_registration.set_ready(false); let readable = readiness.is_readable(); let writable = readiness.is_writable(); if readable { debug!("conn {}: sock read event", c.id); } // for TLS, set readable on all events if readable || using_tls { c.set_sock_readable(); } if writable { debug!("conn {}: sock write event", c.id); } if (readable && c.want.sock_read) || (writable && c.want.sock_write) { break; } } // zreceiver_wait Select6::R3(_) => { debug!("conn {}: zreceiver event", c.id); zreceiver_registration.set_ready(false); break; } // zsender1_wait Select6::R4(_) => { debug!("conn {}: zsender1 event", c.id); zsender1_registration.set_ready(false); break; } // zsender2_wait Select6::R5(_) => { debug!("conn {}: zsender2 event", c.id); zsender2_registration.as_ref().unwrap().set_ready(false); c.set_out_stream_can_write(); break; } // sleep Select6::R6(_) => { debug!("conn {}: timeout", c.id); break; } } } } stream_registration.deregister_io(c.get_tcp()).unwrap(); done.send(ckey).await.unwrap(); debug!("worker {}: task stopped: connection-{}", worker_id, ckey); } async fn keep_alives_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, instance_id: Rc, sender: channel::LocalSender<(ArrayVec<[u8; 64]>, zmq::Message)>, conns: Rc, ) { debug!("worker {}: task started: keep_alives", id); let reactor = Reactor::current().unwrap(); let mut keep_alive_count = 0; let mut next_keep_alive_time = reactor.now() + KEEP_ALIVE_INTERVAL; let mut next_keep_alive = AsyncSleep::new(next_keep_alive_time); let mut next_keep_alive_index = 0; let sender_registration = reactor .register_custom_local(sender.get_write_registration(), mio::Interest::WRITABLE) .unwrap(); sender_registration.set_readiness(Some(mio::Interest::WRITABLE)); 'main: loop { while conns.batch_is_empty() { // wait for next keep alive time match select_2(stop.recv(), next_keep_alive.sleep()).await { Select2::R1(_) => break 'main, Select2::R2(_) => {} } for _ in 0..conns.batch_capacity() { if next_keep_alive_index >= conns.items_capacity() { break; } let key = next_keep_alive_index; next_keep_alive_index += 1; if conns.is_item_stream(key) { // ignore errors let _ = conns.batch_add(key); } } keep_alive_count += 1; if keep_alive_count >= KEEP_ALIVE_BATCHES { keep_alive_count = 0; next_keep_alive_index = 0; } // keep steady pace next_keep_alive_time += KEEP_ALIVE_INTERVAL; next_keep_alive = AsyncSleep::new(next_keep_alive_time); } match select_2( stop.recv(), event_wait(&sender_registration, mio::Interest::WRITABLE), ) .await { Select2::R1(_) => break, Select2::R2(_) => {} } if !sender.check_send() { // if check_send returns false, we'll be on the waitlist for a notification sender_registration.clear_readiness(mio::Interest::WRITABLE); continue; } // if check_send returns true, we are guaranteed to be able to send match conns.next_batch_message(&instance_id, BatchType::KeepAlive) { Some((count, addr, msg)) => { debug!("worker {}: sending keep alives for {} sessions", id, count); if let Err(e) = sender.try_send((addr, msg)) { error!("zhttp write error: {}", e); } } None => { // this could happen if message construction failed sender.cancel(); } } if conns.batch_is_empty() { conns.batch_clear(); let now = reactor.now(); if now >= next_keep_alive_time + KEEP_ALIVE_INTERVAL { // got really behind somehow. just skip ahead next_keep_alive_time = now + KEEP_ALIVE_INTERVAL; next_keep_alive = AsyncSleep::new(next_keep_alive_time); } } } debug!("worker {}: task stopped: keep_alives", id); } } impl Drop for Worker { fn drop(&mut self) { self.stop = None; let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } pub struct Server { addrs: Vec, // underscore-prefixed because we never reference after construction _workers: Vec, _req_listener: Listener, _stream_listener: Listener, } impl Server { pub fn new( instance_id: &str, worker_count: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, listen_addrs: &[ListenConfig], certs_dir: &Path, zsockman: zhttpsocket::SocketManager, handle_bound: usize, ) -> Result { let identities = Arc::new(IdentityCache::new(certs_dir)); let mut req_tcp_listeners = Vec::new(); let mut stream_tcp_listeners = Vec::new(); let mut req_acceptor_tls = Vec::new(); let mut stream_acceptor_tls = Vec::new(); let zsockman = Arc::new(zsockman); let mut addrs = Vec::new(); for lc in listen_addrs.iter() { let l = match TcpListener::bind(lc.addr) { Ok(l) => l, Err(e) => return Err(format!("failed to bind {}: {}", lc.addr, e)), }; let addr = l.local_addr().unwrap(); info!("listening on {}", addr); addrs.push(addr); if lc.stream { stream_tcp_listeners.push(l); stream_acceptor_tls.push((lc.tls, lc.default_cert.clone())); } else { req_tcp_listeners.push(l); req_acceptor_tls.push((lc.tls, lc.default_cert.clone())); }; } let mut workers = Vec::new(); let mut req_lsenders = Vec::new(); let mut stream_lsenders = Vec::new(); for i in 0..worker_count { // rendezvous channels let (s, req_r) = channel::channel(0); req_lsenders.push(s); let (s, stream_r) = channel::channel(0); stream_lsenders.push(s); let w = Worker::new( instance_id, i, req_maxconn / worker_count, stream_maxconn / worker_count, buffer_size, body_buffer_size, messages_max, req_timeout, stream_timeout, req_r, stream_r, &req_acceptor_tls, &stream_acceptor_tls, &identities, &zsockman, handle_bound, ); workers.push(w); } let req_listener = Listener::new(req_tcp_listeners, req_lsenders); let stream_listener = Listener::new(stream_tcp_listeners, stream_lsenders); Ok(Self { addrs: addrs, _workers: workers, _req_listener: req_listener, _stream_listener: stream_listener, }) } pub fn addrs(&self) -> &[SocketAddr] { &self.addrs } } pub struct TestServer { server: Server, thread: Option>, stop: channel::Sender<()>, } impl TestServer { pub fn new(workers: usize) -> Self { let zmq_context = Arc::new(zmq::Context::new()); let req_maxconn = 100; let stream_maxconn = 100; let maxconn = req_maxconn + stream_maxconn; let mut zsockman = zhttpsocket::SocketManager::new( Arc::clone(&zmq_context), "test", (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * workers), 100, 100, ); zsockman .set_client_req_specs(&vec![SpecInfo { spec: String::from("inproc://server-test"), bind: true, ipc_file_mode: 0, }]) .unwrap(); zsockman .set_client_stream_specs( &vec![SpecInfo { spec: String::from("inproc://server-test-out"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://server-test-out-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://server-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let addr1 = "127.0.0.1:0".parse().unwrap(); let addr2 = "127.0.0.1:0".parse().unwrap(); let server = Server::new( "test", workers, req_maxconn, stream_maxconn, 1024, 1024, 10, Duration::from_secs(5), Duration::from_secs(5), &vec![ ListenConfig { addr: addr1, stream: false, tls: false, default_cert: None, }, ListenConfig { addr: addr2, stream: true, tls: false, default_cert: None, }, ], Path::new("."), zsockman, 100, ) .unwrap(); let (started_s, started_r) = channel::channel(1); let (stop_s, stop_r) = channel::channel(1); let thread = thread::spawn(move || { Self::run(started_s, stop_r, zmq_context); }); // wait for handler thread to start started_r.recv().unwrap(); Self { server, thread: Some(thread), stop: stop_s, } } pub fn req_addr(&self) -> SocketAddr { self.server.addrs()[0] } pub fn stream_addr(&self) -> SocketAddr { self.server.addrs()[1] } fn respond(id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"code")?; w.write_int(200)?; w.write_string(b"reason")?; w.write_string(b"OK")?; w.write_string(b"body")?; w.write_string(b"world\n")?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_stream(id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write(b"test T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; w.write_string(b"code")?; w.write_int(200)?; w.write_string(b"reason")?; w.write_string(b"OK")?; w.write_string(b"headers")?; w.start_array()?; w.start_array()?; w.write_string(b"Content-Length")?; w.write_string(b"6")?; w.end_array()?; w.end_array()?; w.write_string(b"body")?; w.write_string(b"world\n")?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_ws(id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write(b"test T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; w.write_string(b"code")?; w.write_int(101)?; w.write_string(b"reason")?; w.write_string(b"Switching Protocols")?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_msg( id: &[u8], seq: u32, ptype: &str, content_type: &str, body: &[u8], code: Option, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write(b"test T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(seq as isize)?; if ptype.is_empty() { w.write_string(b"content-type")?; w.write_string(content_type.as_bytes())?; } else { w.write_string(b"type")?; w.write_string(ptype.as_bytes())?; } if let Some(x) = code { w.write_string(b"code")?; w.write_int(x as isize)?; } w.write_string(b"body")?; w.write_string(body)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn run( started: channel::Sender<()>, stop: channel::Receiver<()>, zmq_context: Arc, ) { let rep_sock = zmq_context.socket(zmq::REP).unwrap(); rep_sock.connect("inproc://server-test").unwrap(); let in_sock = zmq_context.socket(zmq::PULL).unwrap(); in_sock.connect("inproc://server-test-out").unwrap(); let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock.set_identity(b"handler").unwrap(); in_stream_sock .connect("inproc://server-test-out-stream") .unwrap(); let out_sock = zmq_context.socket(zmq::XPUB).unwrap(); out_sock.connect("inproc://server-test-in").unwrap(); // ensure zsockman is subscribed let msg = out_sock.recv_msg(0).unwrap(); assert_eq!(&msg[..], b"\x01test "); started.send(()).unwrap(); let mut poller = event::Poller::new(1).unwrap(); poller .register_custom( stop.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&rep_sock.get_fd().unwrap()), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_sock.get_fd().unwrap()), mio::Token(3), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_stream_sock.get_fd().unwrap()), mio::Token(4), mio::Interest::READABLE, ) .unwrap(); let mut rep_events = rep_sock.get_events().unwrap(); let mut in_events = in_sock.get_events().unwrap(); let mut in_stream_events = in_stream_sock.get_events().unwrap(); loop { while rep_events.contains(zmq::POLLIN) { let parts = match rep_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; assert_eq!(parts.len(), 1); let msg = &parts[0]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut method = ""; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(&f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "method" => { let s = tnetstring::parse_string(&f.data).unwrap(); method = str::from_utf8(s).unwrap(); } _ => {} } } assert_eq!(method, "GET"); let msg = Self::respond(id.as_bytes()).unwrap(); rep_sock.send(msg, 0).unwrap(); rep_events = rep_sock.get_events().unwrap(); } while in_events.contains(zmq::POLLIN) { let parts = match in_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; in_events = in_sock.get_events().unwrap(); assert_eq!(parts.len(), 1); let msg = &parts[0]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut method = ""; let mut uri = ""; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(&f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "method" => { let s = tnetstring::parse_string(&f.data).unwrap(); method = str::from_utf8(s).unwrap(); } "uri" => { let s = tnetstring::parse_string(&f.data).unwrap(); uri = str::from_utf8(s).unwrap(); } _ => {} } } assert_eq!(method, "GET"); if uri.starts_with("ws:") { let msg = Self::respond_ws(id.as_bytes()).unwrap(); out_sock.send(msg, 0).unwrap(); } else { let msg = Self::respond_stream(id.as_bytes()).unwrap(); out_sock.send(msg, 0).unwrap(); } } while in_stream_events.contains(zmq::POLLIN) { let parts = match in_stream_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; in_stream_events = in_stream_sock.get_events().unwrap(); assert_eq!(parts.len(), 3); assert_eq!(parts[1].len(), 0); let msg = &parts[2]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut seq = None; let mut ptype = ""; let mut content_type = ""; let mut body = &b""[..]; let mut code = None; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(&f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "seq" => { seq = Some(tnetstring::parse_int(&f.data).unwrap() as u32); } "type" => { let s = tnetstring::parse_string(&f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "content-type" => { let s = tnetstring::parse_string(&f.data).unwrap(); content_type = str::from_utf8(s).unwrap(); } "body" => { body = tnetstring::parse_string(&f.data).unwrap(); } "code" => { code = Some(tnetstring::parse_int(&f.data).unwrap() as u16); } _ => {} } } let seq = seq.unwrap(); // as a hack to make the test server stateless, respond to every message // using the received sequence number. for messages we don't care about, // respond with keep-alive in order to keep the sequencing going if ptype.is_empty() || ptype == "ping" || ptype == "pong" || ptype == "close" { if ptype == "ping" { ptype = "pong"; } let msg = Self::respond_msg(id.as_bytes(), seq, ptype, content_type, body, code) .unwrap(); out_sock.send(msg, 0).unwrap(); } else { let msg = Self::respond_msg(id.as_bytes(), seq, "keep-alive", "", &b""[..], None) .unwrap(); out_sock.send(msg, 0).unwrap(); } } poller.poll(None).unwrap(); let mut done = false; for event in poller.iter_events() { match event.token() { mio::Token(1) => { if stop.try_recv().is_ok() { done = true; break; } } mio::Token(2) => { rep_events = rep_sock.get_events().unwrap(); } mio::Token(3) => { in_events = in_sock.get_events().unwrap(); } mio::Token(4) => { in_stream_events = in_stream_sock.get_events().unwrap(); } _ => unreachable!(), } } if done { break; } } } } impl Drop for TestServer { fn drop(&mut self) { self.stop.try_send(()).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] pub mod tests { use super::*; use crate::websocket; use std::io::Read; #[test] fn test_batch() { let mut batch = Batch::new(3); assert_eq!(batch.capacity(), 3); assert_eq!(batch.len(), 0); assert_eq!(batch.last_group_ckeys(), &[]); assert!(batch.add(b"addr-a", 1).is_ok()); assert!(batch.add(b"addr-a", 2).is_ok()); assert!(batch.add(b"addr-b", 3).is_ok()); assert_eq!(batch.len(), 3); assert!(batch.add(b"addr-c", 4).is_err()); assert_eq!(batch.len(), 3); assert_eq!(batch.is_empty(), false); let ids = ["id-1", "id-2", "id-3"]; let group = batch .take_group(|ckey| (ids[ckey - 1].as_bytes(), 0)) .unwrap(); assert_eq!(group.ids().len(), 2); assert_eq!(group.ids()[0].id, b"id-1"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.ids()[1].id, b"id-2"); assert_eq!(group.ids()[1].seq, Some(0)); assert_eq!(group.addr(), b"addr-a"); drop(group); assert_eq!(batch.is_empty(), false); assert_eq!(batch.last_group_ckeys(), &[1, 2]); let group = batch .take_group(|ckey| (ids[ckey - 1].as_bytes(), 0)) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-3"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); drop(group); assert_eq!(batch.is_empty(), true); assert_eq!(batch.last_group_ckeys(), &[3]); assert!(batch .take_group(|ckey| { (ids[ckey - 1].as_bytes(), 0) }) .is_none()); assert_eq!(batch.last_group_ckeys(), &[3]); } #[test] fn test_server() { let server = TestServer::new(1); // req let mut client = std::net::TcpStream::connect(&server.req_addr()).unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n" ); // stream (http) let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n" ); // stream (ws) let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); let req = concat!( "GET /hello HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n", ); client.write(req.as_bytes()).unwrap(); let mut buf = Vec::new(); let mut resp_end = 0; loop { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { resp_end = i + 4; break; } } if resp_end > 0 { break; } } let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&buf[..resp_end]).unwrap(), expected); buf = buf.split_off(resp_end); // send message let mut data = vec![0; 1024]; let body = &b"hello"[..]; let size = websocket::write_header(true, websocket::OPCODE_TEXT, body.len(), None, &mut data) .unwrap(); &mut data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv message let mut msg = Vec::new(); loop { let fi = match websocket::read_header(&buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert!(size > 0); buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => panic!("{:?}", e), }; assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); msg.extend_from_slice(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); break; } assert_eq!(str::from_utf8(&msg).unwrap(), "hello"); } #[test] fn test_ws() { let server = TestServer::new(1); let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); let req = concat!( "GET /hello HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n", ); client.write(req.as_bytes()).unwrap(); let mut buf = Vec::new(); let mut resp_end = 0; loop { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { resp_end = i + 4; break; } } if resp_end > 0 { break; } } let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&buf[..resp_end]).unwrap(), expected); buf = buf.split_off(resp_end); // send binary let mut data = vec![0; 1024]; let body = &[1, 2, 3][..]; let size = websocket::write_header(true, websocket::OPCODE_BINARY, body.len(), None, &mut data) .unwrap(); &mut data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv binary let mut msg = Vec::new(); loop { let fi = match websocket::read_header(&buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert!(size > 0); buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => panic!("{:?}", e), }; assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_BINARY); msg.extend_from_slice(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); break; } assert_eq!(msg, &[1, 2, 3][..]); buf.clear(); // send ping let mut data = vec![0; 1024]; let body = &b""[..]; let size = websocket::write_header(true, websocket::OPCODE_PING, body.len(), None, &mut data) .unwrap(); client.write(&data[..size]).unwrap(); // recv pong let mut msg = Vec::new(); loop { let fi = match websocket::read_header(&buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert!(size > 0); buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => panic!("{:?}", e), }; assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_PONG); msg.extend_from_slice(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); break; } assert_eq!(str::from_utf8(&msg).unwrap(), ""); buf.clear(); // send close let mut data = vec![0; 1024]; let body = &b"\x03\xf0gone"[..]; let size = websocket::write_header(true, websocket::OPCODE_CLOSE, body.len(), None, &mut data) .unwrap(); &mut data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv close let mut msg = Vec::new(); loop { let fi = match websocket::read_header(&buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert!(size > 0); buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => panic!("{:?}", e), }; assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_CLOSE); msg.extend_from_slice(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); break; } assert_eq!(msg, &b"\x03\xf0gone"[..]); // expect tcp close let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert_eq!(size, 0); } } condure-1.3.1/src/shuffle.rs000064400000000000000000000037110000000000000140350ustar 00000000000000/* * Copyright (c) 2016 Alex Crichton * Copyright (c) 2017 The Tokio Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::{ cell::Cell, collections::hash_map::DefaultHasher, hash::Hasher, num::Wrapping, sync::atomic::{AtomicUsize, Ordering}, }; // Based on [Fisher–Yates shuffle]. // // [Fisher–Yates shuffle]: https://en.wikipedia.org/wiki/Fisher–Yates_shuffle pub fn shuffle(slice: &mut [T]) { for i in (1..slice.len()).rev() { slice.swap(i, gen_index(i + 1)); } } /// Return a value from `0..n`. fn gen_index(n: usize) -> usize { (random() % n as u64) as usize } /// Pseudorandom number generator based on [xorshift*]. /// /// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift* fn random() -> u64 { thread_local! { static RNG: Cell> = Cell::new(Wrapping(prng_seed())); } fn prng_seed() -> u64 { static COUNTER: AtomicUsize = AtomicUsize::new(0); // Any non-zero seed will do let mut seed = 0; while seed == 0 { let mut hasher = DefaultHasher::new(); hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed)); seed = hasher.finish(); } seed } RNG.with(|rng| { let mut x = rng.get(); debug_assert_ne!(x.0, 0); x ^= x >> 12; x ^= x << 25; x ^= x >> 27; rng.set(x); x.0.wrapping_mul(0x2545_f491_4f6c_dd1d) }) } condure-1.3.1/src/timer.rs000064400000000000000000000406450000000000000135300ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // adapted from http://25thandclement.com/~william/projects/timeout.c.html (MIT licensed) use crate::list; use slab::Slab; use std::cmp; const WHEEL_BITS: usize = 6; const WHEEL_NUM: usize = 4; const WHEEL_LEN: usize = 1 << WHEEL_BITS; const WHEEL_MAX: usize = WHEEL_LEN - 1; const WHEEL_MASK: u64 = (WHEEL_LEN as u64) - 1; const TIMEOUT_MAX: u64 = (1 << (WHEEL_BITS * WHEEL_NUM)) - 1; // find last set fn fls64(x: u64) -> u32 { 64 - x.leading_zeros() } fn need_resched(curtime: u64, newtime: u64) -> [u64; WHEEL_NUM] { let mut result = [0; WHEEL_NUM]; // no time elapsed if newtime <= curtime { return result; } let mut elapsed = newtime - curtime; for wheel in 0..WHEEL_NUM { // we only care about the highest bits let trunc_bits = (wheel * WHEEL_BITS) as u64; let pending; if (elapsed >> trunc_bits) > (WHEEL_MAX as u64) { // all slots need processing pending = !0; } else { let old_slot = (curtime >> trunc_bits) & WHEEL_MASK; let new_slot = (newtime >> trunc_bits) & WHEEL_MASK; let d = if new_slot > old_slot { new_slot - old_slot } else { (WHEEL_LEN as u64) - old_slot + new_slot }; pending = if d >= WHEEL_LEN as u64 { !0 } else if wheel > 0 { ((1 << d) - 1u64).rotate_left(old_slot as u32) } else { ((1 << d) - 1u64).rotate_left((old_slot + 1) as u32) }; } result[wheel] = pending; let finished_bit = if wheel > 0 { // higher wheels have completed a full rotation when slot 63 is processed 1 << (WHEEL_LEN - 1) } else { // lowest wheel has completed a full rotation when slot 0 is processed 1 }; // if the current wheel didn't finish a full rotation then we don't need to look // at higher wheels if pending & finished_bit == 0 { break; } // ensure the elapsed time includes the current slot of the next wheel elapsed = cmp::max(elapsed, (WHEEL_LEN << (wheel * WHEEL_BITS)) as u64); } result } #[cfg(test)] fn need_resched_simple(curtime: u64, newtime: u64) -> [u64; WHEEL_NUM] { let mut result = [0; WHEEL_NUM]; // no time elapsed if newtime <= curtime { return result; } for curtime in curtime..newtime { for wheel in 0..WHEEL_NUM { let trunc_bits = (wheel * WHEEL_BITS) as u64; let old_slot = (curtime >> trunc_bits) & WHEEL_MASK; let new_slot = ((curtime + 1) >> trunc_bits) & WHEEL_MASK; if old_slot != new_slot { if wheel > 0 { result[wheel] |= 1 << old_slot; } else { result[wheel] |= 1 << new_slot; } } } } result } enum InList { Wheel(usize, usize), Expired, } struct Timer { expires: u64, list: Option, user_data: usize, } pub struct TimerWheel { nodes: Slab>, wheel: [[list::List; WHEEL_LEN]; WHEEL_NUM], expired: list::List, pending: [u64; WHEEL_NUM], curtime: u64, } impl TimerWheel { pub fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), wheel: [[list::List::default(); WHEEL_LEN]; WHEEL_NUM], expired: list::List::default(), pending: [0; WHEEL_NUM], curtime: 0, } } pub fn add(&mut self, expires: u64, user_data: usize) -> Result { if self.nodes.len() == self.nodes.capacity() { return Err(()); } let t = Timer { expires, list: None, user_data, }; let key = self.nodes.insert(list::Node::new(t)); self.sched(key); Ok(key) } pub fn remove(&mut self, key: usize) { let n = match self.nodes.get(key) { Some(n) => n, None => return, }; match n.value.list { Some(InList::Wheel(wheel, slot)) => { let l = &mut self.wheel[wheel][slot]; l.remove(&mut self.nodes, key); if l.is_empty() { self.pending[wheel] &= !(1 << slot); } } Some(InList::Expired) => { self.expired.remove(&mut self.nodes, key); } None => {} } self.nodes.remove(key); } pub fn timeout(&self) -> Option { if !self.expired.is_empty() { return Some(0); } let mut timeout = None; let mut relmask = 0; for wheel in 0..WHEEL_NUM { // we only care about the highest bits let trunc_bits = (wheel * WHEEL_BITS) as u64; if self.pending[wheel] != 0 { let slot = ((self.curtime >> trunc_bits) & WHEEL_MASK) as usize; let pending = self.pending[wheel].rotate_right(slot as u32); // for higher order wheels, timeouts are one step in the future let offset = if wheel > 0 { 1 } else { 0 }; // pending is guaranteed to be non-zero let t = ((pending.trailing_zeros() as u64) + offset) << trunc_bits; // reduce by how much lower wheels have progressed let t = t - (relmask & self.curtime); timeout = Some(match timeout { Some(best) => cmp::min(best, t), None => t, }); } relmask <<= WHEEL_BITS; relmask |= WHEEL_MASK; } timeout } pub fn update(&mut self, curtime: u64) { // time must go forward if curtime <= self.curtime { return; } let need = need_resched(self.curtime, curtime); let mut l = list::List::default(); for wheel in 0..WHEEL_NUM { let pending = need[wheel]; // loop as long as we still have slots to process while pending & self.pending[wheel] != 0 { // get rightmost (earliest) slot that needs processing let slot = (pending & self.pending[wheel]).trailing_zeros() as usize; // move the timers out l.concat(&mut self.nodes, &mut self.wheel[wheel][slot]); self.pending[wheel] &= !(1 << slot); } } self.curtime = curtime; while let Some(key) = l.head { l.remove(&mut self.nodes, key); let n = &mut self.nodes[key]; n.value.list = None; self.sched(key); } } pub fn take_expired(&mut self) -> Option<(usize, usize)> { match self.expired.pop_front(&mut self.nodes) { Some(key) => { let n = &self.nodes[key]; let user_data = n.value.user_data; self.nodes.remove(key); Some((key, user_data)) } None => None, } } fn sched(&mut self, key: usize) { let n = &self.nodes[key]; let expires = n.value.expires; if expires > self.curtime { // get relative timeout, capped let t = cmp::min(expires - self.curtime, TIMEOUT_MAX); assert!(t > 0); // wheel is selected by relative time // t = 0 = not valid // t = 1 = 0b0_000000_000001 -> fls 1, wheel 0 // t = 63 = 0b0_000000_111111 -> fls 6, wheel 0 // t = 64 = 0b0_000001_000000 -> fls 7, wheel 1 // t = 4032 = 0b0_111111_000000 -> fls 12, wheel 1 // t = 4095 = 0b0_111111_111111 -> fls 12, wheel 1 // t = 4096 = 0b1_000000_000000 -> fls 13, wheel 2 let wheel = ((fls64(t) - 1) as usize) / WHEEL_BITS; assert!(wheel < WHEEL_NUM); // we only care about the highest bits let trunc_bits = (wheel * WHEEL_BITS) as u64; // for higher order wheels, schedule 1 slot early. this way, fractional // time remaining can be rescheduled to a lower wheel let offset = if wheel > 0 { 1 } else { 0 }; // slot is selected by absolute time let slot = (((expires >> trunc_bits) - offset) & WHEEL_MASK) as usize; self.wheel[wheel][slot].push_back(&mut self.nodes, key); self.pending[wheel] |= 1 << slot; let n = &mut self.nodes[key]; n.value.list = Some(InList::Wheel(wheel, slot)); } else { self.expired.push_back(&mut self.nodes, key); let n = &mut self.nodes[key]; n.value.list = Some(InList::Expired); } } } #[cfg(test)] mod tests { use super::*; // convert string time of the form "x:x:x:x", where each part is a number between 0-63 fn ts(s: &str) -> u64 { let mut result = 0; for (i, part) in s.rsplit(":").enumerate() { let x: u64 = part.parse().unwrap(); assert!(x <= (WHEEL_MAX as u64)); result |= x << (i * WHEEL_BITS); } result } // convert string range to bits fn r2b(s: &str) -> u64 { let mut it = s.split("-"); let start = it.next().unwrap(); let end = it.next().unwrap(); assert_eq!(it.next(), None); let mut pos: u64 = start.parse().unwrap(); let end: u64 = end.parse().unwrap(); let mut result = 0; loop { result |= 1 << pos; if pos == end { break; } pos = (pos + 1) & WHEEL_MASK; } result } // convert wheel ranges of the form "x:x:x:x", where each part is a range fn r2w(s: &str) -> [u64; WHEEL_NUM] { let mut result = [0; WHEEL_NUM]; for (i, part) in s.rsplit(":").enumerate() { if !part.is_empty() { result[i] = r2b(part); } } result } #[test] fn test_fls() { assert_eq!(fls64(0), 0); assert_eq!(fls64(0b1), 1); assert_eq!(fls64(0b10), 2); assert_eq!(fls64(0x4000000000000000), 63); assert_eq!(fls64(0x8000000000000000), 64); } #[test] fn test_sched() { let mut w = TimerWheel::new(10); w.update(7); // expired let t1 = w.add(0b0_000000, 1).unwrap(); // wheel 0 slot 8 (1 tick away) let t2 = w.add(0b0_001000, 1).unwrap(); // wheel 0 slot 63 (56 ticks away) let t3 = w.add(0b0_111111, 1).unwrap(); // wheel 0 slot 0 (57 ticks away) let t4 = w.add(0b1_000000, 1).unwrap(); // wheel 1 slot 0 let t5 = w.add(0b1_001000, 1).unwrap(); // wheel 3 slot 63 let t6 = w.add(0b1_000000_000000_000000_000000, 1).unwrap(); assert_eq!(w.expired.head, Some(t1)); assert_eq!(w.wheel[0][8].head, Some(t2)); assert_eq!(w.wheel[0][63].head, Some(t3)); assert_eq!(w.wheel[0][0].head, Some(t4)); assert_eq!(w.wheel[1][0].head, Some(t5)); assert_eq!(w.wheel[3][63].head, Some(t6)); } #[test] fn test_need_resched() { struct Test { curtime: &'static str, newtime: &'static str, expected: &'static str, } fn t(curtime: &'static str, newtime: &'static str, expected: &'static str) -> Test { Test { curtime, newtime, expected, } } let table = [ t("00:00", "00:00", ""), t("00:00", "00:01", "01-01"), t("00:01", "00:02", "02-02"), t("00:02", "00:63", "03-63"), t("00:63", "01:00", "00-00:00-00"), t("01:00", "01:02", "01-02"), t("01:02", "05:01", "01-04:00-63"), t("05:01", "05:02", "02-02"), t("05:02", "06:01", "05-05:03-01"), t("00:63:63", "01:00:00", "00-00:63-63:00-00"), t("08:00:00", "08:01:00", "00-00:00-63"), t("04:00:02", "05:01:00", "04-04:00-63:00-63"), t("04:01:02", "05:00:00", "04-04:01-63:00-63"), t("04:00:03", "05:00:00", "04-04:00-63:00-63"), t("04:00:02", "05:00:00", "04-04:00-63:00-63"), t("08:00:19", "08:62:63", "00-61:00-63"), t("08:00:19", "08:63:63", "00-62:00-63"), t("09:00:00", "09:63:62", "00-62:00-63"), ]; for (row, t) in table.iter().enumerate() { let curtime = ts(t.curtime); let newtime = ts(t.newtime); let expected = r2w(t.expected); // ensure the simple algorithm returns what we expect assert_eq!( need_resched_simple(curtime, newtime), expected, "row={} curtime={} newtime={}", row, curtime, newtime ); // ensure the optimized algorithm returns matching results assert_eq!( need_resched(curtime, newtime), expected, "row={} curtime={} newtime={}", row, curtime, newtime ); } } #[test] fn test_rotate() { // test full rotations through wheels 0 and 1, and one step of wheel 2 let count = (64 * 64) + 1; let mut w = TimerWheel::new(count); for i in 0..count { w.add(i as u64, i).unwrap(); } for i in 0..count { let (_, v) = w.take_expired().unwrap(); assert_eq!(v, i); assert_eq!(w.take_expired(), None); w.update((i + 1) as u64); } assert_eq!(w.take_expired(), None); } #[test] fn test_wheel() { let mut w = TimerWheel::new(10); assert_eq!(w.timeout(), None); assert_eq!(w.take_expired(), None); let t1 = w.add(4, 1).unwrap(); assert_eq!(w.timeout(), Some(4)); w.remove(t1); assert_eq!(w.timeout(), None); w.update(5); assert_eq!(w.take_expired(), None); let t2 = w.add(8, 2).unwrap(); assert_eq!(w.timeout(), Some(3)); w.update(7); assert_eq!(w.timeout(), Some(1)); assert_eq!(w.take_expired(), None); w.update(8); assert_eq!(w.timeout(), Some(0)); assert_eq!(w.take_expired(), Some((t2, 2))); assert_eq!(w.take_expired(), None); for i in 0..2 { let base = i * 20_000_000; let t1 = w.add(base + 1, 1).unwrap(); let t2 = w.add(base + 10, 2).unwrap(); let t3 = w.add(base + 1_000, 3).unwrap(); let t4 = w.add(base + 100_000, 4).unwrap(); let t5 = w.add(base + 10_000_000, 5).unwrap(); w.update(base + 100); assert_eq!(w.timeout(), Some(0)); assert_eq!(w.take_expired(), Some((t1, 1))); assert_eq!(w.take_expired(), Some((t2, 2))); assert_eq!(w.take_expired(), None); assert!(w.timeout().unwrap() <= 900); w.update(base + 2_000); assert_eq!(w.timeout(), Some(0)); assert_eq!(w.take_expired(), Some((t3, 3))); assert_eq!(w.take_expired(), None); assert!(w.timeout().unwrap() <= 98_000); w.update(base + 200_000); assert_eq!(w.timeout(), Some(0)); assert_eq!(w.take_expired(), Some((t4, 4))); assert_eq!(w.take_expired(), None); assert!(w.timeout().unwrap() <= 9_800_000); w.update(base + 12_000_000); assert_eq!(w.timeout(), Some(0)); assert_eq!(w.take_expired(), Some((t5, 5))); assert_eq!(w.take_expired(), None); assert_eq!(w.timeout(), None); } } } condure-1.3.1/src/tls.rs000064400000000000000000000326170000000000000132120ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use arrayvec::ArrayString; use log::debug; use mio::net::TcpStream; use openssl::error::ErrorStack; use openssl::ssl::{ HandshakeError, MidHandshakeSslStream, NameType, SniError, SslAcceptor, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslStream, }; use std::cmp; use std::collections::HashMap; use std::fmt; use std::fs; use std::io; use std::io::{Read, Write}; use std::mem; use std::path; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::{Arc, Mutex, MutexGuard}; use std::time::SystemTime; const DOMAIN_LEN_MAX: usize = 253; enum IdentityError { InvalidName, CertMetadata(PathBuf, io::Error), KeyMetadata(PathBuf, io::Error), SslContext(ErrorStack), CertContent(PathBuf, ErrorStack), KeyContent(PathBuf, ErrorStack), CertCheck(ErrorStack), } impl fmt::Display for IdentityError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidName => write!(f, "invalid name"), Self::CertMetadata(fname, e) => { write!(f, "failed to read cert file metadata {:?}: {}", fname, e) } Self::KeyMetadata(fname, e) => { write!(f, "failed to read key file metadata {:?}: {}", fname, e) } Self::SslContext(e) => write!(f, "failed to create SSL context: {}", e), Self::CertContent(fname, e) => { write!(f, "failed to read cert content {:?}: {}", fname, e) } Self::KeyContent(fname, e) => { write!(f, "failed to read key content {:?}: {}", fname, e) } Self::CertCheck(e) => write!(f, "failed to check private key: {}", e), } } } struct Identity { ssl_context: SslContext, cert_fname: PathBuf, key_fname: PathBuf, modified: Option, } impl Identity { fn from_name(dir: &Path, name: &str) -> Result { // forbid long names if name.len() > DOMAIN_LEN_MAX { return Err(IdentityError::InvalidName); } // forbid control chars and '/', for safe filesystem usage for c in name.chars() { if (c as u32) < 0x20 || path::is_separator(c) { return Err(IdentityError::InvalidName); } } let cert_fname = dir.join(Path::new(&format!("{}.crt", name))); let cert_metadata = match fs::metadata(&cert_fname) { Ok(md) => md, Err(e) => return Err(IdentityError::CertMetadata(cert_fname, e)), }; let key_fname = dir.join(Path::new(&format!("{}.key", name))); let key_metadata = match fs::metadata(&key_fname) { Ok(md) => md, Err(e) => return Err(IdentityError::KeyMetadata(key_fname, e)), }; let cert_modified = cert_metadata.modified(); let key_modified = key_metadata.modified(); let modified = if cert_modified.is_ok() && key_modified.is_ok() { Some(cmp::max(cert_modified.unwrap(), key_modified.unwrap())) } else { None }; let mut ctx = match SslContextBuilder::new(SslMethod::tls()) { Ok(ctx) => ctx, Err(e) => return Err(IdentityError::SslContext(e)), }; if let Err(e) = ctx.set_certificate_chain_file(&cert_fname) { return Err(IdentityError::CertContent(cert_fname, e)); } if let Err(e) = ctx.set_private_key_file(&key_fname, SslFiletype::PEM) { return Err(IdentityError::KeyContent(key_fname, e)); } if let Err(e) = ctx.check_private_key() { return Err(IdentityError::CertCheck(e)); } Ok(Self { ssl_context: ctx.build(), cert_fname, key_fname, modified, }) } } fn modified_after(fnames: &[&Path], t: SystemTime) -> Result { for fname in fnames { match fs::metadata(fname)?.modified() { Ok(modified) if modified > t => return Ok(true), _ => {} } } Ok(false) } struct IdentityRef<'a> { _data: MutexGuard<'a, HashMap>, name: &'a str, value: &'a Identity, } pub struct IdentityCache { dir: PathBuf, data: Mutex>, } impl IdentityCache { pub fn new(certs_dir: &Path) -> Self { Self { dir: certs_dir.to_path_buf(), data: Mutex::new(HashMap::new()), } } fn get_by_domain<'a>(&'a self, domain: &str) -> Option> { let name = domain.to_lowercase(); // try to find a file named after the exact host, then try with a // wildcard pattern at the same subdomain level. the filename // format uses underscores instead of asterisks. so, a domain of // www.example.com will attempt to be matched against a file named // www.example.com.crt and _.example.com.crt. wildcards at other // levels are not supported if let Some(identity) = self.get_by_name(&name) { return Some(identity); } let pos = match name.find('.') { Some(pos) => pos, None => return None, }; let name = format!("_{}", &name[pos..]); if let Some(identity) = self.get_by_name(&name) { return Some(identity); } None } fn get_by_name<'a>(&'a self, name: &str) -> Option> { self.ensure_updated(name); let data = self.data.lock().unwrap(); if let Some((name, value)) = data.get_key_value(name) { // extending the lifetimes is safe because we keep the owning MutexGuard let name = unsafe { mem::transmute::<&String, &'a String>(name) }; let value = unsafe { mem::transmute::<&Identity, &'a Identity>(value) }; Some(IdentityRef { _data: data, name: name.as_str(), value, }) } else { None } } fn ensure_updated(&self, name: &str) { let mut data = self.data.lock().unwrap(); let mut update = false; if let Some(value) = data.get(name) { if let Some(modified) = value.modified { update = match modified_after(&[&value.cert_fname, &value.key_fname], modified) { Ok(b) => b, Err(_) => true, }; } } else { update = true; } if update { let identity = match Identity::from_name(&self.dir, name) { Ok(identity) => identity, Err(e) => { debug!("failed to load cert {}: {}", name, e); return; } }; data.insert(String::from(name), identity); debug!("loaded cert: {}", name); } } } enum Stream<'a> { Ssl(SslStream<&'a mut TcpStream>), MidHandshakeSsl(MidHandshakeSslStream<&'a mut TcpStream>), NoSsl, } pub struct TlsAcceptor { acceptor: SslAcceptor, } impl TlsAcceptor { pub fn new(cache: &Arc, default_cert: Option<&str>) -> Self { let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let cache = Arc::clone(cache); let default_cert: Option = default_cert.map(|s| s.to_owned()); acceptor.set_servername_callback(move |ssl, _| { let identity = match ssl.servername(NameType::HOST_NAME) { Some(name) => { debug!("tls server name: {}", name); match cache.get_by_domain(name) { Some(ctx) => ctx, None => match &default_cert { Some(default_cert) => match cache.get_by_name(default_cert) { Some(ctx) => ctx, None => return Err(SniError::ALERT_FATAL), }, None => return Err(SniError::ALERT_FATAL), }, } } None => match &default_cert { Some(default_cert) => match cache.get_by_name(default_cert) { Some(ctx) => ctx, None => return Err(SniError::ALERT_FATAL), }, None => return Err(SniError::ALERT_FATAL), }, }; debug!("using cert: {}", identity.name); if ssl.set_ssl_context(&identity.value.ssl_context).is_err() { return Err(SniError::ALERT_FATAL); } Ok(()) }); Self { acceptor: acceptor.build(), } } pub fn accept(&self, stream: TcpStream) -> Result { let mut tcp_stream_boxed = Box::new(stream); let tcp_stream: &mut TcpStream = &mut tcp_stream_boxed; // safety: TlsStream will take ownership of tcp_stream, and the value // referred to by tcp_stream is on the heap, and tcp_stream will not // be dropped until TlsStream is dropped, so the value referred to // by tcp_stream will remain valid for the lifetime of TlsStream. // further, tcp_stream is a mutable reference, and will only ever // be exclusively mutably accessed, either when wrapped by SslStream // or MidHandshakeSslStream, or when known to be not wrapped let tcp_stream: &'static mut TcpStream = unsafe { mem::transmute(tcp_stream) }; let stream = match self.acceptor.accept(tcp_stream) { Ok(stream) => Stream::Ssl(stream), Err(HandshakeError::SetupFailure(e)) => return Err(e), Err(HandshakeError::Failure(stream)) => Stream::MidHandshakeSsl(stream), Err(HandshakeError::WouldBlock(stream)) => Stream::MidHandshakeSsl(stream), }; Ok(TlsStream { stream, tcp_stream: tcp_stream_boxed, id: ArrayString::new(), }) } } pub struct TlsStream { stream: Stream<'static>, tcp_stream: Box, id: ArrayString<[u8; 32]>, } impl TlsStream { pub fn get_tcp(&mut self) -> &mut TcpStream { match &mut self.stream { Stream::Ssl(stream) => stream.get_mut(), Stream::MidHandshakeSsl(stream) => stream.get_mut(), Stream::NoSsl => &mut self.tcp_stream, } } pub fn set_id(&mut self, id: &str) { self.id = ArrayString::from_str(id).unwrap(); } pub fn shutdown(&mut self) -> Result<(), io::Error> { match &mut self.stream { Stream::Ssl(stream) => match stream.shutdown() { Ok(_) => { debug!("conn {}: tls shutdown sent", self.id); Ok(()) } Err(e) => Err(match e.into_io_error() { Ok(e) => e, Err(_) => io::Error::from(io::ErrorKind::Other), }), }, _ => Err(io::Error::from(io::ErrorKind::Other)), } } fn ensure_handshake(&mut self) -> Result<(), io::Error> { match &self.stream { Stream::Ssl(_) => Ok(()), Stream::MidHandshakeSsl(_) => match mem::replace(&mut self.stream, Stream::NoSsl) { Stream::MidHandshakeSsl(stream) => match stream.handshake() { Ok(stream) => { debug!("conn {}: tls handshake success", self.id); self.stream = Stream::Ssl(stream); Ok(()) } Err(HandshakeError::SetupFailure(_)) => { Err(io::Error::from(io::ErrorKind::Other)) } Err(HandshakeError::Failure(_)) => Err(io::Error::from(io::ErrorKind::Other)), Err(HandshakeError::WouldBlock(stream)) => { self.stream = Stream::MidHandshakeSsl(stream); Err(io::Error::from(io::ErrorKind::WouldBlock)) } }, _ => unreachable!(), }, Stream::NoSsl => Err(io::Error::from(io::ErrorKind::Other)), } } } impl Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> Result { self.ensure_handshake()?; match &mut self.stream { Stream::Ssl(stream) => SslStream::read(stream, buf), _ => unreachable!(), } } } impl Write for TlsStream { fn write(&mut self, buf: &[u8]) -> Result { self.ensure_handshake()?; match &mut self.stream { Stream::Ssl(stream) => SslStream::write(stream, buf), _ => unreachable!(), } } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } condure-1.3.1/src/tnetstring.rs000064400000000000000000000701650000000000000146110ustar 00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::ascii; use std::fmt; use std::io; use std::io::Write; use std::str; const F64_SIZE_MAX: usize = 64; const OPS_MAX: usize = 1_000; const TRUE_BYTES: &[u8] = b"true"; const FALSE_BYTES: &[u8] = b"false"; fn usize_len(value: usize) -> usize { let mut x = value; let mut len = 1; while x >= 10 { x /= 10; len += 1; } len } fn isize_len(value: isize) -> usize { let mut x = value; let mut len: usize = 1; if x < 0 { x = x.abs(); len += 1; } while x >= 10 { x /= 10; len += 1; } len } fn f64_len(value: f64) -> usize { let mut buf = [0; F64_SIZE_MAX]; let mut cursor = io::Cursor::new(&mut buf[..]); write!(&mut cursor, "{}", value).unwrap(); cursor.position() as usize } fn bool_bytes(value: bool) -> &'static [u8] { if value { TRUE_BYTES } else { FALSE_BYTES } } fn bool_len(value: bool) -> usize { bool_bytes(value).len() } fn write_exact(w: &mut io::Cursor<&mut [u8]>, data: &[u8]) -> Result<(), io::Error> { let size = w.write(data)?; if size < data.len() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(()) } #[derive(Copy, Clone, Debug, PartialEq)] pub enum FrameType { Null, Bool, Int, Float, String, Array, Map, } impl fmt::Display for FrameType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s: &str = (*self).into(); write!(f, "{}", s) } } impl From for &str { fn from(ftype: FrameType) -> &'static str { match ftype { FrameType::Null => "null", FrameType::Bool => "bool", FrameType::Int => "int", FrameType::Float => "float", FrameType::String => "string", FrameType::Array => "array", FrameType::Map => "map", } } } #[derive(Debug, PartialEq)] pub enum ParseError { UnexpectedEof, InvalidData, WrongType(FrameType, FrameType), // got, expected InvalidKey, } impl fmt::Display for ParseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::UnexpectedEof => write!(f, "unexpected eof"), Self::InvalidData => write!(f, "invalid data"), Self::WrongType(got, expected) => { write!(f, "wrong type {}, expected {}", got, expected) } Self::InvalidKey => write!(f, "map key must be a utf-8 string"), } } } #[derive(Copy, Clone)] enum Op<'a> { Invalid, #[allow(dead_code)] Null, Bool(bool), Int(isize), #[allow(dead_code)] Float(f64), String(&'a [u8]), StartArray, EndArray, StartMap, EndMap, } impl Op<'_> { fn serialize(&self, w: &mut io::Cursor<&mut [u8]>, len: usize) -> Result<(), io::Error> { match self { Op::Invalid => unreachable!(), Op::Null => { write!(w, "{}:~", len)?; } Op::Bool(b) => { write!(w, "{}:", len)?; write_exact(w, bool_bytes(*b))?; write_exact(w, b"!")?; } Op::Int(x) => { write!(w, "{}:{}#", len, *x)?; } Op::Float(x) => { write!(w, "{}:{}^", len, *x)?; } Op::String(s) => { write!(w, "{}:", len)?; write_exact(w, s)?; write_exact(w, b",")?; } Op::StartArray | Op::StartMap => { write!(w, "{}:", len)?; } Op::EndArray => { write_exact(w, b"]")?; } Op::EndMap => { write_exact(w, b"}")?; } } Ok(()) } } // calculate the length of the first op and any dependants // return the number of ops processed fn calc_len(ops: &[Op], lens: &mut [usize]) -> usize { assert!(ops.len() > 0); assert_eq!(ops.len(), lens.len()); let (len, count) = match ops[0] { Op::Invalid => unreachable!(), Op::Null => (0, 1), Op::Bool(b) => (bool_len(b), 1), Op::Int(x) => (isize_len(x), 1), Op::Float(x) => (f64_len(x), 1), Op::String(s) => (s.len(), 1), Op::StartArray => { let mut total = 0; let mut end = None; let mut i = 1; while i < ops.len() { match ops[i] { Op::EndArray => { end = Some(i + 1); break; } Op::EndMap => panic!("unexpected EndMap"), _ => { let count = calc_len(&ops[i..], &mut lens[i..]); let len = lens[i]; total += usize_len(len) + len + 2; i += count; } } } if end.is_none() { panic!("expected EndArray"); } (total, end.unwrap()) } Op::EndArray => panic!("EndArray without StartArray"), Op::StartMap => { let mut total = 0; let mut end = None; let mut i = 1; while i < ops.len() { match ops[i] { Op::EndMap => { end = Some(i + 1); break; } Op::EndArray => panic!("unexpected EndArray"), _ => { let count = calc_len(&ops[i..], &mut lens[i..]); let len = lens[i]; total += usize_len(len) + len + 2; i += count; } } } if end.is_none() { panic!("expected EndMap"); } (total, end.unwrap()) } Op::EndMap => panic!("EndMap without StartMap"), }; lens[0] = len; count } pub struct Writer<'a, 'b> { ops: [Op<'a>; OPS_MAX], len: usize, dest: &'a mut io::Cursor<&'b mut [u8]>, } impl<'a, 'b> Writer<'a, 'b> { pub fn new(dest: &'a mut io::Cursor<&'b mut [u8]>) -> Self { Self { ops: [Op::Invalid; OPS_MAX], len: 0, dest: dest, } } fn append(&mut self, op: Op<'a>) -> Result<(), io::Error> { if self.len >= self.ops.len() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } self.ops[self.len] = op; self.len += 1; Ok(()) } #[cfg(test)] pub fn write_null(&mut self) -> Result<(), io::Error> { self.append(Op::Null) } pub fn write_bool(&mut self, b: bool) -> Result<(), io::Error> { self.append(Op::Bool(b)) } pub fn write_int(&mut self, x: isize) -> Result<(), io::Error> { self.append(Op::Int(x)) } #[cfg(test)] pub fn write_float(&mut self, x: f64) -> Result<(), io::Error> { self.append(Op::Float(x)) } pub fn write_string(&mut self, s: &'a [u8]) -> Result<(), io::Error> { self.append(Op::String(s)) } pub fn start_array(&mut self) -> Result<(), io::Error> { self.append(Op::StartArray) } pub fn end_array(&mut self) -> Result<(), io::Error> { self.append(Op::EndArray) } pub fn start_map(&mut self) -> Result<(), io::Error> { self.append(Op::StartMap) } pub fn end_map(&mut self) -> Result<(), io::Error> { self.append(Op::EndMap) } pub fn flush(&mut self) -> Result<(), io::Error> { let mut lens = [0; OPS_MAX]; let mut i = 0; while i < self.len { let count = calc_len(&self.ops[i..self.len], &mut lens[i..self.len]); assert!(i + count <= self.len); for _ in 0..count { self.ops[i].serialize(self.dest, lens[i])?; self.ops[i] = Op::Invalid; i += 1; } } self.len = 0; Ok(()) } } pub struct Frame<'a> { pub ftype: FrameType, pub data: &'a [u8], } pub fn parse_frame(src: &[u8]) -> Result<(Frame, usize), ParseError> { let mut size_end: Option = None; // find ':' for (i, &c) in src.iter().enumerate() { if c == b':' { size_end = Some(i); break; } else if !(c as char).is_digit(10) { return Err(ParseError::InvalidData); } } if size_end.is_none() { return Err(ParseError::UnexpectedEof); } let size_end = size_end.unwrap(); let size = match str::from_utf8(&src[..size_end]) { Ok(size) => size, Err(_) => { return Err(ParseError::InvalidData); } }; let size: usize = match size.parse() { Ok(size) => size, Err(_) => { return Err(ParseError::InvalidData); } }; if size_end + size + 2 > src.len() { return Err(ParseError::UnexpectedEof); } let type_byte = src[size_end + 1 + size]; let frame_type = match type_byte { b'~' => FrameType::Null, b'!' => FrameType::Bool, b'#' => FrameType::Int, b'^' => FrameType::Float, b',' => FrameType::String, b']' => FrameType::Array, b'}' => FrameType::Map, _ => { return Err(ParseError::InvalidData); } }; let end = size_end + size + 2; Ok(( Frame { ftype: frame_type, data: &src[(size_end + 1)..(size_end + 1 + size)], }, end, )) } #[cfg(test)] pub fn parse_null(src: &[u8]) -> Result<(), ParseError> { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Null => Ok(()), _ => Err(ParseError::WrongType(frame.ftype, FrameType::Null)), } } pub fn parse_bool(src: &[u8]) -> Result { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Bool => {} _ => { return Err(ParseError::WrongType(frame.ftype, FrameType::Bool)); } } match frame.data { TRUE_BYTES => Ok(true), FALSE_BYTES => Ok(false), _ => Err(ParseError::InvalidData), } } pub fn parse_int(src: &[u8]) -> Result { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Int => {} _ => { return Err(ParseError::WrongType(frame.ftype, FrameType::Int)); } } let x = match str::from_utf8(frame.data) { Ok(x) => x, Err(_) => { return Err(ParseError::InvalidData); } }; match x.parse() { Ok(x) => Ok(x), Err(_) => Err(ParseError::InvalidData), } } #[cfg(test)] pub fn parse_float(src: &[u8]) -> Result { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Float => {} _ => { return Err(ParseError::WrongType(frame.ftype, FrameType::Float)); } } let x = match str::from_utf8(frame.data) { Ok(x) => x, Err(_) => { return Err(ParseError::InvalidData); } }; match x.parse() { Ok(x) => Ok(x), Err(_) => Err(ParseError::InvalidData), } } pub fn parse_string(src: &[u8]) -> Result<&[u8], ParseError> { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::String => Ok(frame.data), _ => Err(ParseError::WrongType(frame.ftype, FrameType::String)), } } #[derive(Debug)] pub struct SequenceItem<'a> { pub ftype: FrameType, pub data: &'a [u8], } #[derive(Debug)] pub struct SequenceIterator<'a> { src: &'a [u8], pos: usize, } impl<'a> SequenceIterator<'a> { pub fn new(src: &'a [u8]) -> Self { Self { src: src, pos: 0 } } } impl<'a> Iterator for SequenceIterator<'a> { type Item = Result, ParseError>; fn next(&mut self) -> Option { if self.pos >= self.src.len() { return None; } match parse_frame(&self.src[self.pos..]) { Ok((frame, end)) => { let next_pos = self.pos + end; let idata = &self.src[self.pos..next_pos]; self.pos = next_pos; Some(Ok(SequenceItem { ftype: frame.ftype, data: idata, })) } Err(e) => { // make this the last iteration self.pos = self.src.len(); Some(Err(e)) } } } } pub fn parse_array(src: &[u8]) -> Result { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Array => Ok(SequenceIterator::new(frame.data)), _ => Err(ParseError::WrongType(frame.ftype, FrameType::Array)), } } #[derive(Debug)] pub struct MapItem<'a> { pub key: &'a str, pub ftype: FrameType, pub data: &'a [u8], } #[derive(Copy, Clone, Debug)] pub struct MapIterator<'a> { src: &'a [u8], pos: usize, } impl<'a> MapIterator<'a> { pub fn new(src: &'a [u8]) -> Self { Self { src, pos: 0 } } } impl<'a> Iterator for MapIterator<'a> { type Item = Result, ParseError>; fn next(&mut self) -> Option { if self.pos >= self.src.len() { return None; } let (kframe, kend) = match parse_frame(&self.src[self.pos..]) { Ok(frame) => frame, Err(e) => { // make this the last iteration self.pos = self.src.len(); return Some(Err(e)); } }; let kdata = match kframe.ftype { FrameType::String => kframe.data, _ => { // make this the last iteration self.pos = self.src.len(); return Some(Err(ParseError::InvalidKey)); } }; let kstr = match str::from_utf8(kdata) { Ok(s) => s, Err(_) => { // make this the last iteration self.pos = self.src.len(); return Some(Err(ParseError::InvalidKey)); } }; let vpos = self.pos + kend; let (vframe, vend) = match parse_frame(&self.src[vpos..]) { Ok(frame) => frame, Err(e) => { // make this the last iteration self.pos = self.src.len(); return Some(Err(e)); } }; let next_pos = vpos + vend; let vdata = &self.src[vpos..next_pos]; self.pos = next_pos; Some(Ok(MapItem { key: kstr, ftype: vframe.ftype, data: vdata, })) } } pub fn parse_map(src: &[u8]) -> Result { let (frame, _) = parse_frame(src)?; match frame.ftype { FrameType::Map => Ok(MapIterator::new(frame.data)), _ => Err(ParseError::WrongType(frame.ftype, FrameType::Map)), } } impl fmt::Display for Frame<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.ftype { FrameType::Map => { let it = MapIterator::new(self.data); write!(f, "{{ ")?; for (i, mi) in it.enumerate() { let mi = match mi { Ok(mi) => mi, Err(_) => return Ok(()), }; // can't fail let (frame, _) = parse_frame(mi.data).unwrap(); if i > 0 { write!(f, ", ")?; } write!(f, "\"{}\": {}", mi.key, frame)?; } write!(f, " }}") } FrameType::Array => { let it = SequenceIterator::new(self.data); write!(f, "[ ")?; for (i, si) in it.enumerate() { let si = match si { Ok(si) => si, Err(_) => return Ok(()), }; // can't fail let (frame, _) = parse_frame(si.data).unwrap(); if i > 0 { write!(f, ", ")?; } write!(f, "{}", frame)?; } write!(f, " ]") } FrameType::Null => write!(f, "null"), FrameType::Bool => match self.data { TRUE_BYTES => write!(f, "true"), FALSE_BYTES => write!(f, "false"), _ => write!(f, ""), }, FrameType::Int => { let x = match str::from_utf8(self.data) { Ok(x) => x, Err(_) => return write!(f, ""), }; let x: isize = match x.parse() { Ok(x) => x, Err(_) => return write!(f, ""), }; write!(f, "{}", x) } FrameType::Float => { let x = match str::from_utf8(self.data) { Ok(x) => x, Err(_) => return write!(f, ""), }; let x: f64 = match x.parse() { Ok(x) => x, Err(_) => return write!(f, ""), }; write!(f, "{}", x) } FrameType::String => { write!(f, "\"")?; for b in self.data { write!(f, "{}", ascii::escape_default(*b))?; } write!(f, "\"") } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_length() { assert_eq!(usize_len(0), 1); assert_eq!(usize_len(1), 1); assert_eq!(usize_len(9), 1); assert_eq!(usize_len(10), 2); assert_eq!(usize_len(99), 2); assert_eq!(usize_len(100), 3); assert_eq!(isize_len(0), 1); assert_eq!(isize_len(1), 1); assert_eq!(isize_len(9), 1); assert_eq!(isize_len(10), 2); assert_eq!(isize_len(99), 2); assert_eq!(isize_len(100), 3); assert_eq!(isize_len(-1), 2); assert_eq!(isize_len(-9), 2); assert_eq!(isize_len(-10), 3); } #[test] fn test_null() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_null().unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"0:~"); let e = parse_null(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_null(b"5:hello,").unwrap_err(); assert_eq!(e, ParseError::WrongType(FrameType::String, FrameType::Null)); let r = parse_null(b"0:~").unwrap(); assert_eq!(r, ()); } #[test] fn test_bool() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_bool(true).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"4:true!"); let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_bool(false).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"5:false!"); let e = parse_bool(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_bool(b"5:hello,").unwrap_err(); assert_eq!(e, ParseError::WrongType(FrameType::String, FrameType::Bool)); let e = parse_bool(b"5:bogus!").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let b = parse_bool(b"4:true!").unwrap(); assert_eq!(b, true); let b = parse_bool(b"5:false!").unwrap(); assert_eq!(b, false); } #[test] fn test_int() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_int(0).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"1:0#"); let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_int(42).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"2:42#"); let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_int(-42).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"3:-42#"); let e = parse_int(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_int(b"5:hello,").unwrap_err(); assert_eq!(e, ParseError::WrongType(FrameType::String, FrameType::Int)); let e = parse_int(b"5:bogus#").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let x = parse_int(b"1:0#").unwrap(); assert_eq!(x, 0); let x = parse_int(b"2:42#").unwrap(); assert_eq!(x, 42); let x = parse_int(b"3:-42#").unwrap(); assert_eq!(x, -42); } #[test] fn test_float() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_float(0.0).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"1:0^"); let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_float(-0.5).unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"4:-0.5^"); let e = parse_float(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_float(b"5:hello,").unwrap_err(); assert_eq!( e, ParseError::WrongType(FrameType::String, FrameType::Float) ); let e = parse_float(b"5:bogus^").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let x = parse_float(b"1:0^").unwrap(); assert_eq!(x, 0.0); let x = parse_float(b"4:-0.5^").unwrap(); assert_eq!(x, -0.5); } #[test] fn test_string() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_string(b"").unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"0:,"); let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_string(b"hello").unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"5:hello,"); let e = parse_string(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_string(b"1:0#").unwrap_err(); assert_eq!(e, ParseError::WrongType(FrameType::Int, FrameType::String)); let s = parse_string(b"5:hello,").unwrap(); assert_eq!(s, b"hello"); } #[test] fn test_array() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.start_array().unwrap(); w.write_string(b"foo").unwrap(); w.write_string(b"bar").unwrap(); w.end_array().unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"12:3:foo,3:bar,]"); let e = parse_array(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_array(b"5:hello,").unwrap_err(); assert_eq!( e, ParseError::WrongType(FrameType::String, FrameType::Array) ); let mut it = parse_array(b"5:inner]").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::InvalidData); let mut it = parse_array(b"8:5:hello,]").unwrap(); let si = it.next().unwrap().unwrap(); assert_eq!(si.ftype, FrameType::String); assert_eq!(si.data, b"5:hello,"); assert!(it.next().is_none()); } #[test] fn test_map() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.start_map().unwrap(); w.write_string(b"foo").unwrap(); w.write_string(b"bar").unwrap(); w.end_map().unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"12:3:foo,3:bar,}"); let e = parse_map(b"bogus").unwrap_err(); assert_eq!(e, ParseError::InvalidData); let e = parse_map(b"5:hello,").unwrap_err(); assert_eq!(e, ParseError::WrongType(FrameType::String, FrameType::Map)); let mut it = parse_map(b"5:inner}").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::InvalidData); let mut it = parse_map(b"4:1:0#}").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::InvalidKey); let mut it = parse_map(b"4:1:\x80,}").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::InvalidKey); let mut it = parse_map(b"7:4:name,}").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::UnexpectedEof); let mut it = parse_map(b"8:4:name,X}").unwrap(); let e = it.next().unwrap().unwrap_err(); assert_eq!(e, ParseError::InvalidData); let mut it = parse_map(b"15:4:name,5:alice,}").unwrap(); let mi = it.next().unwrap().unwrap(); assert_eq!(mi.key, "name"); assert_eq!(mi.ftype, FrameType::String); assert_eq!(mi.data, b"5:alice,"); assert!(it.next().is_none()); } #[test] fn test_sequence() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_string(b"apple").unwrap(); w.write_string(b"banana").unwrap(); w.flush().unwrap(); let end = cursor.position() as usize; assert_eq!(&buf[..end], b"5:apple,6:banana,"); let mut buf = [0; 10]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); w.write_string(b"apple").unwrap(); w.write_string(b"banana").unwrap(); // won't fit let e = w.flush().unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::WriteZero); let mut it = SequenceIterator::new(b"5:apple,6:banana,"); let s = parse_string(it.next().unwrap().unwrap().data).unwrap(); assert_eq!(s, b"apple"); let s = parse_string(it.next().unwrap().unwrap().data).unwrap(); assert_eq!(s, b"banana"); assert!(it.next().is_none()); } #[test] fn test_overflow() { let mut buf = [0; 256]; let mut cursor = io::Cursor::new(&mut buf[..]); let mut w = Writer::new(&mut cursor); for _ in 0..OPS_MAX { w.write_string(b"foo").unwrap(); } // won't fit let e = w.write_string(b"foo").unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::WriteZero); } } condure-1.3.1/src/waker.rs000064400000000000000000000064200000000000000135120ustar 00000000000000/* * Copyright (C) 2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::mem; use std::rc::Rc; use std::task::{RawWaker, RawWakerVTable, Waker}; // adapted from std::task::Wake pub trait RcWake { fn wake(self: Rc); fn wake_by_ref(self: &Rc) { self.clone().wake(); } } pub fn into_std(waker: Rc) -> Waker { // SAFETY: This is safe because raw_waker safely constructs // a RawWaker from Rc. unsafe { Waker::from_raw(raw_waker(waker)) } } #[inline(always)] fn raw_waker(waker: Rc) -> RawWaker { unsafe fn clone_waker(waker: *const ()) -> RawWaker { let waker = mem::ManuallyDrop::new(Rc::from_raw(waker as *const W)); let waker = Rc::clone(&waker); RawWaker::new( Rc::into_raw(waker) as *const (), &RawWakerVTable::new( clone_waker::, wake::, wake_by_ref::, drop_waker::, ), ) } unsafe fn wake(waker: *const ()) { let waker = Rc::from_raw(waker as *const W); ::wake(waker); } unsafe fn wake_by_ref(waker: *const ()) { let waker = mem::ManuallyDrop::new(Rc::from_raw(waker as *const W)); ::wake_by_ref(&waker); } unsafe fn drop_waker(waker: *const ()) { Rc::from_raw(waker as *const W); } RawWaker::new( Rc::into_raw(waker) as *const (), &RawWakerVTable::new( clone_waker::, wake::, wake_by_ref::, drop_waker::, ), ) } #[cfg(test)] mod tests { use super::*; use std::cell::Cell; struct TestWaker { waked: Cell, } impl TestWaker { fn new() -> Self { TestWaker { waked: Cell::new(0), } } fn waked(&self) -> u32 { self.waked.get() } } impl RcWake for TestWaker { fn wake(self: Rc) { self.waked.set(self.waked.get() + 1); } } #[test] fn test_waker() { let data = Rc::new(TestWaker::new()); assert_eq!(Rc::strong_count(&data), 1); let waker = into_std(data.clone()); assert_eq!(Rc::strong_count(&data), 2); let waker2 = waker.clone(); assert_eq!(Rc::strong_count(&data), 3); assert_eq!(data.waked(), 0); waker2.wake(); assert_eq!(Rc::strong_count(&data), 2); assert_eq!(data.waked(), 1); waker.wake_by_ref(); assert_eq!(Rc::strong_count(&data), 2); assert_eq!(data.waked(), 2); mem::drop(waker); assert_eq!(Rc::strong_count(&data), 1); } } condure-1.3.1/src/websocket.rs000064400000000000000000000516250000000000000143760ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::buffer::{write_vectored_offset, RefRead, VECTORED_MAX}; use std::cmp; use std::io; use std::io::Write; pub const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // 1 byte flags + 9 bytes payload size + 4 bytes mask pub const HEADER_SIZE_MAX: usize = 14; const PSIZE_3BYTE: usize = 126; const PSIZE_9BYTE: usize = 65536; pub const OPCODE_CONTINUATION: u8 = 0; pub const OPCODE_TEXT: u8 = 1; pub const OPCODE_BINARY: u8 = 2; pub const OPCODE_CLOSE: u8 = 8; pub const OPCODE_PING: u8 = 9; pub const OPCODE_PONG: u8 = 10; pub const CONTROL_FRAME_PAYLOAD_MAX: usize = 125; #[derive(Clone, Copy)] pub struct FrameInfo { pub fin: bool, pub opcode: u8, pub mask: Option<[u8; 4]>, pub payload_offset: usize, pub payload_size: usize, } fn header_size(payload_size: usize, masked: bool) -> usize { let size = if payload_size < PSIZE_3BYTE { 1 + 1 } else if payload_size < PSIZE_9BYTE { 1 + 3 } else { 1 + 9 }; if masked { size + 4 } else { size } } pub fn read_header(buf: &[u8]) -> Result { if buf.len() < 2 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let b1 = buf[1] & 0x7f; let (mut hsize, psize) = if b1 < (PSIZE_3BYTE as u8) { (2, b1 as usize) } else if b1 == (PSIZE_3BYTE as u8) { if buf.len() < 2 + 2 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut arr = [0; 2]; arr.copy_from_slice(&buf[2..4]); (4, u16::from_be_bytes(arr) as usize) } else { if buf.len() < 2 + 8 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut arr = [0; 8]; arr.copy_from_slice(&buf[2..10]); (10, u64::from_be_bytes(arr) as usize) }; let mask = if buf[1] & 0x80 != 0 { if buf.len() < hsize + 4 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut mask = [0; 4]; mask.copy_from_slice(&buf[hsize..hsize + 4]); hsize += 4; Some(mask) } else { None }; Ok(FrameInfo { fin: buf[0] & 0x80 != 0, opcode: buf[0] & 0x0f, mask, payload_offset: hsize, payload_size: psize, }) } // return payload offset pub fn write_header( fin: bool, opcode: u8, payload_size: usize, mask: Option<[u8; 4]>, buf: &mut [u8], ) -> Result { let hsize = header_size(payload_size, mask.is_some()); if buf.len() < hsize { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let mut b0 = 0; if fin { b0 |= 0x80; } b0 |= opcode & 0x0f; buf[0] = b0; let hsize = if payload_size < PSIZE_3BYTE { buf[1] = payload_size as u8; 2 } else if payload_size < PSIZE_9BYTE { buf[1] = PSIZE_3BYTE as u8; let arr = (payload_size as u16).to_be_bytes(); &buf[2..4].copy_from_slice(&arr); 4 } else { buf[1] = 0x7f; let arr = (payload_size as u64).to_be_bytes(); &buf[2..10].copy_from_slice(&arr); 10 }; if let Some(mask) = mask { buf[1] |= 0x80; &buf[hsize..hsize + 4].copy_from_slice(&mask); Ok(hsize + 4) } else { Ok(hsize) } } fn apply_mask(buf: &mut [u8], mask: [u8; 4], offset: usize) { for (i, c) in buf.iter_mut().enumerate() { *c ^= mask[(offset + i) % 4]; } } #[cfg(test)] pub struct Frame<'a> { pub opcode: u8, pub data: &'a [u8], pub fin: bool, } #[derive(Debug, PartialEq, Clone, Copy)] pub enum State { // call: send_frame, recv_frame // next: Connected, PeerClosed, Closing Connected, // call: send_frame // next: PeerClosed, Finished PeerClosed, // call: recv_frame // next: Closing, Finished Closing, // session has completed Finished, } #[derive(Debug)] pub enum Error { Io(io::Error), InvalidControlFrame, UnexpectedOpcode, } impl From for Error { fn from(e: io::Error) -> Self { Self::Io(e) } } struct SendingFrame { header: [u8; HEADER_SIZE_MAX], header_len: usize, sent: usize, } struct SendingMessage { opcode: u8, mask: Option<[u8; 4]>, frame_sent: bool, } struct ReceivingMessage { opcode: u8, frame_payload_read: usize, } pub struct Protocol { state: State, sending_frame: Option, receiving_frame: Option, sending_message: Option, receiving_message: Option, } impl<'buf> Protocol { pub fn new() -> Self { Self { state: State::Connected, sending_frame: None, receiving_frame: None, sending_message: None, receiving_message: None, } } pub fn state(&self) -> State { self.state } pub fn send_frame( &mut self, writer: &mut dyn Write, opcode: u8, src: &[&[u8]], fin: bool, mask: Option<[u8; 4]>, ) -> Result { assert!(self.state == State::Connected || self.state == State::PeerClosed); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if self.sending_frame.is_none() { let mut h = [0; HEADER_SIZE_MAX]; let size = write_header(fin, opcode, src_len, mask, &mut h[..])?; self.sending_frame = Some(SendingFrame { header: h, header_len: size, sent: 0, }); } let sending_frame = self.sending_frame.as_mut().unwrap(); let header = &sending_frame.header[..sending_frame.header_len]; let total = header.len() + src_len; let mut out_arr = [&b""[..]; VECTORED_MAX]; let mut out_arr_len = 0; out_arr[0] = header; out_arr_len += 1; for buf in src.iter() { out_arr[out_arr_len] = buf; out_arr_len += 1; } let size = write_vectored_offset(writer, &out_arr[..out_arr_len], sending_frame.sent)?; sending_frame.sent += size; if sending_frame.sent < total { return Ok(0); } self.sending_frame = None; if opcode == OPCODE_CLOSE { if self.state == State::PeerClosed { self.state = State::Finished; } else { self.state = State::Closing; } } Ok(src_len) } // on success, it's up to the caller to advance the buffer by frame.data.len() #[cfg(test)] pub fn recv_frame( &mut self, rbuf: &'buf mut dyn RefRead, ) -> Option, Error>> { assert!(self.state == State::Connected || self.state == State::Closing); if self.receiving_frame.is_none() { let fi = match read_header(rbuf.get_ref()) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return None, Err(e) => return Some(Err(e.into())), }; rbuf.consume(fi.payload_offset); self.receiving_frame = Some(fi); } let fi = self.receiving_frame.unwrap(); if rbuf.get_ref().len() < fi.payload_size { return None; } if fi.opcode == OPCODE_CLOSE { if self.state == State::Closing { self.state = State::Finished; } else { self.state = State::PeerClosed; } } let buf = rbuf.get_mut(); if let Some(mask) = fi.mask { apply_mask(buf, mask, 0); } self.receiving_frame = None; Some(Ok(Frame { opcode: fi.opcode, data: &buf[..fi.payload_size], fin: fi.fin, })) } pub fn is_sending_message(&self) -> bool { self.sending_message.is_some() } pub fn send_message_start(&mut self, opcode: u8, mask: Option<[u8; 4]>) { assert!(self.state == State::Connected || self.state == State::PeerClosed); assert_eq!(self.sending_message.is_some(), false); self.sending_message = Some(SendingMessage { opcode, mask, frame_sent: false, }); } pub fn send_message_content( &mut self, writer: &mut dyn Write, src: &[&[u8]], end: bool, ) -> Result<(usize, bool), Error> { assert!(self.state == State::Connected || self.state == State::PeerClosed); let msg = self.sending_message.as_ref().unwrap(); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } // control frames (ping, pong, close) must have a small payload length // and must not be fragmented if msg.opcode & 0x08 != 0 && (src_len > CONTROL_FRAME_PAYLOAD_MAX || !end) { return Err(Error::InvalidControlFrame); } let opcode = if msg.frame_sent { OPCODE_CONTINUATION } else { msg.opcode }; let fin = if let Some(f) = &self.sending_frame { f.header[0] & 0x80 != 0 } else { end }; let mask = msg.mask; let size = self.send_frame(writer, opcode, src, fin, mask)?; if self.sending_frame.is_none() && fin { self.sending_message = None; } else { let msg = self.sending_message.as_mut().unwrap(); msg.frame_sent = true; } let done = self.sending_message.is_none(); Ok((size, done)) } pub fn recv_message_content( &mut self, rbuf: &mut dyn RefRead, dest: &mut [u8], ) -> Option> { assert!(self.state == State::Connected || self.state == State::Closing); if self.receiving_frame.is_none() { let fi = match read_header(rbuf.get_ref()) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return None, Err(e) => return Some(Err(e.into())), }; rbuf.consume(fi.payload_offset); self.receiving_frame = Some(fi); if let Some(msg) = &mut self.receiving_message { if fi.opcode != OPCODE_CONTINUATION { return Some(Err(Error::UnexpectedOpcode)); } msg.frame_payload_read = 0; } else { if fi.opcode == OPCODE_CONTINUATION { return Some(Err(Error::UnexpectedOpcode)); } if fi.opcode & 0x08 != 0 && (fi.payload_size > CONTROL_FRAME_PAYLOAD_MAX || !fi.fin) { return Some(Err(Error::InvalidControlFrame)); } self.receiving_message = Some(ReceivingMessage { opcode: fi.opcode, frame_payload_read: 0, }); } } let fi = self.receiving_frame.as_ref().unwrap(); let msg = self.receiving_message.as_mut().unwrap(); let buf = rbuf.get_ref(); // control frames must be available in their entirety if fi.opcode & 0x08 != 0 && buf.len() < fi.payload_size { return None; } let left = fi.payload_size - msg.frame_payload_read; if left > 0 && buf.len() == 0 { return None; } let size = cmp::min(cmp::min(left, buf.len()), dest.len()); &mut dest[..size].copy_from_slice(&buf[..size]); rbuf.consume(size); if let Some(mask) = fi.mask { apply_mask(dest, mask, msg.frame_payload_read); } let opcode = msg.opcode; let fin = fi.fin; msg.frame_payload_read += size; if msg.frame_payload_read >= fi.payload_size { self.receiving_frame = None; if fin { self.receiving_message = None; if opcode == OPCODE_CLOSE { if self.state == State::Closing { self.state = State::Finished; } else { self.state = State::PeerClosed; } } } } Some(Ok((opcode, size, self.receiving_message.is_none()))) } } #[cfg(test)] mod tests { use super::*; struct MyWriter { data: Vec, } impl MyWriter { fn new() -> Self { Self { data: Vec::new() } } } impl Write for MyWriter { fn write(&mut self, buf: &[u8]) -> Result { self.data.extend_from_slice(buf.as_ref()); Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { total += buf.len(); self.data.extend_from_slice(buf.as_ref()); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } #[test] fn test_header_size() { assert_eq!(header_size(0, false), 2); assert_eq!(header_size(125, false), 2); assert_eq!(header_size(125, true), 6); assert_eq!(header_size(126, false), 4); assert_eq!(header_size(65535, false), 4); assert_eq!(header_size(65535, true), 8); assert_eq!(header_size(65536, false), 10); assert_eq!(header_size(65536, true), HEADER_SIZE_MAX); } #[test] fn test_read_write_header() { let mut buf = [ 0x81, 0x85, 0x01, 0x02, 0x03, 0x04, 0x69, 0x67, 0x6f, 0x68, 0x6e, ]; let r = read_header(&buf); assert!(r.is_ok()); let fi = r.unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, OPCODE_TEXT); assert_eq!(fi.mask, Some([0x01, 0x02, 0x03, 0x04])); assert_eq!(fi.payload_offset, 6); assert_eq!(fi.payload_size, 5); let end = fi.payload_offset + fi.payload_size; let payload = &mut buf[fi.payload_offset..end]; apply_mask(payload, (&fi.mask).unwrap(), 0); assert_eq!(payload, b"hello"); let payload = b"hello"; let mut buf2 = Vec::new(); buf2.resize(header_size(payload.len(), true) + payload.len(), 0); let r = write_header( true, OPCODE_TEXT, payload.len(), Some([0x01, 0x02, 0x03, 0x04]), &mut buf2, ); assert!(r.is_ok()); let offset = r.unwrap(); assert_eq!(offset, 6); &buf2[offset..offset + payload.len()].copy_from_slice(payload); assert_eq!(buf2, buf); } #[test] fn test_apply_mask() { let mut buf = [b'a', b'b', b'c', b'd', b'e']; apply_mask(&mut buf, [0x01, 0x02, 0x03, 0x04], 0); assert_eq!(buf, [0x60, 0x60, 0x60, 0x60, 0x64]); } #[test] fn test_send_frame() { let mut p = Protocol::new(); assert_eq!(p.state(), State::Connected); let mut writer = MyWriter::new(); let size = p .send_frame(&mut writer, OPCODE_TEXT, &[b"hello"], true, None) .unwrap(); assert_eq!(size, 5); assert_eq!(writer.data, b"\x81\x05hello"); assert_eq!(p.state(), State::Connected); } #[test] fn test_send_message() { let mut p = Protocol::new(); assert_eq!(p.state(), State::Connected); let mut writer = MyWriter::new(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content(&mut writer, &[b"hel", b"lo"], true) .unwrap(); assert_eq!(size, 5); assert_eq!(done, true); assert_eq!(writer.data, b"\x81\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content(&mut writer, &[b"hello"], false) .unwrap(); assert_eq!(size, 5); assert_eq!(done, false); assert_eq!(writer.data, b"\x01\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); let (size, done) = p.send_message_content(&mut writer, &[b""], true).unwrap(); assert_eq!(size, 0); assert_eq!(done, true); assert_eq!(writer.data, b"\x80\x00"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_PING, None); let (size, done) = p .send_message_content(&mut writer, &[b"hello"], true) .unwrap(); assert_eq!(size, 5); assert_eq!(done, true); assert_eq!(writer.data, b"\x89\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_PING, None); let r = p.send_message_content(&mut writer, &[b"hello"], false); assert!(r.is_err()); } #[test] fn test_recv_frame() { let mut data = b"\x81\x05hello".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let mut p = Protocol::new(); assert_eq!(p.state(), State::Connected); let frame = p.recv_frame(&mut rbuf).unwrap().unwrap(); assert_eq!(frame.opcode, OPCODE_TEXT); assert_eq!(frame.data, b"hello"); assert_eq!(frame.fin, true); let size = frame.data.len(); rbuf.consume(size); assert_eq!(p.state(), State::Connected); } #[test] fn test_recv_message() { let mut data = b"\x81\x05hello".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let mut p = Protocol::new(); assert_eq!(p.state(), State::Connected); let mut dest = [0; 1024]; let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hello"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let r = p.recv_message_content(&mut rbuf, &mut dest); assert!(r.is_none()); let mut data = b"\x01\x03hel\x80\x02lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"lo"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"\x81\x05hel".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); assert!(p.recv_message_content(&mut rbuf, &mut dest).is_none()); let mut data = b"lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"lo"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"\x01\x03hel\x01\x02lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); let r = p.recv_message_content(&mut rbuf, &mut dest).unwrap(); assert!(r.is_err()); } } condure-1.3.1/src/zhttppacket.rs000064400000000000000000001162220000000000000147440ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::tnetstring; use std::cell::RefCell; use std::fmt; use std::io; use std::mem; use std::str; pub const IDS_MAX: usize = 128; const HEADERS_MAX: usize = 32; const EMPTY_BYTES: &[u8] = b""; const EMPTY_ID: Id = Id { id: EMPTY_BYTES, seq: None, }; pub const EMPTY_HEADER: Header = Header { name: "", value: EMPTY_BYTES, }; const EMPTY_HEADERS: [Header; 0] = [EMPTY_HEADER; 0]; #[derive(Clone, Copy)] pub struct Id<'a> { pub id: &'a [u8], pub seq: Option, } pub struct Header<'a> { pub name: &'a str, pub value: &'a [u8], } #[derive(Debug, PartialEq)] pub enum ContentType { Text, Binary, } pub struct RequestData<'buf, 'headers> { pub credits: u32, pub more: bool, pub stream: bool, pub max_size: u32, pub timeout: u32, pub method: &'buf str, pub uri: &'buf str, pub headers: &'headers [Header<'buf>], pub content_type: Option, // websocket pub body: &'buf [u8], pub peer_address: &'buf str, pub peer_port: u16, } impl RequestData<'_, '_> { pub fn new() -> Self { Self { credits: 0, more: false, stream: false, max_size: 0, timeout: 0, method: "", uri: "", headers: &EMPTY_HEADERS, body: EMPTY_BYTES, content_type: None, peer_address: "", peer_port: 0, } } } pub struct ResponseData<'buf, 'headers> { pub credits: u32, pub more: bool, pub code: u32, pub reason: &'buf str, pub headers: &'headers [Header<'buf>], pub content_type: Option, // websocket pub body: &'buf [u8], } pub struct RejectedInfo<'buf, 'headers> { pub code: u32, pub reason: &'buf str, pub headers: &'headers [Header<'buf>], pub body: &'buf [u8], } pub struct ErrorData<'buf, 'headers> { pub condition: &'buf str, pub rejected_info: Option>, // rejected (websocket) } pub struct CreditData { pub credits: u32, } pub struct CloseData<'a> { // code, reason pub status: Option<(u16, &'a str)>, } pub struct PingData<'a> { pub credits: u32, pub body: &'a [u8], } pub struct PongData<'a> { pub credits: u32, pub body: &'a [u8], } pub enum RequestPacket<'buf, 'headers> { Data(RequestData<'buf, 'headers>), Error(ErrorData<'buf, 'headers>), Credit(CreditData), KeepAlive, Cancel, HandoffStart, HandoffProceed, Close(CloseData<'buf>), Ping(PingData<'buf>), Pong(PongData<'buf>), } pub enum ResponsePacket<'buf, 'headers> { Unknown, Data(ResponseData<'buf, 'headers>), Error(ErrorData<'buf, 'headers>), Credit(CreditData), KeepAlive, Cancel, HandoffStart, HandoffProceed, Close(CloseData<'buf>), Ping(PingData<'buf>), Pong(PongData<'buf>), } pub struct Request<'buf, 'ids, 'headers> { pub from: &'buf [u8], pub ids: &'ids [Id<'buf>], pub multi: bool, pub ptype: RequestPacket<'buf, 'headers>, } impl<'buf, 'ids, 'headers> Request<'buf, 'ids, 'headers> { pub fn new_data( from: &'buf [u8], ids: &'ids [Id<'buf>], data: RequestData<'buf, 'headers>, ) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Data(data), } } pub fn new_error(from: &'buf [u8], ids: &'ids [Id<'buf>], condition: &'buf str) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Error(ErrorData { condition, rejected_info: None, }), } } pub fn new_credit(from: &'buf [u8], ids: &'ids [Id<'buf>], credits: u32) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Credit(CreditData { credits }), } } pub fn new_keep_alive(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::KeepAlive, } } pub fn new_cancel(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Cancel, } } pub fn new_handoff_start(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::HandoffStart, } } pub fn new_handoff_proceed(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::HandoffProceed, } } pub fn new_close( from: &'buf [u8], ids: &'ids [Id<'buf>], status: Option<(u16, &'buf str)>, ) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Close(CloseData { status }), } } pub fn new_ping(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Ping(PingData { credits: 0, body }), } } pub fn new_pong(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self { from: from, ids: ids, multi: false, ptype: RequestPacket::Pong(PongData { credits: 0, body }), } } pub fn serialize(&self, dest: &mut [u8]) -> Result { if dest.len() < 1 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } dest[0] = b'T'; let mut cursor = io::Cursor::new(&mut dest[1..]); let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; if !self.from.is_empty() { w.write_string(b"from")?; w.write_string(self.from)?; } if self.ids.len() == 1 { w.write_string(b"id")?; w.write_string(self.ids[0].id)?; if let Some(seq) = self.ids[0].seq { w.write_string(b"seq")?; w.write_int(seq as isize)?; } } else if self.ids.len() > 1 { w.write_string(b"id")?; w.start_array()?; for id in self.ids.iter() { w.start_map()?; w.write_string(b"id")?; w.write_string(id.id)?; if let Some(seq) = id.seq { w.write_string(b"seq")?; w.write_int(seq as isize)?; } w.end_map()?; } w.end_array()?; } if self.multi { w.write_string(b"ext")?; w.start_map()?; w.write_string(b"multi")?; w.write_bool(true)?; w.end_map()?; } match &self.ptype { RequestPacket::Data(data) => { if !data.method.is_empty() { w.write_string(b"method")?; w.write_string(data.method.as_bytes())?; } if !data.uri.is_empty() { w.write_string(b"uri")?; w.write_string(data.uri.as_bytes())?; } if !data.headers.is_empty() { w.write_string(b"headers")?; w.start_array()?; for h in data.headers.iter() { w.start_array()?; w.write_string(h.name.as_bytes())?; w.write_string(h.value)?; w.end_array()?; } w.end_array()?; } if let Some(ctype) = &data.content_type { w.write_string(b"content-type")?; let s: &[u8] = match ctype { ContentType::Text => b"text", ContentType::Binary => b"binary", }; w.write_string(s)?; } if !data.body.is_empty() { w.write_string(b"body")?; w.write_string(&data.body)?; } if data.credits > 0 { w.write_string(b"credits")?; w.write_int(data.credits as isize)?; } if data.more { w.write_string(b"more")?; w.write_bool(true)?; } if data.stream { w.write_string(b"stream")?; w.write_bool(true)?; } if data.max_size > 0 { w.write_string(b"max_size")?; w.write_int(data.max_size as isize)?; } if data.timeout > 0 { w.write_string(b"timeout")?; w.write_int(data.timeout as isize)?; } if !data.peer_address.is_empty() { w.write_string(b"peer-address")?; w.write_string(data.peer_address.as_bytes())?; w.write_string(b"peer-port")?; w.write_int(data.peer_port as isize)?; } } RequestPacket::Error(data) => { w.write_string(b"type")?; w.write_string(b"error")?; w.write_string(b"condition")?; w.write_string(data.condition.as_bytes())?; } RequestPacket::Credit(data) => { w.write_string(b"type")?; w.write_string(b"credit")?; w.write_string(b"credits")?; w.write_int(data.credits as isize)?; } RequestPacket::KeepAlive => { w.write_string(b"type")?; w.write_string(b"keep-alive")?; } RequestPacket::Cancel => { w.write_string(b"type")?; w.write_string(b"cancel")?; } RequestPacket::HandoffStart => { w.write_string(b"type")?; w.write_string(b"handoff-start")?; } RequestPacket::HandoffProceed => { w.write_string(b"type")?; w.write_string(b"handoff-proceed")?; } RequestPacket::Close(data) => { w.write_string(b"type")?; w.write_string(b"close")?; if let Some(status) = data.status { w.write_string(b"code")?; w.write_int(status.0 as isize)?; if !status.1.is_empty() { w.write_string(b"body")?; w.write_string(status.1.as_bytes())?; } } } RequestPacket::Ping(data) => { w.write_string(b"type")?; w.write_string(b"ping")?; if !data.body.is_empty() { w.write_string(b"body")?; w.write_string(&data.body)?; } } RequestPacket::Pong(data) => { w.write_string(b"type")?; w.write_string(b"pong")?; if !data.body.is_empty() { w.write_string(b"body")?; w.write_string(&data.body)?; } } } w.end_map()?; w.flush()?; Ok((cursor.position() as usize) + 1) } } #[derive(Debug, PartialEq)] pub enum ParseError { Unrecognized, TnetParse(tnetstring::ParseError), WrongType(&'static str, tnetstring::FrameType), NotMapOrString(&'static str), NotUtf8(&'static str), NegativeInt(&'static str), TooManyIds, TooManyHeaders, InvalidHeader, NoId, } impl fmt::Display for ParseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Unrecognized => write!(f, "unrecognized data format"), Self::TnetParse(e) => e.fmt(f), Self::WrongType(field, expected) => write!(f, "{} must be of type {}", field, expected), Self::NotMapOrString(field) => write!(f, "{} must be of type map or string", field), Self::NotUtf8(field) => write!(f, "{} must be a utf-8 string", field), Self::NegativeInt(field) => write!(f, "{} must not be negative", field), Self::TooManyIds => write!(f, "too many ids"), Self::TooManyHeaders => write!(f, "too many headers"), Self::InvalidHeader => write!(f, "header item must have size 2"), Self::NoId => write!(f, "no id"), } } } impl From for ParseError { fn from(e: tnetstring::ParseError) -> Self { Self::TnetParse(e) } } trait ErrorContext { fn field(self, field: &'static str) -> Result; } impl ErrorContext for Result { fn field(self, field: &'static str) -> Result { match self { Ok(v) => Ok(v), Err(tnetstring::ParseError::WrongType(_, expected)) => { Err(ParseError::WrongType(field, expected)) } Err(e) => Err(e.into()), } } } impl ErrorContext for Result { fn field(self, field: &'static str) -> Result { match self { Ok(v) => Ok(v), Err(_) => Err(ParseError::NotUtf8(field)), } } } pub struct ResponseScratch<'a> { pub ids: [Id<'a>; IDS_MAX], pub headers: [Header<'a>; HEADERS_MAX], } impl ResponseScratch<'_> { pub fn new() -> Self { Self { ids: [EMPTY_ID; IDS_MAX], headers: [EMPTY_HEADER; HEADERS_MAX], } } } pub struct Response<'buf, 'ids, 'headers> { pub from: &'buf [u8], pub ids: &'ids [Id<'buf>], pub multi: bool, pub ptype: ResponsePacket<'buf, 'headers>, pub ptype_str: &'buf str, } impl<'buf, 'scratch> Response<'_, '_, '_> { pub fn parse_ids( src: &'buf [u8], scratch: &'scratch mut ResponseScratch<'buf>, ) -> Result<&'scratch [Id<'buf>], ParseError> { if src.len() < 1 || src[0] != b'T' { return Err(ParseError::Unrecognized); } let root = tnetstring::parse_map(&src[1..]).field("root")?; let mut ids_pos = 0; for e in root.clone() { let e = e?; match e.key { "id" => { match e.ftype { tnetstring::FrameType::Array => { for idm in tnetstring::parse_array(e.data)? { let idm = idm?; if ids_pos >= scratch.ids.len() { return Err(ParseError::TooManyIds); } for m in tnetstring::parse_map(idm.data)? { let m = m?; match m.key { "id" => { let s = tnetstring::parse_string(m.data).field("id")?; scratch.ids[ids_pos].id = s; } _ => {} // skip other fields } } ids_pos += 1; } } tnetstring::FrameType::String => { let s = tnetstring::parse_string(e.data)?; scratch.ids[0].id = s; ids_pos = 1; } _ => { return Err(ParseError::NotMapOrString("id")); } } return Ok(&scratch.ids[..ids_pos]); } _ => {} // skip other fields } } Ok(&scratch.ids[..ids_pos]) } pub fn parse( src: &'buf [u8], scratch: &'scratch mut ResponseScratch<'buf>, ) -> Result, ParseError> { if src.len() < 1 || src[0] != b'T' { return Err(ParseError::Unrecognized); } let root = tnetstring::parse_map(&src[1..]).field("root")?; // first, read the common fields let mut from = EMPTY_BYTES; let mut ids_pos = 0; let mut multi = false; let mut ptype_str = ""; for e in root.clone() { let e = e?; match e.key { "from" => { from = tnetstring::parse_string(e.data).field("from")?; } "id" => { match e.ftype { tnetstring::FrameType::Array => { for idm in tnetstring::parse_array(e.data)? { let idm = idm?; if ids_pos >= scratch.ids.len() { return Err(ParseError::TooManyIds); } for m in tnetstring::parse_map(idm.data)? { let m = m?; match m.key { "id" => { let s = tnetstring::parse_string(m.data).field("id")?; scratch.ids[ids_pos].id = s; } "seq" => { let x = tnetstring::parse_int(m.data).field("seq")?; if x < 0 { return Err(ParseError::NegativeInt("seq")); } scratch.ids[ids_pos].seq = Some(x as u32); } _ => {} // skip unknown fields } } ids_pos += 1; } } tnetstring::FrameType::String => { let s = tnetstring::parse_string(e.data)?; scratch.ids[0].id = s; ids_pos = 1; } _ => { return Err(ParseError::NotMapOrString("id")); } } } "seq" => { let x = tnetstring::parse_int(e.data).field("seq")?; if x < 0 { return Err(ParseError::NegativeInt("seq")); } scratch.ids[0].seq = Some(x as u32); } "type" => { let s = tnetstring::parse_string(e.data).field("type")?; let s = str::from_utf8(s).field("type")?; ptype_str = s; } "ext" => { let ext = tnetstring::parse_map(e.data).field("ext")?; for m in ext { let m = m?; match m.key { "multi" => { let b = tnetstring::parse_bool(m.data).field("multi")?; multi = b; } _ => {} // skip unknown fields } } } _ => {} // skip unknown fields } } let ptype = match ptype_str { // data "" => { let mut credits = 0; let mut more = false; let mut code = 0; let mut reason = ""; let mut headers_pos = 0; let mut content_type = None; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } "more" => { let b = tnetstring::parse_bool(e.data).field("more")?; more = b; } "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = x as u32; } "reason" => { let s = tnetstring::parse_string(e.data).field("reason")?; let s = str::from_utf8(s).field("reason")?; reason = s; } "headers" => { let headers = tnetstring::parse_array(e.data).field("headers")?; for ha in headers { let ha = ha?; if headers_pos >= scratch.headers.len() { return Err(ParseError::TooManyHeaders); } let mut hi = tnetstring::parse_array(ha.data).field("header item")?; let name = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let name = tnetstring::parse_string(name.data).field("header name")?; let name = str::from_utf8(name).field("header name")?; let value = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let value = tnetstring::parse_string(value.data).field("header value")?; scratch.headers[headers_pos] = Header { name, value }; headers_pos += 1; } } "content-type" => { let s = tnetstring::parse_string(e.data).field("content-type")?; content_type = Some(match s { b"binary" => ContentType::Binary, _ => ContentType::Text, }); } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } ResponsePacket::Data(ResponseData { credits, more, code, reason, headers: &scratch.headers[..headers_pos], content_type, body, }) } "error" => { let mut condition = ""; let mut code = 0; let mut reason = ""; let mut headers_pos = 0; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "condition" => { let s = tnetstring::parse_string(e.data).field("condition")?; let s = str::from_utf8(s).field("condition")?; condition = s; } "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = x as u32; } "reason" => { let s = tnetstring::parse_string(e.data).field("reason")?; let s = str::from_utf8(s).field("reason")?; reason = s; } "headers" => { let headers = tnetstring::parse_array(e.data).field("headers")?; for ha in headers { let ha = ha?; if headers_pos >= scratch.headers.len() { return Err(ParseError::TooManyHeaders); } let mut hi = tnetstring::parse_array(ha.data).field("header item")?; let name = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let name = tnetstring::parse_string(name.data).field("header name")?; let name = str::from_utf8(name).field("header name")?; let value = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let value = tnetstring::parse_string(value.data).field("header value")?; scratch.headers[headers_pos] = Header { name, value }; headers_pos += 1; } } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } let rejected_info = if condition == "rejected" { Some(RejectedInfo { code, reason, headers: &scratch.headers[..headers_pos], body, }) } else { None }; ResponsePacket::Error(ErrorData { condition, rejected_info, }) } "credit" => { let mut credits = 0; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } _ => {} // skip unknown fields } } ResponsePacket::Credit(CreditData { credits }) } "keep-alive" => ResponsePacket::KeepAlive, "cancel" => ResponsePacket::Cancel, "handoff-start" => ResponsePacket::HandoffStart, "handoff-proceed" => ResponsePacket::HandoffProceed, "close" => { let mut code = None; let mut reason = ""; for e in root { let e = e?; match e.key { "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = Some(x as u16); } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; let s = str::from_utf8(s).field("condition")?; reason = s; } _ => {} // skip unknown fields } } if let Some(code) = code { ResponsePacket::Close(CloseData { status: Some((code, reason)), }) } else { ResponsePacket::Close(CloseData { status: None }) } } "ping" | "pong" => { let mut credits = 0; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } match ptype_str { "ping" => ResponsePacket::Ping(PingData { credits, body }), "pong" => ResponsePacket::Pong(PongData { credits, body }), _ => unreachable!(), } } _ => ResponsePacket::Unknown, }; Ok(Response { from, ids: &scratch.ids[..ids_pos], multi, ptype, ptype_str, }) } } pub struct OwnedResponse { resp: Response<'static, 'static, 'static>, _scratch: arena::Rc>>, _src: arena::Arc, } impl OwnedResponse { pub fn parse( src: arena::Arc, offset: usize, scratch: arena::Rc>>, ) -> Result { let src_ref: &[u8] = &src.get()[offset..]; // safety: Self will take ownership of src, and the bytes referred to // by src_ref are on the heap, and src will not be modified or // dropped until Self is dropped, so the bytes referred to by src_ref // will remain valid for the lifetime of Self let src_ref: &'static [u8] = unsafe { mem::transmute(src_ref) }; // safety: Self will take ownership of scratch, and the location // referred to by scratch_mut is in an arena, and scratch will not // be dropped until Self is dropped, so the location referred to by // scratch_mut will remain valid for the lifetime of Self // // further, it is safe for Response::parse() to write references to // src_ref into scratch_mut, because src_ref and scratch_mut have // the same lifetime let scratch_mut: &'static mut ResponseScratch<'static> = unsafe { scratch.get().as_ptr().as_mut().unwrap() }; let resp = Response::parse(src_ref, scratch_mut)?; Ok(Self { resp, _scratch: scratch, _src: src, }) } pub fn get<'a>(&'a self) -> &'a Response<'a, 'a, 'a> { let resp = &self.resp; // safety: here we simply reduce the inner lifetimes to that of the owning // object, which is fine let resp: &'a Response<'a, 'a, 'a> = unsafe { mem::transmute(resp) }; resp } } #[cfg(test)] mod tests { use super::*; use std::rc::Rc; use std::sync::Arc; #[test] fn test_serialize() { struct Test { name: &'static str, req: Request<'static, 'static, 'static>, expected: &'static str, } // data, error, credit, keepalive, cancel, handoffstart/proceed, close, ping, pong let tests = [ Test { name: "data", req: Request { from: b"client", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: RequestPacket::Data(RequestData { credits: 0, more: true, stream: false, max_size: 0, timeout: 0, method: "POST", uri: "http://example.com/path", headers: &[Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello", peer_address: "", peer_port: 0, }), }, expected: concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ), }, Test { name: "error", req: Request { from: b"client", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: RequestPacket::Error(ErrorData { condition: "bad-request", rejected_info: None, }), }, expected: concat!( "T77:4:from,6:client,2:id,1:1,3:seq,1:0#4:type,5:error,9:condi", "tion,11:bad-request,}", ), }, ]; for test in tests.iter() { let mut data = [0; 1024]; let size = test.req.serialize(&mut data).unwrap(); assert_eq!( str::from_utf8(&data[..size]).unwrap(), test.expected, "test={}", test.name ); } } #[test] fn test_parse() { let data = concat!( "T208:4:more,4:true!7:headers,34:30:12:Content-Type,10:text/pl", "ain,]]12:content-type,6:binary,4:from,6:server,2:id,1:1,6:rea", "son,2:OK,7:credits,3:100#9:user-data,12:3:foo,3:bar,}3:seq,1:", "0#4:code,3:200#4:body,5:hello,}" ) .as_bytes(); let mut scratch = ResponseScratch::new(); let resp = Response::parse(&data, &mut scratch).unwrap(); assert_eq!(resp.from, b"server"); assert_eq!(resp.ids.len(), 1); assert_eq!(resp.ids[0].id, b"1"); assert_eq!(resp.ids[0].seq, Some(0)); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.credits, 100); assert_eq!(rdata.more, true); assert_eq!(rdata.code, 200); assert_eq!(rdata.reason, "OK"); assert_eq!(rdata.headers.len(), 1); assert_eq!(rdata.headers[0].name, "Content-Type"); assert_eq!(rdata.headers[0].value, b"text/plain"); assert_eq!(rdata.body, b"hello"); let ctype = rdata.content_type.unwrap(); assert_eq!(ctype, ContentType::Binary); } #[test] fn test_owned_parse() { let data = concat!( "addr T208:4:more,4:true!7:headers,34:30:12:Content-Type,10:te", "xt/plain,]]12:content-type,6:binary,4:from,6:server,2:id,1:1,", "6:reason,2:OK,7:credits,3:100#9:user-data,12:3:foo,3:bar,}3:s", "eq,1:0#4:code,3:200#4:body,5:hello,}" ) .as_bytes(); let msg_memory = Arc::new(arena::ArcMemory::new(1)); let scratch_memory = Rc::new(arena::RcMemory::new(1)); let msg = arena::Arc::new(zmq::Message::from(data), &msg_memory).unwrap(); let scratch = arena::Rc::new(RefCell::new(ResponseScratch::new()), &scratch_memory).unwrap(); let resp = OwnedResponse::parse(msg, 5, scratch).unwrap(); let resp = resp.get(); assert_eq!(resp.from, b"server"); assert_eq!(resp.ids.len(), 1); assert_eq!(resp.ids[0].id, b"1"); assert_eq!(resp.ids[0].seq, Some(0)); } } condure-1.3.1/src/zhttpsocket.rs000064400000000000000000001536230000000000000147730ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::arena; use crate::channel; use crate::event; use crate::executor::Executor; use crate::future::{ select_9, select_option, select_slice, AsyncReceiver, AsyncSender, AsyncZmqSocket, RecvFuture, Select9, ZmqSendFuture, ZmqSendToFuture, REGISTRATIONS_PER_CHANNEL, REGISTRATIONS_PER_ZMQSOCKET, }; use crate::list; use crate::reactor::Reactor; use crate::tnetstring; use crate::zhttppacket::{Id, Response, ResponseScratch}; use crate::zmq::{MultipartHeader, SpecInfo, ZmqSocket}; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, log_enabled, trace, warn}; use slab::Slab; use std::cell::{Cell, RefCell}; use std::fmt; use std::future::Future; use std::io; use std::marker; use std::pin::Pin; use std::str; use std::str::FromStr; use std::sync::{mpsc, Arc, Mutex}; use std::task::{Context, Poll}; use std::thread; use std::time::Duration; const HANDLES_MAX: usize = 1_024; const STREAM_OUT_STREAM_DELAY: Duration = Duration::from_millis(50); const LOG_METADATA_MAX: usize = 1_000; const LOG_CONTENT_MAX: usize = 1_000; const EXECUTOR_TASKS_MAX: usize = 1; fn trim(s: &str, max: usize) -> String { // NOTE: O(n) let char_len = s.chars().count(); if char_len > max && max >= 7 { let dist = max / 2; let mut left_end = 0; let mut right_start = 0; // NOTE: O(n) for (i, (pos, _)) in s.char_indices().enumerate() { // dist guaranteed to be < char_len if i == dist { left_end = pos; } // (char_len - dist + 3) guaranteed to be < char_len if i == char_len - dist + 3 { right_start = pos; } } let left = &s[..left_end]; let right = &s[right_start..]; format!("{}...{}", left, right) } else { s.to_owned() } } struct Packet<'a> { map_frame: tnetstring::Frame<'a>, content_field: Option<&'a str>, } impl Packet<'_> { fn fmt_metadata(&self, f: &mut dyn io::Write) -> Result<(), io::Error> { let it = tnetstring::MapIterator::new(self.map_frame.data); write!(f, "{{ ")?; let mut first = true; for mi in it { let mi = match mi { Ok(mi) => mi, Err(_) => return Ok(()), }; if let Some(field) = self.content_field { if mi.key == field { continue; } } // can't fail let (frame, _) = tnetstring::parse_frame(mi.data).unwrap(); if !first { write!(f, ", ")?; } first = false; write!(f, "\"{}\": {}", mi.key, frame)?; } write!(f, " }}") } fn fmt_content(&self, f: &mut dyn io::Write) -> Result, io::Error> { let field = match self.content_field { Some(field) => field, None => return Ok(None), }; let it = tnetstring::MapIterator::new(self.map_frame.data); let mut ptype = &b""[..]; let mut content = None; for mi in it { let mi = match mi { Ok(mi) => mi, Err(_) => return Ok(None), }; if mi.key == "type" { ptype = match tnetstring::parse_string(mi.data) { Ok(s) => s, Err(_) => return Ok(None), }; } // can't fail let (frame, _) = tnetstring::parse_frame(mi.data).unwrap(); if mi.key == field { content = Some(frame); } } // only take content from data packets (ptype empty) if ptype.is_empty() { if let Some(frame) = content { write!(f, "{}", frame)?; return Ok(Some(frame.data.len())); } else { return Ok(Some(0)); } } Ok(None) } } impl fmt::Display for Packet<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut meta = Vec::new(); if self.fmt_metadata(&mut meta).is_err() { return Err(fmt::Error); } // formatted output is guaranteed to be utf8 let meta = String::from_utf8(meta).unwrap(); let meta = trim(&meta, LOG_METADATA_MAX); if self.content_field.is_some() { let mut content = Vec::new(); let clen = match self.fmt_content(&mut content) { Ok(clen) => clen, Err(_) => return Err(fmt::Error), }; if let Some(clen) = clen { // formatted output is guaranteed to be utf8 let content = String::from_utf8(content).unwrap(); let content = trim(&content, LOG_CONTENT_MAX); return write!(f, "{} {} {}", meta, clen, content); } } write!(f, "{}", meta) } } fn packet_to_string(data: &[u8]) -> String { if data.len() == 0 { return String::from(""); } if data[0] == b'T' { let (frame, _) = match tnetstring::parse_frame(&data[1..]) { Ok(frame) => frame, Err(e) => return format!("", e), }; if frame.ftype != tnetstring::FrameType::Map { return format!(""); } let p = Packet { map_frame: frame, content_field: Some("body"), }; p.to_string() } else { // maybe it's addr-prefixed let mut pos = None; for i in 0..data.len() { if data[i] == b' ' { pos = Some(i); break; } } if pos.is_none() { return String::from(""); } let pos = pos.unwrap(); let addr = match str::from_utf8(&data[..pos]) { Ok(addr) => addr, Err(e) => return format!("", e), }; let payload = &data[(pos + 1)..]; if payload.len() == 0 { return String::from(""); } if payload[0] != b'T' { return String::from(""); } let payload = &data[(pos + 2)..]; let (frame, _) = match tnetstring::parse_frame(payload) { Ok(frame) => frame, Err(e) => return format!("", e), }; if frame.ftype != tnetstring::FrameType::Map { return format!(""); } let p = Packet { map_frame: frame, content_field: Some("body"), }; format!("{} {}", addr, p) } } struct ClientReqSockets { sock: AsyncZmqSocket, } struct ClientStreamSockets { out: AsyncZmqSocket, out_stream: AsyncZmqSocket, in_: AsyncZmqSocket, } struct ReqPipeEnd { sender: channel::Sender>, receiver: channel::Receiver, } struct StreamPipeEnd { sender: channel::Sender>, receiver_any: channel::Receiver, receiver_addr: channel::Receiver<(ArrayVec<[u8; 64]>, zmq::Message)>, } struct AsyncReqPipeEnd { sender: AsyncSender>, receiver: AsyncReceiver, } struct AsyncStreamPipeEnd { sender: AsyncSender>, receiver_any: AsyncReceiver, receiver_addr: AsyncReceiver<(ArrayVec<[u8; 64]>, zmq::Message)>, } enum ControlRequest { Stop, SetClientReq(Vec), SetClientStream(Vec, Vec, Vec), AddClientReqHandle(ReqPipeEnd, ArrayString<[u8; 8]>), AddClientStreamHandle(StreamPipeEnd, ArrayString<[u8; 8]>), } type ControlResponse = Result<(), String>; struct ReqPipe { pe: AsyncReqPipeEnd, filter: ArrayString<[u8; 8]>, valid: Cell, } struct StreamPipe { pe: AsyncStreamPipeEnd, filter: ArrayString<[u8; 8]>, valid: Cell, } struct RecvWrapperFuture<'a, T> { fut: RecvFuture<'a, T>, nkey: usize, } impl Future for RecvWrapperFuture<'_, T> { type Output = (usize, Result); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let nkey = self.nkey; let fut = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.fut) }; match fut.poll(cx) { Poll::Ready(result) => match result { Ok(value) => Poll::Ready((nkey, Ok(value))), Err(mpsc::RecvError) => Poll::Ready((nkey, Err(mpsc::RecvError))), }, Poll::Pending => Poll::Pending, } } } struct RecvScratch { tasks: arena::ReusableVec, slice_scratch: Vec, _marker: marker::PhantomData, } impl RecvScratch { fn new(capacity: usize) -> Self { Self { tasks: arena::ReusableVec::new::>(capacity), slice_scratch: Vec::with_capacity(capacity), _marker: marker::PhantomData, } } fn get<'a>( &mut self, ) -> ( arena::ReusableVecHandle>, &mut Vec, ) { (self.tasks.get_as_new(), &mut self.slice_scratch) } } struct ReqHandles { nodes: Slab>, list: list::List, recv_scratch: RefCell>, need_cleanup: Cell, } impl ReqHandles { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_scratch: RefCell::new(RecvScratch::new(capacity)), need_cleanup: Cell::new(false), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncReqPipeEnd, filter: ArrayString<[u8; 8]>) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(ReqPipe { pe, filter, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } async fn recv(&self) -> zmq::Message { let mut scratch = self.recv_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(msg))) => return msg, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } async fn send(&self, msg: &arena::Arc, ids: &[Id<'_>]) { let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; let mut do_send = false; for id in ids.iter() { if id.id.starts_with(p.filter.as_bytes()) { do_send = true; break; } } if p.valid.get() && do_send { // blocking send. handle is expected to read as fast as possible // without downstream backpressure match p.pe.sender.send(arena::Arc::clone(msg)).await { Ok(_) => {} Err(_) => { p.valid.set(false); self.need_cleanup.set(true); } } } next = n.next; } } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&ReqPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } struct StreamHandles { nodes: Slab>, list: list::List, recv_any_scratch: RefCell>, recv_addr_scratch: RefCell, zmq::Message)>>, need_cleanup: Cell, } impl StreamHandles { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_any_scratch: RefCell::new(RecvScratch::new(capacity)), recv_addr_scratch: RefCell::new(RecvScratch::new(capacity)), need_cleanup: Cell::new(false), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncStreamPipeEnd, filter: ArrayString<[u8; 8]>) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(StreamPipe { pe, filter, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } async fn recv_any(&self) -> zmq::Message { let mut scratch = self.recv_any_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver_any.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(msg))) => return msg, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } async fn recv_addr(&self) -> (ArrayVec<[u8; 64]>, zmq::Message) { let mut scratch = self.recv_addr_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver_addr.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(ret))) => return ret, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } async fn send(&self, msg: &arena::Arc, ids: &[Id<'_>]) { let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; let mut do_send = false; for id in ids.iter() { if id.id.starts_with(p.filter.as_bytes()) { do_send = true; break; } } if p.valid.get() && do_send { // blocking send. handle is expected to read as fast as possible // without downstream backpressure match p.pe.sender.send(arena::Arc::clone(msg)).await { Ok(_) => {} Err(_) => { p.valid.set(false); self.need_cleanup.set(true); } } } next = n.next; } } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&StreamPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } pub struct SocketManager { handle_bound: usize, thread: Option>, control_pipe: Mutex<( channel::Sender, channel::Receiver, )>, } impl SocketManager { // retained_max is the maximum number of received messages that the user // will keep around at any moment. for example, if the user plans to // set up 4 handles on the manager and read 1 message at a time from // each of the handles (i.e. process and drop a message before reading // the next), then the value here should be 4, because there would be // no more than 4 dequeued messages alive at any one time. this number // is needed to help size the internal arena pub fn new( ctx: Arc, instance_id: &str, retained_max: usize, hwm: usize, handle_bound: usize, ) -> Self { let (s1, r1) = channel::channel(1); let (s2, r2) = channel::channel(1); let instance_id = String::from(instance_id); let thread = thread::spawn(move || { debug!("manager thread start"); // 2 control channels, 3 channels per handle, 4 zmq sockets let channels = 2 + (HANDLES_MAX * 3); let zmqsockets = 4; let registrations_max = (channels * REGISTRATIONS_PER_CHANNEL) + (zmqsockets * REGISTRATIONS_PER_ZMQSOCKET); let reactor = Reactor::new(registrations_max); let executor = Executor::new(EXECUTOR_TASKS_MAX); executor .spawn(Self::run(ctx, s1, r2, instance_id, retained_max, hwm)) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("manager thread end"); }); Self { handle_bound, thread: Some(thread), control_pipe: Mutex::new((s2, r1)), } } pub fn set_client_req_specs(&mut self, specs: &[SpecInfo]) -> Result<(), String> { self.control_req(ControlRequest::SetClientReq(specs.to_vec())) } pub fn set_client_stream_specs( &mut self, out_specs: &[SpecInfo], out_stream_specs: &[SpecInfo], in_specs: &[SpecInfo], ) -> Result<(), String> { self.control_req(ControlRequest::SetClientStream( out_specs.to_vec(), out_stream_specs.to_vec(), in_specs.to_vec(), )) } pub fn client_req_handle(&self, id_prefix: &[u8]) -> ClientReqHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let pe = ReqPipeEnd { sender: s1, receiver: r2, }; let prefix = ArrayString::from_str(str::from_utf8(id_prefix).unwrap()).unwrap(); self.control_send(ControlRequest::AddClientReqHandle(pe, prefix)); ClientReqHandle { sender: s2, receiver: r1, } } pub fn client_stream_handle(&self, id_prefix: &[u8]) -> ClientStreamHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let (s3, r3) = channel::channel(self.handle_bound); let pe = StreamPipeEnd { sender: s1, receiver_any: r2, receiver_addr: r3, }; let prefix = ArrayString::from_str(str::from_utf8(id_prefix).unwrap()).unwrap(); self.control_send(ControlRequest::AddClientStreamHandle(pe, prefix)); ClientStreamHandle { sender_any: s2, sender_addr: s3, receiver: r1, } } fn control_send(&self, req: ControlRequest) { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this will block if queue is full pipe.0.send(req).unwrap(); } fn control_req(&self, req: ControlRequest) -> Result<(), String> { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this is a blocking exchange pipe.0.send(req).unwrap(); pipe.1.recv().unwrap() } async fn run( ctx: Arc, control_sender: channel::Sender, control_receiver: channel::Receiver, instance_id: String, retained_max: usize, hwm: usize, ) { let control_sender = AsyncSender::new(control_sender); let control_receiver = AsyncReceiver::new(control_receiver); // the messages arena needs to fit the max number of potential incoming messages that // still need to be processed. this is the entire channel queue for every handle, plus // the most number of messages the user might retain, plus 1 extra for the next message // we are preparing to send to the handles let arena_size = (HANDLES_MAX * hwm) + retained_max + 1; let messages_memory = Arc::new(arena::SyncMemory::new(arena_size)); let client_req = ClientReqSockets { sock: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::DEALER)), }; let client_stream = ClientStreamSockets { out: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::PUSH)), out_stream: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER)), in_: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::SUB)), }; client_req .sock .inner() .inner() .set_sndhwm(hwm as i32) .unwrap(); client_req .sock .inner() .inner() .set_rcvhwm(hwm as i32) .unwrap(); client_stream .out .inner() .inner() .set_sndhwm(hwm as i32) .unwrap(); client_stream .out_stream .inner() .inner() .set_sndhwm(hwm as i32) .unwrap(); client_stream .in_ .inner() .inner() .set_rcvhwm(hwm as i32) .unwrap(); client_stream .out_stream .inner() .inner() .set_router_mandatory(true) .unwrap(); // a ROUTER socket may still be writable after returning EAGAIN, which // could mean that a different peer than the one we tried to write to // is writable. there's no way to know when the desired peer will be // writable, so we'll keep trying again after a delay client_stream .out_stream .set_retry_timeout(Some(STREAM_OUT_STREAM_DELAY)); let sub = format!("{} ", instance_id); client_stream .in_ .inner() .inner() .set_subscribe(sub.as_bytes()) .unwrap(); let mut req_handles = ReqHandles::new(HANDLES_MAX); let mut stream_handles = StreamHandles::new(HANDLES_MAX); let mut req_send: Option = None; let mut stream_out_send: Option = None; let mut stream_out_stream_send: Option = None; loop { let req_handles_recv = if req_send.is_none() { Some(req_handles.recv()) } else { None }; let stream_handles_recv_any = if stream_out_send.is_none() { Some(stream_handles.recv_any()) } else { None }; let stream_handles_recv_addr = if stream_out_stream_send.is_none() { Some(stream_handles.recv_addr()) } else { None }; let result = select_9( control_receiver.recv(), select_option(req_handles_recv), select_option(req_send.as_mut()), client_req.sock.recv_routed(), select_option(stream_handles_recv_any), select_option(stream_out_send.as_mut()), select_option(stream_handles_recv_addr), select_option(stream_out_stream_send.as_mut()), client_stream.in_.recv(), ) .await; match result { // control_receiver.recv Select9::R1(result) => match result { Ok(req) => match req { ControlRequest::Stop => break, ControlRequest::SetClientReq(specs) => { debug!("applying req specs: {:?}", specs); let result = Self::apply_req_specs(&client_req, &specs); control_sender .send(result) .await .expect("failed to send control response"); } ControlRequest::SetClientStream(out_specs, out_stream_specs, in_specs) => { debug!( "applying stream specs: {:?} {:?} {:?}", out_specs, out_stream_specs, in_specs ); let result = Self::apply_stream_specs( &client_stream, &out_specs, &out_stream_specs, &in_specs, ); control_sender .send(result) .await .expect("failed to send control response"); } ControlRequest::AddClientReqHandle(pe, filter) => { debug!("adding req handle: filter=[{}]", filter); if req_handles.len() + stream_handles.len() < HANDLES_MAX { req_handles.add( AsyncReqPipeEnd { sender: AsyncSender::new(pe.sender), receiver: AsyncReceiver::new(pe.receiver), }, filter, ); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } ControlRequest::AddClientStreamHandle(pe, filter) => { debug!("adding stream handle: filter=[{}]", filter); if req_handles.len() + stream_handles.len() < HANDLES_MAX { stream_handles.add( AsyncStreamPipeEnd { sender: AsyncSender::new(pe.sender), receiver_any: AsyncReceiver::new(pe.receiver_any), receiver_addr: AsyncReceiver::new(pe.receiver_addr), }, filter, ); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } }, Err(e) => error!("control recv: {}", e), }, // req_handles_recv Select9::R2(msg) => { if log_enabled!(log::Level::Trace) { trace!("OUT req {}", packet_to_string(&msg)); } let h = MultipartHeader::new(); req_send = Some(client_req.sock.send_to(h, msg)); } // req_send Select9::R3(result) => { if let Err(e) = result { error!("req zmq send: {}", e); } req_send = None; } // client_req.sock.recv_routed Select9::R4(result) => match result { Ok(msg) => { if log_enabled!(log::Level::Trace) { trace!("IN req {}", packet_to_string(&msg)); } Self::handle_req_message(msg, &messages_memory, &mut req_handles).await; } Err(e) => error!("req zmq recv: {}", e), }, // stream_handles_recv_any Select9::R5(msg) => { if log_enabled!(log::Level::Trace) { trace!("OUT stream {}", packet_to_string(&msg)); } stream_out_send = Some(client_stream.out.send(msg)); } // stream_out_send Select9::R6(result) => { if let Err(e) = result { error!("stream zmq send: {}", e); } stream_out_send = None; } // stream_handles_recv_addr Select9::R7((addr, msg)) => { let mut h = MultipartHeader::new(); h.push(zmq::Message::from(addr.as_ref())); if log_enabled!(log::Level::Trace) { trace!("OUT stream to {}", packet_to_string(&msg)); } stream_out_stream_send = Some(client_stream.out_stream.send_to(h, msg)); } // stream_out_stream_send Select9::R8(result) => { if let Err(e) = result { error!("stream zmq send to: {}", e); } stream_out_stream_send = None; } // client_stream.in_.recv Select9::R9(result) => match result { Ok(msg) => { if log_enabled!(log::Level::Trace) { trace!("IN stream {}", packet_to_string(&msg)); } Self::handle_stream_message( msg, &messages_memory, &instance_id, &mut stream_handles, ) .await; } Err(e) => error!("stream zmq recv: {}", e), }, } if req_handles.need_cleanup() { req_handles.cleanup(|p| { debug!("req handle disconnected: filter=[{}]", p.filter); }); } if stream_handles.need_cleanup() { stream_handles.cleanup(|p| { debug!("stream handle disconnected: filter=[{}]", p.filter); }); } } } fn apply_req_specs(client_req: &ClientReqSockets, specs: &[SpecInfo]) -> Result<(), String> { if let Err(e) = client_req.sock.inner().apply_specs(&specs) { return Err(e.to_string()); } return Ok(()); } fn apply_stream_specs( client_stream: &ClientStreamSockets, out_specs: &[SpecInfo], out_stream_specs: &[SpecInfo], in_specs: &[SpecInfo], ) -> Result<(), String> { if let Err(e) = client_stream.out.inner().apply_specs(&out_specs) { return Err(e.to_string()); } if let Err(e) = client_stream .out_stream .inner() .apply_specs(&out_stream_specs) { return Err(e.to_string()); } if let Err(e) = client_stream.in_.inner().apply_specs(&in_specs) { return Err(e.to_string()); } Ok(()) } async fn handle_req_message( msg: zmq::Message, messages_memory: &Arc>, handles: &mut ReqHandles, ) { let msg = arena::Arc::new(msg, messages_memory).unwrap(); let mut scratch = ResponseScratch::new(); let ids = match Response::parse_ids(msg.get(), &mut scratch) { Ok(ids) => ids, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send(&msg, ids).await; } async fn handle_stream_message( msg: zmq::Message, messages_memory: &Arc>, instance_id: &str, handles: &mut StreamHandles, ) { let msg = arena::Arc::new(msg, messages_memory).unwrap(); let buf = msg.get(); let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = match pos { Some(pos) => pos, None => { warn!("unable to determine packet address"); return; } }; let addr = &buf[..pos]; if addr != instance_id.as_bytes() { warn!("packet not for us"); return; } let buf = &buf[pos + 1..]; let mut scratch = ResponseScratch::new(); let ids = match Response::parse_ids(&buf, &mut scratch) { Ok(ids) => ids, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send(&msg, ids).await; } } impl Drop for SocketManager { fn drop(&mut self) { self.control_send(ControlRequest::Stop); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[derive(Debug)] pub enum SendError { Full(zmq::Message), Io(io::Error), } pub struct ClientReqHandle { sender: channel::Sender, receiver: channel::Receiver>, } impl ClientReqHandle { pub fn get_read_registration(&self) -> &event::Registration { self.receiver.get_read_registration() } pub fn get_write_registration(&self) -> &event::Registration { self.sender.get_write_registration() } pub fn recv(&self) -> Result, io::Error> { match self.receiver.try_recv() { Ok(msg) => Ok(msg), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send(&self, msg: zmq::Message) -> Result<(), SendError> { match self.sender.try_send(msg) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncClientReqHandle { sender: AsyncSender, receiver: AsyncReceiver>, } impl AsyncClientReqHandle { pub fn new(h: ClientReqHandle) -> Self { Self { sender: AsyncSender::new(h.sender), receiver: AsyncReceiver::new(h.receiver), } } pub async fn recv(&self) -> Result, io::Error> { match self.receiver.recv().await { Ok(msg) => Ok(msg), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send(&self, msg: zmq::Message) -> Result<(), io::Error> { match self.sender.send(msg).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } pub struct ClientStreamHandle { sender_any: channel::Sender, sender_addr: channel::Sender<(ArrayVec<[u8; 64]>, zmq::Message)>, receiver: channel::Receiver>, } impl ClientStreamHandle { pub fn get_read_registration(&self) -> &event::Registration { self.receiver.get_read_registration() } pub fn get_write_any_registration(&self) -> &event::Registration { self.sender_any.get_write_registration() } pub fn get_write_addr_registration(&self) -> &event::Registration { self.sender_addr.get_write_registration() } pub fn recv(&self) -> Result, io::Error> { match self.receiver.try_recv() { Ok(msg) => Ok(msg), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send_to_any(&self, msg: zmq::Message) -> Result<(), SendError> { match self.sender_any.try_send(msg) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } pub fn send_to_addr(&self, addr: &[u8], msg: zmq::Message) -> Result<(), SendError> { let mut a = ArrayVec::new(); if a.try_extend_from_slice(addr).is_err() { return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))); } match self.sender_addr.try_send((a, msg)) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncClientStreamHandle { sender_any: AsyncSender, sender_addr: AsyncSender<(ArrayVec<[u8; 64]>, zmq::Message)>, receiver: AsyncReceiver>, } impl AsyncClientStreamHandle { pub fn new(h: ClientStreamHandle) -> Self { Self { sender_any: AsyncSender::new(h.sender_any), sender_addr: AsyncSender::new(h.sender_addr), receiver: AsyncReceiver::new(h.receiver), } } pub async fn recv(&self) -> Result, io::Error> { match self.receiver.recv().await { Ok(msg) => Ok(msg), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send_to_any(&self, msg: zmq::Message) -> Result<(), io::Error> { match self.sender_any.send(msg).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send_to_addr( &self, addr: ArrayVec<[u8; 64]>, msg: zmq::Message, ) -> Result<(), io::Error> { match self.sender_addr.send((addr, msg)).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } #[cfg(test)] mod tests { use super::*; use crate::event; use crate::zhttppacket::ResponsePacket; use std::mem; fn wait_readable(poller: &mut event::Poller, token: mio::Token) { loop { poller.poll(None).unwrap(); for event in poller.iter_events() { if event.token() == token && event.is_readable() { return; } } } } fn wait_writable(poller: &mut event::Poller, token: mio::Token) { loop { poller.poll(None).unwrap(); for event in poller.iter_events() { if event.token() == token && event.is_writable() { return; } } } } #[test] fn test_send_flow() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = SocketManager::new(Arc::clone(&zmq_context), "test", 1, 1, 1); zsockman .set_client_stream_specs( &vec![SpecInfo { spec: String::from("inproc://flow-test-out"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://flow-test-out-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://flow-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); // connect an out-stream receiver. the other sockets we'll leave alone let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock .set_identity("test-handler".as_bytes()) .unwrap(); in_stream_sock.set_rcvhwm(1).unwrap(); in_stream_sock .connect("inproc://flow-test-out-stream") .unwrap(); let h = zsockman.client_stream_handle(b"a-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h.get_write_addr_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); // write four times, which will all succeed eventually. after this // we'll have filled the handle, the manager's temporary variable, // and the HWMs of both the sending and receiving zmq sockets for i in 1..=4 { loop { match h.send_to_addr( "test-handler".as_bytes(), zmq::Message::from(format!("{}", i).into_bytes()), ) { Ok(()) => break, Err(SendError::Full(_)) => wait_writable(&mut poller, mio::Token(1)), Err(SendError::Io(e)) => panic!("{:?}", e), } } } // once we were able to write a fourth time, this means the manager // has started processing the third message. let's wait a short bit // for the manager to attempt to send the third message to the zmq // socket and fail with EAGAIN thread::sleep(Duration::from_millis(10)); // fifth write will fail. there's no room let e = h .send_to_addr( "test-handler".as_bytes(), zmq::Message::from("5".as_bytes()), ) .unwrap_err(); let msg = match e { SendError::Full(msg) => msg, _ => panic!("unexpected error"), }; assert_eq!(str::from_utf8(&msg).unwrap(), "5"); // blocking read from the zmq socket so another message can flow let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"1"); // fifth write will now succeed, eventually loop { match h.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("5".as_bytes()), ) { Ok(()) => break, Err(SendError::Full(_)) => wait_writable(&mut poller, mio::Token(1)), Err(SendError::Io(e)) => panic!("{:?}", e), } } // read the rest of the messages for i in 2..=5 { let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], format!("{}", i).as_bytes()); } } #[test] fn test_req() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = SocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100); zsockman .set_client_req_specs(&vec![SpecInfo { spec: String::from("inproc://test-req"), bind: true, ipc_file_mode: 0, }]) .unwrap(); let h1 = zsockman.client_req_handle(b"a-"); let h2 = zsockman.client_req_handle(b"b-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); let rep_sock = zmq_context.socket(zmq::REP).unwrap(); rep_sock.connect("inproc://test-req").unwrap(); h1.send(zmq::Message::from("hello a".as_bytes())).unwrap(); let parts = rep_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello a"); rep_sock .send("T26:2:id,3:a-1,4:body,5:world,}".as_bytes(), 0) .unwrap(); let msg; loop { match h1.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let mut scratch = ResponseScratch::new(); let resp = Response::parse(&msg, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); h2.send(zmq::Message::from("hello b".as_bytes())).unwrap(); let parts = rep_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello b"); rep_sock .send("T26:2:id,3:b-1,4:body,5:world,}".as_bytes(), 0) .unwrap(); let msg; loop { match h2.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let mut scratch = ResponseScratch::new(); let resp = Response::parse(&msg, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); mem::drop(h1); mem::drop(h2); mem::drop(zsockman); } #[test] fn test_stream() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = SocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100); zsockman .set_client_stream_specs( &vec![SpecInfo { spec: String::from("inproc://test-out"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-out-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let h1 = zsockman.client_stream_handle(b"a-"); let h2 = zsockman.client_stream_handle(b"b-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); let in_sock = zmq_context.socket(zmq::PULL).unwrap(); in_sock.connect("inproc://test-out").unwrap(); let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock .set_identity("test-handler".as_bytes()) .unwrap(); in_stream_sock.connect("inproc://test-out-stream").unwrap(); let out_sock = zmq_context.socket(zmq::XPUB).unwrap(); out_sock.connect("inproc://test-in").unwrap(); // ensure zsockman is subscribed let msg = out_sock.recv_msg(0).unwrap(); assert_eq!(&msg[..], b"\x01test "); h1.send_to_any(zmq::Message::from("hello a".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello a"); out_sock .send( "test T49:4:from,12:test-handler,2:id,3:a-1,4:body,5:world,}".as_bytes(), 0, ) .unwrap(); let msg; loop { match h1.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let buf = &msg; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = pos.unwrap(); let buf = &buf[pos + 1..]; let mut scratch = ResponseScratch::new(); let resp = Response::parse(buf, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); h2.send_to_any(zmq::Message::from("hello b".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello b"); out_sock .send( "test T49:4:from,12:test-handler,2:id,3:b-1,4:body,5:world,}".as_bytes(), 0, ) .unwrap(); let msg; loop { match h2.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let buf = &msg; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = pos.unwrap(); let buf = &buf[pos + 1..]; let mut scratch = ResponseScratch::new(); let resp = Response::parse(buf, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); h1.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("hello a".as_bytes()), ) .unwrap(); let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"hello a"); h2.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("hello b".as_bytes()), ) .unwrap(); let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"hello b"); mem::drop(h1); mem::drop(h2); mem::drop(zsockman); } } condure-1.3.1/src/zmq.rs000064400000000000000000000324400000000000000132110ustar 00000000000000/* * Copyright (C) 2020-2021 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::cell::Cell; use std::cell::RefCell; use std::fmt; use std::fs; use std::io; use std::os::unix::fs::PermissionsExt; const MULTIPART_HEADERS_MAX: usize = 1; fn trim_prefix<'a>(s: &'a str, prefix: &str) -> Result<&'a str, ()> { if s.starts_with(prefix) { Ok(&s[prefix.len()..]) } else { Err(()) } } #[derive(Clone)] pub struct SpecInfo { pub spec: String, pub bind: bool, pub ipc_file_mode: usize, } impl fmt::Display for SpecInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.bind { write!(f, "bind:{}", self.spec) } else { write!(f, "connect:{}", self.spec) } } } impl fmt::Debug for SpecInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self) } } #[derive(Debug)] pub enum ZmqSocketError { Connect(String, zmq::Error), Bind(String, zmq::Error), SetMode(String, io::Error), } impl ToString for ZmqSocketError { fn to_string(&self) -> String { match self { ZmqSocketError::Connect(spec, e) => format!("connect {}: {}", spec, e), ZmqSocketError::Bind(spec, e) => format!("bind {}: {}", spec, e), ZmqSocketError::SetMode(spec, e) => format!("set mode {}: {}", spec, e), } } } #[derive(Clone)] struct ActiveSpec { pub spec: SpecInfo, pub endpoint: String, } fn unbind(sock: &zmq::Socket, endpoint: &str) -> zmq::Result<()> { // NOTE: use zmq_unbind instead when it becomes available in rust-zmq sock.disconnect(endpoint) } fn setup_spec(sock: &zmq::Socket, spec: &SpecInfo) -> Result { if spec.bind { match sock.bind(&spec.spec) { Ok(_) => { let endpoint = sock.get_last_endpoint().unwrap().unwrap(); if let Ok(path) = trim_prefix(&spec.spec, "ipc://") { if spec.ipc_file_mode > 0 { let perms = fs::Permissions::from_mode(spec.ipc_file_mode as u32); if let Err(e) = fs::set_permissions(path, perms) { // if setting perms fails, undo the bind unbind(sock, &endpoint).unwrap(); return Err(ZmqSocketError::SetMode(spec.spec.clone(), e)); } } } Ok(endpoint) } Err(e) => Err(ZmqSocketError::Bind(spec.spec.clone(), e)), } } else { match sock.connect(&spec.spec) { Ok(_) => Ok(spec.spec.clone()), Err(e) => Err(ZmqSocketError::Connect(spec.spec.clone(), e)), } } } fn unsetup_spec(sock: &zmq::Socket, spec: &ActiveSpec) { if spec.spec.bind { unbind(sock, &spec.endpoint).unwrap(); if let Ok(path) = trim_prefix(&spec.endpoint, "ipc://") { if fs::remove_file(path).is_err() { // oh well, we tried } } } else { sock.disconnect(&spec.endpoint).unwrap(); } } pub struct MultipartHeader { parts: [Option; MULTIPART_HEADERS_MAX], len: usize, } impl MultipartHeader { pub fn new() -> Self { Self { parts: [None; MULTIPART_HEADERS_MAX], len: 0, } } pub fn push(&mut self, msg: zmq::Message) { self.parts[self.len] = Some(msg); self.len += 1; } } pub struct ZmqSocket { inner: zmq::Socket, events: Cell, specs: RefCell>, } impl ZmqSocket { pub fn new(ctx: &zmq::Context, socket_type: zmq::SocketType) -> Self { Self { inner: ctx.socket(socket_type).unwrap(), events: Cell::new(zmq::PollEvents::empty()), specs: RefCell::new(Vec::new()), } } pub fn inner(&self) -> &zmq::Socket { &self.inner } pub fn update_events(&self) { loop { match self.inner.get_events() { Ok(events) => { self.events.set(events); break; } Err(zmq::Error::EINTR) => continue, Err(e) => panic!("get events error: {}", e), } } } pub fn events(&self) -> zmq::PollEvents { self.events.get() } pub fn send(&self, msg: zmq::Message, flags: i32) -> Result<(), zmq::Error> { let flags = flags & zmq::DONTWAIT; if let Err(e) = self.inner.send(msg, flags) { self.update_events(); return Err(e); } self.update_events(); Ok(()) } pub fn send_to( &self, header: &MultipartHeader, content: zmq::Message, flags: i32, ) -> Result<(), zmq::Error> { if header.len + 2 > 8 { panic!("cannot send more than 8 parts") } let mut headers: [&[u8]; 8] = [b""; 8]; for i in 0..header.len { headers[i] = &header.parts[i].as_ref().unwrap(); } let headers = &headers[..header.len]; let flags = flags & zmq::DONTWAIT; if let Err(e) = self.inner.send_multipart(headers, flags | zmq::SNDMORE) { self.update_events(); return Err(e); } if let Err(e) = self.inner.send(zmq::Message::new(), flags | zmq::SNDMORE) { self.update_events(); return Err(e); } self.send(content, flags) } pub fn recv(&self, flags: i32) -> Result { let flags = flags & zmq::DONTWAIT; // get the first part let msg = match self.inner.recv_msg(flags) { Ok(msg) => msg, Err(e) => { self.update_events(); return Err(e); } }; let flags = 0; // eat the rest of the parts while self.inner.get_rcvmore().unwrap() { self.inner.recv_msg(flags).unwrap(); } self.update_events(); Ok(msg) } pub fn recv_routed(&self, flags: i32) -> Result { let flags = flags & zmq::DONTWAIT; loop { // eat parts until we reach the separator match self.inner.recv_msg(flags) { Ok(msg) => { if msg.is_empty() { break; } } Err(e) => { self.update_events(); return Err(e); } } } let flags = 0; // if we get here, we've read the separator. content parts should follow if !self.inner.get_rcvmore().unwrap() { return Err(zmq::Error::EINVAL); } // get the first part of the content let msg = match self.inner.recv_msg(flags) { Ok(msg) => msg, Err(e) => { self.update_events(); return Err(e); } }; // eat the rest of the parts while self.inner.get_rcvmore().unwrap() { self.inner.recv_msg(flags).unwrap(); } self.update_events(); Ok(msg) } pub fn apply_specs(&self, new_specs: &[SpecInfo]) -> Result<(), ZmqSocketError> { let mut specs = self.specs.borrow_mut(); let mut to_remove = Vec::new(); for cur in specs.iter() { let mut found = false; for new in new_specs.iter() { if cur.spec.spec == new.spec && cur.spec.bind == new.bind { found = true; break; } } if !found { to_remove.push(cur.clone()); } } let mut to_add = Vec::new(); let mut to_update = Vec::new(); for new in new_specs.iter() { let mut found = None; for (ci, cur) in specs.iter().enumerate() { if new.spec == cur.spec.spec && new.bind == cur.spec.bind { found = Some(ci); break; } } match found { Some(ci) => { if new.ipc_file_mode != specs[ci].spec.ipc_file_mode { to_update.push(new.clone()); } } None => { to_add.push(new.clone()); } } } let mut added = Vec::new(); // add specs we dont have. on fail, undo them for spec in to_add.iter() { match setup_spec(&self.inner, spec) { Ok(endpoint) => { added.push(ActiveSpec { spec: spec.clone(), endpoint, }); } Err(e) => { // undo previous adds for spec in added.iter().rev() { unsetup_spec(&self.inner, spec); } return Err(e); } } } // update ipc file mode let mut prev_perms = Vec::new(); for spec in to_update.iter() { let mut err = None; if let Ok(path) = trim_prefix(&spec.spec, "ipc://") { if spec.ipc_file_mode > 0 { match fs::metadata(path) { Ok(meta) => { let perms = fs::Permissions::from_mode(spec.ipc_file_mode as u32); match fs::set_permissions(path, perms) { Ok(_) => { prev_perms.push((String::from(path), meta.permissions())); } Err(e) => { err = Some(ZmqSocketError::SetMode(spec.spec.clone(), e)); } } } Err(e) => { err = Some(ZmqSocketError::SetMode(spec.spec.clone(), e)); } } } } if let Some(err) = err { // undo previous perms changes for (path, perms) in prev_perms { if fs::set_permissions(path, perms).is_err() { // oh well, we tried } } // undo previous adds for spec in added.iter().rev() { unsetup_spec(&self.inner, spec); } return Err(err); } } for spec in to_remove.iter() { unsetup_spec(&self.inner, spec); } // move current specs aside let prev_specs = std::mem::replace(&mut *specs, Vec::new()); // recompute current specs for new in new_specs { let mut s = None; // is it one we added? for spec in added.iter() { if new.spec == spec.spec.spec && new.bind == spec.spec.bind { s = Some(spec.clone()); break; } } // else, it must be one we had already if s.is_none() { for spec in prev_specs.iter() { if new.spec == spec.spec.spec && new.bind == spec.spec.bind { s = Some(spec.clone()); break; } } } assert!(s.is_some()); specs.push(s.unwrap()); } Ok(()) } } impl Drop for ZmqSocket { fn drop(&mut self) { let specs = self.specs.borrow(); for spec in specs.iter() { unsetup_spec(&self.inner, spec); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_send_after_disconnect() { let zmq_context = zmq::Context::new(); let s = ZmqSocket::new(&zmq_context, zmq::REQ); s.apply_specs(&[SpecInfo { spec: String::from("inproc://send-test"), bind: true, ipc_file_mode: 0, }]) .unwrap(); assert_eq!(s.events().contains(zmq::POLLOUT), false); let r = ZmqSocket::new(&zmq_context, zmq::REP); r.apply_specs(&[SpecInfo { spec: String::from("inproc://send-test"), bind: false, ipc_file_mode: 0, }]) .unwrap(); s.update_events(); assert_eq!(s.events().contains(zmq::POLLOUT), true); drop(r); assert_eq!( s.send((&b"test"[..]).into(), zmq::DONTWAIT), Err(zmq::Error::EAGAIN) ); assert_eq!(s.events().contains(zmq::POLLOUT), false); } }