sqlx-mysql-0.8.3/.cargo_vcs_info.json0000644000000001500000000000100131720ustar { "git": { "sha1": "28cfdbb40c4fe535721c9ee5e1583409e0cac27e" }, "path_in_vcs": "sqlx-mysql" }sqlx-mysql-0.8.3/Cargo.toml0000644000000101510000000000100111720ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" name = "sqlx-mysql" version = "0.8.3" authors = [ "Ryan Leckey ", "Austin Bonander ", "Chloe Ross ", "Daniel Akhterov ", ] description = "MySQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." documentation = "https://docs.rs/sqlx" license = "MIT OR Apache-2.0" repository = "https://github.com/launchbadge/sqlx" [dependencies.atoi] version = "2.0" [dependencies.base64] version = "0.22.0" features = ["std"] default-features = false [dependencies.bigdecimal] version = "0.4.0" optional = true [dependencies.bitflags] version = "2" features = ["serde"] default-features = false [dependencies.byteorder] version = "1.4.3" features = ["std"] default-features = false [dependencies.bytes] version = "1.1.0" [dependencies.chrono] version = "0.4.34" features = [ "std", "clock", ] optional = true default-features = false [dependencies.crc] version = "3.0.0" [dependencies.digest] version = "0.10.0" features = ["std"] default-features = false [dependencies.dotenvy] version = "0.15.5" [dependencies.either] version = "1.6.1" [dependencies.futures-channel] version = "0.3.19" features = [ "sink", "alloc", "std", ] default-features = false [dependencies.futures-core] version = "0.3.19" default-features = false [dependencies.futures-io] version = "0.3.24" [dependencies.futures-util] version = "0.3.19" features = [ "alloc", "sink", "io", ] default-features = false [dependencies.generic-array] version = "0.14.4" default-features = false [dependencies.hex] version = "0.4.3" [dependencies.hkdf] version = "0.12.0" [dependencies.hmac] version = "0.12.0" default-features = false [dependencies.itoa] version = "1.0.1" [dependencies.log] version = "0.4.18" [dependencies.md-5] version = "0.10.0" default-features = false [dependencies.memchr] version = "2.4.1" default-features = false [dependencies.once_cell] version = "1.9.0" [dependencies.percent-encoding] version = "2.1.0" [dependencies.rand] version = "0.8.4" features = [ "std", "std_rng", ] default-features = false [dependencies.rsa] version = "0.9" [dependencies.rust_decimal] version = "1.26.1" features = ["std"] optional = true default-features = false [dependencies.serde] version = "1.0.144" optional = true [dependencies.sha1] version = "0.10.1" default-features = false [dependencies.sha2] version = "0.10.0" default-features = false [dependencies.smallvec] version = "1.7.0" [dependencies.sqlx-core] version = "=0.8.3" [dependencies.stringprep] version = "0.1.2" [dependencies.thiserror] version = "2.0.0" [dependencies.time] version = "0.3.36" features = [ "formatting", "parsing", "macros", ] optional = true [dependencies.tracing] version = "0.1.37" features = ["log"] [dependencies.uuid] version = "1.1.2" optional = true [dependencies.whoami] version = "1.2.1" default-features = false [dev-dependencies.sqlx] version = "=0.8.3" features = ["mysql"] default-features = false [features] any = ["sqlx-core/any"] bigdecimal = [ "dep:bigdecimal", "sqlx-core/bigdecimal", ] chrono = [ "dep:chrono", "sqlx-core/chrono", ] json = [ "sqlx-core/json", "serde", ] migrate = ["sqlx-core/migrate"] offline = [ "sqlx-core/offline", "serde/derive", ] rust_decimal = [ "dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal", ] time = [ "dep:time", "sqlx-core/time", ] uuid = [ "dep:uuid", "sqlx-core/uuid", ] [lints.clippy] cast_possible_truncation = "deny" cast_possible_wrap = "deny" cast_sign_loss = "deny" disallowed_methods = "deny" sqlx-mysql-0.8.3/Cargo.toml.orig000064400000000000000000000053711046102023000146630ustar 00000000000000[package] name = "sqlx-mysql" documentation = "https://docs.rs/sqlx" description = "MySQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." version.workspace = true license.workspace = true edition.workspace = true authors.workspace = true repository.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] json = ["sqlx-core/json", "serde"] any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] chrono = ["dep:chrono", "sqlx-core/chrono"] rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] time = ["dep:time", "sqlx-core/time"] uuid = ["dep:uuid", "sqlx-core/uuid"] [dependencies] sqlx-core = { workspace = true } # Futures crates futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] } futures-core = { version = "0.3.19", default-features = false } futures-io = "0.3.24" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } # Cryptographic Primitives crc = "3.0.0" digest = { version = "0.10.0", default-features = false, features = ["std"] } hkdf = "0.12.0" hmac = { version = "0.12.0", default-features = false } md-5 = { version = "0.10.0", default-features = false } rand = { version = "0.8.4", default-features = false, features = ["std", "std_rng"] } rsa = "0.9" sha1 = { version = "0.10.1", default-features = false } sha2 = { version = "0.10.0", default-features = false } # Type Integrations (versions inherited from `[workspace.dependencies]`) bigdecimal = { workspace = true, optional = true } chrono = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } uuid = { workspace = true, optional = true } # Misc atoi = "2.0" base64 = { version = "0.22.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false, features = ["serde"] } byteorder = { version = "1.4.3", default-features = false, features = ["std"] } bytes = "1.1.0" dotenvy = "0.15.5" either = "1.6.1" generic-array = { version = "0.14.4", default-features = false } hex = "0.4.3" itoa = "1.0.1" log = "0.4.18" memchr = { version = "2.4.1", default-features = false } once_cell = "1.9.0" percent-encoding = "2.1.0" smallvec = "1.7.0" stringprep = "0.1.2" thiserror = "2.0.0" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } serde = { version = "1.0.144", optional = true } [dev-dependencies] sqlx = { workspace = true, features = ["mysql"] } [lints] workspace = true sqlx-mysql-0.8.3/LICENSE-APACHE000064400000000000000000000240031046102023000137110ustar 00000000000000Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2020 LaunchBadge, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.sqlx-mysql-0.8.3/LICENSE-MIT000064400000000000000000000020441046102023000134220ustar 00000000000000Copyright (c) 2020 LaunchBadge, LLC Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sqlx-mysql-0.8.3/src/any.rs000064400000000000000000000157171046102023000137250ustar 00000000000000use crate::protocol::text::ColumnType; use crate::{ MySql, MySqlColumn, MySqlConnectOptions, MySqlConnection, MySqlQueryResult, MySqlRow, MySqlTransactionManager, MySqlTypeInfo, }; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, }; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; use std::future; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); impl AnyConnectionBackend for MySqlConnection { fn name(&self) -> &str { ::NAME } fn close(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { Connection::close(*self) } fn close_hard(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { Connection::close_hard(*self) } fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::ping(self) } fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { MySqlTransactionManager::begin(self) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { MySqlTransactionManager::commit(self) } fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { MySqlTransactionManager::rollback(self) } fn start_rollback(&mut self) { MySqlTransactionManager::start_rollback(self) } fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::flush(self) } fn should_flush(&self) -> bool { Connection::should_flush(self) } #[cfg(feature = "migrate")] fn as_migrate( &mut self, ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { Ok(self) } fn fetch_many<'q>( &'q mut self, query: &'q str, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { let persistent = persistent && arguments.is_some(); let arguments = match arguments.as_ref().map(AnyArguments::convert_to).transpose() { Ok(arguments) => arguments, Err(error) => { return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() } }; Box::pin( self.run(query, arguments, persistent) .try_flatten_stream() .map(|res| { Ok(match res? { Either::Left(result) => Either::Left(map_result(result)), Either::Right(row) => Either::Right(AnyRow::try_from(&row)?), }) }), ) } fn fetch_optional<'q>( &'q mut self, query: &'q str, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { let persistent = persistent && arguments.is_some(); let arguments = arguments .as_ref() .map(AnyArguments::convert_to) .transpose() .map_err(sqlx_core::Error::Encode); Box::pin(async move { let arguments = arguments?; let stream = self.run(query, arguments, persistent).await?; futures_util::pin_mut!(stream); while let Some(result) = stream.try_next().await? { if let Either::Right(row) = result { return Ok(Some(AnyRow::try_from(&row)?)); } } Ok(None) }) } fn prepare_with<'c, 'q: 'c>( &'c mut self, sql: &'q str, _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result>> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; AnyStatement::try_from_statement( sql, &statement, statement.metadata.column_names.clone(), ) }) } fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; describe.try_into_any() }) } } impl<'a> TryFrom<&'a MySqlTypeInfo> for AnyTypeInfo { type Error = sqlx_core::Error; fn try_from(type_info: &'a MySqlTypeInfo) -> Result { Ok(AnyTypeInfo { kind: match &type_info.r#type { ColumnType::Null => AnyTypeInfoKind::Null, ColumnType::Short => AnyTypeInfoKind::SmallInt, ColumnType::Long => AnyTypeInfoKind::Integer, ColumnType::LongLong => AnyTypeInfoKind::BigInt, ColumnType::Float => AnyTypeInfoKind::Real, ColumnType::Double => AnyTypeInfoKind::Double, ColumnType::Blob | ColumnType::TinyBlob | ColumnType::MediumBlob | ColumnType::LongBlob => AnyTypeInfoKind::Blob, ColumnType::String | ColumnType::VarString | ColumnType::VarChar => { AnyTypeInfoKind::Text } _ => { return Err(sqlx_core::Error::AnyDriverError( format!("Any driver does not support MySql type {type_info:?}").into(), )) } }, }) } } impl<'a> TryFrom<&'a MySqlColumn> for AnyColumn { type Error = sqlx_core::Error; fn try_from(column: &'a MySqlColumn) -> Result { let type_info = AnyTypeInfo::try_from(&column.type_info)?; Ok(AnyColumn { ordinal: column.ordinal, name: column.name.clone(), type_info, }) } } impl<'a> TryFrom<&'a MySqlRow> for AnyRow { type Error = sqlx_core::Error; fn try_from(row: &'a MySqlRow) -> Result { AnyRow::map_from(row, row.column_names.clone()) } } impl<'a> TryFrom<&'a AnyConnectOptions> for MySqlConnectOptions { type Error = sqlx_core::Error; fn try_from(any_opts: &'a AnyConnectOptions) -> Result { let mut opts = Self::parse_from_url(&any_opts.database_url)?; opts.log_settings = any_opts.log_settings.clone(); Ok(opts) } } fn map_result(result: MySqlQueryResult) -> AnyQueryResult { AnyQueryResult { rows_affected: result.rows_affected, // Don't expect this to be a problem #[allow(clippy::cast_possible_wrap)] last_insert_id: Some(result.last_insert_id as i64), } } sqlx-mysql-0.8.3/src/arguments.rs000064400000000000000000000052671046102023000151420ustar 00000000000000use crate::encode::{Encode, IsNull}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo}; pub(crate) use sqlx_core::arguments::*; use sqlx_core::error::BoxDynError; use std::ops::Deref; /// Implementation of [`Arguments`] for MySQL. #[derive(Debug, Default, Clone)] pub struct MySqlArguments { pub(crate) values: Vec, pub(crate) types: Vec, pub(crate) null_bitmap: NullBitMap, } impl MySqlArguments { pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, MySql> + Type, { let ty = value.produces().unwrap_or_else(T::type_info); let value_length_before_encoding = self.values.len(); let is_null = match value.encode(&mut self.values) { Ok(is_null) => is_null, Err(error) => { // reset the value buffer to its previous value if encoding failed so we don't leave a half-encoded value behind self.values.truncate(value_length_before_encoding); return Err(error); } }; self.types.push(ty); self.null_bitmap.push(is_null); Ok(()) } } impl<'q> Arguments<'q> for MySqlArguments { type Database = MySql; fn reserve(&mut self, len: usize, size: usize) { self.types.reserve(len); self.values.reserve(size); } fn add(&mut self, value: T) -> Result<(), BoxDynError> where T: Encode<'q, Self::Database> + Type, { self.add(value) } fn len(&self) -> usize { self.types.len() } } #[derive(Debug, Default, Clone)] pub(crate) struct NullBitMap { bytes: Vec, length: usize, } impl NullBitMap { fn push(&mut self, is_null: IsNull) { let byte_index = self.length / (u8::BITS as usize); let bit_offset = self.length % (u8::BITS as usize); if bit_offset == 0 { self.bytes.push(0); } self.bytes[byte_index] |= u8::from(is_null.is_null()) << bit_offset; self.length += 1; } } impl Deref for NullBitMap { type Target = [u8]; fn deref(&self) -> &Self::Target { &self.bytes } } #[cfg(test)] mod test { use super::*; #[test] fn null_bit_map_should_push_is_null() { let mut bit_map = NullBitMap::default(); bit_map.push(IsNull::Yes); bit_map.push(IsNull::No); bit_map.push(IsNull::Yes); bit_map.push(IsNull::No); bit_map.push(IsNull::Yes); bit_map.push(IsNull::No); bit_map.push(IsNull::Yes); bit_map.push(IsNull::No); bit_map.push(IsNull::Yes); assert_eq!([0b01010101, 0b1].as_slice(), bit_map.deref()); } } sqlx-mysql-0.8.3/src/collation.rs000064400000000000000000001204301046102023000151070ustar 00000000000000use crate::error::Error; use std::str::FromStr; #[allow(non_camel_case_types)] #[derive(Copy, Clone)] pub(crate) enum CharSet { armscii8, ascii, big5, binary, cp1250, cp1251, cp1256, cp1257, cp850, cp852, cp866, cp932, dec8, eucjpms, euckr, gb18030, gb2312, gbk, geostd8, greek, hebrew, hp8, keybcs2, koi8r, koi8u, latin1, latin2, latin5, latin7, macce, macroman, sjis, swe7, tis620, ucs2, ujis, utf16, utf16le, utf32, utf8, utf8mb4, } impl CharSet { pub(crate) fn as_str(&self) -> &'static str { match self { CharSet::armscii8 => "armscii8", CharSet::ascii => "ascii", CharSet::big5 => "big5", CharSet::binary => "binary", CharSet::cp1250 => "cp1250", CharSet::cp1251 => "cp1251", CharSet::cp1256 => "cp1256", CharSet::cp1257 => "cp1257", CharSet::cp850 => "cp850", CharSet::cp852 => "cp852", CharSet::cp866 => "cp866", CharSet::cp932 => "cp932", CharSet::dec8 => "dec8", CharSet::eucjpms => "eucjpms", CharSet::euckr => "euckr", CharSet::gb18030 => "gb18030", CharSet::gb2312 => "gb2312", CharSet::gbk => "gbk", CharSet::geostd8 => "geostd8", CharSet::greek => "greek", CharSet::hebrew => "hebrew", CharSet::hp8 => "hp8", CharSet::keybcs2 => "keybcs2", CharSet::koi8r => "koi8r", CharSet::koi8u => "koi8u", CharSet::latin1 => "latin1", CharSet::latin2 => "latin2", CharSet::latin5 => "latin5", CharSet::latin7 => "latin7", CharSet::macce => "macce", CharSet::macroman => "macroman", CharSet::sjis => "sjis", CharSet::swe7 => "swe7", CharSet::tis620 => "tis620", CharSet::ucs2 => "ucs2", CharSet::ujis => "ujis", CharSet::utf16 => "utf16", CharSet::utf16le => "utf16le", CharSet::utf32 => "utf32", CharSet::utf8 => "utf8", CharSet::utf8mb4 => "utf8mb4", } } pub(crate) fn default_collation(&self) -> Collation { match self { CharSet::armscii8 => Collation::armscii8_general_ci, CharSet::ascii => Collation::ascii_general_ci, CharSet::big5 => Collation::big5_chinese_ci, CharSet::binary => Collation::binary, CharSet::cp1250 => Collation::cp1250_general_ci, CharSet::cp1251 => Collation::cp1251_general_ci, CharSet::cp1256 => Collation::cp1256_general_ci, CharSet::cp1257 => Collation::cp1257_general_ci, CharSet::cp850 => Collation::cp850_general_ci, CharSet::cp852 => Collation::cp852_general_ci, CharSet::cp866 => Collation::cp866_general_ci, CharSet::cp932 => Collation::cp932_japanese_ci, CharSet::dec8 => Collation::dec8_swedish_ci, CharSet::eucjpms => Collation::eucjpms_japanese_ci, CharSet::euckr => Collation::euckr_korean_ci, CharSet::gb18030 => Collation::gb18030_chinese_ci, CharSet::gb2312 => Collation::gb2312_chinese_ci, CharSet::gbk => Collation::gbk_chinese_ci, CharSet::geostd8 => Collation::geostd8_general_ci, CharSet::greek => Collation::greek_general_ci, CharSet::hebrew => Collation::hebrew_general_ci, CharSet::hp8 => Collation::hp8_english_ci, CharSet::keybcs2 => Collation::keybcs2_general_ci, CharSet::koi8r => Collation::koi8r_general_ci, CharSet::koi8u => Collation::koi8u_general_ci, CharSet::latin1 => Collation::latin1_swedish_ci, CharSet::latin2 => Collation::latin2_general_ci, CharSet::latin5 => Collation::latin5_turkish_ci, CharSet::latin7 => Collation::latin7_general_ci, CharSet::macce => Collation::macce_general_ci, CharSet::macroman => Collation::macroman_general_ci, CharSet::sjis => Collation::sjis_japanese_ci, CharSet::swe7 => Collation::swe7_swedish_ci, CharSet::tis620 => Collation::tis620_thai_ci, CharSet::ucs2 => Collation::ucs2_general_ci, CharSet::ujis => Collation::ujis_japanese_ci, CharSet::utf16 => Collation::utf16_general_ci, CharSet::utf16le => Collation::utf16le_general_ci, CharSet::utf32 => Collation::utf32_general_ci, CharSet::utf8 => Collation::utf8_unicode_ci, CharSet::utf8mb4 => Collation::utf8mb4_unicode_ci, } } } impl FromStr for CharSet { type Err = Error; fn from_str(char_set: &str) -> Result { Ok(match char_set { "armscii8" => CharSet::armscii8, "ascii" => CharSet::ascii, "big5" => CharSet::big5, "binary" => CharSet::binary, "cp1250" => CharSet::cp1250, "cp1251" => CharSet::cp1251, "cp1256" => CharSet::cp1256, "cp1257" => CharSet::cp1257, "cp850" => CharSet::cp850, "cp852" => CharSet::cp852, "cp866" => CharSet::cp866, "cp932" => CharSet::cp932, "dec8" => CharSet::dec8, "eucjpms" => CharSet::eucjpms, "euckr" => CharSet::euckr, "gb18030" => CharSet::gb18030, "gb2312" => CharSet::gb2312, "gbk" => CharSet::gbk, "geostd8" => CharSet::geostd8, "greek" => CharSet::greek, "hebrew" => CharSet::hebrew, "hp8" => CharSet::hp8, "keybcs2" => CharSet::keybcs2, "koi8r" => CharSet::koi8r, "koi8u" => CharSet::koi8u, "latin1" => CharSet::latin1, "latin2" => CharSet::latin2, "latin5" => CharSet::latin5, "latin7" => CharSet::latin7, "macce" => CharSet::macce, "macroman" => CharSet::macroman, "sjis" => CharSet::sjis, "swe7" => CharSet::swe7, "tis620" => CharSet::tis620, "ucs2" => CharSet::ucs2, "ujis" => CharSet::ujis, "utf16" => CharSet::utf16, "utf16le" => CharSet::utf16le, "utf32" => CharSet::utf32, "utf8" => CharSet::utf8, "utf8mb4" => CharSet::utf8mb4, _ => { return Err(Error::Configuration( format!("unsupported MySQL charset: {char_set}").into(), )); } }) } } #[derive(Copy, Clone)] #[allow(non_camel_case_types)] #[repr(u8)] pub(crate) enum Collation { armscii8_bin = 64, armscii8_general_ci = 32, ascii_bin = 65, ascii_general_ci = 11, big5_bin = 84, big5_chinese_ci = 1, binary = 63, cp1250_bin = 66, cp1250_croatian_ci = 44, cp1250_czech_cs = 34, cp1250_general_ci = 26, cp1250_polish_ci = 99, cp1251_bin = 50, cp1251_bulgarian_ci = 14, cp1251_general_ci = 51, cp1251_general_cs = 52, cp1251_ukrainian_ci = 23, cp1256_bin = 67, cp1256_general_ci = 57, cp1257_bin = 58, cp1257_general_ci = 59, cp1257_lithuanian_ci = 29, cp850_bin = 80, cp850_general_ci = 4, cp852_bin = 81, cp852_general_ci = 40, cp866_bin = 68, cp866_general_ci = 36, cp932_bin = 96, cp932_japanese_ci = 95, dec8_bin = 69, dec8_swedish_ci = 3, eucjpms_bin = 98, eucjpms_japanese_ci = 97, euckr_bin = 85, euckr_korean_ci = 19, gb18030_bin = 249, gb18030_chinese_ci = 248, gb18030_unicode_520_ci = 250, gb2312_bin = 86, gb2312_chinese_ci = 24, gbk_bin = 87, gbk_chinese_ci = 28, geostd8_bin = 93, geostd8_general_ci = 92, greek_bin = 70, greek_general_ci = 25, hebrew_bin = 71, hebrew_general_ci = 16, hp8_bin = 72, hp8_english_ci = 6, keybcs2_bin = 73, keybcs2_general_ci = 37, koi8r_bin = 74, koi8r_general_ci = 7, koi8u_bin = 75, koi8u_general_ci = 22, latin1_bin = 47, latin1_danish_ci = 15, latin1_general_ci = 48, latin1_general_cs = 49, latin1_german1_ci = 5, latin1_german2_ci = 31, latin1_spanish_ci = 94, latin1_swedish_ci = 8, latin2_bin = 77, latin2_croatian_ci = 27, latin2_czech_cs = 2, latin2_general_ci = 9, latin2_hungarian_ci = 21, latin5_bin = 78, latin5_turkish_ci = 30, latin7_bin = 79, latin7_estonian_cs = 20, latin7_general_ci = 41, latin7_general_cs = 42, macce_bin = 43, macce_general_ci = 38, macroman_bin = 53, macroman_general_ci = 39, sjis_bin = 88, sjis_japanese_ci = 13, swe7_bin = 82, swe7_swedish_ci = 10, tis620_bin = 89, tis620_thai_ci = 18, ucs2_bin = 90, ucs2_croatian_ci = 149, ucs2_czech_ci = 138, ucs2_danish_ci = 139, ucs2_esperanto_ci = 145, ucs2_estonian_ci = 134, ucs2_general_ci = 35, ucs2_general_mysql500_ci = 159, ucs2_german2_ci = 148, ucs2_hungarian_ci = 146, ucs2_icelandic_ci = 129, ucs2_latvian_ci = 130, ucs2_lithuanian_ci = 140, ucs2_persian_ci = 144, ucs2_polish_ci = 133, ucs2_roman_ci = 143, ucs2_romanian_ci = 131, ucs2_sinhala_ci = 147, ucs2_slovak_ci = 141, ucs2_slovenian_ci = 132, ucs2_spanish_ci = 135, ucs2_spanish2_ci = 142, ucs2_swedish_ci = 136, ucs2_turkish_ci = 137, ucs2_unicode_520_ci = 150, ucs2_unicode_ci = 128, ucs2_vietnamese_ci = 151, ujis_bin = 91, ujis_japanese_ci = 12, utf16_bin = 55, utf16_croatian_ci = 122, utf16_czech_ci = 111, utf16_danish_ci = 112, utf16_esperanto_ci = 118, utf16_estonian_ci = 107, utf16_general_ci = 54, utf16_german2_ci = 121, utf16_hungarian_ci = 119, utf16_icelandic_ci = 102, utf16_latvian_ci = 103, utf16_lithuanian_ci = 113, utf16_persian_ci = 117, utf16_polish_ci = 106, utf16_roman_ci = 116, utf16_romanian_ci = 104, utf16_sinhala_ci = 120, utf16_slovak_ci = 114, utf16_slovenian_ci = 105, utf16_spanish_ci = 108, utf16_spanish2_ci = 115, utf16_swedish_ci = 109, utf16_turkish_ci = 110, utf16_unicode_520_ci = 123, utf16_unicode_ci = 101, utf16_vietnamese_ci = 124, utf16le_bin = 62, utf16le_general_ci = 56, utf32_bin = 61, utf32_croatian_ci = 181, utf32_czech_ci = 170, utf32_danish_ci = 171, utf32_esperanto_ci = 177, utf32_estonian_ci = 166, utf32_general_ci = 60, utf32_german2_ci = 180, utf32_hungarian_ci = 178, utf32_icelandic_ci = 161, utf32_latvian_ci = 162, utf32_lithuanian_ci = 172, utf32_persian_ci = 176, utf32_polish_ci = 165, utf32_roman_ci = 175, utf32_romanian_ci = 163, utf32_sinhala_ci = 179, utf32_slovak_ci = 173, utf32_slovenian_ci = 164, utf32_spanish_ci = 167, utf32_spanish2_ci = 174, utf32_swedish_ci = 168, utf32_turkish_ci = 169, utf32_unicode_520_ci = 182, utf32_unicode_ci = 160, utf32_vietnamese_ci = 183, utf8_bin = 83, utf8_croatian_ci = 213, utf8_czech_ci = 202, utf8_danish_ci = 203, utf8_esperanto_ci = 209, utf8_estonian_ci = 198, utf8_general_ci = 33, utf8_general_mysql500_ci = 223, utf8_german2_ci = 212, utf8_hungarian_ci = 210, utf8_icelandic_ci = 193, utf8_latvian_ci = 194, utf8_lithuanian_ci = 204, utf8_persian_ci = 208, utf8_polish_ci = 197, utf8_roman_ci = 207, utf8_romanian_ci = 195, utf8_sinhala_ci = 211, utf8_slovak_ci = 205, utf8_slovenian_ci = 196, utf8_spanish_ci = 199, utf8_spanish2_ci = 206, utf8_swedish_ci = 200, utf8_tolower_ci = 76, utf8_turkish_ci = 201, utf8_unicode_520_ci = 214, utf8_unicode_ci = 192, utf8_vietnamese_ci = 215, utf8mb4_0900_ai_ci = 255, utf8mb4_bin = 46, utf8mb4_croatian_ci = 245, utf8mb4_czech_ci = 234, utf8mb4_danish_ci = 235, utf8mb4_esperanto_ci = 241, utf8mb4_estonian_ci = 230, utf8mb4_general_ci = 45, utf8mb4_german2_ci = 244, utf8mb4_hungarian_ci = 242, utf8mb4_icelandic_ci = 225, utf8mb4_latvian_ci = 226, utf8mb4_lithuanian_ci = 236, utf8mb4_persian_ci = 240, utf8mb4_polish_ci = 229, utf8mb4_roman_ci = 239, utf8mb4_romanian_ci = 227, utf8mb4_sinhala_ci = 243, utf8mb4_slovak_ci = 237, utf8mb4_slovenian_ci = 228, utf8mb4_spanish_ci = 231, utf8mb4_spanish2_ci = 238, utf8mb4_swedish_ci = 232, utf8mb4_turkish_ci = 233, utf8mb4_unicode_520_ci = 246, utf8mb4_unicode_ci = 224, utf8mb4_vietnamese_ci = 247, } impl Collation { pub(crate) fn as_str(&self) -> &'static str { match self { Collation::armscii8_bin => "armscii8_bin", Collation::armscii8_general_ci => "armscii8_general_ci", Collation::ascii_bin => "ascii_bin", Collation::ascii_general_ci => "ascii_general_ci", Collation::big5_bin => "big5_bin", Collation::big5_chinese_ci => "big5_chinese_ci", Collation::binary => "binary", Collation::cp1250_bin => "cp1250_bin", Collation::cp1250_croatian_ci => "cp1250_croatian_ci", Collation::cp1250_czech_cs => "cp1250_czech_cs", Collation::cp1250_general_ci => "cp1250_general_ci", Collation::cp1250_polish_ci => "cp1250_polish_ci", Collation::cp1251_bin => "cp1251_bin", Collation::cp1251_bulgarian_ci => "cp1251_bulgarian_ci", Collation::cp1251_general_ci => "cp1251_general_ci", Collation::cp1251_general_cs => "cp1251_general_cs", Collation::cp1251_ukrainian_ci => "cp1251_ukrainian_ci", Collation::cp1256_bin => "cp1256_bin", Collation::cp1256_general_ci => "cp1256_general_ci", Collation::cp1257_bin => "cp1257_bin", Collation::cp1257_general_ci => "cp1257_general_ci", Collation::cp1257_lithuanian_ci => "cp1257_lithuanian_ci", Collation::cp850_bin => "cp850_bin", Collation::cp850_general_ci => "cp850_general_ci", Collation::cp852_bin => "cp852_bin", Collation::cp852_general_ci => "cp852_general_ci", Collation::cp866_bin => "cp866_bin", Collation::cp866_general_ci => "cp866_general_ci", Collation::cp932_bin => "cp932_bin", Collation::cp932_japanese_ci => "cp932_japanese_ci", Collation::dec8_bin => "dec8_bin", Collation::dec8_swedish_ci => "dec8_swedish_ci", Collation::eucjpms_bin => "eucjpms_bin", Collation::eucjpms_japanese_ci => "eucjpms_japanese_ci", Collation::euckr_bin => "euckr_bin", Collation::euckr_korean_ci => "euckr_korean_ci", Collation::gb18030_bin => "gb18030_bin", Collation::gb18030_chinese_ci => "gb18030_chinese_ci", Collation::gb18030_unicode_520_ci => "gb18030_unicode_520_ci", Collation::gb2312_bin => "gb2312_bin", Collation::gb2312_chinese_ci => "gb2312_chinese_ci", Collation::gbk_bin => "gbk_bin", Collation::gbk_chinese_ci => "gbk_chinese_ci", Collation::geostd8_bin => "geostd8_bin", Collation::geostd8_general_ci => "geostd8_general_ci", Collation::greek_bin => "greek_bin", Collation::greek_general_ci => "greek_general_ci", Collation::hebrew_bin => "hebrew_bin", Collation::hebrew_general_ci => "hebrew_general_ci", Collation::hp8_bin => "hp8_bin", Collation::hp8_english_ci => "hp8_english_ci", Collation::keybcs2_bin => "keybcs2_bin", Collation::keybcs2_general_ci => "keybcs2_general_ci", Collation::koi8r_bin => "koi8r_bin", Collation::koi8r_general_ci => "koi8r_general_ci", Collation::koi8u_bin => "koi8u_bin", Collation::koi8u_general_ci => "koi8u_general_ci", Collation::latin1_bin => "latin1_bin", Collation::latin1_danish_ci => "latin1_danish_ci", Collation::latin1_general_ci => "latin1_general_ci", Collation::latin1_general_cs => "latin1_general_cs", Collation::latin1_german1_ci => "latin1_german1_ci", Collation::latin1_german2_ci => "latin1_german2_ci", Collation::latin1_spanish_ci => "latin1_spanish_ci", Collation::latin1_swedish_ci => "latin1_swedish_ci", Collation::latin2_bin => "latin2_bin", Collation::latin2_croatian_ci => "latin2_croatian_ci", Collation::latin2_czech_cs => "latin2_czech_cs", Collation::latin2_general_ci => "latin2_general_ci", Collation::latin2_hungarian_ci => "latin2_hungarian_ci", Collation::latin5_bin => "latin5_bin", Collation::latin5_turkish_ci => "latin5_turkish_ci", Collation::latin7_bin => "latin7_bin", Collation::latin7_estonian_cs => "latin7_estonian_cs", Collation::latin7_general_ci => "latin7_general_ci", Collation::latin7_general_cs => "latin7_general_cs", Collation::macce_bin => "macce_bin", Collation::macce_general_ci => "macce_general_ci", Collation::macroman_bin => "macroman_bin", Collation::macroman_general_ci => "macroman_general_ci", Collation::sjis_bin => "sjis_bin", Collation::sjis_japanese_ci => "sjis_japanese_ci", Collation::swe7_bin => "swe7_bin", Collation::swe7_swedish_ci => "swe7_swedish_ci", Collation::tis620_bin => "tis620_bin", Collation::tis620_thai_ci => "tis620_thai_ci", Collation::ucs2_bin => "ucs2_bin", Collation::ucs2_croatian_ci => "ucs2_croatian_ci", Collation::ucs2_czech_ci => "ucs2_czech_ci", Collation::ucs2_danish_ci => "ucs2_danish_ci", Collation::ucs2_esperanto_ci => "ucs2_esperanto_ci", Collation::ucs2_estonian_ci => "ucs2_estonian_ci", Collation::ucs2_general_ci => "ucs2_general_ci", Collation::ucs2_general_mysql500_ci => "ucs2_general_mysql500_ci", Collation::ucs2_german2_ci => "ucs2_german2_ci", Collation::ucs2_hungarian_ci => "ucs2_hungarian_ci", Collation::ucs2_icelandic_ci => "ucs2_icelandic_ci", Collation::ucs2_latvian_ci => "ucs2_latvian_ci", Collation::ucs2_lithuanian_ci => "ucs2_lithuanian_ci", Collation::ucs2_persian_ci => "ucs2_persian_ci", Collation::ucs2_polish_ci => "ucs2_polish_ci", Collation::ucs2_roman_ci => "ucs2_roman_ci", Collation::ucs2_romanian_ci => "ucs2_romanian_ci", Collation::ucs2_sinhala_ci => "ucs2_sinhala_ci", Collation::ucs2_slovak_ci => "ucs2_slovak_ci", Collation::ucs2_slovenian_ci => "ucs2_slovenian_ci", Collation::ucs2_spanish_ci => "ucs2_spanish_ci", Collation::ucs2_spanish2_ci => "ucs2_spanish2_ci", Collation::ucs2_swedish_ci => "ucs2_swedish_ci", Collation::ucs2_turkish_ci => "ucs2_turkish_ci", Collation::ucs2_unicode_520_ci => "ucs2_unicode_520_ci", Collation::ucs2_unicode_ci => "ucs2_unicode_ci", Collation::ucs2_vietnamese_ci => "ucs2_vietnamese_ci", Collation::ujis_bin => "ujis_bin", Collation::ujis_japanese_ci => "ujis_japanese_ci", Collation::utf16_bin => "utf16_bin", Collation::utf16_croatian_ci => "utf16_croatian_ci", Collation::utf16_czech_ci => "utf16_czech_ci", Collation::utf16_danish_ci => "utf16_danish_ci", Collation::utf16_esperanto_ci => "utf16_esperanto_ci", Collation::utf16_estonian_ci => "utf16_estonian_ci", Collation::utf16_general_ci => "utf16_general_ci", Collation::utf16_german2_ci => "utf16_german2_ci", Collation::utf16_hungarian_ci => "utf16_hungarian_ci", Collation::utf16_icelandic_ci => "utf16_icelandic_ci", Collation::utf16_latvian_ci => "utf16_latvian_ci", Collation::utf16_lithuanian_ci => "utf16_lithuanian_ci", Collation::utf16_persian_ci => "utf16_persian_ci", Collation::utf16_polish_ci => "utf16_polish_ci", Collation::utf16_roman_ci => "utf16_roman_ci", Collation::utf16_romanian_ci => "utf16_romanian_ci", Collation::utf16_sinhala_ci => "utf16_sinhala_ci", Collation::utf16_slovak_ci => "utf16_slovak_ci", Collation::utf16_slovenian_ci => "utf16_slovenian_ci", Collation::utf16_spanish_ci => "utf16_spanish_ci", Collation::utf16_spanish2_ci => "utf16_spanish2_ci", Collation::utf16_swedish_ci => "utf16_swedish_ci", Collation::utf16_turkish_ci => "utf16_turkish_ci", Collation::utf16_unicode_520_ci => "utf16_unicode_520_ci", Collation::utf16_unicode_ci => "utf16_unicode_ci", Collation::utf16_vietnamese_ci => "utf16_vietnamese_ci", Collation::utf16le_bin => "utf16le_bin", Collation::utf16le_general_ci => "utf16le_general_ci", Collation::utf32_bin => "utf32_bin", Collation::utf32_croatian_ci => "utf32_croatian_ci", Collation::utf32_czech_ci => "utf32_czech_ci", Collation::utf32_danish_ci => "utf32_danish_ci", Collation::utf32_esperanto_ci => "utf32_esperanto_ci", Collation::utf32_estonian_ci => "utf32_estonian_ci", Collation::utf32_general_ci => "utf32_general_ci", Collation::utf32_german2_ci => "utf32_german2_ci", Collation::utf32_hungarian_ci => "utf32_hungarian_ci", Collation::utf32_icelandic_ci => "utf32_icelandic_ci", Collation::utf32_latvian_ci => "utf32_latvian_ci", Collation::utf32_lithuanian_ci => "utf32_lithuanian_ci", Collation::utf32_persian_ci => "utf32_persian_ci", Collation::utf32_polish_ci => "utf32_polish_ci", Collation::utf32_roman_ci => "utf32_roman_ci", Collation::utf32_romanian_ci => "utf32_romanian_ci", Collation::utf32_sinhala_ci => "utf32_sinhala_ci", Collation::utf32_slovak_ci => "utf32_slovak_ci", Collation::utf32_slovenian_ci => "utf32_slovenian_ci", Collation::utf32_spanish_ci => "utf32_spanish_ci", Collation::utf32_spanish2_ci => "utf32_spanish2_ci", Collation::utf32_swedish_ci => "utf32_swedish_ci", Collation::utf32_turkish_ci => "utf32_turkish_ci", Collation::utf32_unicode_520_ci => "utf32_unicode_520_ci", Collation::utf32_unicode_ci => "utf32_unicode_ci", Collation::utf32_vietnamese_ci => "utf32_vietnamese_ci", Collation::utf8_bin => "utf8_bin", Collation::utf8_croatian_ci => "utf8_croatian_ci", Collation::utf8_czech_ci => "utf8_czech_ci", Collation::utf8_danish_ci => "utf8_danish_ci", Collation::utf8_esperanto_ci => "utf8_esperanto_ci", Collation::utf8_estonian_ci => "utf8_estonian_ci", Collation::utf8_general_ci => "utf8_general_ci", Collation::utf8_general_mysql500_ci => "utf8_general_mysql500_ci", Collation::utf8_german2_ci => "utf8_german2_ci", Collation::utf8_hungarian_ci => "utf8_hungarian_ci", Collation::utf8_icelandic_ci => "utf8_icelandic_ci", Collation::utf8_latvian_ci => "utf8_latvian_ci", Collation::utf8_lithuanian_ci => "utf8_lithuanian_ci", Collation::utf8_persian_ci => "utf8_persian_ci", Collation::utf8_polish_ci => "utf8_polish_ci", Collation::utf8_roman_ci => "utf8_roman_ci", Collation::utf8_romanian_ci => "utf8_romanian_ci", Collation::utf8_sinhala_ci => "utf8_sinhala_ci", Collation::utf8_slovak_ci => "utf8_slovak_ci", Collation::utf8_slovenian_ci => "utf8_slovenian_ci", Collation::utf8_spanish_ci => "utf8_spanish_ci", Collation::utf8_spanish2_ci => "utf8_spanish2_ci", Collation::utf8_swedish_ci => "utf8_swedish_ci", Collation::utf8_tolower_ci => "utf8_tolower_ci", Collation::utf8_turkish_ci => "utf8_turkish_ci", Collation::utf8_unicode_520_ci => "utf8_unicode_520_ci", Collation::utf8_unicode_ci => "utf8_unicode_ci", Collation::utf8_vietnamese_ci => "utf8_vietnamese_ci", Collation::utf8mb4_0900_ai_ci => "utf8mb4_0900_ai_ci", Collation::utf8mb4_bin => "utf8mb4_bin", Collation::utf8mb4_croatian_ci => "utf8mb4_croatian_ci", Collation::utf8mb4_czech_ci => "utf8mb4_czech_ci", Collation::utf8mb4_danish_ci => "utf8mb4_danish_ci", Collation::utf8mb4_esperanto_ci => "utf8mb4_esperanto_ci", Collation::utf8mb4_estonian_ci => "utf8mb4_estonian_ci", Collation::utf8mb4_general_ci => "utf8mb4_general_ci", Collation::utf8mb4_german2_ci => "utf8mb4_german2_ci", Collation::utf8mb4_hungarian_ci => "utf8mb4_hungarian_ci", Collation::utf8mb4_icelandic_ci => "utf8mb4_icelandic_ci", Collation::utf8mb4_latvian_ci => "utf8mb4_latvian_ci", Collation::utf8mb4_lithuanian_ci => "utf8mb4_lithuanian_ci", Collation::utf8mb4_persian_ci => "utf8mb4_persian_ci", Collation::utf8mb4_polish_ci => "utf8mb4_polish_ci", Collation::utf8mb4_roman_ci => "utf8mb4_roman_ci", Collation::utf8mb4_romanian_ci => "utf8mb4_romanian_ci", Collation::utf8mb4_sinhala_ci => "utf8mb4_sinhala_ci", Collation::utf8mb4_slovak_ci => "utf8mb4_slovak_ci", Collation::utf8mb4_slovenian_ci => "utf8mb4_slovenian_ci", Collation::utf8mb4_spanish_ci => "utf8mb4_spanish_ci", Collation::utf8mb4_spanish2_ci => "utf8mb4_spanish2_ci", Collation::utf8mb4_swedish_ci => "utf8mb4_swedish_ci", Collation::utf8mb4_turkish_ci => "utf8mb4_turkish_ci", Collation::utf8mb4_unicode_520_ci => "utf8mb4_unicode_520_ci", Collation::utf8mb4_unicode_ci => "utf8mb4_unicode_ci", Collation::utf8mb4_vietnamese_ci => "utf8mb4_vietnamese_ci", } } } // Handshake packet have only 1 byte for collation_id. // So we can't use collations with ID > 255. impl FromStr for Collation { type Err = Error; fn from_str(collation: &str) -> Result { Ok(match collation { "big5_chinese_ci" => Collation::big5_chinese_ci, "swe7_swedish_ci" => Collation::swe7_swedish_ci, "utf16_unicode_ci" => Collation::utf16_unicode_ci, "utf16_icelandic_ci" => Collation::utf16_icelandic_ci, "utf16_latvian_ci" => Collation::utf16_latvian_ci, "utf16_romanian_ci" => Collation::utf16_romanian_ci, "utf16_slovenian_ci" => Collation::utf16_slovenian_ci, "utf16_polish_ci" => Collation::utf16_polish_ci, "utf16_estonian_ci" => Collation::utf16_estonian_ci, "utf16_spanish_ci" => Collation::utf16_spanish_ci, "utf16_swedish_ci" => Collation::utf16_swedish_ci, "ascii_general_ci" => Collation::ascii_general_ci, "utf16_turkish_ci" => Collation::utf16_turkish_ci, "utf16_czech_ci" => Collation::utf16_czech_ci, "utf16_danish_ci" => Collation::utf16_danish_ci, "utf16_lithuanian_ci" => Collation::utf16_lithuanian_ci, "utf16_slovak_ci" => Collation::utf16_slovak_ci, "utf16_spanish2_ci" => Collation::utf16_spanish2_ci, "utf16_roman_ci" => Collation::utf16_roman_ci, "utf16_persian_ci" => Collation::utf16_persian_ci, "utf16_esperanto_ci" => Collation::utf16_esperanto_ci, "utf16_hungarian_ci" => Collation::utf16_hungarian_ci, "ujis_japanese_ci" => Collation::ujis_japanese_ci, "utf16_sinhala_ci" => Collation::utf16_sinhala_ci, "utf16_german2_ci" => Collation::utf16_german2_ci, "utf16_croatian_ci" => Collation::utf16_croatian_ci, "utf16_unicode_520_ci" => Collation::utf16_unicode_520_ci, "utf16_vietnamese_ci" => Collation::utf16_vietnamese_ci, "ucs2_unicode_ci" => Collation::ucs2_unicode_ci, "ucs2_icelandic_ci" => Collation::ucs2_icelandic_ci, "sjis_japanese_ci" => Collation::sjis_japanese_ci, "ucs2_latvian_ci" => Collation::ucs2_latvian_ci, "ucs2_romanian_ci" => Collation::ucs2_romanian_ci, "ucs2_slovenian_ci" => Collation::ucs2_slovenian_ci, "ucs2_polish_ci" => Collation::ucs2_polish_ci, "ucs2_estonian_ci" => Collation::ucs2_estonian_ci, "ucs2_spanish_ci" => Collation::ucs2_spanish_ci, "ucs2_swedish_ci" => Collation::ucs2_swedish_ci, "ucs2_turkish_ci" => Collation::ucs2_turkish_ci, "ucs2_czech_ci" => Collation::ucs2_czech_ci, "ucs2_danish_ci" => Collation::ucs2_danish_ci, "cp1251_bulgarian_ci" => Collation::cp1251_bulgarian_ci, "ucs2_lithuanian_ci" => Collation::ucs2_lithuanian_ci, "ucs2_slovak_ci" => Collation::ucs2_slovak_ci, "ucs2_spanish2_ci" => Collation::ucs2_spanish2_ci, "ucs2_roman_ci" => Collation::ucs2_roman_ci, "ucs2_persian_ci" => Collation::ucs2_persian_ci, "ucs2_esperanto_ci" => Collation::ucs2_esperanto_ci, "ucs2_hungarian_ci" => Collation::ucs2_hungarian_ci, "ucs2_sinhala_ci" => Collation::ucs2_sinhala_ci, "ucs2_german2_ci" => Collation::ucs2_german2_ci, "ucs2_croatian_ci" => Collation::ucs2_croatian_ci, "latin1_danish_ci" => Collation::latin1_danish_ci, "ucs2_unicode_520_ci" => Collation::ucs2_unicode_520_ci, "ucs2_vietnamese_ci" => Collation::ucs2_vietnamese_ci, "ucs2_general_mysql500_ci" => Collation::ucs2_general_mysql500_ci, "hebrew_general_ci" => Collation::hebrew_general_ci, "utf32_unicode_ci" => Collation::utf32_unicode_ci, "utf32_icelandic_ci" => Collation::utf32_icelandic_ci, "utf32_latvian_ci" => Collation::utf32_latvian_ci, "utf32_romanian_ci" => Collation::utf32_romanian_ci, "utf32_slovenian_ci" => Collation::utf32_slovenian_ci, "utf32_polish_ci" => Collation::utf32_polish_ci, "utf32_estonian_ci" => Collation::utf32_estonian_ci, "utf32_spanish_ci" => Collation::utf32_spanish_ci, "utf32_swedish_ci" => Collation::utf32_swedish_ci, "utf32_turkish_ci" => Collation::utf32_turkish_ci, "utf32_czech_ci" => Collation::utf32_czech_ci, "utf32_danish_ci" => Collation::utf32_danish_ci, "utf32_lithuanian_ci" => Collation::utf32_lithuanian_ci, "utf32_slovak_ci" => Collation::utf32_slovak_ci, "utf32_spanish2_ci" => Collation::utf32_spanish2_ci, "utf32_roman_ci" => Collation::utf32_roman_ci, "utf32_persian_ci" => Collation::utf32_persian_ci, "utf32_esperanto_ci" => Collation::utf32_esperanto_ci, "utf32_hungarian_ci" => Collation::utf32_hungarian_ci, "utf32_sinhala_ci" => Collation::utf32_sinhala_ci, "tis620_thai_ci" => Collation::tis620_thai_ci, "utf32_german2_ci" => Collation::utf32_german2_ci, "utf32_croatian_ci" => Collation::utf32_croatian_ci, "utf32_unicode_520_ci" => Collation::utf32_unicode_520_ci, "utf32_vietnamese_ci" => Collation::utf32_vietnamese_ci, "euckr_korean_ci" => Collation::euckr_korean_ci, "utf8_unicode_ci" => Collation::utf8_unicode_ci, "utf8_icelandic_ci" => Collation::utf8_icelandic_ci, "utf8_latvian_ci" => Collation::utf8_latvian_ci, "utf8_romanian_ci" => Collation::utf8_romanian_ci, "utf8_slovenian_ci" => Collation::utf8_slovenian_ci, "utf8_polish_ci" => Collation::utf8_polish_ci, "utf8_estonian_ci" => Collation::utf8_estonian_ci, "utf8_spanish_ci" => Collation::utf8_spanish_ci, "latin2_czech_cs" => Collation::latin2_czech_cs, "latin7_estonian_cs" => Collation::latin7_estonian_cs, "utf8_swedish_ci" => Collation::utf8_swedish_ci, "utf8_turkish_ci" => Collation::utf8_turkish_ci, "utf8_czech_ci" => Collation::utf8_czech_ci, "utf8_danish_ci" => Collation::utf8_danish_ci, "utf8_lithuanian_ci" => Collation::utf8_lithuanian_ci, "utf8_slovak_ci" => Collation::utf8_slovak_ci, "utf8_spanish2_ci" => Collation::utf8_spanish2_ci, "utf8_roman_ci" => Collation::utf8_roman_ci, "utf8_persian_ci" => Collation::utf8_persian_ci, "utf8_esperanto_ci" => Collation::utf8_esperanto_ci, "latin2_hungarian_ci" => Collation::latin2_hungarian_ci, "utf8_hungarian_ci" => Collation::utf8_hungarian_ci, "utf8_sinhala_ci" => Collation::utf8_sinhala_ci, "utf8_german2_ci" => Collation::utf8_german2_ci, "utf8_croatian_ci" => Collation::utf8_croatian_ci, "utf8_unicode_520_ci" => Collation::utf8_unicode_520_ci, "utf8_vietnamese_ci" => Collation::utf8_vietnamese_ci, "koi8u_general_ci" => Collation::koi8u_general_ci, "utf8_general_mysql500_ci" => Collation::utf8_general_mysql500_ci, "utf8mb4_unicode_ci" => Collation::utf8mb4_unicode_ci, "utf8mb4_icelandic_ci" => Collation::utf8mb4_icelandic_ci, "utf8mb4_latvian_ci" => Collation::utf8mb4_latvian_ci, "utf8mb4_romanian_ci" => Collation::utf8mb4_romanian_ci, "utf8mb4_slovenian_ci" => Collation::utf8mb4_slovenian_ci, "utf8mb4_polish_ci" => Collation::utf8mb4_polish_ci, "cp1251_ukrainian_ci" => Collation::cp1251_ukrainian_ci, "utf8mb4_estonian_ci" => Collation::utf8mb4_estonian_ci, "utf8mb4_spanish_ci" => Collation::utf8mb4_spanish_ci, "utf8mb4_swedish_ci" => Collation::utf8mb4_swedish_ci, "utf8mb4_turkish_ci" => Collation::utf8mb4_turkish_ci, "utf8mb4_czech_ci" => Collation::utf8mb4_czech_ci, "utf8mb4_danish_ci" => Collation::utf8mb4_danish_ci, "utf8mb4_lithuanian_ci" => Collation::utf8mb4_lithuanian_ci, "utf8mb4_slovak_ci" => Collation::utf8mb4_slovak_ci, "utf8mb4_spanish2_ci" => Collation::utf8mb4_spanish2_ci, "utf8mb4_roman_ci" => Collation::utf8mb4_roman_ci, "gb2312_chinese_ci" => Collation::gb2312_chinese_ci, "utf8mb4_persian_ci" => Collation::utf8mb4_persian_ci, "utf8mb4_esperanto_ci" => Collation::utf8mb4_esperanto_ci, "utf8mb4_hungarian_ci" => Collation::utf8mb4_hungarian_ci, "utf8mb4_sinhala_ci" => Collation::utf8mb4_sinhala_ci, "utf8mb4_german2_ci" => Collation::utf8mb4_german2_ci, "utf8mb4_croatian_ci" => Collation::utf8mb4_croatian_ci, "utf8mb4_unicode_520_ci" => Collation::utf8mb4_unicode_520_ci, "utf8mb4_vietnamese_ci" => Collation::utf8mb4_vietnamese_ci, "gb18030_chinese_ci" => Collation::gb18030_chinese_ci, "gb18030_bin" => Collation::gb18030_bin, "greek_general_ci" => Collation::greek_general_ci, "gb18030_unicode_520_ci" => Collation::gb18030_unicode_520_ci, "utf8mb4_0900_ai_ci" => Collation::utf8mb4_0900_ai_ci, "cp1250_general_ci" => Collation::cp1250_general_ci, "latin2_croatian_ci" => Collation::latin2_croatian_ci, "gbk_chinese_ci" => Collation::gbk_chinese_ci, "cp1257_lithuanian_ci" => Collation::cp1257_lithuanian_ci, "dec8_swedish_ci" => Collation::dec8_swedish_ci, "latin5_turkish_ci" => Collation::latin5_turkish_ci, "latin1_german2_ci" => Collation::latin1_german2_ci, "armscii8_general_ci" => Collation::armscii8_general_ci, "utf8_general_ci" => Collation::utf8_general_ci, "cp1250_czech_cs" => Collation::cp1250_czech_cs, "ucs2_general_ci" => Collation::ucs2_general_ci, "cp866_general_ci" => Collation::cp866_general_ci, "keybcs2_general_ci" => Collation::keybcs2_general_ci, "macce_general_ci" => Collation::macce_general_ci, "macroman_general_ci" => Collation::macroman_general_ci, "cp850_general_ci" => Collation::cp850_general_ci, "cp852_general_ci" => Collation::cp852_general_ci, "latin7_general_ci" => Collation::latin7_general_ci, "latin7_general_cs" => Collation::latin7_general_cs, "macce_bin" => Collation::macce_bin, "cp1250_croatian_ci" => Collation::cp1250_croatian_ci, "utf8mb4_general_ci" => Collation::utf8mb4_general_ci, "utf8mb4_bin" => Collation::utf8mb4_bin, "latin1_bin" => Collation::latin1_bin, "latin1_general_ci" => Collation::latin1_general_ci, "latin1_general_cs" => Collation::latin1_general_cs, "latin1_german1_ci" => Collation::latin1_german1_ci, "cp1251_bin" => Collation::cp1251_bin, "cp1251_general_ci" => Collation::cp1251_general_ci, "cp1251_general_cs" => Collation::cp1251_general_cs, "macroman_bin" => Collation::macroman_bin, "utf16_general_ci" => Collation::utf16_general_ci, "utf16_bin" => Collation::utf16_bin, "utf16le_general_ci" => Collation::utf16le_general_ci, "cp1256_general_ci" => Collation::cp1256_general_ci, "cp1257_bin" => Collation::cp1257_bin, "cp1257_general_ci" => Collation::cp1257_general_ci, "hp8_english_ci" => Collation::hp8_english_ci, "utf32_general_ci" => Collation::utf32_general_ci, "utf32_bin" => Collation::utf32_bin, "utf16le_bin" => Collation::utf16le_bin, "binary" => Collation::binary, "armscii8_bin" => Collation::armscii8_bin, "ascii_bin" => Collation::ascii_bin, "cp1250_bin" => Collation::cp1250_bin, "cp1256_bin" => Collation::cp1256_bin, "cp866_bin" => Collation::cp866_bin, "dec8_bin" => Collation::dec8_bin, "koi8r_general_ci" => Collation::koi8r_general_ci, "greek_bin" => Collation::greek_bin, "hebrew_bin" => Collation::hebrew_bin, "hp8_bin" => Collation::hp8_bin, "keybcs2_bin" => Collation::keybcs2_bin, "koi8r_bin" => Collation::koi8r_bin, "koi8u_bin" => Collation::koi8u_bin, "utf8_tolower_ci" => Collation::utf8_tolower_ci, "latin2_bin" => Collation::latin2_bin, "latin5_bin" => Collation::latin5_bin, "latin7_bin" => Collation::latin7_bin, "latin1_swedish_ci" => Collation::latin1_swedish_ci, "cp850_bin" => Collation::cp850_bin, "cp852_bin" => Collation::cp852_bin, "swe7_bin" => Collation::swe7_bin, "utf8_bin" => Collation::utf8_bin, "big5_bin" => Collation::big5_bin, "euckr_bin" => Collation::euckr_bin, "gb2312_bin" => Collation::gb2312_bin, "gbk_bin" => Collation::gbk_bin, "sjis_bin" => Collation::sjis_bin, "tis620_bin" => Collation::tis620_bin, "latin2_general_ci" => Collation::latin2_general_ci, "ucs2_bin" => Collation::ucs2_bin, "ujis_bin" => Collation::ujis_bin, "geostd8_general_ci" => Collation::geostd8_general_ci, "geostd8_bin" => Collation::geostd8_bin, "latin1_spanish_ci" => Collation::latin1_spanish_ci, "cp932_japanese_ci" => Collation::cp932_japanese_ci, "cp932_bin" => Collation::cp932_bin, "eucjpms_japanese_ci" => Collation::eucjpms_japanese_ci, "eucjpms_bin" => Collation::eucjpms_bin, "cp1250_polish_ci" => Collation::cp1250_polish_ci, _ => { return Err(Error::Configuration( format!("unsupported MySQL collation: {collation}").into(), )); } }) } } sqlx-mysql-0.8.3/src/column.rs000064400000000000000000000013131046102023000144160ustar 00000000000000use crate::ext::ustr::UStr; use crate::protocol::text::ColumnFlags; use crate::{MySql, MySqlTypeInfo}; pub(crate) use sqlx_core::column::*; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct MySqlColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: MySqlTypeInfo, #[cfg_attr(feature = "offline", serde(skip))] pub(crate) flags: Option, } impl Column for MySqlColumn { type Database = MySql; fn ordinal(&self) -> usize { self.ordinal } fn name(&self) -> &str { &self.name } fn type_info(&self) -> &MySqlTypeInfo { &self.type_info } } sqlx-mysql-0.8.3/src/connection/auth.rs000064400000000000000000000130601046102023000162230ustar 00000000000000use bytes::buf::Chain; use bytes::Bytes; use digest::{Digest, OutputSizeUser}; use generic_array::GenericArray; use rand::thread_rng; use rsa::{pkcs8::DecodePublicKey, Oaep, RsaPublicKey}; use sha1::Sha1; use sha2::Sha256; use crate::connection::stream::MySqlStream; use crate::error::Error; use crate::protocol::auth::AuthPlugin; use crate::protocol::Packet; impl AuthPlugin { pub(super) async fn scramble( self, stream: &mut MySqlStream, password: &str, nonce: &Chain, ) -> Result, Error> { match self { // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()), AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()), // https://mariadb.com/kb/en/sha256_password-plugin/ AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await, AuthPlugin::MySqlClearPassword => { let mut pw_bytes = password.as_bytes().to_owned(); pw_bytes.push(0); // null terminate Ok(pw_bytes) } } } pub(super) async fn handle( self, stream: &mut MySqlStream, packet: Packet, password: &str, nonce: &Chain, ) -> Result { match self { AuthPlugin::CachingSha2Password if packet[0] == 0x01 => { match packet[1] { // AUTH_OK 0x03 => Ok(true), // AUTH_CONTINUE 0x04 => { let payload = encrypt_rsa(stream, 0x02, password, nonce).await?; stream.write_packet(&*payload)?; stream.flush().await?; Ok(false) } v => { Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v)) } } } _ => Err(err_protocol!( "unexpected packet 0x{:02x} for auth plugin '{}' during authentication", packet[0], self.name() )), } } } fn scramble_sha1( password: &str, nonce: &Chain, ) -> GenericArray::OutputSize> { // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin let mut ctx = Sha1::new(); ctx.update(password); let mut pw_hash = ctx.finalize_reset(); ctx.update(pw_hash); let pw_hash_hash = ctx.finalize_reset(); ctx.update(nonce.first_ref()); ctx.update(nonce.last_ref()); ctx.update(pw_hash_hash); let pw_seed_hash_hash = ctx.finalize(); xor_eq(&mut pw_hash, &pw_seed_hash_hash); pw_hash } fn scramble_sha256( password: &str, nonce: &Chain, ) -> GenericArray::OutputSize> { // XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password)))) // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password let mut ctx = Sha256::new(); ctx.update(password); let mut pw_hash = ctx.finalize_reset(); ctx.update(pw_hash); let pw_hash_hash = ctx.finalize_reset(); ctx.update(nonce.first_ref()); ctx.update(nonce.last_ref()); ctx.update(pw_hash_hash); let pw_seed_hash_hash = ctx.finalize(); xor_eq(&mut pw_hash, &pw_seed_hash_hash); pw_hash } async fn encrypt_rsa<'s>( stream: &'s mut MySqlStream, public_key_request_id: u8, password: &'s str, nonce: &'s Chain, ) -> Result, Error> { // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ if stream.is_tls { // If in a TLS stream, send the password directly in clear text return Ok(to_asciz(password)); } // client sends a public key request stream.write_packet(&[public_key_request_id][..])?; stream.flush().await?; // server sends a public key response let packet = stream.recv_packet().await?; let rsa_pub_key = &packet[1..]; // xor the password with the given nonce let mut pass = to_asciz(password); let (a, b) = (nonce.first_ref(), nonce.last_ref()); let mut nonce = Vec::with_capacity(a.len() + b.len()); nonce.extend_from_slice(a); nonce.extend_from_slice(b); xor_eq(&mut pass, &nonce); // client sends an RSA encrypted password let pkey = parse_rsa_pub_key(rsa_pub_key)?; let padding = Oaep::new::(); pkey.encrypt(&mut thread_rng(), padding, &pass[..]) .map_err(Error::protocol) } // XOR(x, y) // If len(y) < len(x), wrap around inside y fn xor_eq(x: &mut [u8], y: &[u8]) { let y_len = y.len(); for i in 0..x.len() { x[i] ^= y[i % y_len]; } } fn to_asciz(s: &str) -> Vec { let mut z = String::with_capacity(s.len() + 1); z.push_str(s); z.push('\0'); z.into_bytes() } // https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1 fn parse_rsa_pub_key(key: &[u8]) -> Result { let pem = std::str::from_utf8(key).map_err(Error::protocol)?; // This takes advantage of the knowledge that we know // we are receiving a PKCS#8 RSA Public Key at all // times from MySQL RsaPublicKey::from_public_key_pem(pem).map_err(Error::protocol) } sqlx-mysql-0.8.3/src/connection/establish.rs000064400000000000000000000137601046102023000172470ustar 00000000000000use bytes::buf::Buf; use bytes::Bytes; use crate::collation::{CharSet, Collation}; use crate::common::StatementCache; use crate::connection::{tls, MySqlConnectionInner, MySqlStream, MAX_PACKET_SIZE}; use crate::error::Error; use crate::net::{Socket, WithSocket}; use crate::protocol::connect::{ AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse, }; use crate::protocol::Capabilities; use crate::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; impl MySqlConnection { pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { let do_handshake = DoHandshake::new(options)?; let handshake = match &options.socket { Some(path) => crate::net::connect_uds(path, do_handshake).await?, None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, }; let stream = handshake?; Ok(Self { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), }) } } struct DoHandshake<'a> { options: &'a MySqlConnectOptions, charset: CharSet, collation: Collation, } impl<'a> DoHandshake<'a> { fn new(options: &'a MySqlConnectOptions) -> Result { let charset: CharSet = options.charset.parse()?; let collation: Collation = options .collation .as_deref() .map(|collation| collation.parse()) .transpose()? .unwrap_or_else(|| charset.default_collation()); if options.enable_cleartext_plugin && matches!( options.ssl_mode, MySqlSslMode::Disabled | MySqlSslMode::Preferred ) { log::warn!("Security warning: sending cleartext passwords without requiring SSL"); } Ok(Self { options, charset, collation, }) } async fn do_handshake(self, socket: S) -> Result { let DoHandshake { options, charset, collation, } = self; let mut stream = MySqlStream::with_socket(charset, collation, options, socket); // https://dev.mysql.com/doc/internals/en/connection-phase.html // https://mariadb.com/kb/en/connection/ let handshake: Handshake = stream.recv_packet().await?.decode()?; let mut plugin = handshake.auth_plugin; let nonce = handshake.auth_plugin_data; // FIXME: server version parse is a bit ugly // expecting MAJOR.MINOR.PATCH let mut server_version = handshake.server_version.split('.'); let server_version_major: u16 = server_version .next() .unwrap_or_default() .parse() .unwrap_or(0); let server_version_minor: u16 = server_version .next() .unwrap_or_default() .parse() .unwrap_or(0); let server_version_patch: u16 = server_version .next() .unwrap_or_default() .parse() .unwrap_or(0); stream.server_version = ( server_version_major, server_version_minor, server_version_patch, ); stream.capabilities &= handshake.server_capabilities; stream.capabilities |= Capabilities::PROTOCOL_41; let mut stream = tls::maybe_upgrade(stream, self.options).await?; let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) { Some(plugin.scramble(&mut stream, password, &nonce).await?) } else { None }; stream.write_packet(HandshakeResponse { collation: stream.collation as u8, max_packet_size: MAX_PACKET_SIZE, username: &options.username, database: options.database.as_deref(), auth_plugin: plugin, auth_response: auth_response.as_deref(), })?; stream.flush().await?; loop { let packet = stream.recv_packet().await?; match packet[0] { 0x00 => { let _ok = packet.ok()?; break; } 0xfe => { let switch: AuthSwitchRequest = packet.decode_with(self.options.enable_cleartext_plugin)?; plugin = Some(switch.plugin); let nonce = switch.data.chain(Bytes::new()); let response = switch .plugin .scramble( &mut stream, options.password.as_deref().unwrap_or_default(), &nonce, ) .await?; stream.write_packet(AuthSwitchResponse(response))?; stream.flush().await?; } id => { if let (Some(plugin), Some(password)) = (plugin, &options.password) { if plugin.handle(&mut stream, packet, password, &nonce).await? { // plugin signaled authentication is ok break; } // plugin signaled to continue authentication } else { return Err(err_protocol!( "unexpected packet 0x{:02x} during authentication", id )); } } } } Ok(stream) } } impl<'a> WithSocket for DoHandshake<'a> { type Output = Result; async fn with_socket(self, socket: S) -> Self::Output { self.do_handshake(socket).await } } sqlx-mysql-0.8.3/src/connection/executor.rs000064400000000000000000000334531046102023000171300ustar 00000000000000use super::MySqlStream; use crate::connection::stream::Waiting; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::io::MySqlBufExt; use crate::logger::QueryLogger; use crate::protocol::response::Status; use crate::protocol::statement::{ BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose, }; use crate::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; use crate::statement::{MySqlStatement, MySqlStatementMetadata}; use crate::HashMap; use crate::{ MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlQueryResult, MySqlRow, MySqlTypeInfo, MySqlValueFormat, }; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::{borrow::Cow, sync::Arc}; impl MySqlConnection { async fn prepare_statement<'c>( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK self.inner .stream .send_packet(Prepare { query: sql }) .await?; let ok: PrepareOk = self.inner.stream.recv().await?; // the parameter definitions are very unreliable so we skip over them // as we have little use if ok.params > 0 { for _ in 0..ok.params { let _def: ColumnDefinition = self.inner.stream.recv().await?; } self.inner.stream.maybe_recv_eof().await?; } // the column definitions are berefit the type information from the // to-be-bound parameters; we will receive the output column definitions // once more on execute so we wait for that let mut columns = Vec::new(); let column_names = if ok.columns > 0 { recv_result_metadata(&mut self.inner.stream, ok.columns as usize, &mut columns).await? } else { Default::default() }; let id = ok.statement_id; let metadata = MySqlStatementMetadata { parameters: ok.params as usize, columns: Arc::new(columns), column_names: Arc::new(column_names), }; Ok((id, metadata)) } async fn get_or_prepare_statement<'c>( &mut self, sql: &str, ) -> Result<(u32, MySqlStatementMetadata), Error> { if let Some(statement) = self.inner.cache_statement.get_mut(sql) { // is internally reference-counted return Ok((*statement).clone()); } let (id, metadata) = self.prepare_statement(sql).await?; // in case of the cache being full, close the least recently used statement if let Some((id, _)) = self .inner .cache_statement .insert(sql, (id, metadata.clone())) { self.inner .stream .send_packet(StmtClose { statement: id }) .await?; } Ok((id, metadata)) } #[allow(clippy::needless_lifetimes)] pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, sql: &'q str, arguments: Option, persistent: bool, ) -> Result, Error>> + 'e, Error> { let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); self.inner.stream.wait_until_ready().await?; self.inner.stream.waiting.push_back(Waiting::Result); Ok(Box::pin(try_stream! { // make a slot for the shared column data // as long as a reference to a row is not held past one iteration, this enables us // to re-use this memory freely between result sets let mut columns = Arc::new(Vec::new()); let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments { if persistent && self.inner.cache_statement.is_enabled() { let (id, metadata) = self .get_or_prepare_statement(sql) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html self.inner.stream .send_packet(StatementExecute { statement: id, arguments: &arguments, }) .await?; (metadata.column_names, MySqlValueFormat::Binary, false) } else { let (id, metadata) = self .prepare_statement(sql) .await?; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html self.inner.stream .send_packet(StatementExecute { statement: id, arguments: &arguments, }) .await?; self.inner.stream.send_packet(StmtClose { statement: id }).await?; (metadata.column_names, MySqlValueFormat::Binary, false) } } else { // https://dev.mysql.com/doc/internals/en/com-query.html self.inner.stream.send_packet(Query(sql)).await?; (Arc::default(), MySqlValueFormat::Text, true) }; loop { // query response is a meta-packet which may be one of: // Ok, Err, ResultSet, or (unhandled) LocalInfileRequest let mut packet = self.inner.stream.recv_packet().await?; if packet[0] == 0x00 || packet[0] == 0xff { // first packet in a query response is OK or ERR // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { rows_affected, last_insert_id: ok.last_insert_id, }; r#yield!(Either::Left(done)); if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one continue; } self.inner.stream.waiting.pop_front(); return Ok(()); } // otherwise, this first packet is the start of the result-set metadata, *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row; let num_columns = packet.get_uint_lenenc(); // column count let num_columns = usize::try_from(num_columns) .map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?; if needs_metadata { column_names = Arc::new(recv_result_metadata(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?); } else { // next time we hit here, it'll be a new result set and we'll need the // full metadata needs_metadata = true; recv_result_columns(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?; } // finally, there will be none or many result-rows loop { let packet = self.inner.stream.recv_packet().await?; if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0, })); if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Result; break; } self.inner.stream.waiting.pop_front(); return Ok(()); } let row = match format { MySqlValueFormat::Binary => packet.decode_with::(&columns)?.0, MySqlValueFormat::Text => packet.decode_with::(&columns)?.0, }; let v = Either::Right(MySqlRow { row, format, columns: Arc::clone(&columns), column_names: Arc::clone(&column_names), }); logger.increment_rows_returned(); r#yield!(v); } } })) } } impl<'c> Executor<'c> for &'c mut MySqlConnection { type Database = MySql; fn fetch_many<'e, 'q, E>( self, mut query: E, ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, 'q: 'e, E: 'q, { let sql = query.sql(); let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { let arguments = arguments?; let s = self.run(sql, arguments, persistent).await?; pin_mut!(s); while let Some(v) = s.try_next().await? { r#yield!(v); } Ok(()) }) } fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, 'q: 'e, E: 'q, { let mut s = self.fetch_many(query); Box::pin(async move { while let Some(v) = s.try_next().await? { if let Either::Right(r) = v { return Ok(Some(r)); } } Ok(None) }) } fn prepare_with<'e, 'q: 'e>( self, sql: &'q str, _parameters: &'e [MySqlTypeInfo], ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.inner.stream.wait_until_ready().await?; let metadata = if self.inner.cache_statement.is_enabled() { self.get_or_prepare_statement(sql).await?.1 } else { let (id, metadata) = self.prepare_statement(sql).await?; self.inner .stream .send_packet(StmtClose { statement: id }) .await?; metadata }; Ok(MySqlStatement { sql: Cow::Borrowed(sql), // metadata has internal Arcs for expensive data structures metadata: metadata.clone(), }) }) } #[doc(hidden)] fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.inner.stream.wait_until_ready().await?; let (id, metadata) = self.prepare_statement(sql).await?; self.inner .stream .send_packet(StmtClose { statement: id }) .await?; let columns = (*metadata.columns).clone(); let nullable = columns .iter() .map(|col| { col.flags .map(|flags| !flags.contains(ColumnFlags::NOT_NULL)) }) .collect(); Ok(Describe { parameters: Some(Either::Right(metadata.parameters)), columns, nullable, }) }) } } async fn recv_result_columns( stream: &mut MySqlStream, num_columns: usize, columns: &mut Vec, ) -> Result<(), Error> { columns.clear(); columns.reserve(num_columns); for ordinal in 0..num_columns { columns.push(recv_next_result_column(&stream.recv().await?, ordinal)?); } if num_columns > 0 { stream.maybe_recv_eof().await?; } Ok(()) } fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result { // if the alias is empty, use the alias // only then use the name let name = match (def.name()?, def.alias()?) { (_, alias) if !alias.is_empty() => UStr::new(alias), (name, _) => UStr::new(name), }; let type_info = MySqlTypeInfo::from_column(def); Ok(MySqlColumn { name, type_info, ordinal, flags: Some(def.flags), }) } async fn recv_result_metadata( stream: &mut MySqlStream, num_columns: usize, columns: &mut Vec, ) -> Result, Error> { // the result-set metadata is primarily a listing of each output // column in the result-set let mut column_names = HashMap::with_capacity(num_columns); columns.clear(); columns.reserve(num_columns); for ordinal in 0..num_columns { let def: ColumnDefinition = stream.recv().await?; let column = recv_next_result_column(&def, ordinal)?; column_names.insert(column.name.clone(), ordinal); columns.push(column); } stream.maybe_recv_eof().await?; Ok(column_names) } sqlx-mysql-0.8.3/src/connection/mod.rs000064400000000000000000000060771046102023000160530ustar 00000000000000use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; use futures_util::FutureExt; pub(crate) use sqlx_core::connection::*; pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; use crate::transaction::Transaction; use crate::{MySql, MySqlConnectOptions}; mod auth; mod establish; mod executor; mod stream; mod tls; const MAX_PACKET_SIZE: u32 = 1024; /// A connection to a MySQL database. pub struct MySqlConnection { pub(crate) inner: Box, } pub(crate) struct MySqlConnectionInner { // underlying TCP stream, // wrapped in a potentially TLS stream, // wrapped in a buffered stream pub(crate) stream: MySqlStream, // transaction status pub(crate) transaction_depth: usize, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, log_settings: LogSettings, } impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() } } impl Connection for MySqlConnection { type Database = MySql; type Options = MySqlConnectOptions; fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { Box::pin(async move { self.inner.stream.send_packet(Quit).await?; self.inner.stream.shutdown().await?; Ok(()) }) } fn close_hard(mut self) -> BoxFuture<'static, Result<(), Error>> { Box::pin(async move { self.inner.stream.shutdown().await?; Ok(()) }) } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { self.inner.stream.wait_until_ready().await?; self.inner.stream.send_packet(Ping).await?; self.inner.stream.recv_ok().await?; Ok(()) }) } #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.inner.stream.wait_until_ready().boxed() } fn cached_statements_size(&self) -> usize { self.inner.cache_statement.len() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { while let Some((statement_id, _)) = self.inner.cache_statement.remove_lru() { self.inner .stream .send_packet(StmtClose { statement: statement_id, }) .await?; } Ok(()) }) } #[doc(hidden)] fn should_flush(&self) -> bool { !self.inner.stream.write_buffer().is_empty() } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> where Self: Sized, { Transaction::begin(self) } fn shrink_buffers(&mut self) { self.inner.stream.shrink_buffers(); } } sqlx-mysql-0.8.3/src/connection/stream.rs000064400000000000000000000167071046102023000165700ustar 00000000000000use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes, BytesMut}; use crate::collation::{CharSet, Collation}; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; use crate::{MySqlConnectOptions, MySqlDatabaseError}; pub struct MySqlStream> { // Wrapping the socket in `Box` allows us to unsize in-place. pub(crate) socket: BufferedSocket, pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, pub(crate) waiting: VecDeque, pub(crate) charset: CharSet, pub(crate) collation: Collation, pub(crate) is_tls: bool, } #[derive(Debug, PartialEq, Eq)] pub(crate) enum Waiting { // waiting for a result set Result, // waiting for a row within a result set Row, } impl MySqlStream { pub(crate) fn with_socket( charset: CharSet, collation: Collation, options: &MySqlConnectOptions, socket: S, ) -> Self { let mut capabilities = Capabilities::PROTOCOL_41 | Capabilities::IGNORE_SPACE | Capabilities::DEPRECATE_EOF | Capabilities::FOUND_ROWS | Capabilities::TRANSACTIONS | Capabilities::SECURE_CONNECTION | Capabilities::PLUGIN_AUTH_LENENC_DATA | Capabilities::MULTI_STATEMENTS | Capabilities::MULTI_RESULTS | Capabilities::PLUGIN_AUTH | Capabilities::PS_MULTI_RESULTS | Capabilities::SSL; if options.database.is_some() { capabilities |= Capabilities::CONNECT_WITH_DB; } Self { waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, collation, charset, socket: BufferedSocket::new(socket), is_tls: false, } } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.socket.write_buffer().is_empty() { self.socket.flush().await?; } while !self.waiting.is_empty() { while self.waiting.front() == Some(&Waiting::Row) { let packet = self.recv_packet().await?; if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.capabilities)?; if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { *self.waiting.front_mut().unwrap() = Waiting::Result; } else { self.waiting.pop_front(); }; } } while self.waiting.front() == Some(&Waiting::Result) { let packet = self.recv_packet().await?; if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) { let ok = packet.ok()?; if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { self.waiting.pop_front(); } } else { *self.waiting.front_mut().unwrap() = Waiting::Row; self.skip_result_metadata(packet).await?; } } } Ok(()) } pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, { self.sequence_id = 0; self.write_packet(payload)?; self.flush().await?; Ok(()) } pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, { self.socket .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } async fn recv_packet_part(&mut self) -> Result { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet let mut header: Bytes = self.socket.read(4).await?; // cannot overflow #[allow(clippy::cast_possible_truncation)] let packet_size = header.get_uint_le(3) as usize; let sequence_id = header.get_u8(); self.sequence_id = sequence_id.wrapping_add(1); let payload: Bytes = self.socket.read(packet_size).await?; // TODO: packet compression Ok(payload) } // receive the next packet from the database server // may block (async) on more data from the server pub(crate) async fn recv_packet(&mut self) -> Result, Error> { let payload = self.recv_packet_part().await?; let payload = if payload.len() < 0xFF_FF_FF { payload } else { let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2); final_payload.extend_from_slice(&payload); drop(payload); // we don't need the allocation anymore let mut last_read = 0xFF_FF_FF; while last_read == 0xFF_FF_FF { let part = self.recv_packet_part().await?; last_read = part.len(); final_payload.extend_from_slice(&part); } final_payload.into() }; if payload .first() .ok_or(err_protocol!("Packet empty"))? .eq(&0xff) { self.waiting.pop_front(); // instead of letting this packet be looked at everywhere, we check here // and emit a proper Error return Err( MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(), ); } Ok(Packet(payload)) } pub(crate) async fn recv<'de, T>(&mut self) -> Result where T: ProtocolDecode<'de, Capabilities>, { self.recv_packet().await?.decode_with(self.capabilities) } pub(crate) async fn recv_ok(&mut self) -> Result { self.recv_packet().await?.ok() } pub(crate) async fn maybe_recv_eof(&mut self) -> Result, Error> { if self.capabilities.contains(Capabilities::DEPRECATE_EOF) { Ok(None) } else { self.recv().await.map(Some) } } async fn skip_result_metadata(&mut self, mut packet: Packet) -> Result<(), Error> { let num_columns: u64 = packet.get_uint_lenenc(); // column count for _ in 0..num_columns { let _ = self.recv_packet().await?; } self.maybe_recv_eof().await?; Ok(()) } pub fn boxed_socket(self) -> MySqlStream { MySqlStream { socket: self.socket.boxed(), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, waiting: self.waiting, charset: self.charset, collation: self.collation, is_tls: self.is_tls, } } } impl Deref for MySqlStream { type Target = BufferedSocket; fn deref(&self) -> &Self::Target { &self.socket } } impl DerefMut for MySqlStream { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.socket } } sqlx-mysql-0.8.3/src/connection/tls.rs000064400000000000000000000070101046102023000160620ustar 00000000000000use crate::collation::{CharSet, Collation}; use crate::connection::{MySqlStream, Waiting}; use crate::error::Error; use crate::net::tls::TlsConfig; use crate::net::{tls, BufferedSocket, Socket, WithSocket}; use crate::protocol::connect::SslRequest; use crate::protocol::Capabilities; use crate::{MySqlConnectOptions, MySqlSslMode}; use std::collections::VecDeque; struct MapStream { server_version: (u16, u16, u16), capabilities: Capabilities, sequence_id: u8, waiting: VecDeque, charset: CharSet, collation: Collation, } pub(super) async fn maybe_upgrade( mut stream: MySqlStream, options: &MySqlConnectOptions, ) -> Result { let server_supports_tls = stream.capabilities.contains(Capabilities::SSL); if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() { // remove the SSL capability if SSL has been explicitly disabled stream.capabilities.remove(Capabilities::SSL); } // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS match options.ssl_mode { MySqlSslMode::Disabled => return Ok(stream.boxed_socket()), MySqlSslMode::Preferred => { if !tls::available() { // Client doesn't support TLS tracing::debug!("not performing TLS upgrade: TLS support not compiled in"); return Ok(stream.boxed_socket()); } if !server_supports_tls { // Server doesn't support TLS tracing::debug!("not performing TLS upgrade: unsupported by server"); return Ok(stream.boxed_socket()); } } MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => { tls::error_if_unavailable()?; if !server_supports_tls { // upgrade failed, die return Err(Error::Tls("server does not support TLS".into())); } } } let tls_config = TlsConfig { accept_invalid_certs: !matches!( options.ssl_mode, MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity ), accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity), hostname: &options.host, root_cert_path: options.ssl_ca.as_ref(), client_cert_path: options.ssl_client_cert.as_ref(), client_key_path: options.ssl_client_key.as_ref(), }; // Request TLS upgrade stream.write_packet(SslRequest { max_packet_size: super::MAX_PACKET_SIZE, collation: stream.collation as u8, })?; stream.flush().await?; tls::handshake( stream.socket.into_inner(), tls_config, MapStream { server_version: stream.server_version, capabilities: stream.capabilities, sequence_id: stream.sequence_id, waiting: stream.waiting, charset: stream.charset, collation: stream.collation, }, ) .await } impl WithSocket for MapStream { type Output = MySqlStream; async fn with_socket(self, socket: S) -> Self::Output { MySqlStream { socket: BufferedSocket::new(Box::new(socket)), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, waiting: self.waiting, charset: self.charset, collation: self.collation, is_tls: true, } } } sqlx-mysql-0.8.3/src/database.rs000064400000000000000000000016561046102023000146770ustar 00000000000000use crate::value::{MySqlValue, MySqlValueRef}; use crate::{ MySqlArguments, MySqlColumn, MySqlConnection, MySqlQueryResult, MySqlRow, MySqlStatement, MySqlTransactionManager, MySqlTypeInfo, }; pub(crate) use sqlx_core::database::{Database, HasStatementCache}; /// MySQL database driver. #[derive(Debug)] pub struct MySql; impl Database for MySql { type Connection = MySqlConnection; type TransactionManager = MySqlTransactionManager; type Row = MySqlRow; type QueryResult = MySqlQueryResult; type Column = MySqlColumn; type TypeInfo = MySqlTypeInfo; type Value = MySqlValue; type ValueRef<'r> = MySqlValueRef<'r>; type Arguments<'q> = MySqlArguments; type ArgumentBuffer<'q> = Vec; type Statement<'q> = MySqlStatement<'q>; const NAME: &'static str = "MySQL"; const URL_SCHEMES: &'static [&'static str] = &["mysql", "mariadb"]; } impl HasStatementCache for MySql {} sqlx-mysql-0.8.3/src/error.rs000064400000000000000000000154641046102023000142660ustar 00000000000000use std::error::Error as StdError; use std::fmt::{self, Debug, Display, Formatter}; use crate::protocol::response::ErrPacket; use std::borrow::Cow; pub(crate) use sqlx_core::error::*; /// An error returned from the MySQL database. pub struct MySqlDatabaseError(pub(super) ErrPacket); impl MySqlDatabaseError { /// The [SQLSTATE](https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html) code for this error. pub fn code(&self) -> Option<&str> { self.0.sql_state.as_deref() } /// The [number](https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html) /// for this error. /// /// MySQL tends to use SQLSTATE as a general error category, and the error number as a more /// granular indication of the error. pub fn number(&self) -> u16 { self.0.error_code } /// The human-readable error message. pub fn message(&self) -> &str { &self.0.error_message } } impl Debug for MySqlDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlDatabaseError") .field("code", &self.code()) .field("number", &self.number()) .field("message", &self.message()) .finish() } } impl Display for MySqlDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if let Some(code) = &self.code() { write!(f, "{} ({}): {}", self.number(), code, self.message()) } else { write!(f, "{}: {}", self.number(), self.message()) } } } impl StdError for MySqlDatabaseError {} impl DatabaseError for MySqlDatabaseError { #[inline] fn message(&self) -> &str { self.message() } #[inline] fn code(&self) -> Option> { self.code().map(Cow::Borrowed) } #[doc(hidden)] fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) { self } #[doc(hidden)] fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { self } #[doc(hidden)] fn into_error(self: Box) -> Box { self } fn kind(&self) -> ErrorKind { match self.number() { error_codes::ER_DUP_KEY | error_codes::ER_DUP_ENTRY | error_codes::ER_DUP_UNIQUE | error_codes::ER_DUP_ENTRY_WITH_KEY_NAME | error_codes::ER_DUP_UNKNOWN_IN_INDEX => ErrorKind::UniqueViolation, error_codes::ER_NO_REFERENCED_ROW | error_codes::ER_NO_REFERENCED_ROW_2 | error_codes::ER_ROW_IS_REFERENCED | error_codes::ER_ROW_IS_REFERENCED_2 | error_codes::ER_FK_COLUMN_NOT_NULL | error_codes::ER_FK_CANNOT_DELETE_PARENT => ErrorKind::ForeignKeyViolation, error_codes::ER_BAD_NULL_ERROR | error_codes::ER_NO_DEFAULT_FOR_FIELD => { ErrorKind::NotNullViolation } error_codes::ER_CHECK_CONSTRAINT_VIOLATED => ErrorKind::CheckViolation, // https://mariadb.com/kb/en/e4025/ error_codes::mariadb::ER_CONSTRAINT_FAILED // MySQL uses this code for a completely different error, // but we can differentiate by SQLSTATE: // { ErrorKind::CheckViolation } _ => ErrorKind::Other, } } } /// The MySQL server uses SQLSTATEs as a generic error category, /// and returns a `error_code` instead within the error packet. /// /// For reference: . pub(crate) mod error_codes { /// Caused when a DDL operation creates duplicated keys. pub const ER_DUP_KEY: u16 = 1022; /// Caused when a DML operation tries create a duplicated entry for a key, /// be it a unique or primary one. pub const ER_DUP_ENTRY: u16 = 1062; /// Similar to `ER_DUP_ENTRY`, but only present in NDB clusters. /// /// See: . pub const ER_DUP_UNIQUE: u16 = 1169; /// Similar to `ER_DUP_ENTRY`, but with a formatted string message. /// /// See: . pub const ER_DUP_ENTRY_WITH_KEY_NAME: u16 = 1586; /// Caused when a DDL operation to add a unique index fails, /// because duplicate items were created by concurrent DML operations. /// When this happens, the key is unknown, so the server can't use `ER_DUP_KEY`. /// /// For example: an `INSERT` operation creates duplicate `name` fields when `ALTER`ing a table and making `name` unique. pub const ER_DUP_UNKNOWN_IN_INDEX: u16 = 1859; /// Caused when inserting an entry with a column with a value that does not reference a foreign row. pub const ER_NO_REFERENCED_ROW: u16 = 1216; /// Caused when deleting a row that is referenced in other tables. pub const ER_ROW_IS_REFERENCED: u16 = 1217; /// Caused when deleting a row that is referenced in other tables. /// This differs from `ER_ROW_IS_REFERENCED` in that the error message contains the affected constraint. pub const ER_ROW_IS_REFERENCED_2: u16 = 1451; /// Caused when inserting an entry with a column with a value that does not reference a foreign row. /// This differs from `ER_NO_REFERENCED_ROW` in that the error message contains the affected constraint. pub const ER_NO_REFERENCED_ROW_2: u16 = 1452; /// Caused when creating a FK with `ON DELETE SET NULL` or `ON UPDATE SET NULL` to a column that is `NOT NULL`, or vice-versa. pub const ER_FK_COLUMN_NOT_NULL: u16 = 1830; /// Removed in 5.7.3. pub const ER_FK_CANNOT_DELETE_PARENT: u16 = 1834; /// Caused when inserting a NULL value to a column marked as NOT NULL. pub const ER_BAD_NULL_ERROR: u16 = 1048; /// Caused when inserting a DEFAULT value to a column marked as NOT NULL, which also doesn't have a default value set. pub const ER_NO_DEFAULT_FOR_FIELD: u16 = 1364; /// Caused when a check constraint is violated. /// /// Only available after 8.0.16. pub const ER_CHECK_CONSTRAINT_VIOLATED: u16 = 3819; pub(crate) mod mariadb { /// Error code emitted by MariaDB for constraint errors: /// /// MySQL emits this code for a completely different error: /// /// /// You also check that SQLSTATE is `23000`. pub const ER_CONSTRAINT_FAILED: u16 = 4025; } } sqlx-mysql-0.8.3/src/io/buf.rs000064400000000000000000000027131046102023000143110ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::BufExt; pub trait MySqlBufExt: Buf { // Read a length-encoded integer. // NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL. // NOTE: 0xff is only returned during a result set to indicate ERR. // fn get_uint_lenenc(&mut self) -> u64; // Read a length-encoded string. #[allow(dead_code)] fn get_str_lenenc(&mut self) -> Result; // Read a length-encoded byte sequence. fn get_bytes_lenenc(&mut self) -> Result; } impl MySqlBufExt for Bytes { fn get_uint_lenenc(&mut self) -> u64 { match self.get_u8() { 0xfc => u64::from(self.get_u16_le()), 0xfd => self.get_uint_le(3), 0xfe => self.get_u64_le(), v => u64::from(v), } } fn get_str_lenenc(&mut self) -> Result { let size = self.get_uint_lenenc(); let size = usize::try_from(size) .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; self.get_str(size) } fn get_bytes_lenenc(&mut self) -> Result { let size = self.get_uint_lenenc(); let size = usize::try_from(size) .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; Ok(self.split_to(size)) } } sqlx-mysql-0.8.3/src/io/buf_mut.rs000064400000000000000000000060461046102023000152010ustar 00000000000000use bytes::BufMut; pub trait MySqlBufMutExt: BufMut { fn put_uint_lenenc(&mut self, v: u64); fn put_str_lenenc(&mut self, v: &str); fn put_bytes_lenenc(&mut self, v: &[u8]); } impl MySqlBufMutExt for Vec { fn put_uint_lenenc(&mut self, v: u64) { // https://dev.mysql.com/doc/internals/en/integer.html // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers let encoded_le = v.to_le_bytes(); match v { 0..=250 => self.push(encoded_le[0]), 251..=0xFF_FF => { self.push(0xfc); self.extend_from_slice(&encoded_le[..2]); } 0x1_00_00..=0xFF_FF_FF => { self.push(0xfd); self.extend_from_slice(&encoded_le[..3]); } _ => { self.push(0xfe); self.extend_from_slice(&encoded_le); } } } fn put_str_lenenc(&mut self, v: &str) { self.put_bytes_lenenc(v.as_bytes()); } fn put_bytes_lenenc(&mut self, v: &[u8]) { self.put_uint_lenenc(v.len() as u64); self.extend(v); } } #[test] fn test_encodes_int_lenenc_u8() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFA as u64); assert_eq!(&buf[..], b"\xFA"); } #[test] fn test_encodes_int_lenenc_u16() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(std::u16::MAX as u64); assert_eq!(&buf[..], b"\xFC\xFF\xFF"); } #[test] fn test_encodes_int_lenenc_u24() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFF_FF_FF as u64); assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); } #[test] fn test_encodes_int_lenenc_u64() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(std::u64::MAX); assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); } #[test] fn test_encodes_int_lenenc_fb() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFB as u64); assert_eq!(&buf[..], b"\xFC\xFB\x00"); } #[test] fn test_encodes_int_lenenc_fc() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFC as u64); assert_eq!(&buf[..], b"\xFC\xFC\x00"); } #[test] fn test_encodes_int_lenenc_fd() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFD as u64); assert_eq!(&buf[..], b"\xFC\xFD\x00"); } #[test] fn test_encodes_int_lenenc_fe() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFE as u64); assert_eq!(&buf[..], b"\xFC\xFE\x00"); } #[test] fn test_encodes_int_lenenc_ff() { let mut buf = Vec::with_capacity(1024); buf.put_uint_lenenc(0xFF as u64); assert_eq!(&buf[..], b"\xFC\xFF\x00"); } #[test] fn test_encodes_string_lenenc() { let mut buf = Vec::with_capacity(1024); buf.put_str_lenenc("random_string"); assert_eq!(&buf[..], b"\x0Drandom_string"); } #[test] fn test_encodes_byte_lenenc() { let mut buf = Vec::with_capacity(1024); buf.put_bytes_lenenc(b"random_string"); assert_eq!(&buf[..], b"\x0Drandom_string"); } sqlx-mysql-0.8.3/src/io/mod.rs000064400000000000000000000001641046102023000143120ustar 00000000000000mod buf; mod buf_mut; pub use buf::MySqlBufExt; pub use buf_mut::MySqlBufMutExt; pub(crate) use sqlx_core::io::*; sqlx-mysql-0.8.3/src/lib.rs000064400000000000000000000037711046102023000137010ustar 00000000000000//! **MySQL** database driver. #![deny(clippy::cast_possible_truncation)] #![deny(clippy::cast_possible_wrap)] #![deny(clippy::cast_sign_loss)] #[macro_use] extern crate sqlx_core; use crate::executor::Executor; pub(crate) use sqlx_core::driver_prelude::*; #[cfg(feature = "any")] pub mod any; mod arguments; mod collation; mod column; mod connection; mod database; mod error; mod io; mod options; mod protocol; mod query_result; mod row; mod statement; mod transaction; mod type_checking; mod type_info; pub mod types; mod value; #[cfg(feature = "migrate")] mod migrate; #[cfg(feature = "migrate")] mod testing; pub use arguments::MySqlArguments; pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; pub use options::{MySqlConnectOptions, MySqlSslMode}; pub use query_result::MySqlQueryResult; pub use row::MySqlRow; pub use statement::MySqlStatement; pub use transaction::MySqlTransactionManager; pub use type_info::MySqlTypeInfo; pub use value::{MySqlValue, MySqlValueFormat, MySqlValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for MySQL. pub type MySqlPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MySQL. pub type MySqlPoolOptions = crate::pool::PoolOptions; /// An alias for [`Executor<'_, Database = MySql>`][Executor]. pub trait MySqlExecutor<'c>: Executor<'c, Database = MySql> {} impl<'c, T: Executor<'c, Database = MySql>> MySqlExecutor<'c> for T {} /// An alias for [`Transaction`][crate::transaction::Transaction], specialized for MySQL. pub type MySqlTransaction<'c> = crate::transaction::Transaction<'c, MySql>; // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(MySqlArguments); impl_acquire!(MySql, MySqlConnection); impl_column_index_for_row!(MySqlRow); impl_column_index_for_statement!(MySqlStatement); // required because some databases have a different handling of NULL impl_encode_for_option!(MySql); sqlx-mysql-0.8.3/src/migrate.rs000064400000000000000000000233331046102023000145570ustar 00000000000000use std::str::FromStr; use std::time::Duration; use std::time::Instant; use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::*; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::{MySql, MySqlConnectOptions, MySqlConnection}; fn parse_for_maintenance(url: &str) -> Result<(MySqlConnectOptions, String), Error> { let mut options = MySqlConnectOptions::from_str(url)?; let database = if let Some(database) = &options.database { database.to_owned() } else { return Err(Error::Configuration( "DATABASE_URL does not specify a database".into(), )); }; // switch us to database for create/drop commands options.database = None; Ok((options, database)) } impl MigrateDatabase for MySql { fn create_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let _ = conn .execute(&*format!("CREATE DATABASE `{database}`")) .await?; Ok(()) }) } fn database_exists(url: &str) -> BoxFuture<'_, Result> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let exists: bool = query_scalar( "select exists(SELECT 1 from INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?)", ) .bind(database) .fetch_one(&mut conn) .await?; Ok(exists) }) } fn drop_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let _ = conn .execute(&*format!("DROP DATABASE IF EXISTS `{database}`")) .await?; Ok(()) }) } } impl Migrate for MySqlConnection { fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { // language=MySQL self.execute( r#" CREATE TABLE IF NOT EXISTS _sqlx_migrations ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, success BOOLEAN NOT NULL, checksum BLOB NOT NULL, execution_time BIGINT NOT NULL ); "#, ) .await?; Ok(()) }) } fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL let row: Option<(i64,)> = query_as( "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", ) .fetch_optional(self) .await?; Ok(row.map(|r| r.0)) }) } fn list_applied_migrations( &mut self, ) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL let rows: Vec<(i64, Vec)> = query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") .fetch_all(self) .await?; let migrations = rows .into_iter() .map(|(version, checksum)| AppliedMigration { version, checksum: checksum.into(), }) .collect(); Ok(migrations) }) } fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { let database_name = current_database(self).await?; let lock_id = generate_lock_id(&database_name); // create an application lock over the database // this function will not return until the lock is acquired // https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS // https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS-TABLE // language=MySQL let _ = query("SELECT GET_LOCK(?, -1)") .bind(lock_id) .execute(self) .await?; Ok(()) }) } fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { let database_name = current_database(self).await?; let lock_id = generate_lock_id(&database_name); // language=MySQL let _ = query("SELECT RELEASE_LOCK(?)") .bind(lock_id) .execute(self) .await?; Ok(()) }) } fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, ) -> BoxFuture<'m, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. // The `execution_time` however can only be measured for the whole transaction. This value _only_ exists for // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let mut tx = self.begin().await?; let start = Instant::now(); // For MySQL we cannot really isolate migrations due to implicit commits caused by table modification, see // https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html // // To somewhat try to detect this, we first insert the migration into the migration table with // `success=FALSE` and later modify the flag. // // language=MySQL let _ = query( r#" INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) VALUES ( ?, ?, FALSE, ?, -1 ) "#, ) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) .execute(&mut *tx) .await?; let _ = tx .execute(&*migration.sql) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=MySQL let _ = query( r#" UPDATE _sqlx_migrations SET success = TRUE WHERE version = ? "#, ) .bind(migration.version) .execute(&mut *tx) .await?; tx.commit().await?; // Update `elapsed_time`. // NOTE: The process may disconnect/die at this point, so the elapsed time value might be lost. We accept // this small risk since this value is not super important. let elapsed = start.elapsed(); #[allow(clippy::cast_possible_truncation)] let _ = query( r#" UPDATE _sqlx_migrations SET execution_time = ? WHERE version = ? "#, ) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) .await?; Ok(elapsed) }) } fn revert<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, ) -> BoxFuture<'m, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. let mut tx = self.begin().await?; let start = Instant::now(); // For MySQL we cannot really isolate migrations due to implicit commits caused by table modification, see // https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html // // To somewhat try to detect this, we first insert the migration into the migration table with // `success=FALSE` and later remove the migration altogether. // // language=MySQL let _ = query( r#" UPDATE _sqlx_migrations SET success = FALSE WHERE version = ? "#, ) .bind(migration.version) .execute(&mut *tx) .await?; tx.execute(&*migration.sql).await?; // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = ?"#) .bind(migration.version) .execute(&mut *tx) .await?; tx.commit().await?; let elapsed = start.elapsed(); Ok(elapsed) }) } } async fn current_database(conn: &mut MySqlConnection) -> Result { // language=MySQL Ok(query_scalar("SELECT DATABASE()").fetch_one(conn).await?) } // inspired from rails: https://github.com/rails/rails/blob/6e49cc77ab3d16c06e12f93158eaf3e507d4120e/activerecord/lib/active_record/migration.rb#L1308 fn generate_lock_id(database_name: &str) -> String { const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll format!( "{:x}", 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) ) } sqlx-mysql-0.8.3/src/options/connect.rs000064400000000000000000000061751046102023000162600ustar 00000000000000use crate::connection::ConnectOptions; use crate::error::Error; use crate::executor::Executor; use crate::{MySqlConnectOptions, MySqlConnection}; use futures_core::future::BoxFuture; use log::LevelFilter; use sqlx_core::Url; use std::time::Duration; impl ConnectOptions for MySqlConnectOptions { type Connection = MySqlConnection; fn from_url(url: &Url) -> Result { Self::parse_from_url(url) } fn to_url_lossy(&self) -> Url { self.build_url() } fn connect(&self) -> BoxFuture<'_, Result> where Self::Connection: Sized, { Box::pin(async move { let mut conn = MySqlConnection::establish(self).await?; // After the connection is established, we initialize by configuring a few // connection parameters // https://mariadb.com/kb/en/sql-mode/ // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. // This means that "A" || "B" can be used in place of CONCAT("A", "B"). // NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is // not available, a warning is given and the default storage // engine is used instead. // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. // -- // Setting the time zone allows us to assume that the output // from a TIMESTAMP field is UTC // -- // https://mathiasbynens.be/notes/mysql-utf8mb4 let mut sql_mode = Vec::new(); if self.pipes_as_concat { sql_mode.push(r#"PIPES_AS_CONCAT"#); } if self.no_engine_substitution { sql_mode.push(r#"NO_ENGINE_SUBSTITUTION"#); } let mut options = Vec::new(); if !sql_mode.is_empty() { options.push(format!( r#"sql_mode=(SELECT CONCAT(@@sql_mode, ',{}'))"#, sql_mode.join(",") )); } if let Some(timezone) = &self.timezone { options.push(format!(r#"time_zone='{}'"#, timezone)); } if self.set_names { options.push(format!( r#"NAMES {} COLLATE {}"#, conn.inner.stream.charset.as_str(), conn.inner.stream.collation.as_str() )) } if !options.is_empty() { conn.execute(&*format!(r#"SET {};"#, options.join(","))) .await?; } Ok(conn) }) } fn log_statements(mut self, level: LevelFilter) -> Self { self.log_settings.log_statements(level); self } fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { self.log_settings.log_slow_statements(level, duration); self } } sqlx-mysql-0.8.3/src/options/mod.rs000064400000000000000000000426671046102023000154140ustar 00000000000000use std::path::{Path, PathBuf}; mod connect; mod parse; mod ssl_mode; use crate::{connection::LogSettings, net::tls::CertificateInput}; pub use ssl_mode::MySqlSslMode; /// Options and flags which can be used to configure a MySQL connection. /// /// A value of `MySqlConnectOptions` can be parsed from a connection URL, /// as described by [MySQL](https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-reference-jdbc-url-format.html). /// /// The generic format of the connection URL: /// /// ```text /// mysql://[host][/database][?properties] /// ``` /// /// This type also implements [`FromStr`][std::str::FromStr] so you can parse it from a string /// containing a connection URL and then further adjust options if necessary (see example below). /// /// ## Properties /// /// |Parameter|Default|Description| /// |---------|-------|-----------| /// | `ssl-mode` | `PREFERRED` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`MySqlSslMode`]. | /// | `ssl-ca` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | /// | `socket` | `None` | Path to the unix domain socket, which will be used instead of TCP if set. | /// /// # Example /// /// ```rust,no_run /// # async fn example() -> sqlx::Result<()> { /// use sqlx::{Connection, ConnectOptions}; /// use sqlx::mysql::{MySqlConnectOptions, MySqlConnection, MySqlPool, MySqlSslMode}; /// /// // URL connection string /// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?; /// /// // Manually-constructed options /// let conn = MySqlConnectOptions::new() /// .host("localhost") /// .username("root") /// .password("password") /// .database("db") /// .connect().await?; /// /// // Modifying options parsed from a string /// let mut opts: MySqlConnectOptions = "mysql://root:password@localhost/db".parse()?; /// /// // Change the log verbosity level for queries. /// // Information about SQL queries is logged at `DEBUG` level by default. /// opts = opts.log_statements(log::LevelFilter::Trace); /// /// let pool = MySqlPool::connect_with(opts).await?; /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] pub struct MySqlConnectOptions { pub(crate) host: String, pub(crate) port: u16, pub(crate) socket: Option, pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, pub(crate) ssl_mode: MySqlSslMode, pub(crate) ssl_ca: Option, pub(crate) ssl_client_cert: Option, pub(crate) ssl_client_key: Option, pub(crate) statement_cache_capacity: usize, pub(crate) charset: String, pub(crate) collation: Option, pub(crate) log_settings: LogSettings, pub(crate) pipes_as_concat: bool, pub(crate) enable_cleartext_plugin: bool, pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, } impl Default for MySqlConnectOptions { fn default() -> Self { Self::new() } } impl MySqlConnectOptions { /// Creates a new, default set of options ready for configuration pub fn new() -> Self { Self { port: 3306, host: String::from("localhost"), socket: None, username: String::from("root"), password: None, database: None, charset: String::from("utf8mb4"), collation: None, ssl_mode: MySqlSslMode::Preferred, ssl_ca: None, ssl_client_cert: None, ssl_client_key: None, statement_cache_capacity: 100, log_settings: Default::default(), pipes_as_concat: true, enable_cleartext_plugin: false, no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, } } /// Sets the name of the host to connect to. /// /// The default behavior when the host is not specified, /// is to connect to localhost. pub fn host(mut self, host: &str) -> Self { host.clone_into(&mut self.host); self } /// Sets the port to connect to at the server host. /// /// The default port for MySQL is `3306`. pub fn port(mut self, port: u16) -> Self { self.port = port; self } /// Pass a path to a Unix socket. This changes the connection stream from /// TCP to UDS. /// /// By default set to `None`. pub fn socket(mut self, path: impl AsRef) -> Self { self.socket = Some(path.as_ref().to_path_buf()); self } /// Sets the username to connect as. pub fn username(mut self, username: &str) -> Self { username.clone_into(&mut self.username); self } /// Sets the password to connect with. pub fn password(mut self, password: &str) -> Self { self.password = Some(password.to_owned()); self } /// Sets the database name. pub fn database(mut self, database: &str) -> Self { self.database = Some(database.to_owned()); self } /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated /// with the server. /// /// By default, the SSL mode is [`Preferred`](MySqlSslMode::Preferred), and the client will /// first attempt an SSL connection but fallback to a non-SSL connection on failure. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::Required); /// ``` pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self { self.ssl_mode = mode; self } /// Sets the name of a file containing a list of trusted SSL Certificate Authorities. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_ca("path/to/ca.crt"); /// ``` pub fn ssl_ca(mut self, file_name: impl AsRef) -> Self { self.ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned())); self } /// Sets PEM encoded list of trusted SSL Certificate Authorities. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_ca_from_pem(vec![]); /// ``` pub fn ssl_ca_from_pem(mut self, pem_certificate: Vec) -> Self { self.ssl_ca = Some(CertificateInput::Inline(pem_certificate)); self } /// Sets the name of a file containing SSL client certificate. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_client_cert("path/to/client.crt"); /// ``` pub fn ssl_client_cert(mut self, cert: impl AsRef) -> Self { self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } /// Sets the SSL client certificate as a PEM-encoded byte slice. /// /// This should be an ASCII-encoded blob that starts with `-----BEGIN CERTIFICATE-----`. /// /// # Example /// Note: embedding SSL certificates and keys in the binary is not advised. /// This is for illustration purposes only. /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// /// const CERT: &[u8] = b"\ /// -----BEGIN CERTIFICATE----- /// /// -----END CERTIFICATE-----"; /// /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_client_cert_from_pem(CERT); /// ``` pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self { self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec())); self } /// Sets the name of a file containing SSL client key. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_client_key("path/to/client.key"); /// ``` pub fn ssl_client_key(mut self, key: impl AsRef) -> Self { self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf())); self } /// Sets the SSL client key as a PEM-encoded byte slice. /// /// This should be an ASCII-encoded blob that starts with `-----BEGIN PRIVATE KEY-----`. /// /// # Example /// Note: embedding SSL certificates and keys in the binary is not advised. /// This is for illustration purposes only. /// /// ```rust /// # use sqlx_mysql::{MySqlSslMode, MySqlConnectOptions}; /// /// const KEY: &[u8] = b"\ /// -----BEGIN PRIVATE KEY----- /// /// -----END PRIVATE KEY-----"; /// /// let options = MySqlConnectOptions::new() /// .ssl_mode(MySqlSslMode::VerifyCa) /// .ssl_client_key_from_pem(KEY); /// ``` pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self { self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec())); self } /// Sets the capacity of the connection's statement cache in a number of stored /// distinct statements. Caching is handled using LRU, meaning when the /// amount of queries hits the defined limit, the oldest statement will get /// dropped. /// /// The default cache capacity is 100 statements. pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { self.statement_cache_capacity = capacity; self } /// Sets the character set for the connection. /// /// The default character set is `utf8mb4`. This is supported from MySQL 5.5.3. /// If you need to connect to an older version, we recommend you to change this to `utf8`. pub fn charset(mut self, charset: &str) -> Self { charset.clone_into(&mut self.charset); self } /// Sets the collation for the connection. /// /// The default collation is derived from the `charset`. Normally, you should only have to set /// the `charset`. pub fn collation(mut self, collation: &str) -> Self { self.collation = Some(collation.to_owned()); self } /// Sets the flag that enables or disables the `PIPES_AS_CONCAT` connection setting /// /// The default value is set to true, but some MySql databases such as PlanetScale /// error out with this connection setting so it needs to be set false in such /// cases. pub fn pipes_as_concat(mut self, flag_val: bool) -> Self { self.pipes_as_concat = flag_val; self } /// Enables mysql_clear_password plugin support. /// /// Security Note: /// Sending passwords as cleartext may be a security problem in some /// configurations. Without additional defensive configuration like /// ssl-mode=VERIFY_IDENTITY, an attacker can compromise a router /// and trick the application into divulging its credentials. /// /// It is strongly recommended to set `.ssl_mode` to `Required`, /// `VerifyCa`, or `VerifyIdentity` when enabling cleartext plugin. pub fn enable_cleartext_plugin(mut self, flag_val: bool) -> Self { self.enable_cleartext_plugin = flag_val; self } #[deprecated = "renamed to .no_engine_substitution()"] pub fn no_engine_subsitution(self, flag_val: bool) -> Self { self.no_engine_substitution(flag_val) } /// Flag that enables or disables the `NO_ENGINE_SUBSTITUTION` sql_mode setting after /// connection. /// /// If not set, if the available storage engine specified by a `CREATE TABLE` is not available, /// a warning is given and the default storage engine is used instead. /// /// By default, this is `true` (`NO_ENGINE_SUBSTITUTION` is passed, forbidding engine /// substitution). /// /// pub fn no_engine_substitution(mut self, flag_val: bool) -> Self { self.no_engine_substitution = flag_val; self } /// If `Some`, sets the `time_zone` option to the given string after connecting to the database. /// /// If `None`, no `time_zone` parameter is sent; the server timezone will be used instead. /// /// Defaults to `Some(String::from("+00:00"))` to ensure all timestamps are in UTC. /// /// ### Warning /// Changing this setting from its default will apply an unexpected skew to any /// `time::OffsetDateTime` or `chrono::DateTime` value, whether passed as a parameter or /// decoded as a result. `TIMESTAMP` values are not encoded with their UTC offset in the MySQL /// protocol, so encoding and decoding of these types assumes the server timezone is *always* /// UTC. /// /// If you are changing this option, ensure your application only uses /// `time::PrimitiveDateTime` or `chrono::NaiveDateTime` and that it does not assume these /// timestamps can be placed on a real timeline without applying the proper offset. pub fn timezone(mut self, value: impl Into>) -> Self { self.timezone = value.into(); self } /// If enabled, `SET NAMES '{charset}' COLLATE '{collation}'` is passed with the values of /// [`.charset()`] and [`.collation()`] after connecting to the database. /// /// This ensures the connection uses the specified character set and collation. /// /// Enabled by default. /// /// ### Warning /// If this is disabled and the default charset is not binary-compatible with UTF-8, query /// strings, column names and string values will likely not decode (or encode) correctly, which /// may result in unexpected errors or garbage outputs at runtime. /// /// For proper functioning, you *must* ensure the server is using a binary-compatible charset, /// such as ASCII or Latin-1 (ISO 8859-1), and that you do not pass any strings containing /// codepoints not supported by said charset. /// /// Instead of disabling this, you may also consider setting [`.charset()`] to a charset that /// is supported by your MySQL or MariaDB server version and compatible with UTF-8. pub fn set_names(mut self, flag_val: bool) -> Self { self.set_names = flag_val; self } } impl MySqlConnectOptions { /// Get the current host. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .host("127.0.0.1"); /// assert_eq!(options.get_host(), "127.0.0.1"); /// ``` pub fn get_host(&self) -> &str { &self.host } /// Get the server's port. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .port(6543); /// assert_eq!(options.get_port(), 6543); /// ``` pub fn get_port(&self) -> u16 { self.port } /// Get the socket path. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .socket("/tmp"); /// assert!(options.get_socket().is_some()); /// ``` pub fn get_socket(&self) -> Option<&PathBuf> { self.socket.as_ref() } /// Get the server's port. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .username("foo"); /// assert_eq!(options.get_username(), "foo"); /// ``` pub fn get_username(&self) -> &str { &self.username } /// Get the current database name. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .database("postgres"); /// assert!(options.get_database().is_some()); /// ``` pub fn get_database(&self) -> Option<&str> { self.database.as_deref() } /// Get the SSL mode. /// /// # Example /// /// ```rust /// # use sqlx_mysql::{MySqlConnectOptions, MySqlSslMode}; /// let options = MySqlConnectOptions::new(); /// assert!(matches!(options.get_ssl_mode(), MySqlSslMode::Preferred)); /// ``` pub fn get_ssl_mode(&self) -> MySqlSslMode { self.ssl_mode } /// Get the server charset. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new(); /// assert_eq!(options.get_charset(), "utf8mb4"); /// ``` pub fn get_charset(&self) -> &str { &self.charset } /// Get the server collation. /// /// # Example /// /// ```rust /// # use sqlx_mysql::MySqlConnectOptions; /// let options = MySqlConnectOptions::new() /// .collation("collation"); /// assert!(options.get_collation().is_some()); /// ``` pub fn get_collation(&self) -> Option<&str> { self.collation.as_deref() } } sqlx-mysql-0.8.3/src/options/parse.rs000064400000000000000000000137061046102023000157370ustar 00000000000000use std::str::FromStr; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; use crate::{error::Error, MySqlSslMode}; use super::MySqlConnectOptions; impl MySqlConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { let mut options = Self::new(); if let Some(host) = url.host_str() { options = options.host(host); } if let Some(port) = url.port() { options = options.port(port); } let username = url.username(); if !username.is_empty() { options = options.username( &percent_decode_str(username) .decode_utf8() .map_err(Error::config)?, ); } if let Some(password) = url.password() { options = options.password( &percent_decode_str(password) .decode_utf8() .map_err(Error::config)?, ); } let path = url.path().trim_start_matches('/'); if !path.is_empty() { options = options.database( &percent_decode_str(path) .decode_utf8() .map_err(Error::config)?, ); } for (key, value) in url.query_pairs().into_iter() { match &*key { "sslmode" | "ssl-mode" => { options = options.ssl_mode(value.parse().map_err(Error::config)?); } "sslca" | "ssl-ca" => { options = options.ssl_ca(&*value); } "charset" => { options = options.charset(&value); } "collation" => { options = options.collation(&value); } "sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value), "sslkey" | "ssl-key" => options = options.ssl_client_key(&*value), "statement-cache-capacity" => { options = options.statement_cache_capacity(value.parse().map_err(Error::config)?); } "socket" => { options = options.socket(&*value); } "timezone" | "time-zone" => { options = options.timezone(Some(value.to_string())); } _ => {} } } Ok(options) } pub(crate) fn build_url(&self) -> Url { let mut url = Url::parse(&format!( "mysql://{}@{}:{}", self.username, self.host, self.port )) .expect("BUG: generated un-parseable URL"); if let Some(password) = &self.password { let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string(); let _ = url.set_password(Some(&password)); } if let Some(database) = &self.database { url.set_path(database); } let ssl_mode = match self.ssl_mode { MySqlSslMode::Disabled => "DISABLED", MySqlSslMode::Preferred => "PREFERRED", MySqlSslMode::Required => "REQUIRED", MySqlSslMode::VerifyCa => "VERIFY_CA", MySqlSslMode::VerifyIdentity => "VERIFY_IDENTITY", }; url.query_pairs_mut().append_pair("ssl-mode", ssl_mode); if let Some(ssl_ca) = &self.ssl_ca { url.query_pairs_mut() .append_pair("ssl-ca", &ssl_ca.to_string()); } url.query_pairs_mut().append_pair("charset", &self.charset); if let Some(collation) = &self.collation { url.query_pairs_mut().append_pair("charset", collation); } if let Some(ssl_client_cert) = &self.ssl_client_cert { url.query_pairs_mut() .append_pair("ssl-cert", &ssl_client_cert.to_string()); } if let Some(ssl_client_key) = &self.ssl_client_key { url.query_pairs_mut() .append_pair("ssl-key", &ssl_client_key.to_string()); } url.query_pairs_mut().append_pair( "statement-cache-capacity", &self.statement_cache_capacity.to_string(), ); if let Some(socket) = &self.socket { url.query_pairs_mut() .append_pair("socket", &socket.to_string_lossy()); } url } } impl FromStr for MySqlConnectOptions { type Err = Error; fn from_str(s: &str) -> Result { let url: Url = s.parse().map_err(Error::config)?; Self::parse_from_url(&url) } } #[test] fn it_parses_username_with_at_sign_correctly() { let url = "mysql://user@hostname:password@hostname:5432/database"; let opts = MySqlConnectOptions::from_str(url).unwrap(); assert_eq!("user@hostname", &opts.username); } #[test] fn it_parses_password_with_non_ascii_chars_correctly() { let url = "mysql://username:p@ssw0rd@hostname:5432/database"; let opts = MySqlConnectOptions::from_str(url).unwrap(); assert_eq!(Some("p@ssw0rd".into()), opts.password); } #[test] fn it_returns_the_parsed_url() { let url = "mysql://username:p@ssw0rd@hostname:3306/database"; let opts = MySqlConnectOptions::from_str(url).unwrap(); let mut expected_url = Url::parse(url).unwrap(); // MySqlConnectOptions defaults let query_string = "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; expected_url.set_query(Some(query_string)); assert_eq!(expected_url, opts.build_url()); } #[test] fn it_parses_timezone() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?timezone=%2B08:00" .parse() .unwrap(); assert_eq!(opts.timezone.as_deref(), Some("+08:00")); let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?time-zone=%2B08:00" .parse() .unwrap(); assert_eq!(opts.timezone.as_deref(), Some("+08:00")); } sqlx-mysql-0.8.3/src/options/ssl_mode.rs000064400000000000000000000036431046102023000164310ustar 00000000000000use crate::error::Error; use std::str::FromStr; /// Options for controlling the desired security state of the connection to the MySQL server. /// /// It is used by the [`ssl_mode`](super::MySqlConnectOptions::ssl_mode) method. #[derive(Debug, Clone, Copy, Default)] pub enum MySqlSslMode { /// Establish an unencrypted connection. Disabled, /// Establish an encrypted connection if the server supports encrypted connections, falling /// back to an unencrypted connection if an encrypted connection cannot be established. /// /// This is the default if `ssl_mode` is not specified. #[default] Preferred, /// Establish an encrypted connection if the server supports encrypted connections. /// The connection attempt fails if an encrypted connection cannot be established. Required, /// Like `Required`, but additionally verify the server Certificate Authority (CA) /// certificate against the configured CA certificates. The connection attempt fails /// if no valid matching CA certificates are found. VerifyCa, /// Like `VerifyCa`, but additionally perform host name identity verification by /// checking the host name the client uses for connecting to the server against the /// identity in the certificate that the server sends to the client. VerifyIdentity, } impl FromStr for MySqlSslMode { type Err = Error; fn from_str(s: &str) -> Result { Ok(match &*s.to_ascii_lowercase() { "disabled" => MySqlSslMode::Disabled, "preferred" => MySqlSslMode::Preferred, "required" => MySqlSslMode::Required, "verify_ca" => MySqlSslMode::VerifyCa, "verify_identity" => MySqlSslMode::VerifyIdentity, _ => { return Err(Error::Configuration( format!("unknown value {s:?} for `ssl_mode`").into(), )); } }) } } sqlx-mysql-0.8.3/src/protocol/auth.rs000064400000000000000000000022601046102023000157250ustar 00000000000000use std::str::FromStr; use crate::error::Error; #[derive(Debug, Copy, Clone)] // These have all the same suffix but they match the auth plugin names. #[allow(clippy::enum_variant_names)] pub enum AuthPlugin { MySqlNativePassword, CachingSha2Password, Sha256Password, MySqlClearPassword, } impl AuthPlugin { pub(crate) fn name(self) -> &'static str { match self { AuthPlugin::MySqlNativePassword => "mysql_native_password", AuthPlugin::CachingSha2Password => "caching_sha2_password", AuthPlugin::Sha256Password => "sha256_password", AuthPlugin::MySqlClearPassword => "mysql_clear_password", } } } impl FromStr for AuthPlugin { type Err = Error; fn from_str(s: &str) -> Result { match s { "mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword), "caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password), "sha256_password" => Ok(AuthPlugin::Sha256Password), "mysql_clear_password" => Ok(AuthPlugin::MySqlClearPassword), _ => Err(err_protocol!("unknown authentication plugin: {}", s)), } } } sqlx-mysql-0.8.3/src/protocol/capabilities.rs000064400000000000000000000073161046102023000174240ustar 00000000000000// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html // https://mariadb.com/kb/en/library/connection/#capabilities // // MySQL defines the capabilities flags as fitting in an `int<4>` but MariaDB // extends this with more bits sent in a separate field. // For simplicity, we've chosen to combine these into one type. bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Capabilities: u64 { // [MariaDB] MySQL compatibility const MYSQL = 1; // [*] Send found rows instead of affected rows in EOF_Packet. const FOUND_ROWS = 2; // Get all column flags. const LONG_FLAG = 4; // [*] Database (schema) name can be specified on connect in Handshake Response Packet. const CONNECT_WITH_DB = 8; // Don't allow database.table.column const NO_SCHEMA = 16; // [*] Compression protocol supported const COMPRESS = 32; // Special handling of ODBC behavior. const ODBC = 64; // Can use LOAD DATA LOCAL const LOCAL_FILES = 128; // [*] Ignore spaces before '(' const IGNORE_SPACE = 256; // [*] New 4.1+ protocol const PROTOCOL_41 = 512; // This is an interactive client const INTERACTIVE = 1024; // Use SSL encryption for this session const SSL = 2048; // Client knows about transactions const TRANSACTIONS = 8192; // 4.1+ authentication const SECURE_CONNECTION = 1 << 15; // Enable/disable multi-statement support for COM_QUERY *and* COM_STMT_PREPARE const MULTI_STATEMENTS = 1 << 16; // Enable/disable multi-results for COM_QUERY const MULTI_RESULTS = 1 << 17; // Enable/disable multi-results for COM_STMT_PREPARE const PS_MULTI_RESULTS = 1 << 18; // Client supports plugin authentication const PLUGIN_AUTH = 1 << 19; // Client supports connection attributes const CONNECT_ATTRS = 1 << 20; // Enable authentication response packet to be larger than 255 bytes. const PLUGIN_AUTH_LENENC_DATA = 1 << 21; // Don't close the connection for a user account with expired password. const CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22; // Capable of handling server state change information. const SESSION_TRACK = 1 << 23; // Client no longer needs EOF_Packet and will use OK_Packet instead. const DEPRECATE_EOF = 1 << 24; // Support ZSTD protocol compression const ZSTD_COMPRESSION_ALGORITHM = 1 << 26; // Verify server certificate const SSL_VERIFY_SERVER_CERT = 1 << 30; // The client can handle optional metadata information in the resultset const OPTIONAL_RESULTSET_METADATA = 1 << 25; // Don't reset the options after an unsuccessful connect const REMEMBER_OPTIONS = 1 << 31; // Extended capabilities (MariaDB only, as of writing) // Client support progress indicator (since 10.2) const MARIADB_CLIENT_PROGRESS = 1 << 32; // Permit COM_MULTI protocol const MARIADB_CLIENT_MULTI = 1 << 33; // Permit bulk insert const MARIADB_CLIENT_STMT_BULK_OPERATIONS = 1 << 34; // Add extended metadata information const MARIADB_CLIENT_EXTENDED_TYPE_INFO = 1 << 35; // Permit skipping metadata const MARIADB_CLIENT_CACHE_METADATA = 1 << 36; // when enabled, indicate that Bulk command can use STMT_BULK_FLAG_SEND_UNIT_RESULTS flag // that permit to return a result-set of all affected rows and auto-increment values const MARIADB_CLIENT_BULK_UNIT_RESULTS = 1 << 37; } } sqlx-mysql-0.8.3/src/protocol/connect/auth_switch.rs000064400000000000000000000067101046102023000207430ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::ProtocolEncode; use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html #[derive(Debug)] pub struct AuthSwitchRequest { pub plugin: AuthPlugin, pub data: Bytes, } impl ProtocolDecode<'_, bool> for AuthSwitchRequest { fn decode_with(mut buf: Bytes, enable_cleartext_plugin: bool) -> Result { let header = buf.get_u8(); if header != 0xfe { return Err(err_protocol!( "expected 0xfe (AUTH_SWITCH) but found 0x{:x}", header )); } let plugin = buf.get_str_nul()?.parse()?; if matches!(plugin, AuthPlugin::MySqlClearPassword) && !enable_cleartext_plugin { return Err(err_protocol!("mysql_cleartext_plugin disabled")); } if matches!(plugin, AuthPlugin::MySqlClearPassword) && buf.is_empty() { // Contrary to the MySQL protocol, AWS Aurora with IAM sends // no data. That is fine because the mysql_clear_password says to // ignore any data sent. // See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_authentication_methods_clear_text_password.html return Ok(Self { plugin, data: Bytes::new(), }); } // See: https://github.com/mysql/mysql-server/blob/ea7d2e2d16ac03afdd9cb72a972a95981107bf51/sql/auth/sha2_password.cc#L942 if buf.len() != 21 { return Err(err_protocol!( "expected 21 bytes but found {} bytes", buf.len() )); } let data = buf.get_bytes(20); buf.advance(1); // NUL-terminator Ok(Self { plugin, data }) } } #[derive(Debug)] pub struct AuthSwitchResponse(pub Vec); impl ProtocolEncode<'_, Capabilities> for AuthSwitchResponse { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), Error> { buf.extend_from_slice(&self.0); Ok(()) } } #[test] fn test_decode_auth_switch_packet_data() { const AUTH_SWITCH_NO_DATA: &[u8] = b"\xfecaching_sha2_password\x00abcdefghijabcdefghij\x00"; let p = AuthSwitchRequest::decode_with(AUTH_SWITCH_NO_DATA.into(), true).unwrap(); assert!(matches!(p.plugin, AuthPlugin::CachingSha2Password)); assert_eq!(p.data, &b"abcdefghijabcdefghij"[..]); } #[test] fn test_decode_auth_switch_cleartext_disabled() { const AUTH_SWITCH_CLEARTEXT: &[u8] = b"\xfemysql_clear_password\x00abcdefghijabcdefghij\x00"; let e = AuthSwitchRequest::decode_with(AUTH_SWITCH_CLEARTEXT.into(), false).unwrap_err(); let e_str = e.to_string(); let expected = "encountered unexpected or invalid data: mysql_cleartext_plugin disabled"; assert!( // Don't want to assert the full string since it contains the module path now. e_str.starts_with(expected), "expected error string to start with {expected:?}, got {e_str:?}" ); } #[test] fn test_decode_auth_switch_packet_no_data() { const AUTH_SWITCH_NO_DATA: &[u8] = b"\xfemysql_clear_password\x00"; let p = AuthSwitchRequest::decode_with(AUTH_SWITCH_NO_DATA.into(), true).unwrap(); assert!(matches!(p.plugin, AuthPlugin::MySqlClearPassword)); assert_eq!(p.data, Bytes::new()); } sqlx-mysql-0.8.3/src/protocol/connect/handshake.rs000064400000000000000000000157051046102023000203530ustar 00000000000000use bytes::buf::Chain; use bytes::{Buf, Bytes}; use std::cmp; use crate::error::Error; use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::response::Status; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake // https://mariadb.com/kb/en/connection/#initial-handshake-packet #[derive(Debug)] pub(crate) struct Handshake { #[allow(unused)] pub(crate) protocol_version: u8, pub(crate) server_version: String, #[allow(unused)] pub(crate) connection_id: u32, pub(crate) server_capabilities: Capabilities, #[allow(unused)] pub(crate) server_default_collation: u8, #[allow(unused)] pub(crate) status: Status, pub(crate) auth_plugin: Option, pub(crate) auth_plugin_data: Chain, } impl ProtocolDecode<'_> for Handshake { fn decode_with(mut buf: Bytes, _: ()) -> Result { let protocol_version = buf.get_u8(); // int<1> let server_version = buf.get_str_nul()?; // string let connection_id = buf.get_u32_le(); // int<4> let auth_plugin_data_1 = buf.get_bytes(8); // string<8> buf.advance(1); // reserved: string<1> let capabilities_1 = buf.get_u16_le(); // int<2> let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); let collation = buf.get_u8(); // int<1> let status = Status::from_bits_truncate(buf.get_u16_le()); let capabilities_2 = buf.get_u16_le(); // int<2> capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) { buf.get_u8() } else { buf.advance(1); // int<1> 0 }; buf.advance(6); // reserved: string<6> if capabilities.contains(Capabilities::MYSQL) { buf.advance(4); // reserved: string<4> } else { let capabilities_3 = buf.get_u32_le(); // int<4> capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); } let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { let len = cmp::max(auth_plugin_data_len.saturating_sub(9), 12); let v = buf.get_bytes(len as usize); buf.advance(1); // NUL-terminator v } else { Bytes::new() }; let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { Some(buf.get_str_nul()?.parse()?) } else { None }; Ok(Self { protocol_version, server_version, connection_id, server_default_collation: collation, status, server_capabilities: capabilities, auth_plugin, auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2), }) } } #[test] fn test_decode_handshake_mysql_8_0_18() { const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; let p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); assert_eq!(p.protocol_version, 10); assert_eq!( p.server_capabilities, Capabilities::MYSQL | Capabilities::FOUND_ROWS | Capabilities::LONG_FLAG | Capabilities::CONNECT_WITH_DB | Capabilities::NO_SCHEMA | Capabilities::COMPRESS | Capabilities::ODBC | Capabilities::LOCAL_FILES | Capabilities::IGNORE_SPACE | Capabilities::PROTOCOL_41 | Capabilities::INTERACTIVE | Capabilities::SSL | Capabilities::TRANSACTIONS | Capabilities::SECURE_CONNECTION | Capabilities::MULTI_STATEMENTS | Capabilities::MULTI_RESULTS | Capabilities::PS_MULTI_RESULTS | Capabilities::PLUGIN_AUTH | Capabilities::CONNECT_ATTRS | Capabilities::PLUGIN_AUTH_LENENC_DATA | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS | Capabilities::SESSION_TRACK | Capabilities::DEPRECATE_EOF | Capabilities::ZSTD_COMPRESSION_ALGORITHM | Capabilities::SSL_VERIFY_SERVER_CERT | Capabilities::OPTIONAL_RESULTSET_METADATA | Capabilities::REMEMBER_OPTIONS, ); assert_eq!(p.server_default_collation, 255); assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); assert!(matches!( p.auth_plugin, Some(AuthPlugin::CachingSha2Password) )); assert_eq!( &*p.auth_plugin_data.into_iter().collect::>(), &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,] ); } #[test] fn test_decode_handshake_mariadb_10_4_7() { const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\">(), &[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,] ); } sqlx-mysql-0.8.3/src/protocol/connect/handshake_response.rs000064400000000000000000000050751046102023000222700ustar 00000000000000use crate::io::MySqlBufMutExt; use crate::io::{BufMutExt, ProtocolEncode}; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse // https://mariadb.com/kb/en/connection/#client-handshake-response #[derive(Debug)] pub struct HandshakeResponse<'a> { pub database: Option<&'a str>, /// Max size of a command packet that the client wants to send to the server pub max_packet_size: u32, /// Default collation for the connection pub collation: u8, /// Name of the SQL account which client wants to log in pub username: &'a str, /// Authentication method used by the client pub auth_plugin: Option, /// Opaque authentication response pub auth_response: Option<&'a [u8]>, } impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { fn encode_with( &self, buf: &mut Vec, mut context: Capabilities, ) -> Result<(), crate::Error> { if self.auth_plugin.is_none() { // ensure PLUGIN_AUTH is set *only* if we have a defined plugin context.remove(Capabilities::PLUGIN_AUTH); } // NOTE: Half of this packet is identical to the SSL Request packet SslRequest { max_packet_size: self.max_packet_size, collation: self.collation, } .encode_with(buf, context)?; buf.put_str_nul(self.username); if context.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { buf.put_bytes_lenenc(self.auth_response.unwrap_or_default()); } else if context.contains(Capabilities::SECURE_CONNECTION) { let response = self.auth_response.unwrap_or_default(); let response_len = u8::try_from(response.len()) .map_err(|_| err_protocol!("auth_response.len() too long: {}", response.len()))?; buf.push(response_len); buf.extend(response); } else { buf.push(0); } if context.contains(Capabilities::CONNECT_WITH_DB) { if let Some(database) = &self.database { buf.put_str_nul(database); } else { buf.push(0); } } if context.contains(Capabilities::PLUGIN_AUTH) { if let Some(plugin) = &self.auth_plugin { buf.put_str_nul(plugin.name()); } else { buf.push(0); } } Ok(()) } } sqlx-mysql-0.8.3/src/protocol/connect/mod.rs000064400000000000000000000005571046102023000172030ustar 00000000000000//! Connection Phase //! //! mod auth_switch; mod handshake; mod handshake_response; mod ssl_request; pub(crate) use auth_switch::{AuthSwitchRequest, AuthSwitchResponse}; pub(crate) use handshake::Handshake; pub(crate) use handshake_response::HandshakeResponse; pub(crate) use ssl_request::SslRequest; sqlx-mysql-0.8.3/src/protocol/connect/ssl_request.rs000064400000000000000000000022131046102023000207640ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest #[derive(Debug)] pub struct SslRequest { pub max_packet_size: u32, pub collation: u8, } impl ProtocolEncode<'_, Capabilities> for SslRequest { fn encode_with(&self, buf: &mut Vec, context: Capabilities) -> Result<(), crate::Error> { // truncation is intended #[allow(clippy::cast_possible_truncation)] buf.extend(&(context.bits() as u32).to_le_bytes()); buf.extend(&self.max_packet_size.to_le_bytes()); buf.push(self.collation); // reserved: string<19> buf.extend(&[0_u8; 19]); if context.contains(Capabilities::MYSQL) { // reserved: string<4> buf.extend(&[0_u8; 4]); } else { // extended client capabilities (MariaDB-specified): int<4> buf.extend(&((context.bits() >> 32) as u32).to_le_bytes()); } Ok(()) } } sqlx-mysql-0.8.3/src/protocol/mod.rs000064400000000000000000000004001046102023000155350ustar 00000000000000pub(crate) mod auth; mod capabilities; pub(crate) mod connect; mod packet; pub(crate) mod response; mod row; pub(crate) mod statement; pub(crate) mod text; pub(crate) use capabilities::Capabilities; pub(crate) use packet::Packet; pub(crate) use row::Row; sqlx-mysql-0.8.3/src/protocol/packet.rs000064400000000000000000000064301046102023000162360ustar 00000000000000use std::cmp::min; use std::ops::{Deref, DerefMut}; use bytes::Bytes; use crate::error::Error; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::protocol::response::{EofPacket, OkPacket}; use crate::protocol::Capabilities; #[derive(Debug)] pub struct Packet(pub(crate) T); impl<'en, 'stream, T> ProtocolEncode<'stream, (Capabilities, &'stream mut u8)> for Packet where T: ProtocolEncode<'en, Capabilities>, { fn encode_with( &self, buf: &mut Vec, (capabilities, sequence_id): (Capabilities, &'stream mut u8), ) -> Result<(), Error> { let mut next_header = |len: u32| { let mut buf = len.to_le_bytes(); buf[3] = *sequence_id; *sequence_id = sequence_id.wrapping_add(1); buf }; // reserve space to write the prefixed length let offset = buf.len(); buf.extend(&[0_u8; 4]); // encode the payload self.0.encode_with(buf, capabilities)?; // determine the length of the encoded payload // and write to our reserved space let len = buf.len() - offset - 4; let header = &mut buf[offset..]; // // `min(.., 0xFF_FF_FF)` cannot overflow #[allow(clippy::cast_possible_truncation)] header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32)); // add more packets if we need to split the data if len >= 0xFF_FF_FF { let rest = buf.split_off(offset + 4 + 0xFF_FF_FF); let mut chunks = rest.chunks_exact(0xFF_FF_FF); for chunk in chunks.by_ref() { buf.reserve(chunk.len() + 4); // `chunk.len() == 0xFF_FF_FF` #[allow(clippy::cast_possible_truncation)] buf.extend(&next_header(chunk.len() as u32)); buf.extend(chunk); } // this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF let remainder = chunks.remainder(); buf.reserve(remainder.len() + 4); // `remainder.len() < 0xFF_FF_FF` #[allow(clippy::cast_possible_truncation)] buf.extend(&next_header(remainder.len() as u32)); buf.extend(remainder); } Ok(()) } } impl Packet { pub(crate) fn decode<'de, T>(self) -> Result where T: ProtocolDecode<'de, ()>, { self.decode_with(()) } pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result where T: ProtocolDecode<'de, C>, { T::decode_with(self.0, context) } pub(crate) fn ok(self) -> Result { self.decode() } pub(crate) fn eof(self, capabilities: Capabilities) -> Result { if capabilities.contains(Capabilities::DEPRECATE_EOF) { let ok = self.ok()?; Ok(EofPacket { warnings: ok.warnings, status: ok.status, }) } else { self.decode_with(capabilities) } } } impl Deref for Packet { type Target = Bytes; fn deref(&self) -> &Bytes { &self.0 } } impl DerefMut for Packet { fn deref_mut(&mut self) -> &mut Bytes { &mut self.0 } } sqlx-mysql-0.8.3/src/protocol/response/eof.rs000064400000000000000000000017131046102023000173750ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::ProtocolDecode; use crate::protocol::response::Status; use crate::protocol::Capabilities; /// Marks the end of a result set, returning status and warnings. /// /// # Note /// /// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL /// prior MySQL versions. #[derive(Debug)] pub struct EofPacket { #[allow(dead_code)] pub warnings: u16, pub status: Status, } impl ProtocolDecode<'_, Capabilities> for EofPacket { fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { let header = buf.get_u8(); if header != 0xfe { return Err(err_protocol!( "expected 0xfe (EOF_Packet) but found 0x{:x}", header )); } let warnings = buf.get_u16_le(); let status = Status::from_bits_truncate(buf.get_u16_le()); Ok(Self { status, warnings }) } } sqlx-mysql-0.8.3/src/protocol/response/err.rs000064400000000000000000000041031046102023000174100ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html // https://mariadb.com/kb/en/err_packet/ /// Indicates that an error occurred. #[derive(Debug)] pub struct ErrPacket { pub error_code: u16, pub sql_state: Option, pub error_message: String, } impl ProtocolDecode<'_, Capabilities> for ErrPacket { fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result { let header = buf.get_u8(); if header != 0xff { return Err(err_protocol!( "expected 0xff (ERR_Packet) but found 0x{:x}", header )); } let error_code = buf.get_u16_le(); let mut sql_state = None; if capabilities.contains(Capabilities::PROTOCOL_41) { // If the next byte is '#' then we have a SQL STATE if buf.starts_with(b"#") { buf.advance(1); sql_state = Some(buf.get_str(5)?); } } let error_message = buf.get_str(buf.len())?; Ok(Self { error_code, sql_state, error_message, }) } } #[test] fn test_decode_err_packet_out_of_order() { const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; let p = ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap(); assert_eq!(&p.error_message, "Got packets out of order"); assert_eq!(p.error_code, 1156); assert_eq!(p.sql_state, None); } #[test] fn test_decode_err_packet_unknown_database() { const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; let p = ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap(); assert_eq!(p.error_code, 1049); assert_eq!(p.sql_state.as_deref(), Some("42000")); assert_eq!(&p.error_message, "Unknown database \'unknown\'"); } sqlx-mysql-0.8.3/src/protocol/response/mod.rs000064400000000000000000000004551046102023000174050ustar 00000000000000//! Generic Response Packets //! //! //! mod eof; mod err; mod ok; mod status; pub use eof::EofPacket; pub use err::ErrPacket; pub use ok::OkPacket; pub use status::Status; sqlx-mysql-0.8.3/src/protocol/response/ok.rs000064400000000000000000000026541046102023000172420ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::ProtocolDecode; use crate::protocol::response::Status; /// Indicates successful completion of a previous command sent by the client. #[derive(Debug)] pub struct OkPacket { pub affected_rows: u64, pub last_insert_id: u64, pub status: Status, pub warnings: u16, } impl ProtocolDecode<'_> for OkPacket { fn decode_with(mut buf: Bytes, _: ()) -> Result { let header = buf.get_u8(); if header != 0 && header != 0xfe { return Err(err_protocol!( "expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}", header )); } let affected_rows = buf.get_uint_lenenc(); let last_insert_id = buf.get_uint_lenenc(); let status = Status::from_bits_truncate(buf.get_u16_le()); let warnings = buf.get_u16_le(); Ok(Self { affected_rows, last_insert_id, status, warnings, }) } } #[test] fn test_decode_ok_packet() { const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; let p = OkPacket::decode(DATA.into()).unwrap(); assert_eq!(p.affected_rows, 0); assert_eq!(p.last_insert_id, 0); assert_eq!(p.warnings, 0); assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED)); } sqlx-mysql-0.8.3/src/protocol/response/status.rs000064400000000000000000000041361046102023000201510ustar 00000000000000// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first // transactional statement, when autocommit=off. const SERVER_STATUS_IN_TRANS = 1; // Autocommit mode is set const SERVER_STATUS_AUTOCOMMIT = 2; // Multi query - next query exists. const SERVER_MORE_RESULTS_EXISTS = 8; const SERVER_QUERY_NO_GOOD_INDEX_USED = 16; const SERVER_QUERY_NO_INDEX_USED = 32; // When using COM_STMT_FETCH, indicate that current cursor still has result const SERVER_STATUS_CURSOR_EXISTS = 64; // When using COM_STMT_FETCH, indicate that current cursor has finished to send results const SERVER_STATUS_LAST_ROW_SENT = 128; // Database has been dropped const SERVER_STATUS_DB_DROPPED = (1 << 8); // Current escape mode is "no backslash escape" const SERVER_STATUS_NO_BACKSLASH_ESCAPES = (1 << 9); // A DDL change did have an impact on an existing PREPARE (an automatic // re-prepare has been executed) const SERVER_STATUS_METADATA_CHANGED = (1 << 10); // Last statement took more than the time value specified // in server variable long_query_time. const SERVER_QUERY_WAS_SLOW = (1 << 11); // This result-set contain stored procedure output parameter. const SERVER_PS_OUT_PARAMS = (1 << 12); // Current transaction is a read-only transaction. const SERVER_STATUS_IN_TRANS_READONLY = (1 << 13); // This status flag, when on, implies that one of the state information has changed // on the server because of the execution of the last statement. const SERVER_SESSION_STATE_CHANGED = (1 << 14); } } sqlx-mysql-0.8.3/src/protocol/row.rs000064400000000000000000000004701046102023000155740ustar 00000000000000use std::ops::Range; use bytes::Bytes; #[derive(Debug)] pub(crate) struct Row { pub(crate) storage: Bytes, pub(crate) values: Vec>>, } impl Row { pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { self.values[index].clone().map(|col| &self.storage[col]) } } sqlx-mysql-0.8.3/src/protocol/statement/execute.rs000064400000000000000000000022571046102023000204400ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::text::ColumnFlags; use crate::protocol::Capabilities; use crate::MySqlArguments; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html #[derive(Debug)] pub struct Execute<'q> { pub statement: u32, pub arguments: &'q MySqlArguments, } impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x17); // COM_STMT_EXECUTE buf.extend(&self.statement.to_le_bytes()); buf.push(0); // NO_CURSOR buf.extend(&1_u32.to_le_bytes()); // iterations (always 1): int<4> if !self.arguments.types.is_empty() { buf.extend_from_slice(&self.arguments.null_bitmap); buf.push(1); // send type to server for ty in &self.arguments.types { buf.push(ty.r#type as u8); buf.push(if ty.flags.contains(ColumnFlags::UNSIGNED) { 0x80 } else { 0 }); } buf.extend(&*self.arguments.values); } Ok(()) } } sqlx-mysql-0.8.3/src/protocol/statement/mod.rs000064400000000000000000000003611046102023000175470ustar 00000000000000mod execute; mod prepare; mod prepare_ok; mod row; mod stmt_close; pub(crate) use execute::Execute; pub(crate) use prepare::Prepare; pub(crate) use prepare_ok::PrepareOk; pub(crate) use row::BinaryRow; pub(crate) use stmt_close::StmtClose; sqlx-mysql-0.8.3/src/protocol/statement/prepare.rs000064400000000000000000000007221046102023000204270ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE pub struct Prepare<'a> { pub query: &'a str, } impl ProtocolEncode<'_, Capabilities> for Prepare<'_> { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x16); // COM_STMT_PREPARE buf.extend(self.query.as_bytes()); Ok(()) } } sqlx-mysql-0.8.3/src/protocol/statement/prepare_ok.rs000064400000000000000000000024301046102023000211160ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::ProtocolDecode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK #[derive(Debug)] pub(crate) struct PrepareOk { pub(crate) statement_id: u32, pub(crate) columns: u16, pub(crate) params: u16, #[allow(unused)] pub(crate) warnings: u16, } impl ProtocolDecode<'_, Capabilities> for PrepareOk { fn decode_with(buf: Bytes, _: Capabilities) -> Result { const SIZE: usize = 12; let mut slice = buf.get(..SIZE).ok_or_else(|| { err_protocol!("PrepareOk expected 12 bytes but got {} bytes", buf.len()) })?; let status = slice.get_u8(); if status != 0x00 { return Err(err_protocol!( "expected 0x00 (COM_STMT_PREPARE_OK) but found 0x{:02x}", status )); } let statement_id = slice.get_u32_le(); let columns = slice.get_u16_le(); let params = slice.get_u16_le(); slice.advance(1); // reserved: string<1> let warnings = slice.get_u16_le(); Ok(Self { statement_id, columns, params, warnings, }) } } sqlx-mysql-0.8.3/src/protocol/statement/row.rs000064400000000000000000000075271046102023000176120ustar 00000000000000use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{BufExt, ProtocolDecode}; use crate::protocol::text::ColumnType; use crate::protocol::Row; use crate::MySqlColumn; // https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html#packet-ProtocolBinary::ResultsetRow // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html #[derive(Debug)] pub(crate) struct BinaryRow(pub(crate) Row); impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow { fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { let header = buf.get_u8(); if header != 0 { return Err(err_protocol!( "exepcted 0x00 (ROW) but found 0x{:02x}", header )); } let storage = buf.clone(); let offset = buf.len(); let null_bitmap_len = (columns.len() + 9) / 8; let null_bitmap = buf.get_bytes(null_bitmap_len); let mut values = Vec::with_capacity(columns.len()); for (column_idx, column) in columns.iter().enumerate() { // NOTE: the column index starts at the 3rd bit let column_null_idx = column_idx + 2; let byte_idx = column_null_idx / 8; let bit_idx = column_null_idx % 8; let is_null = null_bitmap[byte_idx] & (1u8 << bit_idx) != 0; if is_null { values.push(None); continue; } // NOTE: MySQL will never generate NULL types for non-NULL values let type_info = &column.type_info; // Unlike Postgres, MySQL does not length-prefix every value in a binary row. // Values are *either* fixed-length or length-prefixed, // so we need to inspect the type code to be sure. let size: usize = match type_info.r#type { // All fixed-length types. ColumnType::LongLong => 8, ColumnType::Long | ColumnType::Int24 => 4, ColumnType::Short | ColumnType::Year => 2, ColumnType::Tiny => 1, ColumnType::Float => 4, ColumnType::Double => 8, // Blobs and strings are prefixed with their length, // which is itself a length-encoded integer. ColumnType::String | ColumnType::VarChar | ColumnType::VarString | ColumnType::Enum | ColumnType::Set | ColumnType::LongBlob | ColumnType::MediumBlob | ColumnType::Blob | ColumnType::TinyBlob | ColumnType::Geometry | ColumnType::Bit | ColumnType::Decimal | ColumnType::Json | ColumnType::NewDecimal => { let size = buf.get_uint_lenenc(); usize::try_from(size) .map_err(|_| err_protocol!("BLOB length out of range: {size}"))? } // Like strings and blobs, these values are variable-length. // Unlike strings and blobs, however, they exclusively use one byte for length. ColumnType::Time | ColumnType::Timestamp | ColumnType::Date | ColumnType::Datetime => { // Leave the length byte on the front of the value because decoding uses it. buf[0] as usize + 1 } // NOTE: MySQL will never generate NULL types for non-NULL values ColumnType::Null => unreachable!(), }; let offset = offset - buf.len(); values.push(Some(offset..(offset + size))); buf.advance(size); } Ok(BinaryRow(Row { values, storage })) } } sqlx-mysql-0.8.3/src/protocol/statement/stmt_close.rs000064400000000000000000000007131046102023000211450ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-stmt-close.html #[derive(Debug)] pub struct StmtClose { pub statement: u32, } impl ProtocolEncode<'_, Capabilities> for StmtClose { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x19); // COM_STMT_CLOSE buf.extend(&self.statement.to_le_bytes()); Ok(()) } } sqlx-mysql-0.8.3/src/protocol/text/column.rs000064400000000000000000000201501046102023000172430ustar 00000000000000use std::str::from_utf8; use bitflags::bitflags; use bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::ProtocolDecode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html bitflags! { #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(crate) struct ColumnFlags: u16 { /// Field can't be `NULL`. const NOT_NULL = 1; /// Field is part of a primary key. const PRIMARY_KEY = 2; /// Field is part of a unique key. const UNIQUE_KEY = 4; /// Field is part of a multi-part unique or primary key. const MULTIPLE_KEY = 8; /// Field is a blob. const BLOB = 16; /// Field is unsigned. const UNSIGNED = 32; /// Field is zero filled. const ZEROFILL = 64; /// Field is binary. const BINARY = 128; /// Field is an enumeration. const ENUM = 256; /// Field is an auto-incement field. const AUTO_INCREMENT = 512; /// Field is a timestamp. const TIMESTAMP = 1024; /// Field is a set. const SET = 2048; /// Field does not have a default value. const NO_DEFAULT_VALUE = 4096; /// Field is set to NOW on UPDATE. const ON_UPDATE_NOW = 8192; /// Field is a number. const NUM = 32768; } } // https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type #[derive(Debug, Copy, Clone, PartialEq)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[repr(u8)] pub enum ColumnType { Decimal = 0x00, Tiny = 0x01, Short = 0x02, Long = 0x03, Float = 0x04, Double = 0x05, Null = 0x06, Timestamp = 0x07, LongLong = 0x08, Int24 = 0x09, Date = 0x0a, Time = 0x0b, Datetime = 0x0c, Year = 0x0d, VarChar = 0x0f, Bit = 0x10, Json = 0xf5, NewDecimal = 0xf6, Enum = 0xf7, Set = 0xf8, TinyBlob = 0xf9, MediumBlob = 0xfa, LongBlob = 0xfb, Blob = 0xfc, VarString = 0xfd, String = 0xfe, Geometry = 0xff, } // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html // https://mariadb.com/kb/en/resultset/#column-definition-packet // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 #[derive(Debug)] pub(crate) struct ColumnDefinition { #[allow(unused)] catalog: Bytes, #[allow(unused)] schema: Bytes, #[allow(unused)] table_alias: Bytes, #[allow(unused)] table: Bytes, alias: Bytes, name: Bytes, #[allow(unused)] pub(crate) collation: u16, pub(crate) max_size: u32, pub(crate) r#type: ColumnType, pub(crate) flags: ColumnFlags, #[allow(unused)] decimals: u8, } impl ColumnDefinition { // NOTE: strings in-protocol are transmitted according to the client character set // as this is UTF-8, all these strings should be UTF-8 pub(crate) fn name(&self) -> Result<&str, Error> { from_utf8(&self.name).map_err(Error::protocol) } pub(crate) fn alias(&self) -> Result<&str, Error> { from_utf8(&self.alias).map_err(Error::protocol) } } impl ProtocolDecode<'_, Capabilities> for ColumnDefinition { fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { let catalog = buf.get_bytes_lenenc()?; let schema = buf.get_bytes_lenenc()?; let table_alias = buf.get_bytes_lenenc()?; let table = buf.get_bytes_lenenc()?; let alias = buf.get_bytes_lenenc()?; let name = buf.get_bytes_lenenc()?; let _next_len = buf.get_uint_lenenc(); // always 0x0c let collation = buf.get_u16_le(); let max_size = buf.get_u32_le(); let type_id = buf.get_u8(); let flags = buf.get_u16_le(); let decimals = buf.get_u8(); Ok(Self { catalog, schema, table_alias, table, alias, name, collation, max_size, r#type: ColumnType::try_from_u16(type_id)?, flags: ColumnFlags::from_bits_truncate(flags), decimals, }) } } impl ColumnType { pub(crate) fn name(self, flags: ColumnFlags, max_size: Option) -> &'static str { let is_binary = flags.contains(ColumnFlags::BINARY); let is_unsigned = flags.contains(ColumnFlags::UNSIGNED); let is_enum = flags.contains(ColumnFlags::ENUM); match self { ColumnType::Tiny if max_size == Some(1) => "BOOLEAN", ColumnType::Tiny if is_unsigned => "TINYINT UNSIGNED", ColumnType::Short if is_unsigned => "SMALLINT UNSIGNED", ColumnType::Long if is_unsigned => "INT UNSIGNED", ColumnType::Int24 if is_unsigned => "MEDIUMINT UNSIGNED", ColumnType::LongLong if is_unsigned => "BIGINT UNSIGNED", ColumnType::Tiny => "TINYINT", ColumnType::Short => "SMALLINT", ColumnType::Long => "INT", ColumnType::Int24 => "MEDIUMINT", ColumnType::LongLong => "BIGINT", ColumnType::Float => "FLOAT", ColumnType::Double => "DOUBLE", ColumnType::Null => "NULL", ColumnType::Timestamp => "TIMESTAMP", ColumnType::Date => "DATE", ColumnType::Time => "TIME", ColumnType::Datetime => "DATETIME", ColumnType::Year => "YEAR", ColumnType::Bit => "BIT", ColumnType::Enum => "ENUM", ColumnType::Set => "SET", ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL", ColumnType::Geometry => "GEOMETRY", ColumnType::Json => "JSON", ColumnType::String if is_binary => "BINARY", ColumnType::String if is_enum => "ENUM", ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY", ColumnType::String => "CHAR", ColumnType::VarChar | ColumnType::VarString => "VARCHAR", ColumnType::TinyBlob if is_binary => "TINYBLOB", ColumnType::TinyBlob => "TINYTEXT", ColumnType::Blob if is_binary => "BLOB", ColumnType::Blob => "TEXT", ColumnType::MediumBlob if is_binary => "MEDIUMBLOB", ColumnType::MediumBlob => "MEDIUMTEXT", ColumnType::LongBlob if is_binary => "LONGBLOB", ColumnType::LongBlob => "LONGTEXT", } } pub(crate) fn try_from_u16(id: u8) -> Result { Ok(match id { 0x00 => ColumnType::Decimal, 0x01 => ColumnType::Tiny, 0x02 => ColumnType::Short, 0x03 => ColumnType::Long, 0x04 => ColumnType::Float, 0x05 => ColumnType::Double, 0x06 => ColumnType::Null, 0x07 => ColumnType::Timestamp, 0x08 => ColumnType::LongLong, 0x09 => ColumnType::Int24, 0x0a => ColumnType::Date, 0x0b => ColumnType::Time, 0x0c => ColumnType::Datetime, 0x0d => ColumnType::Year, // [internal] 0x0e => ColumnType::NewDate, 0x0f => ColumnType::VarChar, 0x10 => ColumnType::Bit, // [internal] 0x11 => ColumnType::Timestamp2, // [internal] 0x12 => ColumnType::Datetime2, // [internal] 0x13 => ColumnType::Time2, 0xf5 => ColumnType::Json, 0xf6 => ColumnType::NewDecimal, 0xf7 => ColumnType::Enum, 0xf8 => ColumnType::Set, 0xf9 => ColumnType::TinyBlob, 0xfa => ColumnType::MediumBlob, 0xfb => ColumnType::LongBlob, 0xfc => ColumnType::Blob, 0xfd => ColumnType::VarString, 0xfe => ColumnType::String, 0xff => ColumnType::Geometry, _ => { return Err(err_protocol!("unknown column type 0x{:02x}", id)); } }) } } sqlx-mysql-0.8.3/src/protocol/text/mod.rs000064400000000000000000000003511046102023000165260ustar 00000000000000mod column; mod ping; mod query; mod quit; mod row; pub(crate) use column::{ColumnDefinition, ColumnFlags, ColumnType}; pub(crate) use ping::Ping; pub(crate) use query::Query; pub(crate) use quit::Quit; pub(crate) use row::TextRow; sqlx-mysql-0.8.3/src/protocol/text/ping.rs000064400000000000000000000005561046102023000167130ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-ping.html #[derive(Debug)] pub(crate) struct Ping; impl ProtocolEncode<'_, Capabilities> for Ping { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x0e); // COM_PING Ok(()) } } sqlx-mysql-0.8.3/src/protocol/text/query.rs000064400000000000000000000006651046102023000171240ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-query.html #[derive(Debug)] pub(crate) struct Query<'q>(pub(crate) &'q str); impl ProtocolEncode<'_, Capabilities> for Query<'_> { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x03); // COM_QUERY buf.extend(self.0.as_bytes()); Ok(()) } } sqlx-mysql-0.8.3/src/protocol/text/quit.rs000064400000000000000000000005561046102023000167400ustar 00000000000000use crate::io::ProtocolEncode; use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/com-quit.html #[derive(Debug)] pub(crate) struct Quit; impl ProtocolEncode<'_, Capabilities> for Quit { fn encode_with(&self, buf: &mut Vec, _: Capabilities) -> Result<(), crate::Error> { buf.push(0x01); // COM_QUIT Ok(()) } } sqlx-mysql-0.8.3/src/protocol/text/row.rs000064400000000000000000000021301046102023000165530ustar 00000000000000use bytes::{Buf, Bytes}; use crate::column::MySqlColumn; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::ProtocolDecode; use crate::protocol::Row; #[derive(Debug)] pub(crate) struct TextRow(pub(crate) Row); impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow { fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { let storage = buf.clone(); let offset = buf.len(); let mut values = Vec::with_capacity(columns.len()); for _ in columns { if buf[0] == 0xfb { // NULL is sent as 0xfb values.push(None); buf.advance(1); } else { let size = buf.get_uint_lenenc(); let size = usize::try_from(size) .map_err(|_| err_protocol!("TextRow length out of range: {size}"))?; let offset = offset - buf.len(); values.push(Some(offset..(offset + size))); buf.advance(size); } } Ok(TextRow(Row { values, storage })) } } sqlx-mysql-0.8.3/src/query_result.rs000064400000000000000000000017711046102023000156740ustar 00000000000000use std::iter::{Extend, IntoIterator}; #[derive(Debug, Default)] pub struct MySqlQueryResult { pub(super) rows_affected: u64, pub(super) last_insert_id: u64, } impl MySqlQueryResult { pub fn last_insert_id(&self) -> u64 { self.last_insert_id } pub fn rows_affected(&self) -> u64 { self.rows_affected } } impl Extend for MySqlQueryResult { fn extend>(&mut self, iter: T) { for elem in iter { self.rows_affected += elem.rows_affected; self.last_insert_id = elem.last_insert_id; } } } #[cfg(feature = "any")] /// This conversion attempts to save last_insert_id by converting to i64. impl From for sqlx_core::any::AnyQueryResult { fn from(done: MySqlQueryResult) -> Self { sqlx_core::any::AnyQueryResult { rows_affected: done.rows_affected(), last_insert_id: done.last_insert_id().try_into().ok(), } } } sqlx-mysql-0.8.3/src/row.rs000064400000000000000000000024331046102023000137340ustar 00000000000000use std::sync::Arc; pub(crate) use sqlx_core::row::*; use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::HashMap; use crate::{protocol, MySql, MySqlColumn, MySqlValueFormat, MySqlValueRef}; /// Implementation of [`Row`] for MySQL. #[derive(Debug)] pub struct MySqlRow { pub(crate) row: protocol::Row, pub(crate) format: MySqlValueFormat, pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } impl Row for MySqlRow { type Database = MySql; fn columns(&self) -> &[MySqlColumn] { &self.columns } fn try_get_raw(&self, index: I) -> Result, Error> where I: ColumnIndex, { let index = index.index(self)?; let column = &self.columns[index]; let value = self.row.get(index); Ok(MySqlValueRef { format: self.format, row: Some(&self.row.storage), type_info: column.type_info.clone(), value, }) } } impl ColumnIndex for &'_ str { fn index(&self, row: &MySqlRow) -> Result { row.column_names .get(*self) .ok_or_else(|| Error::ColumnNotFound((*self).into())) .copied() } } sqlx-mysql-0.8.3/src/statement.rs000064400000000000000000000030211046102023000151230ustar 00000000000000use super::MySqlColumn; use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::HashMap; use crate::{MySql, MySqlArguments, MySqlTypeInfo}; use either::Either; use std::borrow::Cow; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; #[derive(Debug, Clone)] pub struct MySqlStatement<'q> { pub(crate) sql: Cow<'q, str>, pub(crate) metadata: MySqlStatementMetadata, } #[derive(Debug, Default, Clone)] pub(crate) struct MySqlStatementMetadata { pub(crate) columns: Arc>, pub(crate) column_names: Arc>, pub(crate) parameters: usize, } impl<'q> Statement<'q> for MySqlStatement<'q> { type Database = MySql; fn to_owned(&self) -> MySqlStatement<'static> { MySqlStatement::<'static> { sql: Cow::Owned(self.sql.clone().into_owned()), metadata: self.metadata.clone(), } } fn sql(&self) -> &str { &self.sql } fn parameters(&self) -> Option> { Some(Either::Right(self.metadata.parameters)) } fn columns(&self) -> &[MySqlColumn] { &self.metadata.columns } impl_statement_query!(MySqlArguments); } impl ColumnIndex> for &'_ str { fn index(&self, statement: &MySqlStatement<'_>) -> Result { statement .metadata .column_names .get(*self) .ok_or_else(|| Error::ColumnNotFound((*self).into())) .copied() } } sqlx-mysql-0.8.3/src/testing/mod.rs000064400000000000000000000172411046102023000153640ustar 00000000000000use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, SystemTime}; use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; use crate::connection::Connection; use crate::error::Error; use crate::executor::Executor; use crate::pool::{Pool, PoolOptions}; use crate::query::query; use crate::query_builder::QueryBuilder; use crate::query_scalar::query_scalar; use crate::{MySql, MySqlConnectOptions, MySqlConnection}; pub(crate) use sqlx_core::testing::*; // Using a blocking `OnceCell` here because the critical sections are short. static MASTER_POOL: OnceCell> = OnceCell::new(); // Automatically delete any databases created before the start of the test binary. static DO_CLEANUP: AtomicBool = AtomicBool::new(true); impl TestSupport for MySql { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { Box::pin(async move { test_context(args).await }) } fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let mut conn = MASTER_POOL .get() .expect("cleanup_test() invoked outside `#[sqlx::test]") .acquire() .await?; let db_id = db_id(db_name); conn.execute(&format!("drop database if exists {db_name};")[..]) .await?; query("delete from _sqlx_test_databases where db_id = ?") .bind(db_id) .execute(&mut *conn) .await?; Ok(()) }) } fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { Box::pin(async move { let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); let mut conn = MySqlConnection::connect(&url).await?; let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); let num_deleted = do_cleanup(&mut conn, now).await?; let _ = conn.close().await; Ok(Some(num_deleted)) }) } fn snapshot( _conn: &mut Self::Connection, ) -> BoxFuture<'_, Result, Error>> { // TODO: I want to get the testing feature out the door so this will have to wait, // but I'm keeping the code around for now because I plan to come back to it. todo!() } } async fn test_context(args: &TestArgs) -> Result, Error> { let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); let master_opts = MySqlConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); let pool = PoolOptions::new() // MySql's normal connection limit is 150 plus 1 superuser connection // We don't want to use the whole cap and there may be fuzziness here due to // concurrently running tests anyway. .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) .connect_lazy_with(master_opts); let master_pool = match MASTER_POOL.try_insert(pool) { Ok(inserted) => inserted, Err((existing, pool)) => { // Sanity checks. assert_eq!( existing.connect_options().host, pool.connect_options().host, "DATABASE_URL changed at runtime, host differs" ); assert_eq!( existing.connect_options().database, pool.connect_options().database, "DATABASE_URL changed at runtime, database differs" ); existing } }; let mut conn = master_pool.acquire().await?; // language=MySQL conn.execute( r#" create table if not exists _sqlx_test_databases ( db_id bigint unsigned primary key auto_increment, test_path text not null, created_at timestamp not null default current_timestamp ); "#, ) .await?; // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This // prevents the first test thread from accidentally deleting new test dbs // created by other test threads if we're a bit slow. let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); // Only run cleanup if the test binary just started. if DO_CLEANUP.swap(false, Ordering::SeqCst) { do_cleanup(&mut conn, now).await?; } query("insert into _sqlx_test_databases(test_path) values (?)") .bind(args.test_path) .execute(&mut *conn) .await?; // MySQL doesn't have `INSERT ... RETURNING` let new_db_id: u64 = query_scalar("select last_insert_id()") .fetch_one(&mut *conn) .await?; let new_db_name = db_name(new_db_id); conn.execute(&format!("create database {new_db_name}")[..]) .await?; eprintln!("created database {new_db_name}"); Ok(TestContext { pool_opts: PoolOptions::new() // Don't allow a single test to take all the connections. // Most tests shouldn't require more than 5 connections concurrently, // or else they're likely doing too much in one test. .max_connections(5) // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), connect_opts: master_pool .connect_options() .deref() .clone() .database(&new_db_name), db_name: new_db_name, }) } async fn do_cleanup(conn: &mut MySqlConnection, created_before: Duration) -> Result { // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads let created_before_as_secs = created_before.as_secs() - 2; let delete_db_ids: Vec = query_scalar( "select db_id from _sqlx_test_databases \ where created_at < from_unixtime(?)", ) .bind(created_before_as_secs) .fetch_all(&mut *conn) .await?; if delete_db_ids.is_empty() { return Ok(0); } let mut deleted_db_ids = Vec::with_capacity(delete_db_ids.len()); let mut command = String::new(); for db_id in delete_db_ids { command.clear(); let db_name = db_name(db_id); writeln!(command, "drop database if exists {db_name}").ok(); match conn.execute(&*command).await { Ok(_deleted) => { deleted_db_ids.push(db_id); } // Assume a database error just means the DB is still in use. Err(Error::Database(dbe)) => { eprintln!("could not clean test database {db_id:?}: {dbe}") } // Bubble up other errors Err(e) => return Err(e), } } let mut query = QueryBuilder::new("delete from _sqlx_test_databases where db_id in ("); let mut separated = query.separated(","); for db_id in &deleted_db_ids { separated.push_bind(db_id); } query.push(")").build().execute(&mut *conn).await?; Ok(deleted_db_ids.len()) } fn db_name(id: u64) -> String { format!("_sqlx_test_database_{id}") } fn db_id(name: &str) -> u64 { name.trim_start_matches("_sqlx_test_database_") .parse() .unwrap_or_else(|_1| panic!("failed to parse ID from database name {name:?}")) } #[test] fn test_db_name_id() { assert_eq!(db_name(12345), "_sqlx_test_database_12345"); assert_eq!(db_id("_sqlx_test_database_12345"), 12345); } sqlx-mysql-0.8.3/src/transaction.rs000064400000000000000000000037231046102023000154550ustar 00000000000000use futures_core::future::BoxFuture; use crate::connection::Waiting; use crate::error::Error; use crate::executor::Executor; use crate::protocol::text::Query; use crate::{MySql, MySqlConnection}; pub(crate) use sqlx_core::transaction::*; /// Implementation of [`TransactionManager`] for MySQL. pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; conn.execute(&*begin_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth + 1; Ok(()) }) } fn commit(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; if depth > 0 { conn.execute(&*commit_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } Ok(()) }) } fn rollback(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; if depth > 0 { conn.execute(&*rollback_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } Ok(()) }) } fn start_rollback(conn: &mut MySqlConnection) { let depth = conn.inner.transaction_depth; if depth > 0 { conn.inner.stream.waiting.push_back(Waiting::Result); conn.inner.stream.sequence_id = 0; conn.inner .stream .write_packet(Query(&rollback_ansi_transaction_sql(depth))) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; } } } sqlx-mysql-0.8.3/src/type_checking.rs000064400000000000000000000032331046102023000157400ustar 00000000000000// Type mappings used by the macros and `Debug` impls. #[allow(unused_imports)] use sqlx_core as sqlx; use crate::MySql; impl_type_checking!( MySql { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, // ordering is important here as otherwise we might infer strings to be binary // CHAR, VAR_CHAR, TEXT String, // BINARY, VAR_BINARY, BLOB Vec, // Types from third-party crates need to be referenced at a known path // for the macros to work, but we don't want to require the user to add extra dependencies. #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveTime, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveDate, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveDateTime, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::DateTime, #[cfg(feature = "time")] sqlx::types::time::Time, #[cfg(feature = "time")] sqlx::types::time::Date, #[cfg(feature = "time")] sqlx::types::time::PrimitiveDateTime, #[cfg(feature = "time")] sqlx::types::time::OffsetDateTime, #[cfg(feature = "bigdecimal")] sqlx::types::BigDecimal, #[cfg(feature = "rust_decimal")] sqlx::types::Decimal, #[cfg(feature = "json")] sqlx::types::JsonValue, }, ParamChecking::Weak, feature-types: info => info.__type_feature_gate(), ); sqlx-mysql-0.8.3/src/type_info.rs000064400000000000000000000065321046102023000151250ustar 00000000000000use std::fmt::{self, Display, Formatter}; pub(crate) use sqlx_core::type_info::*; use crate::protocol::text::{ColumnDefinition, ColumnFlags, ColumnType}; /// Type information for a MySql type. #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct MySqlTypeInfo { pub(crate) r#type: ColumnType, pub(crate) flags: ColumnFlags, // [max_size] for integer types, this is (M) in BIT(M) or TINYINT(M) #[cfg_attr(feature = "offline", serde(default))] pub(crate) max_size: Option, } impl MySqlTypeInfo { pub(crate) const fn binary(ty: ColumnType) -> Self { Self { r#type: ty, flags: ColumnFlags::BINARY, max_size: None, } } #[doc(hidden)] pub const fn __enum() -> Self { // Newer versions of MySQL seem to expect that a parameter binding of `MYSQL_TYPE_ENUM` // means that the value is encoded as an integer. // // For "strong" enums inputted as strings, we need to specify this type instead // for wider compatibility. This works on all covered versions of MySQL and MariaDB. // // Annoyingly, MySQL's developer documentation doesn't really explain this anywhere; // this had to be determined experimentally. Self { r#type: ColumnType::String, flags: ColumnFlags::ENUM, max_size: None, } } #[doc(hidden)] pub fn __type_feature_gate(&self) -> Option<&'static str> { match self.r#type { ColumnType::Date | ColumnType::Time | ColumnType::Timestamp | ColumnType::Datetime => { Some("time") } ColumnType::Json => Some("json"), ColumnType::NewDecimal => Some("bigdecimal"), _ => None, } } pub(crate) fn from_column(column: &ColumnDefinition) -> Self { Self { r#type: column.r#type, flags: column.flags, max_size: Some(column.max_size), } } } impl Display for MySqlTypeInfo { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.pad(self.name()) } } impl TypeInfo for MySqlTypeInfo { fn is_null(&self) -> bool { matches!(self.r#type, ColumnType::Null) } fn name(&self) -> &str { self.r#type.name(self.flags, self.max_size) } } impl PartialEq for MySqlTypeInfo { fn eq(&self, other: &MySqlTypeInfo) -> bool { if self.r#type != other.r#type { return false; } match self.r#type { ColumnType::Tiny | ColumnType::Short | ColumnType::Long | ColumnType::Int24 | ColumnType::LongLong => { return self.flags.contains(ColumnFlags::UNSIGNED) == other.flags.contains(ColumnFlags::UNSIGNED); } // for string types, check that our charset matches ColumnType::VarChar | ColumnType::Blob | ColumnType::TinyBlob | ColumnType::MediumBlob | ColumnType::LongBlob | ColumnType::String | ColumnType::VarString | ColumnType::Enum => { return self.flags == other.flags; } _ => {} } true } } impl Eq for MySqlTypeInfo {} sqlx-mysql-0.8.3/src/types/bigdecimal.rs000064400000000000000000000015711046102023000163530ustar 00000000000000use bigdecimal::BigDecimal; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::protocol::text::ColumnType; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for BigDecimal { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::NewDecimal) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Decimal | ColumnType::NewDecimal) } } impl Encode<'_, MySql> for BigDecimal { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for BigDecimal { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(value.as_str()?.parse()?) } } sqlx-mysql-0.8.3/src/types/bool.rs000064400000000000000000000022741046102023000152270ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{ protocol::text::{ColumnFlags, ColumnType}, MySql, MySqlTypeInfo, MySqlValueRef, }; impl Type for bool { fn type_info() -> MySqlTypeInfo { // MySQL has no actual `BOOLEAN` type, the type is an alias of `TINYINT(1)` MySqlTypeInfo { flags: ColumnFlags::BINARY | ColumnFlags::UNSIGNED, max_size: Some(1), r#type: ColumnType::Tiny, } } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::Tiny | ColumnType::Short | ColumnType::Long | ColumnType::Int24 | ColumnType::LongLong | ColumnType::Bit ) } } impl Encode<'_, MySql> for bool { fn encode_by_ref(&self, buf: &mut Vec) -> Result { >::encode(*self as i8, buf) } } impl Decode<'_, MySql> for bool { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(>::decode(value)? != 0) } } sqlx-mysql-0.8.3/src/types/bytes.rs000064400000000000000000000043651046102023000154250ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::protocol::text::ColumnType; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for [u8] { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Blob) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::VarChar | ColumnType::Blob | ColumnType::TinyBlob | ColumnType::MediumBlob | ColumnType::LongBlob | ColumnType::String | ColumnType::VarString | ColumnType::Enum ) } } impl Encode<'_, MySql> for &'_ [u8] { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_bytes_lenenc(self); Ok(IsNull::No) } } impl<'r> Decode<'r, MySql> for &'r [u8] { fn decode(value: MySqlValueRef<'r>) -> Result { value.as_bytes() } } impl Type for Box<[u8]> { fn type_info() -> MySqlTypeInfo { <&[u8] as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&[u8] as Type>::compatible(ty) } } impl Encode<'_, MySql> for Box<[u8]> { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(self.as_ref(), buf) } } impl<'r> Decode<'r, MySql> for Box<[u8]> { fn decode(value: MySqlValueRef<'r>) -> Result { <&[u8] as Decode>::decode(value).map(Box::from) } } impl Type for Vec { fn type_info() -> MySqlTypeInfo { <[u8] as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&[u8] as Type>::compatible(ty) } } impl Encode<'_, MySql> for Vec { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(&**self, buf) } } impl Decode<'_, MySql> for Vec { fn decode(value: MySqlValueRef<'_>) -> Result { <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) } } sqlx-mysql-0.8.3/src/types/chrono.rs000064400000000000000000000251641046102023000155670ustar 00000000000000use bytes::Buf; use chrono::{ DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc, }; use sqlx_core::database::Database; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::protocol::text::ColumnType; use crate::type_info::MySqlTypeInfo; use crate::types::{MySqlTime, MySqlTimeSign, Type}; use crate::{MySql, MySqlValueFormat, MySqlValueRef}; impl Type for DateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Timestamp) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) } } /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { Encode::::encode(self.naive_utc(), buf) } } /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl<'r> Decode<'r, MySql> for DateTime { fn decode(value: MySqlValueRef<'r>) -> Result { let naive: NaiveDateTime = Decode::::decode(value)?; Ok(Utc.from_utc_datetime(&naive)) } } impl Type for DateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Timestamp) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) } } /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { Encode::::encode(self.naive_utc(), buf) } } /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl<'r> Decode<'r, MySql> for DateTime { fn decode(value: MySqlValueRef<'r>) -> Result { Ok( as Decode<'r, MySql>>::decode(value)?.with_timezone(&Local)) } } impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { MySqlTime::type_info() } } impl Encode<'_, MySql> for NaiveTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = naive_time_encoded_len(self); buf.push(len); // NaiveTime is not negative buf.push(0); // Number of days in the interval; always 0 for time-of-day values. // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); encode_time(self, len > 8, buf); Ok(IsNull::No) } fn size_hint(&self) -> usize { naive_time_encoded_len(self) as usize + 1 // plus length byte } } /// Decode from a `TIME` value. /// /// ### Errors /// Returns an error if the `TIME` value is negative or exceeds `23:59:59.999999`. impl<'r> Decode<'r, MySql> for NaiveTime { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { // Covers most possible failure modes. MySqlTime::decode(value)?.try_into() } // Retaining this parsing for now as it allows us to cross-check our impl. MySqlValueFormat::Text => { let s = value.as_str()?; NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into) } } } } impl TryFrom for NaiveTime { type Error = BoxDynError; fn try_from(time: MySqlTime) -> Result { NaiveTime::from_hms_micro_opt( time.hours(), time.minutes() as u32, time.seconds() as u32, time.microseconds(), ) .ok_or_else(|| format!("Cannot convert `MySqlTime` value to `NaiveTime`: {time}").into()) } } impl From for chrono::TimeDelta { fn from(time: MySqlTime) -> Self { chrono::TimeDelta::new(time.whole_seconds_signed(), time.subsec_nanos()) .expect("BUG: chrono::TimeDelta should have a greater range than MySqlTime") } } impl TryFrom for MySqlTime { type Error = BoxDynError; fn try_from(value: chrono::TimeDelta) -> Result { let sign = if value < chrono::TimeDelta::zero() { MySqlTimeSign::Negative } else { MySqlTimeSign::Positive }; Ok( // `std::time::Duration` has a greater positive range than `TimeDelta` // which makes it a great intermediate if you ignore the sign. MySqlTime::try_from(value.abs().to_std()?)?.with_sign(sign), ) } } impl Type for chrono::TimeDelta { fn type_info() -> MySqlTypeInfo { MySqlTime::type_info() } } impl<'r> Decode<'r, MySql> for chrono::TimeDelta { fn decode(value: ::ValueRef<'r>) -> Result { Ok(MySqlTime::decode(value)?.into()) } } impl Type for NaiveDate { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Date) } } impl Encode<'_, MySql> for NaiveDate { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(4); encode_date(self, buf)?; Ok(IsNull::No) } fn size_hint(&self) -> usize { 5 } } impl<'r> Decode<'r, MySql> for NaiveDate { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; // Row decoding should have left the length prefix. if buf.is_empty() { return Err("empty buffer".into()); } decode_date(&buf[1..])?.ok_or_else(|| UnexpectedNullError.into()) } MySqlValueFormat::Text => { let s = value.as_str()?; NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into) } } } } impl Type for NaiveDateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Datetime) } } impl Encode<'_, MySql> for NaiveDateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = naive_dt_encoded_len(self); buf.push(len); encode_date(&self.date(), buf)?; if len > 4 { encode_time(&self.time(), len > 7, buf); } Ok(IsNull::No) } fn size_hint(&self) -> usize { naive_dt_encoded_len(self) as usize + 1 // plus length byte } } impl<'r> Decode<'r, MySql> for NaiveDateTime { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; if buf.is_empty() { return Err("empty buffer".into()); } let len = buf[0]; let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?; let dt = if len > 4 { date.and_time(decode_time(len - 4, &buf[5..])?) } else { date.and_hms_opt(0, 0, 0) .expect("expected `NaiveDate::and_hms_opt(0, 0, 0)` to be valid") }; Ok(dt) } MySqlValueFormat::Text => { let s = value.as_str()?; NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into) } } } } fn encode_date(date: &NaiveDate, buf: &mut Vec) -> Result<(), BoxDynError> { // MySQL supports years 1000 - 9999 let year = u16::try_from(date.year()) .map_err(|_| format!("NaiveDateTime out of range for Mysql: {date}"))?; buf.extend_from_slice(&year.to_le_bytes()); // `NaiveDate` guarantees the ranges of these values #[allow(clippy::cast_possible_truncation)] { buf.push(date.month() as u8); buf.push(date.day() as u8); } Ok(()) } fn decode_date(mut buf: &[u8]) -> Result, BoxDynError> { match buf.len() { // MySQL specifies that if there are no bytes, this is all zeros 0 => Ok(None), 4.. => { let year = buf.get_u16_le() as i32; let month = buf[0] as u32; let day = buf[1] as u32; let date = NaiveDate::from_ymd_opt(year, month, day) .ok_or_else(|| format!("server returned invalid date: {year}/{month}/{day}"))?; Ok(Some(date)) } len => Err(format!("expected at least 4 bytes for date, got {len}").into()), } } fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { // `NaiveTime` API guarantees the ranges of these values #[allow(clippy::cast_possible_truncation)] { buf.push(time.hour() as u8); buf.push(time.minute() as u8); buf.push(time.second() as u8); } if include_micros { buf.extend((time.nanosecond() / 1000).to_le_bytes()); } } fn decode_time(len: u8, mut buf: &[u8]) -> Result { let hour = buf.get_u8(); let minute = buf.get_u8(); let seconds = buf.get_u8(); let micros = if len > 3 { // microseconds : int buf.get_uint_le(buf.len()) } else { 0 }; let micros = u32::try_from(micros) .map_err(|_| format!("server returned microseconds out of range: {micros}"))?; NaiveTime::from_hms_micro_opt(hour as u32, minute as u32, seconds as u32, micros) .ok_or_else(|| format!("server returned invalid time: {hour:02}:{minute:02}:{seconds:02}; micros: {micros}").into()) } #[inline(always)] fn naive_dt_encoded_len(time: &NaiveDateTime) -> u8 { // to save space the packet can be compressed: match ( time.hour(), time.minute(), time.second(), #[allow(deprecated)] time.timestamp_subsec_nanos(), ) { // if hour, minutes, seconds and micro_seconds are all 0, // length is 4 and no other field is sent (0, 0, 0, 0) => 4, // if micro_seconds is 0, length is 7 // and micro_seconds is not sent (_, _, _, 0) => 7, // otherwise length is 11 (_, _, _, _) => 11, } } #[inline(always)] fn naive_time_encoded_len(time: &NaiveTime) -> u8 { if time.nanosecond() == 0 { // if micro_seconds is 0, length is 8 and micro_seconds is not sent 8 } else { // otherwise length is 12 12 } } sqlx-mysql-0.8.3/src/types/float.rs000064400000000000000000000072621046102023000154030ustar 00000000000000use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::protocol::text::ColumnType; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; fn real_compatible(ty: &MySqlTypeInfo) -> bool { // NOTE: `DECIMAL` is explicitly excluded because floating-point numbers have different semantics. matches!(ty.r#type, ColumnType::Float | ColumnType::Double) } impl Type for f32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Float) } fn compatible(ty: &MySqlTypeInfo) -> bool { real_compatible(ty) } } impl Type for f64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Double) } fn compatible(ty: &MySqlTypeInfo) -> bool { real_compatible(ty) } } impl Encode<'_, MySql> for f32 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for f64 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Decode<'_, MySql> for f32 { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; match buf.len() { // These functions panic if `buf` is not exactly the right size. 4 => LittleEndian::read_f32(buf), // MySQL can return 8-byte DOUBLE values for a FLOAT // We take and truncate to f32 as that's the same behavior as *in* MySQL, #[allow(clippy::cast_possible_truncation)] 8 => LittleEndian::read_f64(buf) as f32, other => { // Users may try to decode a DECIMAL as floating point; // inform them why that's a bad idea. return Err(format!( "expected a FLOAT as 4 or 8 bytes, got {other} bytes; \ note that decoding DECIMAL as `f32` is not supported \ due to differing semantics" ) .into()); } } } MySqlValueFormat::Text => value.as_str()?.parse()?, }) } } impl Decode<'_, MySql> for f64 { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; // The `read_*` functions panic if `buf` is not exactly the right size. match buf.len() { // Allow implicit widening here 4 => LittleEndian::read_f32(buf) as f64, 8 => LittleEndian::read_f64(buf), other => { // Users may try to decode a DECIMAL as floating point; // inform them why that's a bad idea. return Err(format!( "expected a DOUBLE as 4 or 8 bytes, got {other} bytes; \ note that decoding DECIMAL as `f64` is not supported \ due to differing semantics" ) .into()); } } } MySqlValueFormat::Text => value.as_str()?.parse()?, }) } } sqlx-mysql-0.8.3/src/types/inet.rs000064400000000000000000000046341046102023000152350ustar 00000000000000use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for Ipv4Addr { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Ipv4Addr { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Ipv4Addr { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &str type to decode from MySQL let text = <&str as Decode>::decode(value)?; // parse a Ipv4Addr from the text text.parse().map_err(Into::into) } } impl Type for Ipv6Addr { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Ipv6Addr { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Ipv6Addr { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &str type to decode from MySQL let text = <&str as Decode>::decode(value)?; // parse a Ipv6Addr from the text text.parse().map_err(Into::into) } } impl Type for IpAddr { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for IpAddr { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for IpAddr { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &str type to decode from MySQL let text = <&str as Decode>::decode(value)?; // parse a IpAddr from the text text.parse().map_err(Into::into) } } sqlx-mysql-0.8.3/src/types/int.rs000064400000000000000000000066641046102023000150750ustar 00000000000000use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::protocol::text::{ColumnFlags, ColumnType}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; fn int_compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::Tiny | ColumnType::Short | ColumnType::Long | ColumnType::Int24 | ColumnType::LongLong ) && !ty.flags.contains(ColumnFlags::UNSIGNED) } impl Type for i8 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Tiny) } fn compatible(ty: &MySqlTypeInfo) -> bool { int_compatible(ty) } } impl Type for i16 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Short) } fn compatible(ty: &MySqlTypeInfo) -> bool { int_compatible(ty) } } impl Type for i32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Long) } fn compatible(ty: &MySqlTypeInfo) -> bool { int_compatible(ty) } } impl Type for i64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::LongLong) } fn compatible(ty: &MySqlTypeInfo) -> bool { int_compatible(ty) } } impl Encode<'_, MySql> for i8 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for i16 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for i32 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for i64 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } fn int_decode(value: MySqlValueRef<'_>) -> Result { Ok(match value.format() { MySqlValueFormat::Text => value.as_str()?.parse()?, MySqlValueFormat::Binary => { let buf = value.as_bytes()?; // Check conditions that could cause `read_int()` to panic. if buf.is_empty() { return Err("empty buffer".into()); } if buf.len() > 8 { return Err(format!( "expected no more than 8 bytes for integer value, got {}", buf.len() ) .into()); } LittleEndian::read_int(buf, buf.len()) } }) } impl Decode<'_, MySql> for i8 { fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i16 { fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i32 { fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for i64 { fn decode(value: MySqlValueRef<'_>) -> Result { int_decode(value) } } sqlx-mysql-0.8.3/src/types/json.rs000064400000000000000000000051421046102023000152420ustar 00000000000000use serde::{Deserialize, Serialize}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::protocol::text::ColumnType; use crate::types::{Json, Type}; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for Json { fn type_info() -> MySqlTypeInfo { // MySql uses the `CHAR` type to pass JSON data from and to the client // NOTE: This is forwards-compatible with MySQL v8+ as CHAR is a common transmission format // and has nothing to do with the native storage ability of MySQL v8+ MySqlTypeInfo::binary(ColumnType::String) } fn compatible(ty: &MySqlTypeInfo) -> bool { ty.r#type == ColumnType::Json || <&str as Type>::compatible(ty) || <&[u8] as Type>::compatible(ty) } } impl Encode<'_, MySql> for Json where T: Serialize, { fn encode_by_ref(&self, buf: &mut Vec) -> Result { // Encode JSON as a length-prefixed string. // // The previous implementation encoded into an intermediate buffer to get the final length. // This is because the length prefix for the string is itself length-encoded, so we have // to know the length first before we can start encoding in the buffer... or do we? // // The docs suggest that the integer length-encoding doesn't actually enforce a range on // the value itself as long as it fits in the chosen encoding, so why not just choose // the full length encoding to begin with? Then we can just reserve the space up-front // and encode directly into the buffer. // // If someone is storing a JSON value it's likely large enough that the overhead of using // the full-length integer encoding doesn't really matter. And if it's so large it overflows // a `u64` then the process is likely to run OOM during the encoding process first anyway. let lenenc_start = buf.len(); buf.extend_from_slice(&[0u8; 9]); let encode_start = buf.len(); self.encode_to(buf)?; let encoded_len = (buf.len() - encode_start) as u64; // This prefix indicates that the following 8 bytes are a little-endian integer. buf[lenenc_start] = 0xFE; buf[lenenc_start + 1..][..8].copy_from_slice(&encoded_len.to_le_bytes()); Ok(IsNull::No) } } impl<'r, T> Decode<'r, MySql> for Json where T: 'r + Deserialize<'r>, { fn decode(value: MySqlValueRef<'r>) -> Result { Json::decode_from_string(value.as_str()?) } } sqlx-mysql-0.8.3/src/types/mod.rs000064400000000000000000000237601046102023000150560ustar 00000000000000//! Conversions between Rust and **MySQL/MariaDB** types. //! //! # Types //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bool` | TINYINT(1), BOOLEAN, BOOL (see below) | //! | `i8` | TINYINT | //! | `i16` | SMALLINT | //! | `i32` | INT | //! | `i64` | BIGINT | //! | `u8` | TINYINT UNSIGNED | //! | `u16` | SMALLINT UNSIGNED | //! | `u32` | INT UNSIGNED | //! | `u64` | BIGINT UNSIGNED | //! | `f32` | FLOAT | //! | `f64` | DOUBLE | //! | `&str`, [`String`] | VARCHAR, CHAR, TEXT | //! | `&[u8]`, `Vec` | VARBINARY, BINARY, BLOB | //! | `IpAddr` | VARCHAR, TEXT | //! | `Ipv4Addr` | INET4 (MariaDB-only), VARCHAR, TEXT | //! | `Ipv6Addr` | INET6 (MariaDB-only), VARCHAR, TEXT | //! | [`MySqlTime`] | TIME (encode and decode full range) | //! | [`Duration`][std::time::Duration] | TIME (for decoding positive values only) | //! //! ##### Note: `BOOLEAN`/`BOOL` Type //! MySQL and MariaDB treat `BOOLEAN` as an alias of the `TINYINT` type: //! //! * [Using Data Types from Other Database Engines (MySQL)](https://dev.mysql.com/doc/refman/8.0/en/other-vendor-data-types.html) //! * [BOOLEAN (MariaDB)](https://mariadb.com/kb/en/boolean/) //! //! For the most part, you can simply use the Rust type `bool` when encoding or decoding a value //! using the dynamic query interface, or passing a boolean as a parameter to the query macros //! (`query!()` _et al._). //! //! However, because the MySQL wire protocol does not distinguish between `TINYINT` and `BOOLEAN`, //! the query macros cannot know that a `TINYINT` column is semantically a boolean. //! By default, they will map a `TINYINT` column as `i8` instead, as that is the safer assumption. //! //! Thus, you must use the type override syntax in the query to tell the macros you are expecting //! a `bool` column. See the docs for `query!()` and `query_as!()` for details on this syntax. //! //! ### NOTE: MySQL's `TIME` type is signed //! MySQL's `TIME` type can be used as either a time-of-day value, or a signed interval. //! Thus, it may take on negative values. //! //! Decoding a [`std::time::Duration`] returns an error if the `TIME` value is negative. //! //! ### [`chrono`](https://crates.io/crates/chrono) //! //! Requires the `chrono` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `chrono::DateTime` | TIMESTAMP | //! | `chrono::DateTime` | TIMESTAMP | //! | `chrono::NaiveDateTime` | DATETIME | //! | `chrono::NaiveDate` | DATE | //! | `chrono::NaiveTime` | TIME (time-of-day only) | //! | `chrono::TimeDelta` | TIME (decodes full range; see note for encoding) | //! //! ### NOTE: MySQL's `TIME` type is dual-purpose //! MySQL's `TIME` type can be used as either a time-of-day value, or an interval. //! However, `chrono::NaiveTime` is designed only to represent a time-of-day. //! //! Decoding a `TIME` value as `chrono::NaiveTime` will return an error if the value is out of range. //! //! The [`MySqlTime`] type supports the full range and it also implements `TryInto`. //! //! Decoding a `chrono::TimeDelta` also supports the full range. //! //! To encode a `chrono::TimeDelta`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`. //! //! ### [`time`](https://crates.io/crates/time) //! //! Requires the `time` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `time::PrimitiveDateTime` | DATETIME | //! | `time::OffsetDateTime` | TIMESTAMP | //! | `time::Date` | DATE | //! | `time::Time` | TIME (time-of-day only) | //! | `time::Duration` | TIME (decodes full range; see note for encoding) | //! //! ### NOTE: MySQL's `TIME` type is dual-purpose //! MySQL's `TIME` type can be used as either a time-of-day value, or an interval. //! However, `time::Time` is designed only to represent a time-of-day. //! //! Decoding a `TIME` value as `time::Time` will return an error if the value is out of range. //! //! The [`MySqlTime`] type supports the full range, and it also implements `TryInto`. //! //! Decoding a `time::Duration` also supports the full range. //! //! To encode a `time::Duration`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`. //! //! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) //! Requires the `bigdecimal` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bigdecimal::BigDecimal` | DECIMAL | //! //! ### [`decimal`](https://crates.io/crates/rust_decimal) //! Requires the `decimal` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `rust_decimal::Decimal` | DECIMAL | //! //! ### [`uuid`](https://crates.io/crates/uuid) //! //! Requires the `uuid` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `uuid::Uuid` | BINARY(16) (see note) | //! | `uuid::fmt::Hyphenated` | CHAR(36), VARCHAR, TEXT, UUID (MariaDB-only) | //! | `uuid::fmt::Simple` | CHAR(32), VARCHAR, TEXT | //! //! #### Note: `Uuid` uses binary format //! //! MySQL does not have a native datatype for UUIDs. //! The `UUID()` function returns a 36-character `TEXT` value, //! which encourages storing UUIDs as text. //! //! MariaDB's `UUID` type stores and retrieves as text, though it has a better representation //! for index sorting (see [MariaDB manual: UUID data-type][mariadb-uuid] for details). //! //! As an opinionated library, SQLx chose to map `uuid::Uuid` to/from binary format by default //! (16 bytes, the raw value of a UUID; SQL type `BINARY(16)`). //! This saves 20 bytes over the text format for each value. //! //! The `impl Decode for Uuid` does not support the text format, and will return an error. //! //! If you want to use the text format compatible with the `UUID()` function, //! use [`uuid::fmt::Hyphenated`][::uuid::fmt::Hyphenated] in the place of `Uuid`. //! //! The MySQL official blog has an article showing how to support both binary and text format UUIDs //! by storing the binary and adding a generated column for the text format, though this is rather //! verbose and fiddly: //! //! [mariadb-uuid]: https://mariadb.com/kb/en/uuid-data-type/ //! //! ### [`json`](https://crates.io/crates/serde_json) //! //! Requires the `json` Cargo feature flag. //! //! | Rust type | MySQL/MariaDB type(s) | //! |---------------------------------------|------------------------------------------------------| //! | [`Json`] | JSON | //! | `serde_json::JsonValue` | JSON | //! | `&serde_json::value::RawValue` | JSON | //! //! # Nullable //! //! In addition, `Option` is supported where `T` implements `Type`. An `Option` represents //! a potentially `NULL` value from MySQL/MariaDB. pub(crate) use sqlx_core::types::*; pub use mysql_time::{MySqlTime, MySqlTimeError, MySqlTimeSign}; mod bool; mod bytes; mod float; mod inet; mod int; mod mysql_time; mod str; mod text; mod uint; #[cfg(feature = "json")] mod json; #[cfg(feature = "bigdecimal")] mod bigdecimal; #[cfg(feature = "rust_decimal")] mod rust_decimal; #[cfg(feature = "chrono")] mod chrono; #[cfg(feature = "time")] mod time; #[cfg(feature = "uuid")] mod uuid; sqlx-mysql-0.8.3/src/types/mysql_time.rs000064400000000000000000000555201046102023000164610ustar 00000000000000//! The [`MysqlTime`] type. use crate::protocol::text::ColumnType; use crate::{MySql, MySqlTypeInfo, MySqlValueFormat}; use bytes::{Buf, BufMut}; use sqlx_core::database::Database; use sqlx_core::decode::Decode; use sqlx_core::encode::{Encode, IsNull}; use sqlx_core::error::BoxDynError; use sqlx_core::types::Type; use std::cmp::Ordering; use std::fmt::{Debug, Display, Formatter, Write}; use std::time::Duration; // Similar to `PgInterval` /// Container for a MySQL `TIME` value, which may be an interval or a time-of-day. /// /// Allowed range is `-838:59:59.0` to `838:59:59.0`. /// /// If this value is used for a time-of-day, the range should be `00:00:00.0` to `23:59:59.999999`. /// You can use [`Self::is_valid_time_of_day()`] to check this easily. /// /// * [MySQL Manual 13.2.3: The TIME Type](https://dev.mysql.com/doc/refman/8.3/en/time.html) /// * [MariaDB Manual: TIME](https://mariadb.com/kb/en/time/) #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct MySqlTime { pub(crate) sign: MySqlTimeSign, pub(crate) magnitude: TimeMagnitude, } // By using a subcontainer for the actual time magnitude, // we can still use a derived `Ord` implementation and just flip the comparison for negative values. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub(crate) struct TimeMagnitude { pub(crate) hours: u32, pub(crate) minutes: u8, pub(crate) seconds: u8, pub(crate) microseconds: u32, } const MAGNITUDE_ZERO: TimeMagnitude = TimeMagnitude { hours: 0, minutes: 0, seconds: 0, microseconds: 0, }; /// Maximum magnitude (positive or negative). const MAGNITUDE_MAX: TimeMagnitude = TimeMagnitude { hours: MySqlTime::HOURS_MAX, minutes: 59, seconds: 59, // Surprisingly this is not 999_999 which is why `MySqlTimeError::SubsecondExcess`. microseconds: 0, }; /// The sign for a [`MySqlTime`] type. #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] pub enum MySqlTimeSign { // The protocol actually specifies negative as 1 and positive as 0, // but by specifying variants this way we can derive `Ord` and it works as expected. /// The interval is negative (invalid for time-of-day values). Negative, /// The interval is positive, or represents a time-of-day. Positive, } /// Errors returned by [`MySqlTime::new()`]. #[derive(Debug, thiserror::Error)] pub enum MySqlTimeError { /// A field of [`MySqlTime`] exceeded its max range. #[error("`MySqlTime` field `{field}` cannot exceed {max}, got {value}")] FieldRange { field: &'static str, max: u32, value: u64, }, /// Error returned for time magnitudes (positive or negative) between `838:59:59.0` and `839:00:00.0`. /// /// Other range errors should be covered by [`Self::FieldRange`] for the `hours` field. /// /// For applications which can tolerate rounding, a valid truncated value is provided. #[error( "`MySqlTime` cannot exceed +/-838:59:59.000000; got {sign}838:59:59.{microseconds:06}" )] SubsecondExcess { /// The sign of the magnitude. sign: MySqlTimeSign, /// The number of microseconds over the maximum. microseconds: u32, /// The truncated value, /// either [`MySqlTime::MIN`] if negative or [`MySqlTime::MAX`] if positive. truncated: MySqlTime, }, /// MySQL coerces `-00:00:00` to `00:00:00` but this API considers that an error. /// /// For applications which can tolerate coercion, you can convert this error to [`MySqlTime::ZERO`]. #[error("attempted to construct a `MySqlTime` value of negative zero")] NegativeZero, } impl MySqlTime { /// The `MySqlTime` value corresponding to `TIME '0:00:00.0'` (zero). pub const ZERO: Self = MySqlTime { sign: MySqlTimeSign::Positive, magnitude: MAGNITUDE_ZERO, }; /// The `MySqlTime` value corresponding to `TIME '838:59:59.0'` (max value). pub const MAX: Self = MySqlTime { sign: MySqlTimeSign::Positive, magnitude: MAGNITUDE_MAX, }; /// The `MySqlTime` value corresponding to `TIME '-838:59:59.0'` (min value). pub const MIN: Self = MySqlTime { sign: MySqlTimeSign::Negative, // Same magnitude, opposite sign. magnitude: MAGNITUDE_MAX, }; // The maximums for the other values are self-evident, but not necessarily this one. pub(crate) const HOURS_MAX: u32 = 838; /// Construct a [`MySqlTime`] that is valid for use as a `TIME` value. /// /// ### Errors /// * [`MySqlTimeError::NegativeZero`] if all fields are 0 but `sign` is [`MySqlTimeSign::Negative`]. /// * [`MySqlTimeError::FieldRange`] if any field is out of range: /// * `hours > 838` /// * `minutes > 59` /// * `seconds > 59` /// * `microseconds > 999_999` /// * [`MySqlTimeError::SubsecondExcess`] if the magnitude is less than one second over the maximum. /// * Durations 839 hours or greater are covered by `FieldRange`. pub fn new( sign: MySqlTimeSign, hours: u32, minutes: u8, seconds: u8, microseconds: u32, ) -> Result { macro_rules! check_fields { ($($name:ident: $max:expr),+ $(,)?) => { $( if $name > $max { return Err(MySqlTimeError::FieldRange { field: stringify!($name), max: $max as u32, value: $name as u64 }) } )+ } } check_fields!( hours: Self::HOURS_MAX, minutes: 59, seconds: 59, microseconds: 999_999 ); let values = TimeMagnitude { hours, minutes, seconds, microseconds, }; if sign.is_negative() && values == MAGNITUDE_ZERO { return Err(MySqlTimeError::NegativeZero); } // This is only `true` if less than 1 second over the maximum magnitude if values > MAGNITUDE_MAX { return Err(MySqlTimeError::SubsecondExcess { sign, microseconds, truncated: if sign.is_positive() { Self::MAX } else { Self::MIN }, }); } Ok(Self { sign, magnitude: values, }) } /// Update the `sign` of this value. pub fn with_sign(self, sign: MySqlTimeSign) -> Self { Self { sign, ..self } } /// Return the sign (positive or negative) for this TIME value. pub fn sign(&self) -> MySqlTimeSign { self.sign } /// Returns `true` if `self` is zero (equal to [`Self::ZERO`]). pub fn is_zero(&self) -> bool { self == &Self::ZERO } /// Returns `true` if `self` is positive or zero, `false` if negative. pub fn is_positive(&self) -> bool { self.sign.is_positive() } /// Returns `true` if `self` is negative, `false` if positive or zero. pub fn is_negative(&self) -> bool { self.sign.is_positive() } /// Returns `true` if this interval is a valid time-of-day. /// /// If `true`, the sign is positive and `hours` is not greater than 23. pub fn is_valid_time_of_day(&self) -> bool { self.sign.is_positive() && self.hours() < 24 } /// Get the total number of hours in this interval, from 0 to 838. /// /// If this value represents a time-of-day, the range is 0 to 23. pub fn hours(&self) -> u32 { self.magnitude.hours } /// Get the number of minutes in this interval, from 0 to 59. pub fn minutes(&self) -> u8 { self.magnitude.minutes } /// Get the number of seconds in this interval, from 0 to 59. pub fn seconds(&self) -> u8 { self.magnitude.seconds } /// Get the number of seconds in this interval, from 0 to 999,999. pub fn microseconds(&self) -> u32 { self.magnitude.microseconds } /// Convert this TIME value to a [`std::time::Duration`]. /// /// Returns `None` if this value is negative (cannot be represented). pub fn to_duration(&self) -> Option { self.is_positive() .then(|| Duration::new(self.whole_seconds() as u64, self.subsec_nanos())) } /// Get the whole number of seconds (`seconds + (minutes * 60) + (hours * 3600)`) in this time. /// /// Sign is ignored. pub(crate) fn whole_seconds(&self) -> u32 { // If `hours` does not exceed 838 then this cannot overflow. self.hours() * 3600 + self.minutes() as u32 * 60 + self.seconds() as u32 } #[cfg_attr(not(any(feature = "time", feature = "chrono")), allow(dead_code))] pub(crate) fn whole_seconds_signed(&self) -> i64 { self.whole_seconds() as i64 * self.sign.signum() as i64 } pub(crate) fn subsec_nanos(&self) -> u32 { self.microseconds() * 1000 } fn encoded_len(&self) -> u8 { if self.is_zero() { 0 } else if self.microseconds() == 0 { 8 } else { 12 } } } impl PartialOrd for MySqlTime { fn partial_cmp(&self, other: &MySqlTime) -> Option { Some(self.cmp(other)) } } impl Ord for MySqlTime { fn cmp(&self, other: &Self) -> Ordering { // If the sides have different signs, we just need to compare those. if self.sign != other.sign { return self.sign.cmp(&other.sign); } // We've checked that both sides have the same sign match self.sign { MySqlTimeSign::Positive => self.magnitude.cmp(&other.magnitude), // Reverse the comparison for negative values (smaller negative magnitude = greater) MySqlTimeSign::Negative => other.magnitude.cmp(&self.magnitude), } } } impl Display for MySqlTime { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let TimeMagnitude { hours, minutes, seconds, microseconds, } = self.magnitude; // Obeys the `+` flag. Display::fmt(&self.sign(), f)?; write!(f, "{hours}:{minutes:02}:{seconds:02}")?; // Write microseconds if not zero or a nonzero precision was explicitly requested. if f.precision().map_or(microseconds != 0, |it| it != 0) { f.write_char('.')?; let mut remaining_precision = f.precision(); let mut remainder = microseconds; let mut power_of_10 = 10u32.pow(5); // Write digits from most-significant to least, up to the requested precision. while remainder > 0 && remaining_precision != Some(0) { let digit = remainder / power_of_10; // 1 % 1 = 0 remainder %= power_of_10; power_of_10 /= 10; write!(f, "{digit}")?; if let Some(remaining_precision) = &mut remaining_precision { *remaining_precision = remaining_precision.saturating_sub(1); } } // If any requested precision remains, pad with zeroes. if let Some(precision) = remaining_precision.filter(|it| *it != 0) { write!(f, "{:0precision$}", 0)?; } } Ok(()) } } impl Type for MySqlTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Time) } } impl<'r> Decode<'r, MySql> for MySqlTime { fn decode(value: ::ValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { let mut buf = value.as_bytes()?; // Row decoding should have left the length byte on the front. if buf.is_empty() { return Err("empty buffer".into()); } let length = buf.get_u8(); // MySQL specifies that if all fields are 0 then the length is 0 and no further data is sent // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html if length == 0 { return Ok(Self::ZERO); } if !matches!(buf.len(), 8 | 12) { return Err(format!( "expected 8 or 12 bytes for TIME value, got {}", buf.len() ) .into()); } let sign = MySqlTimeSign::from_byte(buf.get_u8())?; // The wire protocol includes days but the text format doesn't. Isn't that crazy? let days = buf.get_u32_le(); let hours = buf.get_u8(); let minutes = buf.get_u8(); let seconds = buf.get_u8(); let microseconds = if !buf.is_empty() { buf.get_u32_le() } else { 0 }; let whole_hours = days .checked_mul(24) .and_then(|days_to_hours| days_to_hours.checked_add(hours as u32)) .ok_or("overflow calculating whole hours from `days * 24 + hours`")?; Ok(Self::new( sign, whole_hours, minutes, seconds, microseconds, )?) } MySqlValueFormat::Text => parse(value.as_str()?), } } } impl<'q> Encode<'q, MySql> for MySqlTime { fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'q>, ) -> Result { if self.is_zero() { buf.put_u8(0); return Ok(IsNull::No); } buf.put_u8(self.encoded_len()); buf.put_u8(self.sign.to_byte()); let TimeMagnitude { hours: whole_hours, minutes, seconds, microseconds, } = self.magnitude; let days = whole_hours / 24; let hours = (whole_hours % 24) as u8; buf.put_u32_le(days); buf.put_u8(hours); buf.put_u8(minutes); buf.put_u8(seconds); if microseconds != 0 { buf.put_u32_le(microseconds); } Ok(IsNull::No) } fn size_hint(&self) -> usize { self.encoded_len() as usize + 1 } } /// Convert [`MySqlTime`] from [`std::time::Duration`]. /// /// ### Note: Precision Truncation /// [`Duration`] supports nanosecond precision, but MySQL `TIME` values only support microsecond /// precision. /// /// For simplicity, higher precision values are truncated when converting. /// If you prefer another rounding mode instead, you should apply that to the `Duration` first. /// /// See also: [MySQL Manual, section 13.2.6: Fractional Seconds in Time Values](https://dev.mysql.com/doc/refman/8.3/en/fractional-seconds.html) /// /// ### Errors: /// Returns [`MySqlTimeError::FieldRange`] if the given duration is longer than `838:59:59.999999`. /// impl TryFrom for MySqlTime { type Error = MySqlTimeError; fn try_from(value: Duration) -> Result { let hours = value.as_secs() / 3600; let rem_seconds = value.as_secs() % 3600; let minutes = (rem_seconds / 60) as u8; let seconds = (rem_seconds % 60) as u8; // Simply divides by 1000 let microseconds = value.subsec_micros(); Self::new( MySqlTimeSign::Positive, hours.try_into().map_err(|_| MySqlTimeError::FieldRange { field: "hours", max: Self::HOURS_MAX, value: hours, })?, minutes, seconds, microseconds, ) } } impl MySqlTimeSign { fn from_byte(b: u8) -> Result { match b { 0 => Ok(Self::Positive), 1 => Ok(Self::Negative), other => Err(format!("expected 0 or 1 for TIME sign byte, got {other}").into()), } } fn to_byte(self) -> u8 { match self { // We can't use `#[repr(u8)]` because this is opposite of the ordering we want from `Ord` Self::Negative => 1, Self::Positive => 0, } } fn signum(&self) -> i32 { match self { Self::Negative => -1, Self::Positive => 1, } } /// Returns `true` if positive, `false` if negative. pub fn is_positive(&self) -> bool { matches!(self, Self::Positive) } /// Returns `true` if negative, `false` if positive. pub fn is_negative(&self) -> bool { matches!(self, Self::Negative) } } impl Display for MySqlTimeSign { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Positive if f.sign_plus() => f.write_char('+'), Self::Negative => f.write_char('-'), _ => Ok(()), } } } impl Type for Duration { fn type_info() -> MySqlTypeInfo { MySqlTime::type_info() } } impl<'r> Decode<'r, MySql> for Duration { fn decode(value: ::ValueRef<'r>) -> Result { let time = MySqlTime::decode(value)?; time.to_duration().ok_or_else(|| { format!("`std::time::Duration` can only decode positive TIME values; got {time}").into() }) } } // Not exposing this as a `FromStr` impl currently because `MySqlTime` is not designed to be // a general interchange type. fn parse(text: &str) -> Result { let mut segments = text.split(':'); let hours = segments .next() .ok_or("expected hours segment, got nothing")?; let minutes = segments .next() .ok_or("expected minutes segment, got nothing")?; let seconds = segments .next() .ok_or("expected seconds segment, got nothing")?; // Include the sign in parsing for convenience; // the allowed range of whole hours is much smaller than `i32`'s positive range. let hours: i32 = hours .parse() .map_err(|e| format!("error parsing hours from {text:?} (segment {hours:?}): {e}"))?; let sign = if hours.is_negative() { MySqlTimeSign::Negative } else { MySqlTimeSign::Positive }; let hours = hours.unsigned_abs(); let minutes: u8 = minutes .parse() .map_err(|e| format!("error parsing minutes from {text:?} (segment {minutes:?}): {e}"))?; let (seconds, microseconds): (u8, u32) = if let Some((seconds, microseconds)) = seconds.split_once('.') { ( seconds.parse().map_err(|e| { format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}") })?, parse_microseconds(microseconds).map_err(|e| { format!("error parsing microseconds from {text:?} (segment {microseconds:?}): {e}") })?, ) } else { ( seconds.parse().map_err(|e| { format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}") })?, 0, ) }; Ok(MySqlTime::new(sign, hours, minutes, seconds, microseconds)?) } /// Parse microseconds from a fractional seconds string. fn parse_microseconds(micros: &str) -> Result { const EXPECTED_DIGITS: usize = 6; match micros.len() { 0 => Err("empty string".into()), len @ ..=EXPECTED_DIGITS => { // Fewer than 6 digits, multiply to the correct magnitude let micros: u32 = micros.parse()?; // cast cannot overflow #[allow(clippy::cast_possible_truncation)] Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32)) } // More digits than expected, truncate _ => Ok(micros[..EXPECTED_DIGITS].parse()?), } } #[cfg(test)] mod tests { use super::MySqlTime; use crate::types::MySqlTimeSign; use super::parse_microseconds; #[test] fn test_display() { assert_eq!(MySqlTime::ZERO.to_string(), "0:00:00"); assert_eq!(format!("{:.0}", MySqlTime::ZERO), "0:00:00"); assert_eq!(format!("{:.3}", MySqlTime::ZERO), "0:00:00.000"); assert_eq!(format!("{:.6}", MySqlTime::ZERO), "0:00:00.000000"); assert_eq!(format!("{:.9}", MySqlTime::ZERO), "0:00:00.000000000"); assert_eq!(format!("{:.0}", MySqlTime::MAX), "838:59:59"); assert_eq!(format!("{:.3}", MySqlTime::MAX), "838:59:59.000"); assert_eq!(format!("{:.6}", MySqlTime::MAX), "838:59:59.000000"); assert_eq!(format!("{:.9}", MySqlTime::MAX), "838:59:59.000000000"); assert_eq!(format!("{:+.0}", MySqlTime::MAX), "+838:59:59"); assert_eq!(format!("{:+.3}", MySqlTime::MAX), "+838:59:59.000"); assert_eq!(format!("{:+.6}", MySqlTime::MAX), "+838:59:59.000000"); assert_eq!(format!("{:+.9}", MySqlTime::MAX), "+838:59:59.000000000"); assert_eq!(format!("{:.0}", MySqlTime::MIN), "-838:59:59"); assert_eq!(format!("{:.3}", MySqlTime::MIN), "-838:59:59.000"); assert_eq!(format!("{:.6}", MySqlTime::MIN), "-838:59:59.000000"); assert_eq!(format!("{:.9}", MySqlTime::MIN), "-838:59:59.000000000"); let positive = MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890011).unwrap(); assert_eq!(positive.to_string(), "123:45:56.890011"); assert_eq!(format!("{positive:.0}"), "123:45:56"); assert_eq!(format!("{positive:.3}"), "123:45:56.890"); assert_eq!(format!("{positive:.6}"), "123:45:56.890011"); assert_eq!(format!("{positive:.9}"), "123:45:56.890011000"); assert_eq!(format!("{positive:+.0}"), "+123:45:56"); assert_eq!(format!("{positive:+.3}"), "+123:45:56.890"); assert_eq!(format!("{positive:+.6}"), "+123:45:56.890011"); assert_eq!(format!("{positive:+.9}"), "+123:45:56.890011000"); let negative = MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890011).unwrap(); assert_eq!(negative.to_string(), "-123:45:56.890011"); assert_eq!(format!("{negative:.0}"), "-123:45:56"); assert_eq!(format!("{negative:.3}"), "-123:45:56.890"); assert_eq!(format!("{negative:.6}"), "-123:45:56.890011"); assert_eq!(format!("{negative:.9}"), "-123:45:56.890011000"); } #[test] fn test_parse_microseconds() { assert_eq!(parse_microseconds("010").unwrap(), 10_000); assert_eq!(parse_microseconds("0100000000").unwrap(), 10_000); assert_eq!(parse_microseconds("890").unwrap(), 890_000); assert_eq!(parse_microseconds("0890").unwrap(), 89_000); assert_eq!( // Case in point about not exposing this: // we always truncate excess precision because it's simpler than rounding // and MySQL should never return a higher precision. parse_microseconds("123456789").unwrap(), 123456, ); } } sqlx-mysql-0.8.3/src/types/rust_decimal.rs000064400000000000000000000015571046102023000167520ustar 00000000000000use rust_decimal::Decimal; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::protocol::text::ColumnType; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for Decimal { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::NewDecimal) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Decimal | ColumnType::NewDecimal) } } impl Encode<'_, MySql> for Decimal { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Decimal { fn decode(value: MySqlValueRef<'_>) -> Result { Ok(value.as_str()?.parse()?) } } sqlx-mysql-0.8.3/src/types/str.rs000064400000000000000000000062401046102023000151010ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::protocol::text::{ColumnFlags, ColumnType}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; use std::borrow::Cow; impl Type for str { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { r#type: ColumnType::VarString, // VARCHAR flags: ColumnFlags::empty(), max_size: None, } } fn compatible(ty: &MySqlTypeInfo) -> bool { // TODO: Support more collations being returned from SQL? matches!( ty.r#type, ColumnType::VarChar | ColumnType::Blob | ColumnType::TinyBlob | ColumnType::MediumBlob | ColumnType::LongBlob | ColumnType::String | ColumnType::VarString | ColumnType::Enum ) && !ty.flags.contains(ColumnFlags::BINARY) } } impl Encode<'_, MySql> for &'_ str { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(self); Ok(IsNull::No) } } impl<'r> Decode<'r, MySql> for &'r str { fn decode(value: MySqlValueRef<'r>) -> Result { value.as_str() } } impl Type for Box { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Box { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(&**self, buf) } } impl<'r> Decode<'r, MySql> for Box { fn decode(value: MySqlValueRef<'r>) -> Result { <&str as Decode>::decode(value).map(Box::from) } } impl Type for String { fn type_info() -> MySqlTypeInfo { >::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { >::compatible(ty) } } impl Encode<'_, MySql> for String { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(&**self, buf) } } impl Decode<'_, MySql> for String { fn decode(value: MySqlValueRef<'_>) -> Result { <&str as Decode>::decode(value).map(ToOwned::to_owned) } } impl Type for Cow<'_, str> { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Cow<'_, str> { fn encode_by_ref(&self, buf: &mut Vec) -> Result { match self { Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), } } } impl<'r> Decode<'r, MySql> for Cow<'r, str> { fn decode(value: MySqlValueRef<'r>) -> Result { value.as_str().map(Cow::Borrowed) } } sqlx-mysql-0.8.3/src/types/text.rs000064400000000000000000000035311046102023000152550ustar 00000000000000use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; use sqlx_core::decode::Decode; use sqlx_core::encode::{Encode, IsNull}; use sqlx_core::error::BoxDynError; use sqlx_core::types::{Text, Type}; use std::fmt::Display; use std::str::FromStr; impl Type for Text { fn type_info() -> MySqlTypeInfo { >::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { >::compatible(ty) } } impl<'q, T> Encode<'q, MySql> for Text where T: Display, { fn encode_by_ref(&self, buf: &mut Vec) -> Result { // We can't really do the trick like with Postgres where we reserve the space for the // length up-front and then overwrite it later, because MySQL appears to enforce that // length-encoded integers use the smallest encoding for the value: // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le // // So we'd have to reserve space for the max-width encoding, format into the buffer, // then figure out how many bytes our length-encoded integer needs to be and move the // value bytes down to use up the empty space. // // Copying from a completely separate buffer instead is easier. It may or may not be faster // or slower depending on a ton of different variables, but I don't currently have the time // to implement both approaches and compare their performance. Encode::::encode(self.0.to_string(), buf) } } impl<'r, T> Decode<'r, MySql> for Text where T: FromStr, BoxDynError: From<::Err>, { fn decode(value: MySqlValueRef<'r>) -> Result { let s: &str = Decode::::decode(value)?; Ok(Self(s.parse()?)) } } sqlx-mysql-0.8.3/src/types/time.rs000064400000000000000000000231201046102023000152230ustar 00000000000000use byteorder::{ByteOrder, LittleEndian}; use bytes::Buf; use sqlx_core::database::Database; use time::macros::format_description; use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::protocol::text::ColumnType; use crate::type_info::MySqlTypeInfo; use crate::types::{MySqlTime, MySqlTimeSign, Type}; use crate::{MySql, MySqlValueFormat, MySqlValueRef}; impl Type for OffsetDateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Timestamp) } fn compatible(ty: &MySqlTypeInfo) -> bool { matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) } } impl Encode<'_, MySql> for OffsetDateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { let utc_dt = self.to_offset(UtcOffset::UTC); let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); Encode::::encode(primitive_dt, buf) } } impl<'r> Decode<'r, MySql> for OffsetDateTime { fn decode(value: MySqlValueRef<'r>) -> Result { let primitive: PrimitiveDateTime = Decode::::decode(value)?; Ok(primitive.assume_utc()) } } impl Type for Time { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Time) } } impl Encode<'_, MySql> for Time { fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = time_encoded_len(self); buf.push(len); // sign byte: Time is never negative buf.push(0); // Number of days in the interval; always 0 for time-of-day values. // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding buf.extend_from_slice(&[0_u8; 4]); encode_time(self, len > 8, buf); Ok(IsNull::No) } fn size_hint(&self) -> usize { time_encoded_len(self) as usize + 1 // plus length byte } } impl<'r> Decode<'r, MySql> for Time { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { // Should never panic. MySqlTime::decode(value)?.try_into() } // Retaining this parsing for now as it allows us to cross-check our impl. MySqlValueFormat::Text => Time::parse( value.as_str()?, &format_description!("[hour]:[minute]:[second].[subsecond]"), ) .map_err(Into::into), } } } impl TryFrom for Time { type Error = BoxDynError; fn try_from(time: MySqlTime) -> Result { if !time.is_valid_time_of_day() { return Err(format!("MySqlTime value out of range for `time::Time`: {time}").into()); } #[allow(clippy::cast_possible_truncation)] Ok(Time::from_hms_micro( // `is_valid_time_of_day()` ensures this won't overflow time.hours() as u8, time.minutes(), time.seconds(), time.microseconds(), )?) } } impl From for time::Duration { fn from(time: MySqlTime) -> Self { // `subsec_nanos()` is guaranteed to be between 0 and 10^9 #[allow(clippy::cast_possible_wrap)] time::Duration::new(time.whole_seconds_signed(), time.subsec_nanos() as i32) } } impl TryFrom for MySqlTime { type Error = BoxDynError; fn try_from(value: time::Duration) -> Result { let sign = if value.is_negative() { MySqlTimeSign::Negative } else { MySqlTimeSign::Positive }; // Similar to `TryFrom`, use `std::time::Duration` as an intermediate. Ok(MySqlTime::try_from(std::time::Duration::try_from(value.abs())?)?.with_sign(sign)) } } impl Type for time::Duration { fn type_info() -> MySqlTypeInfo { MySqlTime::type_info() } } impl<'r> Decode<'r, MySql> for time::Duration { fn decode(value: ::ValueRef<'r>) -> Result { Ok(MySqlTime::decode(value)?.into()) } } impl Type for Date { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Date) } } impl Encode<'_, MySql> for Date { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(4); encode_date(self, buf)?; Ok(IsNull::No) } fn size_hint(&self) -> usize { 5 } } impl<'r> Decode<'r, MySql> for Date { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { let buf = value.as_bytes()?; // Row decoding should leave the length byte on the front. if buf.is_empty() { return Err("empty buffer".into()); } Ok(decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?) } MySqlValueFormat::Text => { let s = value.as_str()?; Date::parse(s, &format_description!("[year]-[month]-[day]")).map_err(Into::into) } } } } impl Type for PrimitiveDateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Datetime) } } impl Encode<'_, MySql> for PrimitiveDateTime { fn encode_by_ref(&self, buf: &mut Vec) -> Result { let len = primitive_dt_encoded_len(self); buf.push(len); encode_date(&self.date(), buf)?; if len > 4 { encode_time(&self.time(), len > 7, buf); } Ok(IsNull::No) } fn size_hint(&self) -> usize { primitive_dt_encoded_len(self) as usize + 1 // plus length byte } } impl<'r> Decode<'r, MySql> for PrimitiveDateTime { fn decode(value: MySqlValueRef<'r>) -> Result { match value.format() { MySqlValueFormat::Binary => { let mut buf = value.as_bytes()?; if buf.is_empty() { return Err("empty buffer".into()); } let len = buf.get_u8(); let date = decode_date(buf)?.ok_or(UnexpectedNullError)?; let dt = if len > 4 { date.with_time(decode_time(&buf[4..])?) } else { date.midnight() }; Ok(dt) } MySqlValueFormat::Text => { let s = value.as_str()?; // If there are no nanoseconds parse without them if s.contains('.') { PrimitiveDateTime::parse( s, &format_description!( "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" ), ) .map_err(Into::into) } else { PrimitiveDateTime::parse( s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"), ) .map_err(Into::into) } } } } } fn encode_date(date: &Date, buf: &mut Vec) -> Result<(), BoxDynError> { // MySQL supports years from 1000 - 9999 let year = u16::try_from(date.year()).map_err(|_| format!("Date out of range for Mysql: {date}"))?; buf.extend_from_slice(&year.to_le_bytes()); buf.push(date.month().into()); buf.push(date.day()); Ok(()) } fn decode_date(buf: &[u8]) -> Result, BoxDynError> { if buf.is_empty() { // zero buffer means a zero date (null) return Ok(None); } Date::from_calendar_date( LittleEndian::read_u16(buf) as i32, time::Month::try_from(buf[2])?, buf[3], ) .map_err(Into::into) .map(Some) } fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec) { buf.push(time.hour()); buf.push(time.minute()); buf.push(time.second()); if include_micros { buf.extend(&(time.nanosecond() / 1000).to_le_bytes()); } } fn decode_time(mut buf: &[u8]) -> Result { let hour = buf.get_u8(); let minute = buf.get_u8(); let seconds = buf.get_u8(); let micros = if !buf.is_empty() { // microseconds : int buf.get_uint_le(buf.len()) } else { 0 }; let micros = u32::try_from(micros) .map_err(|_| format!("MySQL returned microseconds out of range: {micros}"))?; Time::from_hms_micro(hour, minute, seconds, micros) .map_err(|e| format!("Time out of range for MySQL: {e}").into()) } #[inline(always)] fn primitive_dt_encoded_len(time: &PrimitiveDateTime) -> u8 { // to save space the packet can be compressed: match (time.hour(), time.minute(), time.second(), time.nanosecond()) { // if hour, minutes, seconds and micro_seconds are all 0, // length is 4 and no other field is sent (0, 0, 0, 0) => 4, // if micro_seconds is 0, length is 7 // and micro_seconds is not sent (_, _, _, 0) => 7, // otherwise length is 11 (_, _, _, _) => 11, } } #[inline(always)] fn time_encoded_len(time: &Time) -> u8 { if time.nanosecond() == 0 { // if micro_seconds is 0, length is 8 and micro_seconds is not sent 8 } else { // otherwise length is 12 12 } } sqlx-mysql-0.8.3/src/types/uint.rs000064400000000000000000000077371046102023000152640ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::protocol::text::{ColumnFlags, ColumnType}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use byteorder::{ByteOrder, LittleEndian}; fn uint_type_info(ty: ColumnType) -> MySqlTypeInfo { MySqlTypeInfo { r#type: ty, flags: ColumnFlags::BINARY | ColumnFlags::UNSIGNED, max_size: None, } } fn uint_compatible(ty: &MySqlTypeInfo) -> bool { matches!( ty.r#type, ColumnType::Tiny | ColumnType::Short | ColumnType::Long | ColumnType::Int24 | ColumnType::LongLong | ColumnType::Year | ColumnType::Bit ) && ty.flags.contains(ColumnFlags::UNSIGNED) } impl Type for u8 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Tiny) } fn compatible(ty: &MySqlTypeInfo) -> bool { uint_compatible(ty) } } impl Type for u16 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Short) } fn compatible(ty: &MySqlTypeInfo) -> bool { uint_compatible(ty) } } impl Type for u32 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Long) } fn compatible(ty: &MySqlTypeInfo) -> bool { uint_compatible(ty) } } impl Type for u64 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::LongLong) } fn compatible(ty: &MySqlTypeInfo) -> bool { uint_compatible(ty) } } impl Encode<'_, MySql> for u8 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for u16 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for u32 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } impl Encode<'_, MySql> for u64 { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.extend(&self.to_le_bytes()); Ok(IsNull::No) } } fn uint_decode(value: MySqlValueRef<'_>) -> Result { if value.type_info.r#type == ColumnType::Bit { // NOTE: Regardless of the value format, there is raw binary data here let buf = value.as_bytes()?; let mut value: u64 = 0; for b in buf { value = (*b as u64) | (value << 8); } return Ok(value); } Ok(match value.format() { MySqlValueFormat::Text => value.as_str()?.parse()?, MySqlValueFormat::Binary => { let buf = value.as_bytes()?; // Check conditions that could cause `read_uint()` to panic. if buf.is_empty() { return Err("empty buffer".into()); } if buf.len() > 8 { return Err(format!( "expected no more than 8 bytes for unsigned integer value, got {}", buf.len() ) .into()); } LittleEndian::read_uint(buf, buf.len()) } }) } impl Decode<'_, MySql> for u8 { fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u16 { fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u32 { fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value)?.try_into().map_err(Into::into) } } impl Decode<'_, MySql> for u64 { fn decode(value: MySqlValueRef<'_>) -> Result { uint_decode(value) } } sqlx-mysql-0.8.3/src/types/uuid.rs000064400000000000000000000055521046102023000152440ustar 00000000000000use uuid::{ fmt::{Hyphenated, Simple}, Uuid, }; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::io::MySqlBufMutExt; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; impl Type for Uuid { fn type_info() -> MySqlTypeInfo { <&[u8] as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&[u8] as Type>::compatible(ty) } } impl Encode<'_, MySql> for Uuid { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_bytes_lenenc(self.as_bytes()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Uuid { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &[u8] type to decode from MySQL let bytes = <&[u8] as Decode>::decode(value)?; if bytes.len() != 16 { return Err(format!( "Expected 16 bytes, got {}; `Uuid` uses binary format for MySQL/MariaDB. \ For text-formatted UUIDs, use `uuid::fmt::Hyphenated` instead of `Uuid`.", bytes.len(), ) .into()); } // construct a Uuid from the returned bytes Uuid::from_slice(bytes).map_err(Into::into) } } impl Type for Hyphenated { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Hyphenated { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Hyphenated { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &str type to decode from MySQL let text = <&str as Decode>::decode(value)?; // parse a UUID from the text Uuid::parse_str(text) .map_err(Into::into) .map(|u| u.hyphenated()) } } impl Type for Simple { fn type_info() -> MySqlTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &MySqlTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Encode<'_, MySql> for Simple { fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.put_str_lenenc(&self.to_string()); Ok(IsNull::No) } } impl Decode<'_, MySql> for Simple { fn decode(value: MySqlValueRef<'_>) -> Result { // delegate to the &str type to decode from MySQL let text = <&str as Decode>::decode(value)?; // parse a UUID from the text Uuid::parse_str(text) .map_err(Into::into) .map(|u| u.simple()) } } sqlx-mysql-0.8.3/src/value.rs000064400000000000000000000052621046102023000142440ustar 00000000000000use std::borrow::Cow; use std::str::from_utf8; use bytes::Bytes; pub(crate) use sqlx_core::value::*; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::protocol::text::ColumnType; use crate::{MySql, MySqlTypeInfo}; #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum MySqlValueFormat { Text, Binary, } /// Implementation of [`Value`] for MySQL. #[derive(Clone)] pub struct MySqlValue { value: Option, type_info: MySqlTypeInfo, format: MySqlValueFormat, } /// Implementation of [`ValueRef`] for MySQL. #[derive(Clone)] pub struct MySqlValueRef<'r> { pub(crate) value: Option<&'r [u8]>, pub(crate) row: Option<&'r Bytes>, pub(crate) type_info: MySqlTypeInfo, pub(crate) format: MySqlValueFormat, } impl<'r> MySqlValueRef<'r> { pub(crate) fn format(&self) -> MySqlValueFormat { self.format } pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> { match &self.value { Some(v) => Ok(v), None => Err(UnexpectedNullError.into()), } } pub(crate) fn as_str(&self) -> Result<&'r str, BoxDynError> { Ok(from_utf8(self.as_bytes()?)?) } } impl Value for MySqlValue { type Database = MySql; fn as_ref(&self) -> MySqlValueRef<'_> { MySqlValueRef { value: self.value.as_deref(), row: None, type_info: self.type_info.clone(), format: self.format, } } fn type_info(&self) -> Cow<'_, MySqlTypeInfo> { Cow::Borrowed(&self.type_info) } fn is_null(&self) -> bool { is_null(self.value.as_deref(), &self.type_info) } } impl<'r> ValueRef<'r> for MySqlValueRef<'r> { type Database = MySql; fn to_owned(&self) -> MySqlValue { let value = match (self.row, self.value) { (Some(row), Some(value)) => Some(row.slice_ref(value)), (None, Some(value)) => Some(Bytes::copy_from_slice(value)), _ => None, }; MySqlValue { value, format: self.format, type_info: self.type_info.clone(), } } fn type_info(&self) -> Cow<'_, MySqlTypeInfo> { Cow::Borrowed(&self.type_info) } #[inline] fn is_null(&self) -> bool { is_null(self.value, &self.type_info) } } fn is_null(value: Option<&[u8]>, ty: &MySqlTypeInfo) -> bool { if let Some(value) = value { // zero dates and date times should be treated the same as NULL if matches!( ty.r#type, ColumnType::Date | ColumnType::Timestamp | ColumnType::Datetime ) && value.starts_with(b"\0") { return true; } } value.is_none() }